예제 #1
0
    def sample(self, batch_size: int):
        if batch_size > 1 and not self._suppress_warning:
            warnings.warn("Sampling many goals is slow. Consider using "
                          "PresampledImageAndStateDistribution")
        contexts = self._base_distribution.sample(batch_size)
        images = []
        for i in range(batch_size):
            goal = ppp.treemap(lambda x: x[i],
                               contexts,
                               atomic_type=np.ndarray)
            env_state = self._env.get_env_state()
            self._env.set_to_goal(goal)
            img_goal = self._renderer(self._env)
            self._env.set_env_state(env_state)
            images.append(img_goal)

        contexts[self._image_goal_key] = np.array(images)
        return contexts
예제 #2
0
    def random_batch(self, batch_size):
        num_future_contexts = int(batch_size * self._fraction_future_context)
        num_next_contexts = int(batch_size * self._fraction_next_context)
        num_replay_buffer_contexts = int(
            batch_size * self._fraction_replay_buffer_context
        )
        num_distrib_contexts = int(
            batch_size * self._fraction_distribution_context)
        num_rollout_contexts = (
                batch_size - num_future_contexts - num_distrib_contexts
                - num_next_contexts - num_replay_buffer_contexts
        )
        indices = self._sample_indices(batch_size)
        obs_dict = self._batch_obs_dict(indices)
        next_obs_dict = self._batch_next_obs_dict(indices)
        contexts = [{
            k: next_obs_dict[k][:num_rollout_contexts]
            for k in self._context_keys
        }]

        if num_distrib_contexts > 0:
            sampled_contexts = self._context_distribution.sample(
                num_distrib_contexts)
            sampled_contexts = {
                k: sampled_contexts[k] for k in self._context_keys}
            contexts.append(sampled_contexts)

        if num_replay_buffer_contexts > 0:
            replay_buffer_contexts = self._get_replay_buffer_contexts(
                num_replay_buffer_contexts,
            )
            replay_buffer_contexts = {
                k: replay_buffer_contexts[k] for k in self._context_keys}
            contexts.append(replay_buffer_contexts)

        if num_next_contexts > 0:
            start_idx = -(num_future_contexts + num_next_contexts)
            start_state_indices = indices[start_idx:-num_future_contexts]
            next_contexts = self._get_next_contexts(start_state_indices)
            contexts.append(next_contexts)

        if num_future_contexts > 0:
            start_state_indices = indices[-num_future_contexts:]
            future_contexts = self._get_future_contexts(start_state_indices)
            future_contexts = {
                k: future_contexts[k] for k in self._context_keys}
            contexts.append(future_contexts)

        actions = self._actions[indices]

        def concat(*x):
            return np.concatenate(x, axis=0)

        new_contexts = ppp.treemap(concat, *tuple(contexts),
                                   atomic_type=np.ndarray)

        batch = {
            'observations': obs_dict[self.observation_key],
            'actions': actions,
            'rewards': self._rewards[indices],
            'terminals': self._terminals[indices],
            'next_observations': next_obs_dict[self.observation_key],
            'indices': np.array(indices).reshape(-1, 1),
            **new_contexts
            # 'contexts': new_contexts,
        }
        new_batch = self._post_process_batch_fn(
            batch,
            self,
            obs_dict, next_obs_dict, new_contexts
        )
        return new_batch
예제 #3
0
def batchify(x):
    return ppp.treemap(lambda x: x[None], x, atomic_type=np.ndarray)