def __init__(self, observation_space, action_space, net, rnn=False, normalize_ac=True): BasePol.__init__(self, observation_space, action_space, net, rnn, normalize_ac) self.pd = MultiCategoricalPd() self.to(get_device())
def __init__(self, observation_space, action_space, net, rnn=False, normalize_ac=True, data_parallel=False, parallel_dim=0): BasePol.__init__(self, observation_space, action_space, net, rnn, normalize_ac, data_parallel, parallel_dim) self.pd = MultiCategoricalPd() self.to(get_device())
class MultiCategoricalPol(BasePol): """ Policy with Categorical distribution. Parameters ---------- observation_space : gym.Space observation's space action_space : gym.Space action's space. This should be gym.spaces.MultiDiscrete net : torch.nn.Module rnn : bool normalize_ac : bool If True, the output of network is spreaded for action_space. In this situation the output of network is expected to be in -1~1. """ def __init__(self, observation_space, action_space, net, rnn=False, normalize_ac=True): BasePol.__init__(self, observation_space, action_space, net, rnn, normalize_ac) self.pd = MultiCategoricalPd() self.to(get_device()) def forward(self, obs, hs=None, h_masks=None): obs = self._check_obs_shape(obs) if self.rnn: time_seq, batch_size, *_ = obs.shape if hs is None: if self.hs is None: self.hs = self.net.init_hs(batch_size) hs = self.hs if h_masks is None: h_masks = hs[0].new(time_seq, batch_size, 1).zero_() h_masks = h_masks.reshape(time_seq, batch_size, 1) pis, hs = self.net(obs, hs, h_masks) self.hs = hs else: pis = self.net(obs) ac = self.pd.sample(dict(pis=pis)) ac_real = self.convert_ac_for_real(ac.detach().cpu().numpy()) return ac_real, ac, dict(pis=pis, hs=hs) def deterministic_ac_real(self, obs, hs=None, h_masks=None): """ action for deployment """ obs = self._check_obs_shape(obs) if self.rnn: time_seq, batch_size, *_ = obs.shape if hs is None: if self.hs is None: self.hs = self.net.init_hs(batch_size) hs = self.hs if h_masks is None: h_masks = hs[0].new(time_seq, batch_size, 1).zero_() h_masks = h_masks.reshape(time_seq, batch_size, 1) pis, hs = self.net(obs, hs, h_masks) self.hs = hs else: pis = self.net(obs) _, ac = torch.max(pis, dim=-1) ac_real = self.convert_ac_for_real(ac.detach().cpu().numpy()) return ac_real, ac, dict(pis=pis, hs=hs)