Exemple #1
0
    def forward(self, xs, hs, h_masks):
        print(xs.shape)
        time_seq, batch_size, *_ = xs.shape

        hs = (hs[0].reshape(batch_size, self.cell_size),
              hs[1].reshape(batch_size, self.cell_size))

        dict_xs = flatten_to_dict(xs, self.ob_space)
        xs = torch.relu(self.input_layer(dict_xs['angle']))
        ang_vels = dict_xs['angular_velocity']

        hiddens = []
        for x, ang_vel, mask in zip(xs, ang_vels, h_masks):
            hs = (hs[0] * (1 - mask), hs[1] * (1 - mask))
            hs = self.cell(
                torch.cat([x, ang_vel], dim=1), hs)
            hiddens.append(hs[0])
        hiddens = torch.cat([h.unsqueeze(0) for h in hiddens], dim=0)

        if not self.discrete:
            means = torch.tanh(self.mean_layer(hiddens))
            log_std = self.log_std_param.expand_as(means)
            return means, log_std, hs
        else:
            if self.multi:
                return torch.cat([torch.softmax(ol(hiddens), dim=-1).unsqueeze(-2) for ol in self.output_layers], dim=-2), hs
            else:
                return torch.softmax(self.output_layer(hiddens), dim=-1), hs
Exemple #2
0
def test_flatten2dict():
    dict_env = gym.make('PendulumDictEnv-v0')
    dict_env = GymEnv(dict_env)
    dict_ob = dict_env.observation_space.sample()
    dict_observation_space = dict_env.observation_space
    dict_keys = dict_env.observation_space.spaces.keys()
    env = _make_flat(dict_env, dict_keys)
    flatten_ob = env.observation(dict_ob)
    recovered_dict_ob = flatten_to_dict(flatten_ob, dict_observation_space,
                                        dict_keys)
    tf = []
    for (a_key, a_val), (b_key, b_val) in zip(dict_ob.items(),
                                              recovered_dict_ob.items()):
        tf.append(a_key == b_key)
        tf.append(all(a_val == b_val))
    assert all(tf)
Exemple #3
0
 def forward(self, ob):
     dict_ob = flatten_to_dict(ob, self.ob_space)
     h = F.relu(self.fc1(dict_ob['angle']))
     h = F.relu(
         self.fc2(torch.cat([h, dict_ob['angular_velocity']], dim=1)))
     if not self.discrete:
         mean = torch.tanh(self.mean_layer(h))
         if not self.deterministic:
             log_std = self.log_std_param.expand_as(mean)
             return mean, log_std
         else:
             return mean
     else:
         if self.multi:
             return torch.cat([torch.softmax(ol(h), dim=-1).unsqueeze(-2) for ol in self.output_layers], dim=-2)
         else:
             return torch.softmax(self.output_layer(h), dim=-1)