Ejemplo n.º 1
0
    def _sample_trajectory(self, initial_states: Tensor, means: Tensor,
                           stds: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        """Randomly samples T actions and computes the trajectory.

        :returns: (sequence of states, sequence of actions, costs)
        """

        actions = Normal(means,
                         stds).sample(sample_shape=(self._num_rollouts, ))
        if self.max_action is not None:
            indices = torch.abs(actions) > self.max_action
            #print(indices.shape, self._num_rollouts)
            while indices.sum() > 0:
                actions[indices] = Normal(
                    means,
                    stds).sample(sample_shape=(self._num_rollouts, ))[indices]
                indices = torch.abs(actions) > self.max_action
            # not needed#
            actions = actions.clip(-self.max_action, self.max_action)

        # One more state than the time horizon because of the initial state.
        trajectories = torch.empty((self.no_models, self._num_rollouts,
                                    self._time_horizon + 1, self._state_dimen),
                                   device=initial_states.device)
        trajectories[:, :, 0, :] = initial_states
        objective_costs = torch.zeros((
            self.no_models,
            self._time_horizon,
            self._num_rollouts,
        ),
                                      device=initial_states.device)
        dones = torch.zeros((
            self.no_models,
            self._num_rollouts,
        ),
                            device=initial_states.device)

        for t in range(self._time_horizon):
            for d, dynamic in enumerate(self._dynamics):
                next_states, costs, done = dynamic.step(
                    trajectories[:, :, t, :].mean(0), actions[:, t, :])
                trajectories[d, :, t + 1, :] = next_states
                dones[d, :] = torch.maximum(done, dones[d, :])
                objective_costs[d, t, :] = (gamma)**t * costs * (
                    1 - dones[d, :])  #+ dones[d,:]*100
                #if t == 0 : print(costs[:3])#, trajectories[:, :2, t, 0].mean(0))

        objective_costs = torch.mean(objective_costs, 0)
        next_cost = objective_costs[0, :].clone()
        objective_costs = torch.sum(objective_costs, 0)

        return trajectories[0, :, :, :], actions, objective_costs, next_cost