Source code for rlang.agents.RLangQLearningAgentClass

from collections import defaultdict

from simple_rl.agents import QLearningAgent

from ..grounding.utils.primitives import VectorState


[docs]class RLangQLearningAgent(QLearningAgent): """Implementation for a Q Learning agent that utilizes RLang hints"""
[docs] def __init__(self, actions, states, knowledge, name="RLang-Q-learning", use_transition=False, alpha=0.1, gamma=0.99, epsilon=0.1, explore="uniform", anneal=False, default_q=0): """ Args: actions (list): Contains strings denoting the actions. states (list): A list of all possible states. knowledge (list): An RLangKnowledge object. name (str): Denotes the name of the agent. alpha (float): Learning rate. gamma (float): Discount factor. epsilon (float): Exploration term. explore (str): One of {softmax, uniform}. Denotes explore policy. default_q (float): the default value to initialize every entry in the q-table with [by default, set to 0.0] """ def weighted_reward(r_func, state_dict): reward = 0 for k, v in state_dict.items(): reward += r_func(state=VectorState(k)) * v return reward def weighted_value(q_func, state_dict): reward = 0 for k, v in state_dict.items(): maxx = q_func[k][actions[0]] for a in actions: val = q_func[k][a] if val > maxx: maxx = val reward += maxx * v return reward q_func = defaultdict(lambda: defaultdict(lambda: default_q)) reward_function = knowledge.reward_function if reward_function: for s in states: for i in range(len(actions)): a = actions[i] q_func[s][a] = reward_function(state=VectorState(s), action=i) transition_function = knowledge.transition_function if use_transition and transition_function and reward_function: for s in states: for i in range(len(actions)): a = actions[i] s_primei = transition_function(state=VectorState(s), action=i) if s_primei: # Q learning Update r_prime = weighted_reward(reward_function, s_primei) v_s_prime = weighted_value(q_func, s_primei) q_func[s][a] += alpha * (r_prime + gamma * v_s_prime) super().__init__(actions, name=name, alpha=alpha, gamma=gamma, epsilon=epsilon, explore=explore, anneal=anneal, custom_q_init=q_func, default_q=default_q)