Exemple #1
0
    def _train_once(self, itr, paths):
        """Perform one step of policy optimization given one batch of samples.

        Args:
            itr (int): Iteration number.
            paths (list[dict]): A list of collected paths.

        Returns:
            numpy.float64: Average return.

        """
        # -- Stage: Calculate baseline
        paths = [
            dict(
                observations=path['observations'],
                actions=(
                    self._env_spec.action_space.flatten_n(  # noqa: E126
                        path['actions'])),
                rewards=path['rewards'],
                env_infos=path['env_infos'],
                agent_infos=path['agent_infos'],
                dones=np.array([
                    step_type == StepType.TERMINAL
                    for step_type in path['step_types']
                ])) for path in paths
        ]

        if hasattr(self._baseline, 'predict_n'):
            baseline_predictions = self._baseline.predict_n(paths)
        else:
            baseline_predictions = [
                self._baseline.predict(path) for path in paths
            ]

        # -- Stage: Pre-process samples based on collected paths
        samples_data = paths_to_tensors(paths, self.max_episode_length,
                                        baseline_predictions, self._discount,
                                        self._gae_lambda)

        # -- Stage: Run and calculate performance of the algorithm
        undiscounted_returns = log_performance(
            itr,
            EpisodeBatch.from_list(self._env_spec, paths),
            discount=self._discount)
        self._episode_reward_mean.extend(undiscounted_returns)
        tabular.record('Extras/EpisodeRewardMean',
                       np.mean(self._episode_reward_mean))

        samples_data['average_return'] = np.mean(undiscounted_returns)

        logger.log('Optimizing policy...')
        self._optimize_policy(samples_data)

        return samples_data['average_return']
Exemple #2
0
    def _paths_to_tensors(self, paths):
        # pylint: disable=too-many-statements
        """Return processed sample data based on the collected paths.

        Args:
            paths (list[dict]): A list of collected paths.

        Returns:
            dict: Processed sample data, with key
                * observations: (numpy.ndarray)
                * tasks: (numpy.ndarray)
                * actions: (numpy.ndarray)
                * trjectories: (numpy.ndarray)
                * rewards: (numpy.ndarray)
                * baselines: (numpy.ndarray)
                * returns: (numpy.ndarray)
                * valids: (numpy.ndarray)
                * agent_infos: (dict)
                * letent_infos: (dict)
                * env_infos: (dict)
                * trjectory_infos: (dict)
                * paths: (list[dict])

        """
        max_episode_length = self.max_episode_length

        def _extract_latent_infos(infos):
            """Extract and pack latent infos from dict.

            Args:
                infos (dict): A dict that contains latent infos with key
                    prefixed by 'latent_'.

            Returns:
                dict: A dict of latent infos.

            """
            latent_infos = dict()
            for k, v in infos.items():
                if k.startswith('latent_'):
                    latent_infos[k[7:]] = v
            return latent_infos

        for path in paths:
            path['actions'] = (self._env_spec.action_space.flatten_n(
                path['actions']))
            path['tasks'] = self.policy.task_space.flatten_n(
                path['env_infos']['task_onehot'])
            path['latents'] = path['agent_infos']['latent']
            path['latent_infos'] = _extract_latent_infos(path['agent_infos'])

            # - Calculate a forward-looking sliding window.
            # - If step_space has shape (n, d), then trajs will have shape
            #   (n, window, d)
            # - The length of the sliding window is determined by the
            #   trajectory inference spec. We smear the last few elements to
            #   preserve the time dimension.
            # - Only observation is used for a single step.
            #   Alternatively, stacked [observation, action] can be used for
            #   in harder tasks.
            obs = pad_tensor(path['observations'], max_episode_length)
            obs_flat = self._env_spec.observation_space.flatten_n(obs)
            steps = obs_flat
            window = self._inference.spec.input_space.shape[0]
            traj = sliding_window(steps, window, smear=True)
            traj_flat = self._inference.spec.input_space.flatten_n(traj)
            path['trajectories'] = traj_flat

            _, traj_info = self._inference.get_latents(traj_flat)
            path['trajectory_infos'] = traj_info

        all_path_baselines = [self._baseline.predict(path) for path in paths]

        tasks = [path['tasks'] for path in paths]
        tasks = pad_tensor_n(tasks, max_episode_length)

        trajectories = np.stack([path['trajectories'] for path in paths])

        latents = [path['latents'] for path in paths]
        latents = pad_tensor_n(latents, max_episode_length)

        latent_infos = [path['latent_infos'] for path in paths]
        latent_infos = stack_tensor_dict_list(
            [pad_tensor_dict(p, max_episode_length) for p in latent_infos])

        trajectory_infos = [path['trajectory_infos'] for path in paths]
        trajectory_infos = stack_tensor_dict_list(
            [pad_tensor_dict(p, max_episode_length) for p in trajectory_infos])

        samples_data = paths_to_tensors(paths, max_episode_length,
                                        all_path_baselines, self._discount,
                                        self._gae_lambda)
        samples_data['tasks'] = tasks
        samples_data['latents'] = latents
        samples_data['latent_infos'] = latent_infos
        samples_data['trajectories'] = trajectories
        samples_data['trajectory_infos'] = trajectory_infos

        return samples_data