Source code for rlang.agents.RLangPolicyAgentClass

import torch
from torch import nn


[docs]class RLangPolicyAgent(nn.Module): """Implementation for an agent that uses an RLang policy"""
[docs] def __init__(self, rlang_policy, epsilon=1e-8, n_actions=2, obs_normalizer=None): """ Args: rlang_policy (Policy): an RLang policy. epsilon (float): Exploration term. n_actions (int): Number of actions. obs_normalizer: """ self.rlang_policy = rlang_policy self.eps = epsilon self.n_actions = n_actions self.obs_normalizer = obs_normalizer super().__init__()
def _dict_to_tensor(self, state): actions_prob = torch.zeros(self.n_actions, device=state.device) actions_mask = torch.zeros(self.n_actions, device=state.device) for action, prob in self.rlang_policy(state=state).items(): actions_prob[action] = prob actions_mask[action] = 1 if len(actions_prob) < self.n_actions: # redistribute the remaining prob uniformly remaining_prob = (1. - actions_prob.sum()) / (len(actions_prob) - self.n_actions) actions_prob = torch.Tensor(actions_prob) actions_prob -= self.eps * (len(actions_prob) - self.n_actions) if remaining_prob == 0 else actions_prob remaining_prob = self.eps if remaining_prob == 0 else remaining_prob actions_prob += (1 - actions_mask) * remaining_prob else: if (actions_prob == 0).any(): actions_prob += self.eps return torch.log(actions_prob / actions_prob.sum()).unsqueeze(0)
[docs] def forward(self, state): state = self.obs_normalizer.inverse(state) if self.obs_normalizer is not None else state batch = state.size()[0] if batch == 1: return self._dict_to_tensor(state[0]) else: acts = [self._dict_to_tensor(state[i]) for i in range(batch)] return torch.cat(acts, dim=0)
def set_obs_normalizer(self, obs_normalizer): self.obs_normalizer = obs_normalizer