Source code for rlang.grounding.groundings

"""Module containing all RLang groundings."""

from __future__ import annotations

from collections.abc import MutableMapping
from collections import defaultdict
from typing import Callable, Any, Union, List

import numpy as np
from numpy.random import default_rng
from .utils.utils import Domain
from .utils.primitives import MDPObject, VectorState, ObjectOrientedState, Action, Primitive
from .utils.grounding_exceptions import RLangGroundingError


[docs]class Grounding(object): """Parent class for all groundings. For all intents and purposes, this is an abstract class. """
[docs] def __init__(self, name=None): self._name = name
@property def name(self): return self._name @name.setter def name(self, name: str): self._name = name def equals(self, other): return self.name == other.name def __hash__(self): return self._name.__hash__() def __repr__(self): return self._name
[docs]class GroundingFunction(Grounding): """Parent class for groundings that are callable. In general, only the children of this class should be used. All GroundingFunctions have a specified domain and codomain. They are invoked using keyword arguments that correspond to their domain:: from rlang import Domain def can_move_fun(*args, **kwargs): return not kwargs['state'] in pit_states and kwargs['action'] in move_actions can_move = GroundingFunction(domain=Domain.STATE_ACTION, codomain=Domain.BOOLEAN, function=can_move_fun) can_move(state=0, action=1) >> True """
[docs] def __init__(self, domain: Union[str, Domain], codomain: Union[str, Domain], function: Callable, name: str = None): """Initialize a GroundingFunction. Args: domain: Domain of the function. codomain: Codomain of the function. function: the function. name: the name of the Grounding. """ if isinstance(domain, str): domain = Domain.from_name(domain) if isinstance(codomain, str): codomain = Domain.from_name(codomain) super().__init__(name) self._domain = domain self._codomain = codomain self._function = function
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): if ufunc == np.multiply: return self.__rmul__(inputs[0]) if ufunc == np.true_divide: return self.__rtruediv__(inputs[0]) if ufunc == np.add: return self.__radd__(inputs[0]) if ufunc == np.subtract: return self.__rsub__(inputs[0]) @property def domain(self): return self._domain @domain.setter def domain(self, domain: Domain): self._domain = domain @property def codomain(self): return self._codomain @codomain.setter def codomain(self, codomain: Domain): self._codomain = codomain @property def function(self): return self._function @function.setter def function(self, function: Callable): self._function = function def __contains__(self, item): def contains(*args, **kwargs): return item(*args, **kwargs) in self(*args, **kwargs) return Proposition(function=contains, domain=self.domain + item.domain) # TODO: Contains should really only be used in sets. We should make a formal distinction between lists and sets in RLang # def contains(self, item): # # TODO: ALERT: This is not actually being used right now. Hopefully we can discard it eventually # # Cannot override __contains__ and return a non-boolean # list_cast = lambda x: x.tolist() if isinstance(x, np.ndarray) else x # # TODO: Fix this! 'in' only works for singleton batch items! # unbatch_cast = lambda x, j: np.asarray(x)[j] if isinstance(x, Primitive) else x # unbatch_size = lambda x: len(x) if isinstance(x, Primitive) else 1 # if isinstance(item, GroundingFunction): # return Proposition(function=lambda *args, **kwargs: [ # [list_cast(unbatch_cast(item(*args, **kwargs), i)) in list_cast(self(*args, **kwargs))] for i in # range(unbatch_size(item))], # domain=self.domain + item.domain) # elif isinstance(item, Primitive): # return Proposition(function=lambda *args, **kwargs: [ # [list_cast(unbatch_cast(item(*args, **kwargs), i)) in list_cast(self(*args, **kwargs))] for i in # range(unbatch_size(item))], # domain=self.domain) # if isinstance(item, (int, float, np.ndarray)): # return Proposition(function=lambda *args, **kwargs: [list_cast(item) in list_cast(self(*args, **kwargs))], # domain=self.domain) # raise RLangGroundingError(message=f"Object of type {type(item)} cannot be in a GroundingFunction") def __call__(self, *args, **kwargs): if 'state' in kwargs.keys(): if not isinstance(kwargs['state'], (VectorState, ObjectOrientedState)): kwargs.update({'state': VectorState(kwargs['state'])}) if 'action' in kwargs.keys(): if not isinstance(kwargs['action'], Action): kwargs.update({'action': Action(kwargs['action'])}) if 'next_state' in kwargs.keys(): if not isinstance(kwargs['next_state'], (VectorState, ObjectOrientedState)): kwargs.update({'next_state': VectorState(kwargs['next_state'])}) return self._function(*args, **kwargs) # TODO: write leq/geq def __lt__(self, other): if isinstance(other, GroundingFunction): return Proposition(function=lambda *args, **kwargs: self(*args, **kwargs) < other(*args, **kwargs), domain=self.domain + other.domain) # if isinstance(other, Callable): # return Proposition(function=lambda *args, **kwargs: self(*args, **kwargs) < other(*args, **kwargs)) if isinstance(other, (np.ndarray, int, float)): return Proposition(function=lambda *args, **kwargs: self(*args, **kwargs) < other, domain=self.domain) raise RLangGroundingError(message=f"Cannot '<' a {type(self)} and a {type(other)}") def __le__(self, other): if isinstance(other, GroundingFunction): return Proposition(function=lambda *args, **kwargs: self(*args, **kwargs) <= other(*args, **kwargs), domain=self.domain + other.domain) if isinstance(other, (np.ndarray, int, float)): return Proposition(function=lambda *args, **kwargs: self(*args, **kwargs) <= other, domain=self.domain) raise RLangGroundingError(message=f"Cannot '<=' a {type(self)} and a {type(other)}") def __eq__(self, other): if isinstance(other, GroundingFunction): return Proposition(function=lambda *args, **kwargs: self(*args, **kwargs) == other(*args, **kwargs), domain=self.domain + other.domain) if isinstance(other, (np.ndarray, int, float)): return Proposition(function=lambda *args, **kwargs: self(*args, **kwargs) == other, domain=self.domain) raise RLangGroundingError(message=f"Cannot '==' a {type(self)} and a {type(other)}") def __ne__(self, other): if isinstance(other, GroundingFunction): return Proposition(function=lambda *args, **kwargs: self(*args, **kwargs) != other(*args, **kwargs), domain=self.domain + other.domain) if isinstance(other, (np.ndarray, int, float)): return Proposition(function=lambda *args, **kwargs: self(*args, **kwargs) != other, domain=self.domain) raise RLangGroundingError(message=f"Cannot '!=' a {type(self)} and a {type(other)}") def __mul__(self, other): if isinstance(other, GroundingFunction): new_domain = self.domain + other.domain if new_domain.value == Domain.ANY: return PrimitiveGrounding(codomain=Domain.REAL_VALUE, value=self() * other()) else: return Feature(function=lambda *args, **kwargs: self(*args, **kwargs) * other(*args, **kwargs), domain=new_domain) if isinstance(other, (np.ndarray, int, float)): return Feature(function=lambda *args, **kwargs: self(*args, **kwargs) * other, domain=self.domain) raise RLangGroundingError(message=f"Cannot '*' a {type(self)} and a {type(other)}") def __rmul__(self, other): return self.__mul__(other) def __truediv__(self, other): if isinstance(other, GroundingFunction): new_domain = self.domain + other.domain if new_domain.value == Domain.ANY: return PrimitiveGrounding(codomain=Domain.REAL_VALUE, value=self() / other()) else: return Feature(function=lambda *args, **kwargs: self(*args, **kwargs) / other(*args, **kwargs), domain=new_domain) if isinstance(other, (np.ndarray, int, float)): return Feature(function=lambda *args, **kwargs: self(*args, **kwargs) / other, domain=self.domain) raise RLangGroundingError(message=f"Cannot '/' a {type(self)} and a {type(other)}") def __rtruediv__(self, other): if isinstance(other, (np.ndarray, int, float)): return Feature(function=lambda *args, **kwargs: other / self(*args, **kwargs), domain=self.domain) raise RLangGroundingError(message=f"Cannot '/' a {type(other)} and a {type(self)}") def __sub__(self, other): if isinstance(other, GroundingFunction): new_domain = self.domain + other.domain if new_domain.value == Domain.ANY: return PrimitiveGrounding(codomain=Domain.REAL_VALUE, value=self() - other()) else: return Feature(function=lambda *args, **kwargs: self(*args, **kwargs) - other(*args, **kwargs), domain=new_domain) if isinstance(other, (np.ndarray, int, float)): return Feature(function=lambda *args, **kwargs: self(*args, **kwargs) - other, domain=self.domain) raise RLangGroundingError(message=f"Cannot '-' a {type(self)} and a {type(other)}") def __rsub__(self, other): if isinstance(other, (np.ndarray, int, float)): return Feature(function=lambda *args, **kwargs: other - self(*args, **kwargs), domain=self.domain) raise RLangGroundingError(message=f"Cannot '-' a {type(other)} and a {type(self)}") def __add__(self, other): if isinstance(other, GroundingFunction): new_domain = self.domain + other.domain if new_domain.value == Domain.ANY: return PrimitiveGrounding(codomain=Domain.REAL_VALUE, value=self() + other()) else: return Feature(function=lambda *args, **kwargs: self(*args, **kwargs) + other(*args, **kwargs), domain=self.domain + other.domain) if isinstance(other, (np.ndarray, int, float)): return Feature(function=lambda *args, **kwargs: self(*args, **kwargs) + other, domain=self.domain) raise RLangGroundingError(message=f"Cannot '+' a {type(self)} and a {type(other)}") def __radd__(self, other): return self.__add__(other) def __hash__(self): return hash((str(self), self.function, self.domain, self.codomain))
[docs]class PrimitiveGrounding(GroundingFunction): """GroundingFunction which requires no arguments, i.e. domain=Domain.ANY"""
[docs] def __init__(self, codomain: Domain, value: Any, name: str = None): # TODO: What about lists? Should lists be cast? Only non-jagged ones? # TODO: Probably need to cast `value` to a Primitive if isinstance(value, (int, float)): value = np.array(value) self.value = value super().__init__(domain=Domain.ANY, codomain=codomain, function=lambda *args, **kwargs: self.value, name=name)
def __repr__(self): if self.name: return f"<PrimitiveGrounding \"{self.name}\": {self()}>" else: return f"<PrimitiveGrounding: {self()}>" def __hash__(self): return hash((str(self), str(self.value)))
[docs]class ConstantGrounding(PrimitiveGrounding): """GroundingFunction for defined RLang Constants """ def __repr__(self): return f"<Constant \"{self.name}\" = {self()}>"
class ParameterizedAction: def __init__(self, function, name=None): self.function = function self.name = name if name else function.__name__ def __call__(self, *args, **kwargs): return self.function(*args, **kwargs)
[docs]class ParameterizedActionExecution(GroundingFunction):
[docs] def __init__(self, parameterized_action, arguments: List[GroundingFunction]): self.parameterized_action = parameterized_action self.arguments = arguments domain = Domain.ANY for arg in arguments: domain = domain + arg.domain argnames = ", ".join([arg.name if arg.name is not None else "unk" for arg in arguments]) super().__init__(domain=domain, codomain=Domain.ACTION, function=lambda *args, **kwargs: parameterized_action(*[arg(*args, **kwargs) for arg in self.arguments], **kwargs), name=parameterized_action.name + "(" + argnames + ")")
[docs]class ActionReference(GroundingFunction): """Represents a reference to a specified action."""
[docs] def __init__(self, action: Any, name=None): """ Args: action: the action. name (optional): name of the action. """ if isinstance(action, (int, float, list)): function = lambda *sargs, **skwargs: Action(np.array(action)) domain = Domain.ANY elif isinstance(action, GroundingFunction): function = action.__call__ domain = action.domain else: raise RLangGroundingError(f"Actions cannot be of type {type(action)}") super().__init__(domain=domain, codomain=Domain.ACTION, function=function, name=name)
def __hash__(self): return hash(self.function) def __repr__(self): if self.name: return f"<ActionReference \"{self.name}\">" else: return f"<ActionReference>"
[docs]class IdentityGrounding(GroundingFunction): """Grounding for representing S, A, and S'."""
[docs] def __init__(self, domain: Union[str, Domain]): """Initialize a new IdentityGrounding.""" if not isinstance(domain, str): domain = domain.name.lower() # Does this work properly? super().__init__(domain=domain, codomain=domain, function=lambda *args, **kwargs: kwargs[domain])
def __repr__(self): return f"<IdentityGrounding {self.codomain.name}>"
[docs]class MDPClassGrounding(GroundingFunction):
[docs] def __init__(self, cls): self.cls = cls super().__init__(domain=Domain.ANY, codomain=Domain.ANY, function=lambda *args, **kwargs: self.cls, name=f"{cls.__name__}_class_grounding")
[docs]class MDPObjectGrounding(GroundingFunction): """For representing objects, which may have properties that are functions of state."""
[docs] def __init__(self, obj: MDPObject, name: str = None, domain=Domain.ANY): """Initialize an abstract object grounding. Args: obj: the MDPObject. name (optional): the name of the object. """ self.obj = obj self.true_obj = None self.calculated = False super().__init__(function=self.calculate_true_obj, codomain=Domain.OBJECT_VALUE, domain=domain, name=obj.name+"_grounding" if name is None else name)
def calculate_true_obj(self, *args, **kwargs): def calculate_attr(attr): if isinstance(attr, GroundingFunction): return attr(*args, **kwargs) else: return attr attrs = list(map(lambda x: getattr(self.obj, x), self.obj.attr_list)) calculated_attrs = list(map(calculate_attr, attrs)) self.true_obj = type(self.obj)(*calculated_attrs) self.calculated = True return self.true_obj def __getattr__(self, item): if self.calculated: return getattr(self.true_obj, item) else: return getattr(self.obj, item) def __eq__(self, other): if isinstance(other, MDPObject): return Proposition(function=lambda *args, **kwargs: self(*args, **kwargs) == other, domain=self.domain) else: return super().__eq__(other) def __hash__(self): return self.obj.__hash__() def __repr__(self): return f"<MDPObjectGrounding({self.name})[{self.obj.__repr__()}]>"
[docs]class MDPObjectAttributeGrounding(GroundingFunction): """For referencing attributes of abstract objects that are *not* in the state."""
[docs] def __init__(self, grounding: GroundingFunction, attribute_chain: List): """Initialize a grounding for referencing abstract object attributes. Args: grounding: the MDPObjectGrounding whose attribute you are referencing. attribute_chain: a list of attribute/sub-attributes (e.g. `["color", "red_value"]`) """ self.attribute_chain = attribute_chain self.grounding = grounding # [assert isinstance(attr, str) for attr in attribute_chain] # for attr in attribute_chain: # assert isinstance(attr, str) # print(self.grounding) # assert self.grounding.name is not None def object_attribute_unwrap(obj, attr_chain): if not hasattr(obj, attr_chain[0]): raise RLangGroundingError(f"Object {obj} does not have attribute {attr_chain[0]}") one_layer_deeper = getattr(obj, attr_chain[0]) if len(attr_chain) == 1: return one_layer_deeper else: return object_attribute_unwrap(one_layer_deeper, attr_chain[1:]) super().__init__( function=lambda *args, **kwargs: object_attribute_unwrap(grounding(*args, **kwargs), self.attribute_chain), codomain=Domain.OBJECT_VALUE, domain=grounding.domain, name=self.grounding.name + '.' + '.'.join(self.attribute_chain))
def equals(self, other): # print(self.grounding, other.grounding, self.attribute_chain, other.attribute_chain) # print(type(other)) if isinstance(other, MDPObjectAttributeGrounding): gdeq = self.grounding.equals(other.grounding) atrseq = self.attribute_chain == other.attribute_chain # print(gdeq, atrseq) return gdeq and atrseq else: return False def __hash__(self): return hash((str(self), self.grounding, str(self.attribute_chain)))
class Predicate: def __init__(self, function, name=None): self.function = function self.name = name if name else function.__name__ def __call__(self, *args, **kwargs): return self.function(*args, **kwargs)
[docs]class PredicateEvaluation(GroundingFunction):
[docs] def __init__(self, predicate, arguments: List[GroundingFunction]): self.parameterized_action = predicate self.arguments = arguments domain = Domain.ANY for arg in arguments: if isinstance(arg, GroundingFunction): domain = domain + arg.domain argnames = ", ".join([arg.name if arg.name is not None else "unk" for arg in arguments]) super().__init__(domain=domain, codomain=Domain.REAL_VALUE+Domain.BOOLEAN, function=lambda *args, **kwargs: predicate(*[arg(*args, **kwargs) for arg in self.arguments], **kwargs), name=predicate.name + "(" + argnames + ")")
[docs]class StateObjectAttributeGrounding(GroundingFunction): """For referencing attributes of objects in the state when the state is object-oriented."""
[docs] def __init__(self, attribute_chain: List, domain: Union[str, Domain] = Domain.STATE): """Initialize a grounding for referencing object attributes, when those objects are in the state. Args: attribute_chain: a list of attribute/sub-attributes (e.g. `["ball", "color", "red_value"]`) domain: either "state" or "next_state". """ self.attribute_chain = attribute_chain if isinstance(domain, Domain): domain_arg = domain.name.lower() elif isinstance(domain, str): domain_arg = domain domain = Domain.from_name(domain) else: raise RLangGroundingError(f"Invalid domain argument for StateObjectAttributeGrounding: {type(domain)}") if domain is not Domain.STATE and domain is not Domain.NEXT_STATE: raise RLangGroundingError(f"StateObjectAttributeGrounding cannot have domain of type {domain.name}") def state_object_attribute_unwrap(state_or_obj, attr_chain): if not hasattr(state_or_obj, attr_chain[0]): raise RLangGroundingError(f"Object {state_or_obj} does not have attribute {attr_chain[0]}") one_layer_deeper = getattr(state_or_obj, attr_chain[0]) if len(attr_chain) == 1: return one_layer_deeper else: return state_object_attribute_unwrap(one_layer_deeper, attr_chain[1:]) super().__init__( function=lambda *args, **kwargs: state_object_attribute_unwrap(kwargs[domain_arg], self.attribute_chain), codomain=Domain.OBJECT_VALUE, domain=domain, name="S." + '.'.join(self.attribute_chain))
def __hash__(self): return hash(self.__repr__()) def __repr__(self): return f"<StateObjectAttributeGrounding [S.{'.'.join(self.attribute_chain)}]>"
[docs]class Factor(GroundingFunction): """Represents a factor of the state space."""
[docs] def __init__(self, state_indexer: Any, name: str = None, domain: Union[str, Domain] = Domain.STATE): """ Args: state_indexer: the indices or slice of the state space. name (optional): the name of the grounding. domain (optional [str]): the domain of the Factor. """ if isinstance(domain, Domain): domain_arg = domain.name.lower() elif isinstance(domain, str): domain_arg = domain domain = Domain.from_name(domain) else: raise RLangGroundingError(f"Invalid domain argument for Factor: {type(domain)}") if domain is not Domain.STATE and domain is not Domain.NEXT_STATE: raise RLangGroundingError(f"Factor cannot have domain of type {domain.name}") if isinstance(state_indexer, int): state_indexer = [state_indexer] self.state_indexer = state_indexer super().__init__(function=lambda *args, **kwargs: kwargs[domain_arg].__getitem__(self.state_indexer), codomain=Domain.REAL_VALUE, domain=domain, name=name)
@property def indexer(self): return self.state_indexer @indexer.setter def indexer(self, new_indexer): self.state_indexer = new_indexer def __getitem__(self, item): if isinstance(item, int): item = [item] if isinstance(self.state_indexer, slice): if self.state_indexer.stop is None: raise RLangGroundingError("We don't know enough about the state space") else: old_indexer = list(range(*self.state_indexer.indices(self.state_indexer.stop))) new_indexer = [old_indexer[i] for i in item] return Factor(state_indexer=new_indexer, domain=self.domain) if isinstance(self.state_indexer, list): if isinstance(item, list): return Factor([self.state_indexer[i] for i in item], domain=self.domain) elif isinstance(item, slice): new_indexer = self.state_indexer[item] return Factor(state_indexer=new_indexer, domain=self.domain) def __hash__(self): return hash((str(self), str(self.state_indexer), self.name)) def __repr__(self): additional_info = "" if self.name: additional_info += f" \"{self.name}\" =" return f"<Factor [{self.domain.name}]->[{self.codomain.name}]:{additional_info} S[{str(self.state_indexer)[1:-1] if isinstance(self.state_indexer, list) else str(self.state_indexer)}]>"
[docs]class Feature(GroundingFunction): """Represents a feature of the state space. Can represent any function of the state space. """
[docs] def __init__(self, function: Callable, name: str = None, domain: Union[str, Domain] = Domain.STATE): """ Args: function: a function of state. name (optional): the name of the grounding. domain (optional [str]): the domain of the Feature. """ super().__init__(function=function, codomain=Domain.REAL_VALUE, domain=domain, name=name)
@classmethod def from_Factor(cls, factor: Factor, name: str = None): return cls(function=factor.__call__, name=name, domain=factor.domain) def __hash__(self): return hash((str(self), self.function, self.domain, self.codomain)) def __repr__(self): return f"<Feature [{self.domain.name}]->[{self.codomain.name}] \"{self.name}\">"
[docs]class MarkovFeature(GroundingFunction): """Represents a Grounding that is a function of (state, action, next_state)"""
[docs] def __init__(self, function: Callable, name: str): """ Args: function: a function of (state, action, next_state) """ super().__init__(domain=Domain.STATE_ACTION_NEXT_STATE, function=function, codomain=Domain.REAL_VALUE, name=name)
@classmethod def from_Factor(cls, factor: Factor, name: str = None): return cls(function=factor.__call__, name=name) def __repr__(self): return f"<MarkovFeature [{self.domain.name}]->[{self.codomain.name}] \"{self.name}\">"
class QuantifierSpecification: def __init__(self, cls, quantifier, dot_exp=None): self.cls = cls self.quantifier = quantifier self.dot_exp = dot_exp self.name = f"{self.quantifier} {self.cls.__name__}" def __repr__(self): return f"<QuantifierSpecification {self.name}{'.'.join(self.dot_exp) if self.dot_exp else ''}>"
[docs]class Proposition(GroundingFunction): """Represents a function which has a truth value. A Proposition is a feature with a codomain restricted to True or False. """
[docs] def __init__(self, function: Callable, name: str = None, domain: Union[str, Domain] = Domain.STATE): """ Args: function: a function of state that evaluates to a bool. name (optional): the name of the grounding. domain (optional [str]): the domain of the Proposition. """ super().__init__(function=function, codomain=Domain.BOOLEAN, domain=domain, name=name)
@classmethod def from_PrimitiveGrounding(cls, primitive_grounding: PrimitiveGrounding): if primitive_grounding.codomain != Domain.BOOLEAN: raise RLangGroundingError( f"Cannot cast PrimitiveGrounding with codomain {primitive_grounding.codomain} to Proposition") return cls(function=lambda *args, **kwargs: primitive_grounding(), domain=Domain.ANY) # TODO: Eventually just work this logic into the Proposition class @classmethod def from_QuantifierSpecification(cls, quantifier_specification: QuantifierSpecification, grounding: GroundingFunction, operation): def unwrap_and_quantify(*args, **kwargs): items = list(kwargs['knowledge'].objects_of_type(quantifier_specification.cls).values()) if quantifier_specification.dot_exp is not None: items = [MDPObjectAttributeGrounding(g, quantifier_specification.dot_exp) for g in items] if quantifier_specification.quantifier == 'all': for item in items: if not operation(grounding(*args, **kwargs), item(*args, **kwargs)): return False return True elif quantifier_specification.quantifier == 'any': for item in items: if operation(grounding(*args, **kwargs), item(*args, **kwargs)): return True return False else: raise RLangGroundingError(f"Unknown quantifier: {quantifier_specification.quantifier}") return cls(function=lambda *args, **kwargs: unwrap_and_quantify(*args, **kwargs), domain=Domain.STATE_KNOWLEDGE) @classmethod def TRUE(cls): return cls(function=lambda *args, **kwargs: True, domain=Domain.ANY) @classmethod def FALSE(cls): return cls(function=lambda *args, **kwargs: False, domain=Domain.ANY) def __and__(self, other) -> Proposition: if isinstance(other, Proposition): return Proposition(function=lambda *args, **kwargs: self(*args, **kwargs) & other(*args, **kwargs), domain=self.domain + other.domain) if isinstance(other, Callable): # TODO: We must know the domain of Callable to properly track the domain return Proposition(function=lambda *args, **kwargs: self(*args, **kwargs) & other(*args, **kwargs)) if isinstance(other, bool): return self if other else Proposition(function=lambda *args, **kwargs: False, domain=Domain.ANY) raise RLangGroundingError(message=f"Cannot & a Proposition with a {type(other)}") def __rand__(self, other): return self.__and__(other) def __or__(self, other) -> Proposition: if isinstance(other, Proposition): return Proposition(function=lambda *args, **kwargs: self(*args, **kwargs) | other(*args, **kwargs), domain=self.domain + other.domain) if isinstance(other, (Proposition, Callable)): # TODO: We must know the domain of Callable to properly track the domain return Proposition(function=lambda *args, **kwargs: self(*args, **kwargs) | other(*args, **kwargs)) if isinstance(other, bool): return self if not other else Proposition(function=lambda *args, **kwargs: True, domain=Domain.ANY) raise RLangGroundingError(message=f"Cannot | a Proposition with a {type(other)}") def __ror__(self, other): return self.__or__(other) def __invert__(self) -> Proposition: return Proposition(function=lambda *args, **kwargs: bool(not self(*args, **kwargs)), domain=self.domain) def __hash__(self): return hash((str(self), self.function)) def __repr__(self): return f"<Proposition [{self.domain.name}]->[{self.codomain.name}] \"{self.name}\">"
[docs]class Goal(Proposition): def __repr__(self): return f"<Goal [{self.domain.name}]->[{self.codomain.name}] \"{self.name}\">"
[docs]class ValueFunction(GroundingFunction): """Represents a value function."""
[docs] def __init__(self, function: Callable): super().__init__(domain=Domain.STATE, codomain=Domain.STATE_VALUE, function=function)
[docs]class ProbabilisticFunction(GroundingFunction): """Represents a function that provides stochastic output. """
[docs] def __init__(self, probability: float = 1.0, *args, **kwargs): self._probability = probability super().__init__(*args, **kwargs)
@property def probability(self): return self._probability @probability.setter def probability(self, probability: float): self._probability = probability def compose_probability(self, probability: float): self._probability = self._probability * probability
[docs]class ProbabilityDistribution(MutableMapping): """ """
[docs] def __init__(self, distribution=None): if distribution is None: distribution = dict() # for function, v in distribution.items(): # if v < 0.0 or v > 1.0: # raise RLangGroundingError(f"Must be bounded between 0.0 and 1.0, got {v}") self.domain = Domain.ANY self.distribution = distribution self.rng = default_rng() self.update_metadata() self.arg_store = list() self.kwarg_store = dict() self.true_distribution = dict() self.calculated = False
def calculate_true_distribution(self): pass @classmethod def from_single(cls, k, *args, **kwargs): return cls({k: 1.0}) @classmethod def from_list_eq(cls, ks, *args, **kwargs): sd_dict = dict() for k in ks: sd_dict[k] = 1.0 return cls(sd_dict) def update_metadata(self): self.calculate_domain() def calculate_domain(self): self.domain = Domain.ANY for k, v in self.distribution.items(): self.domain += k.domain def sample(self): random_variable = self.rng.uniform() offset = 0.0 for k, v in self.true_distribution.items(): offset += v if random_variable < offset: return k def join(self, new_distribution): for k, v in new_distribution.items(): if k in self.distribution: self.distribution[k] += v else: self.distribution[k] = v self.domain += k.domain def compose_probabilities(self, probability): for k, v in self.distribution.items(): self.distribution[k] = v * probability def __call__(self, *args, **kwargs): self.arg_store = args self.kwarg_store = kwargs self.calculate_true_distribution() return self def __getitem__(self, key: Grounding): if self.calculated: return self.true_distribution[key] else: return self.distribution[key] def __setitem__(self, key: Grounding, value: float): self.distribution[key] = value self.update_metadata() def __delitem__(self, key: Grounding): del self.distribution[key] self.update_metadata() def __iter__(self): if self.calculated: return iter(self.true_distribution) else: return iter(self.distribution) def __repr__(self): if self.calculated: return str(self.true_distribution) else: return str(self.distribution) def __len__(self): if self.calculated: return len(self.true_distribution) else: return len(self.distribution) def __hash__(self): return hash(frozenset(self))
[docs]class ActionDistribution(ProbabilityDistribution): """Represents a distribution of possible next actions, options, or policies Args: distribution: a dictionary of the form {Action/Option/Policy: probability,} """ def calculate_true_distribution(self): # TODO: Might be able to change all isinstance(k__, Action) calls to isinstance(k__, PrimitiveGrounding) def update_dictionary(k_, v_): if isinstance(k_, (dict, ProbabilityDistribution)): for k__, v__ in k_.items(): if isinstance(k__, Action): update_dictionary(k__, v_ * v__) else: update_dictionary(k__(*self.arg_store, **self.kwarg_store), v_ * v__) elif k_ is not None: if isinstance(k_, Action): a = k_ else: a = Action(k_) if a in true_distribution: true_distribution[a] += v_ else: true_distribution[a] = v_ true_distribution = dict() update_dictionary(self.distribution, 1.0) self.true_distribution = true_distribution self.calculated = True
[docs]class StateDistribution(ProbabilityDistribution):
[docs] def __init__(self, distribution=None): if distribution: pass # ensure that everything is a State or function of state or something super().__init__(distribution=distribution)
def calculate_true_distribution(self): def update_dictionary(k_, v_): if isinstance(k_, (dict, ProbabilityDistribution)): for k__, v__ in k_.items(): if isinstance(k__, (VectorState, ObjectOrientedState)): update_dictionary(k__, v_ * v__) else: update_dictionary(k__(*self.arg_store, **self.kwarg_store), v_ * v__) elif k_ is not None: if isinstance(k_, (VectorState, ObjectOrientedState)): a = k_ else: a = VectorState(k_) if a in true_distribution: true_distribution[a] += v_ else: true_distribution[a] = v_ true_distribution = dict() update_dictionary(self.distribution, 1.0) self.true_distribution = true_distribution self.calculated = True
[docs]class RewardDistribution(ProbabilityDistribution):
[docs] def __init__(self, distribution=None): if distribution: pass # ensure that everything is a Reward or something super().__init__(distribution=distribution)
def calculate_true_distribution(self): def update_dictionary(k_, v_): if isinstance(k_, (dict, ProbabilityDistribution)): for k__, v__ in k_.items(): if isinstance(k__, Primitive): update_dictionary(k__, v_ * v__) else: update_dictionary(k__(*self.arg_store, **self.kwarg_store), v_ * v__) elif k_ is not None: if isinstance(k_, Primitive): a = k_ else: a = Primitive(k_) if a in true_distribution: true_distribution[a] += v_ else: true_distribution[a] = v_ true_distribution = dict() update_dictionary(self.distribution, 1.0) # print(true_distribution.keys()) self.true_distribution = true_distribution self.calculated = True def expected(self): expected_reward = 0.0 for k, v in self.true_distribution.items(): expected_reward += k * v return expected_reward @classmethod def from_list_eq(cls, ks, *args, **kwargs): numeric_reward = 0.0 sd_dict = dict() for k in ks: if isinstance(k, int): numeric_reward += k else: if k in sd_dict: sd_dict[k] += 1.0 else: sd_dict[k] = 1.0 # sd_dict[RewardFunction(lambda *args, **kwargs: numeric_reward, domain=Domain.ANY)] = 1.0 return cls(sd_dict) def __call__(self, *args, **kwargs): super().__call__(*args, **kwargs) return self.expected()
[docs]class GroundingDistribution(ProbabilityDistribution):
[docs] def __init__(self, grounding: Grounding, distribution=None, complete=False): if distribution: pass # ensure that everything is a groundingfunction or something self.grounding = grounding self.complete = complete super().__init__(distribution=distribution)
def calculate_true_distribution(self): def update_dictionary(k_, v_): if isinstance(k_, (dict, ProbabilityDistribution)): for k__, v__ in k_.items(): if isinstance(k__, (Primitive, MDPObject)): update_dictionary(k__, v_ * v__) else: update_dictionary(k__(*self.arg_store, **self.kwarg_store), v_ * v__) elif k_ is not None: if isinstance(k_, (Primitive, MDPObject)): a = k_ else: a = Primitive(k_) if a in true_distribution: true_distribution[a] += v_ else: true_distribution[a] = v_ true_distribution = dict() update_dictionary(self.distribution, 1.0) self.true_distribution = true_distribution self.calculated = True @classmethod def from_list_eq(cls, ks, *args, **kwargs): sd_dict = dict() for k in ks: sd_dict[k] = 1.0 return cls(args[0], sd_dict)
[docs]class Policy(ProbabilisticFunction): """Represents a closed-loop policy function"""
[docs] def __init__(self, function: Callable, domain: Domain = Domain.STATE, *args, **kwargs): """ Args: function: a function from states to action distributions. """ super().__init__(function=function, domain=domain, codomain=Domain.ACTION, *args, **kwargs)
@classmethod def from_action_distribution(cls, k): if not isinstance(k, ActionDistribution): raise RLangGroundingError(f"Expecting an ActionDistribution, got {type(k)}") return cls(function=k.__call__, domain=k.domain) def __repr__(self): additional_info = "" if self.name: additional_info += f" \"{self.name}\"" return f"<Policy [{self.domain.name}]->[{self.codomain.name}]{additional_info}>"
[docs]class Plan(Grounding): """Represents an open-loop policy"""
[docs] def __init__(self, function: Callable = None, name: str = None): self.function = function super().__init__(name=name)
def reset(self): pass def __call__(self, *args, **kwargs): if self.function is None: raise RLangGroundingError("Plan function is not defined") return self.function(*args, **kwargs) def __repr__(self): if self.name: return f"<Plan \"{self.name}\">" else: return f"<Plan unnamed>"
[docs]class IteratedPlan(Plan): """One kind of plan implementation"""
[docs] def __init__(self, plan_steps, name: str = None): self.plan_steps = plan_steps super().__init__(function=self.__call__, name=name) self.i = 0
def reset(self): self.i = 0 for p in self.plan_steps: if isinstance(p, PlanExecution): p.plan.reset() elif isinstance(p, IteratedPlan): p.reset() def __call__(self, *args, **kwargs): if self.i >= len(self.plan_steps): return None action = self.plan_steps[self.i] # print(action) if isinstance(action, PlanExecution): next_action = action(*args, **kwargs) if next_action is None: action.plan.reset() self.i += 1 return self(*args, **kwargs) else: return next_action else: # print(action) # print(type(action)) # print(self.i) # print(action(*args, **kwargs)) self.i += 1 # if isinstance(action, ActionDistribution): # print("returning") # return action # else: return action(*args, **kwargs) def __repr__(self): if self.name: return f"<IteratedPlan \"{self.name}\">" else: return f"<IteratedPlan unnamed>"
[docs]class PlanExecution(GroundingFunction):
[docs] def __init__(self, plan, arguments: List[GroundingFunction]=None): self.plan = plan if arguments is None: arguments = [] self.arguments = arguments domain = Domain.ANY for arg in arguments: domain = domain + arg.domain argnames = ", ".join([arg.name if arg.name is not None else "unk" for arg in arguments]) super().__init__(domain=domain, codomain=Domain.ACTION, function=lambda *args, **kwargs: self.plan(*[arg(*args, **kwargs) for arg in self.arguments], **kwargs), name=plan.name + "(" + argnames + ")")
def __repr__(self): return f"<PlanExecution of {self.plan} with {self.arguments}>"
# class Plan(ProbabilisticFunction): # """THIS DOES NOT WORK YET # # Represents an open-loop policy # # Args: # distribution_list: a list of ActionDistributions # # # """ # # def __init__(self, distribution_list: [ActionDistribution]): # domain = Domain.ANY # length = None # for d in distribution_list: # domain += d.domain # if length: # if len(d) != length: # length = 0 # break # else: # length = len(d) # # self.i = 0 # self.plan = distribution_list # self.length = length # super().__init__(function=lambda *args, **kwargs: self, domain=domain) # # def append(self, distribution): # if not isinstance(distribution, ActionDistribution): # raise RLangGroundingError(f"Expecting {str(ActionDistribution)}, got {type(distribution)}") # self.plan.append(distribution) # self.domain += distribution.domain # if self.length != 0 and len(distribution) != 0: # self.length += len(distribution) # else: # self.length = 0 # # def extend(self, distribution_list): # domain = Domain.ANY # for d in distribution_list: # if not isinstance(d, ActionDistribution): # raise RLangGroundingError(f"Expecting {str(ActionDistribution)}, got {type(d)}") # domain += d.domain # if self.length != 0 and len(d) != 0: # self.length += len(d) # else: # self.length = 0 # self.plan.extend(distribution_list) # self.domain += domain # # def reset(self): # self.i = 0 # # def __iter__(self): # self.i = 0 # return self # # def __next__(self): # if self.i >= len(self.plan): # raise StopIteration # else: # i = self.i # self.i += 1 # return self.plan[i]
[docs]class OptionTermination: """ """ def __repr__(self): return "<OptionTermination>" def __eq__(self, other): return isinstance(other, OptionTermination)
[docs]class Option(Grounding): """Grounding object for an option."""
[docs] def __init__(self, initiation: Proposition, policy: Policy, termination: Proposition, name: str = None): """ Args: initiation: A Proposition capturing the initiation set of the option. policy: A PolicyOld capturing the policy of the option. termination: A Proposition capturing the termination set of the option. name (optional): the name of the grounding. """ self.initiation = initiation self.termination = termination self.policy = policy super().__init__(name)
def __call__(self, *args, **kwargs): if self.termination(*args, **kwargs): return OptionTermination() else: return self.policy(*args, **kwargs)
[docs] def can_initiate(self, *args, **kwargs) -> bool: """Determines whether the option can be executed in a given state. Args: state: A State object. Returns: bool: True iff the option can be executed in the given state. """ return self.initiation(*args, **kwargs)
def __hash__(self): return hash((self.initiation, self.policy, self.termination)) def __repr__(self): return f"<Option \"{self.name}\">"
[docs]class TransitionFunction(ProbabilisticFunction): """Represents a transition function."""
[docs] def __init__(self, function: Callable = None, domain: Domain = Domain.STATE_ACTION, *args, **kwargs): if function is None: function = StateDistribution().__call__ super().__init__(function=function, domain=domain, codomain=Domain.STATE, *args, **kwargs)
@classmethod def from_state_distribution(cls, k): if not isinstance(k, StateDistribution): raise RLangGroundingError(f"Expecting a StateDistribution, got {type(k)}") return cls(function=k.__call__, domain=k.domain) def __repr__(self): additional_info = "" if self.name: additional_info += f" \"{self.name}\"" return f"<TransitionFunction [{self.domain.name}]->[{self.codomain.name}]{additional_info}>"
[docs]class RewardFunction(ProbabilisticFunction): """Represents function of expected reward."""
[docs] def __init__(self, function: Callable = None, domain: Domain = Domain.ANY, *args, **kwargs): if function is None: function = RewardDistribution().__call__ super().__init__(function=function, domain=domain, codomain=Domain.REWARD, *args, **kwargs)
@classmethod def from_reward_distribution(cls, k): if not isinstance(k, RewardDistribution): raise RLangGroundingError(f"Expecting a RewardDistribution, got {type(k)}") return cls(function=k.__call__, domain=k.domain) def __repr__(self): additional_info = "" if self.name: additional_info += f" \"{self.name}\"" return f"<RewardFunction [{self.domain.name}]->[{self.codomain.name}]{additional_info}>"
[docs]class Prediction(ProbabilisticFunction): """GroundingFunction for an RLang Prediction object. Used to express the predicted value of another RLang object. Limited to GroundingFunctions with a domain of (S) or (S, A). """
[docs] def __init__(self, grounding: Grounding, function: Callable = None, domain: Domain = Domain.STATE_ACTION, complete=False, *args, **kwargs): """ Args: grounding (Grounding): the grounding whom's value we are predicting function (:obj:`Callable`, optional): a function that predicts the value of grounding; can use a GroundingFunction """ if function is None: function = GroundingDistribution(grounding).__call__ self.grounding = grounding self.complete = complete super().__init__(function=function, domain=domain, codomain=Domain.REAL_VALUE, *args, **kwargs)
[docs] @classmethod def from_grounding_distribution(cls, grounding: Grounding, function: GroundingDistribution, complete=False): """ Args: grounding: The grounding that is predicted function: The prediction function """ if not isinstance(function, GroundingDistribution): raise RLangGroundingError(f"Expecting a GroundingDistribution, got {type(function)}") return cls(grounding=grounding, function=function.__call__, domain=function.domain, complete=complete)
def __repr__(self): additional_info = "" if self.name: additional_info += f" \"{self.name}\"" return f"<Prediction [{self.domain.name}]->[{self.codomain.name}]{additional_info} for Grounding: {self.grounding.name}>"
[docs]class Effect(Grounding): """GroundingFunction for an RLang Effect object. Contains an optional RewardFunction, TransitionFunction, and list of Predictions. """
[docs] def __init__(self, reward_function: RewardFunction = None, transition_function: TransitionFunction = None, predictions: List[Prediction] = None, name: str = None, probability: float = 1.0): """ Args: reward_function: a RewardFunction transition_function: a TransitionFunction predictions: a list of Predictions name: name of the Effect probability (Optional[float]): probability of this effect occurring; default: 1 """ if predictions is None: predictions = list() self.reward_function = reward_function self.transition_function = transition_function self.predictions = predictions self.probability = probability super().__init__(name=name)
def shallow_copy(self): """""" return Effect(reward_function=self.reward_function, predictions=self.predictions, transition_function=self.transition_function) def compose_probabilities(self, probability: float): self.probability = self.probability * probability if self.reward_function: self.reward_function = RewardFunction.from_reward_distribution( RewardDistribution({self.reward_function: probability})) if self.transition_function: self.transition_function = TransitionFunction.from_state_distribution( StateDistribution({self.transition_function: probability})) new_predictions = list() for p in self.predictions: new_predictions.append( Prediction.from_grounding_distribution(p.grounding, GroundingDistribution(p.grounding, {p: probability}), complete=p.complete)) self.predictions = new_predictions @property def prediction_dict(self): prediction_dict = defaultdict(list) for p in self.predictions: # print(prediction_dict[p.grounding.name]) prediction_dict[p.grounding.name].append(p) return dict(prediction_dict) def __repr__(self): if self.name: return f"<Effect \"{self.name}\" with P({self.probability})>" else: return f"<Effect with P({self.probability})>"