"""Every RLang program (including any vocabulary files) grounds to an :py:class:`.RLangKnowledge` object."""
from __future__ import annotations
from typing import Dict, Any
from collections.abc import MutableMapping
import functools
from .grounding.utils.utils import Domain
from .grounding.utils.primitives import MDPObject
from .grounding import *
[docs]class RLangKnowledge(MutableMapping):
"""Provides an interface for accessing stored RLang information. Behaves similarly to a Python dictionary.
.. note::
In typical usage, an :py:class:`.RLangKnowledge` object is not instantiated by the user
but is instead returned from a call to :py:func:`.parse_file` or :py:func:`.parse`.
Examples::
base = RLangKnowledge()
base['x_location'] = Factor([1])
"""
[docs] def __init__(self):
self.rlang_variables = dict()
self.policy = None
"""A :py:class:`.Policy` object"""
self.reward_function = None
"""A :py:class:`.RewardFunction` object"""
self.transition_function = None
"""A :py:class:`.TransitionFunction` object"""
self.plan = None
self.proto_predictions = list()
self.mdp_metadata = None
[docs] def predictions(self, *args, **kwargs) -> Dict[Grounding, Any]:
"""Get a dictionary of :py:class:`.Grounding` objects whose value for the next state
can be predicted using the keyword arguments provided.
Args:
state (Optional[State]): a given current state
action (Optional[Action]): a given action
"""
# TODO: This breaks after migrating to probabilistic functions. Fix this somehow.
domain = Domain.ANY
if 'state' in kwargs.keys():
domain += Domain.STATE
if 'action' in kwargs.keys():
domain += Domain.ACTION
if 'next_state' in kwargs.keys():
domain += Domain.NEXT_STATE
# else:
# next_state = self.get_next_state(*args, **kwargs)
# if next_state:
# domain += Domain.NEXT_STATE
# kwargs['next_state'] = next_state
# print(domain)
predictables = list(filter(lambda x: x.domain <= domain, self.proto_predictions))
predictions = dict()
for p in predictables:
predictions[p.grounding] = p(*args, **kwargs)
return predictions
def get_next_state(self, *args, **kwargs):
if self.transition_function:
return self.transition_function(*args, **kwargs)
else:
return dict()
def __getitem__(self, key: str):
return self.rlang_variables[key]
def __setitem__(self, key: str, value: Grounding):
self.rlang_variables[key] = value
def __delitem__(self, key: str):
del self.rlang_variables[key]
def __iter__(self):
return iter(self.rlang_variables)
def __len__(self):
return len(self.rlang_variables)
def rlang_variables_of_type(self, grounding_type):
""":meta private:"""
return {k: v for (k, v) in self.rlang_variables.items() if isinstance(v, grounding_type)}
def factors(self):
return self.rlang_variables_of_type(Factor)
def features(self):
return self.rlang_variables_of_type(Feature)
def propositions(self):
return self.rlang_variables_of_type(Proposition)
def policies(self):
return self.rlang_variables_of_type(Policy)
def effects(self):
return self.rlang_variables_of_type(Effect)
def classes(self):
return {k: v for (k, v) in self.rlang_variables.items() if isinstance(v, type) and issubclass(v, MDPObject)}
def objects(self):
return self.rlang_variables_of_type(MDPObjectGrounding)
def objects_of_type(self, cls):
objs = self.objects()
return {k: v for (k, v) in objs.items() if isinstance(v.obj, cls)}
@functools.lru_cache(maxsize=None)
def memoized_reward_function(self, state, action):
return self.reward_function(state=state, action=action)
@functools.lru_cache(maxsize=None)
def memoized_transition_function(self, state, action):
return self.transition_function(state=state, action=action)
def __hash__(self):
return hash(tuple(sorted(self.rlang_variables.items())))