Exemplo n.º 1
0
    def add_path(self, path):
        obs = path["observations"]
        actions = path["actions"]
        next_obs = path["next_observations"]
        rewards = path["rewards"]
        terminals = path["terminals"]
        path_len = len(terminals)

        # assert path_len == self.max_path_length

        actions = flatten_n(actions)
        obs = flatten_dict(obs, self.ob_keys_to_save)
        next_obs = flatten_dict(next_obs, self.ob_keys_to_save)

        self._actions[self._top][:path_len] = actions

        rewards = np.squeeze(rewards)
        rewards = np.expand_dims(rewards, -1)
        self._environment_rewards[self._top][:path_len] = rewards

        self._terminals[self._top][:path_len] = terminals
        if len(np.argwhere(terminals)):
            episode_length = np.argwhere(terminals)[:, 0][0] + 1
            assert episode_length == path_len
        else:
            episode_length = self.max_path_length
        self._episode_lengths[self._top] = episode_length

        for key in self.ob_keys_to_save:
            self._obs[key][self._top][:path_len] = obs[key]
            self._next_obs[key][self._top][:path_len] = next_obs[key]

        self._actions[self._top][:path_len] = actions

        self._advance()
Exemplo n.º 2
0
    def add_path(self, path):
        obs = path["observations"]
        actions = path["actions"]
        next_obs = path["next_observations"]
        terminals = path["terminals"]
        path_len = len(terminals)
        assert path_len == self.max_path_length

        actions = flatten_n(actions)
        obs = flatten_dict(obs, self.ob_keys_to_save)
        next_obs = flatten_dict(next_obs, self.ob_keys_to_save)

        self._actions[self._top] = actions
        self._terminals[self._top] = terminals
        for key in self.ob_keys_to_save:
            self._obs[key][self._top] = obs[key]
            self._next_obs[key][self._top] = next_obs[key]

        self._actions[self._top] = actions

        self._advance()
Exemplo n.º 3
0
 def add_decoded_vae_goals_to_path(self, path):
     # decoding the self-sampled vae images should be done in batch (here)
     # rather than in the env for efficiency
     desired_goals = flatten_dict(
         path["observations"],
         [self.desired_goal_key])[self.desired_goal_key]
     desired_decoded_goals = self.env._decode(desired_goals)
     desired_decoded_goals = desired_decoded_goals.reshape(
         len(desired_decoded_goals), -1)
     for idx, next_obs in enumerate(path["observations"]):
         path["observations"][idx][
             self.decoded_desired_goal_key] = desired_decoded_goals[idx]
         path["next_observations"][idx][
             self.decoded_desired_goal_key] = desired_decoded_goals[idx]
Exemplo n.º 4
0
 def add_decoded_vae_goals_to_path(self, path):
     # decoding the self-sampled vae images should be done in batch (here)
     # rather than in the env for efficiency
     if 'latent_desired_goal' in path['observations'][0].keys():
         desired_encoded_goals = flatten_dict(
             path['observations'],
             ['latent_desired_goals']
         )['latent_desired_goals']
         desired_decoded_goals = self.env._decode(desired_encoded_goals)
         desired_decoded_goals = desired_decoded_goals.reshape(
             len(desired_decoded_goals),
             -1
         )
         for idx, next_obs in enumerate(path['observations']):
             path['observations'][idx][self.decoded_desired_goal_key] = \
                 desired_decoded_goals[idx]
             path['next_observations'][idx][self.decoded_desired_goal_key] = \
                 desired_decoded_goals[idx]