示例#1
0
    def _calc_cost_for_action_sequence(self, time_step: TimeStep, state,
                                       ac_seqs):
        """
        Args:
            time_step (TimeStep): input data for next step prediction
            state (MbrlState): input state for next step prediction
            ac_seqs: action_sequence (Tensor) of shape [batch_size,
                    population_size, solution_dim]), where
                    solution_dim = planning_horizon * num_actions
        Returns:
            cost (Tensor) with shape [batch_size, population_size]
        """
        obs = time_step.observation
        batch_size = obs.shape[0]

        ac_seqs = torch.reshape(
            ac_seqs,
            [batch_size, self._population_size, self._planning_horizon, -1])

        ac_seqs = ac_seqs.permute(2, 0, 1, 3)
        ac_seqs = torch.reshape(
            ac_seqs, (self._planning_horizon, -1, self._num_actions))

        state = state._replace(dynamics=state.dynamics._replace(feature=obs))
        init_obs = self._expand_to_population(obs)
        state = nest.map_structure(self._expand_to_population, state)

        obs = init_obs
        cost = 0
        for i in range(ac_seqs.shape[0]):
            action = ac_seqs[i]
            time_step = time_step._replace(prev_action=action)
            time_step, state = self._dynamics_func(time_step, state)
            next_obs = time_step.observation
            # Note: currently using (next_obs, action), might need to
            # consider (obs, action) in order to be more compatible
            # with the conventional definition of the reward function
            reward_step, state = self._reward_func(next_obs, action, state)
            cost = cost - reward_step
            obs = next_obs

        # reshape cost back to [batch size, population_size]
        cost = torch.reshape(cost, [batch_size, -1])
        return cost
示例#2
0
    def train_step(self, exp: Experience, state: SacState):
        # We detach exp.observation here so that in the case that exp.observation
        # is calculated by some other trainable module, the training of that
        # module will not be affected by the gradient back-propagated from the
        # actor. However, the gradient from critic will still affect the training
        # of that module.
        (action_distribution, action, critics,
         action_state) = self._predict_action(common.detach(exp.observation),
                                              state=state.action)

        log_pi = nest.map_structure(lambda dist, a: dist.log_prob(a),
                                    action_distribution, action)

        if self._act_type == ActionType.Mixed:
            # For mixed type, add log_pi separately
            log_pi = type(self._action_spec)(
                (sum(nest.flatten(log_pi[0])), sum(nest.flatten(log_pi[1]))))
        else:
            log_pi = sum(nest.flatten(log_pi))

        if self._prior_actor is not None:
            prior_step = self._prior_actor.train_step(exp, ())
            log_prior = dist_utils.compute_log_probability(
                prior_step.output, action)
            log_pi = log_pi - log_prior

        actor_state, actor_loss = self._actor_train_step(
            exp, state.actor, action, critics, log_pi, action_distribution)
        critic_state, critic_info = self._critic_train_step(
            exp, state.critic, action, log_pi, action_distribution)
        alpha_loss = self._alpha_train_step(log_pi)

        state = SacState(action=action_state,
                         actor=actor_state,
                         critic=critic_state)
        info = SacInfo(action_distribution=action_distribution,
                       actor=actor_loss,
                       critic=critic_info,
                       alpha=alpha_loss)
        return AlgStep(action, state, info)
示例#3
0
    def _step(self, action):
        if self._done:
            return self.reset()

        if self._action_spec:
            nest.assert_same_structure(self._action_spec, action)

        self._num_steps += 1

        observation = self._get_observation()
        if self._num_steps < self._min_duration:
            self._done = False
        elif self._max_duration and self._num_steps >= self._max_duration:
            self._done = True
        else:
            self._done = self._rng.uniform() < self._episode_end_probability

        if self._batch_size:
            action = nest.map_structure(
                lambda t: np.concatenate([np.expand_dims(t, 0)] * self.
                                         _batch_size), action)

        if self._done:
            reward = self._reward_fn(ds.StepType.LAST, action, observation)
            self._check_reward_shape(reward)
            time_step = ds.termination(observation,
                                       action,
                                       reward,
                                       env_id=self._env_id)
            self._num_steps = 0
        else:
            reward = self._reward_fn(ds.StepType.MID, action, observation)
            self._check_reward_shape(reward)
            time_step = ds.transition(observation,
                                      action,
                                      reward,
                                      discount=self._discount,
                                      env_id=self._env_id)

        return time_step
示例#4
0
    def _critic_train_step(self, exp: Experience, state: SacCriticState,
                           action, log_pi, action_distribution):
        critics, critics_state = self._compute_critics(self._critic_networks,
                                                       exp.observation,
                                                       exp.action,
                                                       state.critics)

        target_critics, target_critics_state = self._compute_critics(
            self._target_critic_networks, exp.observation, action,
            state.target_critics)
        target_critics = target_critics.min(dim=1)[0]

        if self._act_type == ActionType.Discrete:
            critics = self._select_q_value(exp.action, critics)
            target_critics = self._select_q_value(
                action, target_critics.unsqueeze(dim=1))

        elif self._act_type == ActionType.Mixed:
            critics = self._select_q_value(exp.action[0], critics)
            discrete_act_dist = action_distribution[0]
            target_critics = torch.sum(discrete_act_dist.probs *
                                       target_critics,
                                       dim=-1)

        target_critic = target_critics.reshape(exp.reward.shape)
        if self._use_entropy_reward:
            entropy_reward = nest.map_structure(
                lambda la, lp: -torch.exp(la) * lp, self._log_alpha, log_pi)
            entropy_reward = sum(nest.flatten(entropy_reward))
            target_critic = target_critic + entropy_reward

        target_critic = target_critic.detach()

        state = SacCriticState(critics=critics_state,
                               target_critics=target_critics_state)
        info = SacCriticInfo(critics=critics, target_critic=target_critic)

        return state, info
示例#5
0
    def _actor_train_step(self, exp: Experience, state: DdpgActorState):
        action, actor_state = self._actor_network(exp.observation,
                                                  state=state.actor)

        q_values, critic_states = self._critic_networks(
            (exp.observation, action), state=state.critics)
        if q_values.ndim == 3:
            # Multidimensional reward: [B, num_criric_replicas, reward_dim]
            if self._reward_weights is None:
                q_values = q_values.sum(dim=2)
            else:
                q_values = torch.tensordot(q_values,
                                           self._reward_weights,
                                           dims=1)

        if self._num_critic_replicas > 1:
            q_value = q_values.min(dim=1)[0]
        else:
            q_value = q_values.squeeze(dim=1)

        dqda = nest_utils.grad(action, q_value.sum())

        def actor_loss_fn(dqda, action):
            if self._dqda_clipping:
                dqda = torch.clamp(dqda, -self._dqda_clipping,
                                   self._dqda_clipping)
            loss = 0.5 * losses.element_wise_squared_loss(
                (dqda + action).detach(), action)
            if self._action_l2 > 0:
                assert action.requires_grad
                loss += self._action_l2 * (action**2)
            loss = loss.sum(list(range(1, loss.ndim)))
            return loss

        actor_loss = nest.map_structure(actor_loss_fn, dqda, action)
        state = DdpgActorState(actor=actor_state, critics=critic_states)
        info = LossInfo(loss=sum(nest.flatten(actor_loss)), extra=actor_loss)
        return AlgStep(output=action, state=state, info=info)
示例#6
0
 def after_update(self, experience, train_info: SacInfo):
     self._update_target()
     if self._max_log_alpha is not None:
         nest.map_structure(
             lambda la: la.data.copy_(torch.min(la, self._max_log_alpha)),
             self._log_alpha)
示例#7
0
 def _alpha_train_step(self, log_pi):
     alpha_loss = nest.map_structure(
         lambda la, lp, t: la * (-lp - t).detach(), self._log_alpha, log_pi,
         self._target_entropy)
     return sum(nest.flatten(alpha_loss))
示例#8
0
    def __init__(self,
                 observation_spec,
                 action_spec: BoundedTensorSpec,
                 actor_network_cls=ActorDistributionNetwork,
                 critic_network_cls=CriticNetwork,
                 q_network_cls=QNetwork,
                 reward_weights=None,
                 use_entropy_reward=True,
                 use_parallel_network=False,
                 num_critic_replicas=2,
                 env=None,
                 config: TrainerConfig = None,
                 critic_loss_ctor=None,
                 target_entropy=None,
                 prior_actor_ctor=None,
                 target_kld_per_dim=3.,
                 initial_log_alpha=0.0,
                 max_log_alpha=None,
                 target_update_tau=0.05,
                 target_update_period=1,
                 dqda_clipping=None,
                 actor_optimizer=None,
                 critic_optimizer=None,
                 alpha_optimizer=None,
                 debug_summaries=False,
                 name="SacAlgorithm"):
        """
        Args:
            observation_spec (nested TensorSpec): representing the observations.
            action_spec (nested BoundedTensorSpec): representing the actions; can
                be a mixture of discrete and continuous actions. The number of
                continuous actions can be arbitrary while only one discrete
                action is allowed currently. If it's a mixture, then it must be
                a tuple/list ``(discrete_action_spec, continuous_action_spec)``.
            actor_network_cls (Callable): is used to construct the actor network.
                The constructed actor network will be called
                to sample continuous actions. All of its output specs must be
                continuous. Note that we don't need a discrete actor network
                because a discrete action can simply be sampled from the Q values.
            critic_network_cls (Callable): is used to construct critic network.
                for estimating ``Q(s,a)`` given that the action is continuous.
            q_network (Callable): is used to construct QNetwork for estimating ``Q(s,a)``
                given that the action is discrete. Its output spec must be consistent with
                the discrete action in ``action_spec``.
            reward_weights (None|list[float]): this is only used when the reward is
                multidimensional. In that case, the weighted sum of the q values
                is used for training the actor if reward_weights is not None.
                Otherwise, the sum of the q values is used.
            use_entropy_reward (bool): whether to include entropy as reward
            use_parallel_network (bool): whether to use parallel network for
                calculating critics.
            num_critic_replicas (int): number of critics to be used. Default is 2.
            env (Environment): The environment to interact with. ``env`` is a
                batched environment, which means that it runs multiple simulations
                simultateously. ``env` only needs to be provided to the root
                algorithm.
            config (TrainerConfig): config for training. It only needs to be
                provided to the algorithm which performs ``train_iter()`` by
                itself.
            critic_loss_ctor (None|OneStepTDLoss|MultiStepLoss): a critic loss
                constructor. If ``None``, a default ``OneStepTDLoss`` will be used.
            initial_log_alpha (float): initial value for variable ``log_alpha``.
            max_log_alpha (float|None): if not None, ``log_alpha`` will be
                capped at this value.
            target_entropy (float|Callable|None): If a floating value, it's the
                target average policy entropy, for updating ``alpha``. If a
                callable function, then it will be called on the action spec to
                calculate a target entropy. If ``None``, a default entropy will
                be calculated. For the mixed action type, discrete action and
                continuous action will have separate alphas and target entropies,
                so this argument can be a 2-element list/tuple, where the first
                is for discrete action and the second for continuous action.
            prior_actor_ctor (Callable): If provided, it will be called using
                ``prior_actor_ctor(observation_spec, action_spec, debug_summaries=debug_summaries)``
                to constructor a prior actor. The output of the prior actor is
                the distribution of the next action. Two prior actors are implemented:
                ``alf.algorithms.prior_actor.SameActionPriorActor`` and
                ``alf.algorithms.prior_actor.UniformPriorActor``.
            target_kld_per_dim (float): ``alpha`` is dynamically adjusted so that
                the KLD is about ``target_kld_per_dim * dim``.
            target_update_tau (float): Factor for soft update of the target
                networks.
            target_update_period (int): Period for soft update of the target
                networks.
            dqda_clipping (float): when computing the actor loss, clips the
                gradient dqda element-wise between
                ``[-dqda_clipping, dqda_clipping]``. Will not perform clipping if
                ``dqda_clipping == 0``.
            actor_optimizer (torch.optim.optimizer): The optimizer for actor.
            critic_optimizer (torch.optim.optimizer): The optimizer for critic.
            alpha_optimizer (torch.optim.optimizer): The optimizer for alpha.
            debug_summaries (bool): True if debug summaries should be created.
            name (str): The name of this algorithm.
        """
        self._num_critic_replicas = num_critic_replicas
        self._use_parallel_network = use_parallel_network

        critic_networks, actor_network, self._act_type, reward_dim = self._make_networks(
            observation_spec, action_spec, actor_network_cls,
            critic_network_cls, q_network_cls)

        self._use_entropy_reward = use_entropy_reward

        if reward_dim > 1:
            assert not use_entropy_reward, (
                "use_entropy_reward=True is not supported for multidimensional reward"
            )
            assert self._act_type == ActionType.Continuous, (
                "Only continuous action is supported for multidimensional reward"
            )

        self._reward_weights = None
        if reward_weights:
            assert reward_dim > 1, (
                "reward_weights cannot be used for one dimensional reward")
            assert len(reward_weights) == reward_dim, (
                "Mismatch between len(reward_weights)=%s and reward_dim=%s" %
                (len(reward_weights), reward_dim))
            self._reward_weights = torch.tensor(reward_weights,
                                                dtype=torch.float32)

        def _init_log_alpha():
            return nn.Parameter(torch.tensor(float(initial_log_alpha)))

        if self._act_type == ActionType.Mixed:
            # separate alphas for discrete and continuous actions
            log_alpha = type(action_spec)(
                (_init_log_alpha(), _init_log_alpha()))
        else:
            log_alpha = _init_log_alpha()

        action_state_spec = SacActionState(
            actor_network=(() if self._act_type == ActionType.Discrete else
                           actor_network.state_spec),
            critic=(() if self._act_type == ActionType.Continuous else
                    critic_networks.state_spec))
        super().__init__(
            observation_spec,
            action_spec,
            train_state_spec=SacState(
                action=action_state_spec,
                actor=(() if self._act_type != ActionType.Continuous else
                       critic_networks.state_spec),
                critic=SacCriticState(
                    critics=critic_networks.state_spec,
                    target_critics=critic_networks.state_spec)),
            predict_state_spec=SacState(action=action_state_spec),
            env=env,
            config=config,
            debug_summaries=debug_summaries,
            name=name)

        if actor_optimizer is not None:
            self.add_optimizer(actor_optimizer, [actor_network])
        if critic_optimizer is not None:
            self.add_optimizer(critic_optimizer, [critic_networks])
        if alpha_optimizer is not None:
            self.add_optimizer(alpha_optimizer, nest.flatten(log_alpha))

        self._log_alpha = log_alpha
        if self._act_type == ActionType.Mixed:
            self._log_alpha_paralist = nn.ParameterList(
                nest.flatten(log_alpha))

        if max_log_alpha is not None:
            self._max_log_alpha = torch.tensor(float(max_log_alpha))
        else:
            self._max_log_alpha = None

        self._actor_network = actor_network
        self._critic_networks = critic_networks
        self._target_critic_networks = self._critic_networks.copy(
            name='target_critic_networks')

        if critic_loss_ctor is None:
            critic_loss_ctor = OneStepTDLoss
        critic_loss_ctor = functools.partial(critic_loss_ctor,
                                             debug_summaries=debug_summaries)
        # Have different names to separate their summary curves
        self._critic_losses = []
        for i in range(num_critic_replicas):
            self._critic_losses.append(
                critic_loss_ctor(name="critic_loss%d" % (i + 1)))

        self._prior_actor = None
        if prior_actor_ctor is not None:
            assert self._act_type == ActionType.Continuous, (
                "Only continuous action is supported when using prior_actor")
            self._prior_actor = prior_actor_ctor(
                observation_spec=observation_spec,
                action_spec=action_spec,
                debug_summaries=debug_summaries)
            total_action_dims = sum(
                [spec.numel for spec in alf.nest.flatten(action_spec)])
            self._target_entropy = -target_kld_per_dim * total_action_dims
        else:
            if self._act_type == ActionType.Mixed:
                if not isinstance(target_entropy, (tuple, list)):
                    target_entropy = nest.map_structure(
                        lambda _: target_entropy, self._action_spec)
                # separate target entropies for discrete and continuous actions
                self._target_entropy = nest.map_structure(
                    lambda spec, t: _set_target_entropy(self.name, t, [spec]),
                    self._action_spec, target_entropy)
            else:
                self._target_entropy = _set_target_entropy(
                    self.name, target_entropy, nest.flatten(self._action_spec))

        self._dqda_clipping = dqda_clipping

        self._update_target = common.get_target_updater(
            models=[self._critic_networks],
            target_models=[self._target_critic_networks],
            tau=target_update_tau,
            period=target_update_period)
 def calc_loss(self, info: DynamicsInfo):
     # Here we take mean over the loss to avoid undesired additional
     # masking from base algorithm's ``update_with_gradient``.
     scalar_loss = nest.map_structure(torch.mean, info.loss)
     return LossInfo(scalar_loss=scalar_loss.loss, extra=scalar_loss.loss)
示例#10
0
 def _get_observation(self):
     batch_size = (self._batch_size, ) if self._batch_size else ()
     return nest.map_structure(
         lambda spec: self._sample_spec(spec, batch_size).cpu().numpy(),
         self._observation_spec)