def forward(self, xs, hs, h_masks): print(xs.shape) time_seq, batch_size, *_ = xs.shape hs = (hs[0].reshape(batch_size, self.cell_size), hs[1].reshape(batch_size, self.cell_size)) dict_xs = flatten_to_dict(xs, self.ob_space) xs = torch.relu(self.input_layer(dict_xs['angle'])) ang_vels = dict_xs['angular_velocity'] hiddens = [] for x, ang_vel, mask in zip(xs, ang_vels, h_masks): hs = (hs[0] * (1 - mask), hs[1] * (1 - mask)) hs = self.cell( torch.cat([x, ang_vel], dim=1), hs) hiddens.append(hs[0]) hiddens = torch.cat([h.unsqueeze(0) for h in hiddens], dim=0) if not self.discrete: means = torch.tanh(self.mean_layer(hiddens)) log_std = self.log_std_param.expand_as(means) return means, log_std, hs else: if self.multi: return torch.cat([torch.softmax(ol(hiddens), dim=-1).unsqueeze(-2) for ol in self.output_layers], dim=-2), hs else: return torch.softmax(self.output_layer(hiddens), dim=-1), hs
def test_flatten2dict(): dict_env = gym.make('PendulumDictEnv-v0') dict_env = GymEnv(dict_env) dict_ob = dict_env.observation_space.sample() dict_observation_space = dict_env.observation_space dict_keys = dict_env.observation_space.spaces.keys() env = _make_flat(dict_env, dict_keys) flatten_ob = env.observation(dict_ob) recovered_dict_ob = flatten_to_dict(flatten_ob, dict_observation_space, dict_keys) tf = [] for (a_key, a_val), (b_key, b_val) in zip(dict_ob.items(), recovered_dict_ob.items()): tf.append(a_key == b_key) tf.append(all(a_val == b_val)) assert all(tf)
def forward(self, ob): dict_ob = flatten_to_dict(ob, self.ob_space) h = F.relu(self.fc1(dict_ob['angle'])) h = F.relu( self.fc2(torch.cat([h, dict_ob['angular_velocity']], dim=1))) if not self.discrete: mean = torch.tanh(self.mean_layer(h)) if not self.deterministic: log_std = self.log_std_param.expand_as(mean) return mean, log_std else: return mean else: if self.multi: return torch.cat([torch.softmax(ol(h), dim=-1).unsqueeze(-2) for ol in self.output_layers], dim=-2) else: return torch.softmax(self.output_layer(h), dim=-1)