Ejemplo n.º 1
0
    def _actor_train_step(self, exp: Experience, state, action, critics,
                          log_pi, action_distribution):
        neg_entropy = sum(nest.flatten(log_pi))

        if self._act_type == ActionType.Discrete:
            # Pure discrete case doesn't need to learn an actor network
            return (), LossInfo(extra=SacActorInfo(neg_entropy=neg_entropy))

        if self._act_type == ActionType.Continuous:
            critics, critics_state = self._compute_critics(
                self._critic_networks, exp.observation, action, state)
            if critics.ndim == 3:
                # Multidimensional reward: [B, num_criric_replicas, reward_dim]
                if self._reward_weights is None:
                    critics = critics.sum(dim=2)
                else:
                    critics = torch.tensordot(critics,
                                              self._reward_weights,
                                              dims=1)

            target_q_value = critics.min(dim=1)[0]
            continuous_log_pi = log_pi
            cont_alpha = torch.exp(self._log_alpha).detach()
        else:
            # use the critics computed during action prediction for Mixed type
            critics_state = ()
            discrete_act_dist = action_distribution[0]
            discrete_entropy = discrete_act_dist.entropy()
            # critics is already after min over replicas
            weighted_q_value = torch.sum(discrete_act_dist.probs * critics,
                                         dim=-1)
            discrete_alpha = torch.exp(self._log_alpha[0]).detach()
            target_q_value = weighted_q_value + discrete_alpha * discrete_entropy
            action, continuous_log_pi = action[1], log_pi[1]
            cont_alpha = torch.exp(self._log_alpha[1]).detach()

        dqda = nest_utils.grad(action, target_q_value.sum())

        def actor_loss_fn(dqda, action):
            if self._dqda_clipping:
                dqda = torch.clamp(dqda, -self._dqda_clipping,
                                   self._dqda_clipping)
            loss = 0.5 * losses.element_wise_squared_loss(
                (dqda + action).detach(), action)
            return loss.sum(list(range(1, loss.ndim)))

        actor_loss = nest.map_structure(actor_loss_fn, dqda, action)
        actor_loss = math_ops.add_n(nest.flatten(actor_loss))
        actor_info = LossInfo(loss=actor_loss + cont_alpha * continuous_log_pi,
                              extra=SacActorInfo(actor_loss=actor_loss,
                                                 neg_entropy=neg_entropy))
        return critics_state, actor_info
Ejemplo n.º 2
0
    def calc_loss(self, training_info: TrainingInfo):
        info = training_info.info  # SarsaInfo
        critic_loss = losses.element_wise_squared_loss(info.returns,
                                                       info.critic)
        not_first_step = tf.not_equal(training_info.step_type, StepType.FIRST)
        critic_loss *= tf.cast(not_first_step, tf.float32)

        def _summary():
            with self.name_scope:
                tf.summary.scalar("values", tf.reduce_mean(info.critic))
                tf.summary.scalar("returns", tf.reduce_mean(info.returns))
                safe_mean_hist_summary("td_error", info.returns - info.critic)
                tf.summary.scalar(
                    "explained_variance_of_return_by_value",
                    common.explained_variance(info.critic, info.returns))

        if self._debug_summaries:
            common.run_if(common.should_record_summaries(), _summary)

        return LossInfo(
            loss=info.actor_loss,
            # put critic_loss to scalar_loss because loss will be masked by
            # ~is_last at train_complete(). The critic_loss here should be
            # masked by ~is_first instead, which is done above.
            scalar_loss=tf.reduce_mean(critic_loss),
            extra=SarsaLossInfo(actor=info.actor_loss, critic=critic_loss))
Ejemplo n.º 3
0
    def _actor_train_step(self, exp: Experience, state: DdpgActorState):
        action, actor_state = self._actor_network(exp.observation,
                                                  exp.step_type,
                                                  network_state=state.actor)

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(action)
            q_value, critic_state = self._critic_network(
                (exp.observation, action), network_state=state.critic)

        dqda = tape.gradient(q_value, action)

        def actor_loss_fn(dqda, action):
            if self._dqda_clipping:
                dqda = tf.clip_by_value(dqda, -self._dqda_clipping,
                                        self._dqda_clipping)
            loss = 0.5 * losses.element_wise_squared_loss(
                tf.stop_gradient(dqda + action), action)
            loss = tf.reduce_sum(loss, axis=list(range(1, len(loss.shape))))
            return loss

        actor_loss = tf.nest.map_structure(actor_loss_fn, dqda, action)
        state = DdpgActorState(actor=actor_state, critic=critic_state)
        info = LossInfo(loss=tf.add_n(tf.nest.flatten(actor_loss)),
                        extra=actor_loss)
        return PolicyStep(action=action, state=state, info=info)
    def train_step(self, exp: TimeStep, state):
        # [B, num_unroll_steps + 1]
        info = exp.rollout_info
        targets = common.as_list(info.target)
        batch_size = exp.step_type.shape[0]
        latent, state = self._encoding_net(exp.observation, state)

        sim_latent = self._multi_step_latent_rollout(latent,
                                                     self._num_unroll_steps,
                                                     info.action, state)

        loss = 0
        for i, decoder in enumerate(self._decoders):
            # [num_unroll_steps + 1)*B, ...]
            train_info = decoder.train_step(sim_latent).info
            train_info_spec = dist_utils.extract_spec(train_info)
            train_info = dist_utils.distributions_to_params(train_info)
            train_info = alf.nest.map_structure(
                lambda x: x.reshape(self._num_unroll_steps + 1, batch_size, *x.
                                    shape[1:]), train_info)
            # [num_unroll_steps + 1, B, ...]
            train_info = dist_utils.params_to_distributions(
                train_info, train_info_spec)
            target = alf.nest.map_structure(lambda x: x.transpose(0, 1),
                                            targets[i])
            loss_info = decoder.calc_loss(target, train_info, info.mask.t())
            loss_info = alf.nest.map_structure(lambda x: x.mean(dim=0),
                                               loss_info)
            loss += loss_info.loss

        loss_info = LossInfo(loss=loss, extra=loss)

        return AlgStep(output=latent, state=state, info=loss_info)
Ejemplo n.º 5
0
    def train_step(self, exp: Experience, state, trainable=True):
        """This function trains the discriminator or generates intrinsic rewards.

        If ``trainable=True``, then it only generates and returns the pred loss;
        otherwise it only generates rewards with no grad.
        """
        # Discriminator training from its own replay buffer  or
        # Discriminator computing intrinsic rewards for training lower_rl
        untrans_observation, prev_skill, switch_skill, steps = exp.observation

        observation = self._observation_transformer(untrans_observation)
        loss = self._predict_skill_loss(observation, exp.prev_action,
                                        prev_skill, steps, state)

        first_observation = self._update_state_if_necessary(
            switch_skill, observation, state.first_observation)
        subtrajectory = self._clear_subtrajectory_if_necessary(
            state.subtrajectory, switch_skill)
        new_state = DiscriminatorState(first_observation=first_observation,
                                       untrans_observation=untrans_observation,
                                       subtrajectory=subtrajectory)

        valid_masks = (exp.step_type != StepType.FIRST)
        if self._sparse_reward:
            # Only give intrinsic rewards at the last step of the skill
            valid_masks &= switch_skill
        loss *= valid_masks.to(torch.float32)

        if trainable:
            info = LossInfo(loss=loss, extra=dict(discriminator_loss=loss))
            return AlgStep(state=new_state, info=info)
        else:
            intrinsic_reward = -loss.detach() / self._skill_dim
            return AlgStep(state=common.detach(new_state),
                           info=intrinsic_reward)
Ejemplo n.º 6
0
    def _calc_critic_loss(self, experience, train_info: SacInfo):
        critic_info = train_info.critic

        critic_losses = []
        for i, l in enumerate(self._critic_losses):
            critic_losses.append(
                l(experience=experience,
                  value=critic_info.critics[:, :, i, ...],
                  target_value=critic_info.target_critic).loss)

        critic_loss = math_ops.add_n(critic_losses)

        if (experience.batch_info != ()
                and experience.batch_info.importance_weights != ()):
            valid_masks = (experience.step_type != StepType.LAST).to(
                torch.float32)
            valid_n = torch.clamp(valid_masks.sum(dim=0), min=1.0)
            priority = ((critic_loss * valid_masks).sum(dim=0) /
                        valid_n).sqrt()
        else:
            priority = ()

        return LossInfo(loss=critic_loss,
                        priority=priority,
                        extra=critic_loss / float(self._num_critic_replicas))
Ejemplo n.º 7
0
    def calc_loss(self, experience, train_info: DdpgInfo):

        critic_losses = [None] * self._num_critic_replicas
        for i in range(self._num_critic_replicas):
            critic_losses[i] = self._critic_losses[i](
                experience=experience,
                value=train_info.critic.q_values[:, :, i, ...],
                target_value=train_info.critic.target_q_values).loss

        critic_loss = math_ops.add_n(critic_losses)

        if (experience.batch_info != ()
                and experience.batch_info.importance_weights != ()):
            valid_masks = (experience.step_type != StepType.LAST).to(
                torch.float32)
            valid_n = torch.clamp(valid_masks.sum(dim=0), min=1.0)
            priority = ((critic_loss * valid_masks).sum(dim=0) /
                        valid_n).sqrt()
        else:
            priority = ()

        actor_loss = train_info.actor_loss

        return LossInfo(loss=critic_loss + actor_loss.loss,
                        priority=priority,
                        extra=DdpgLossInfo(critic=critic_loss,
                                           actor=actor_loss.extra))
Ejemplo n.º 8
0
    def _minmax_grad(self,
                     inputs,
                     outputs,
                     loss_func,
                     entropy_regularization,
                     transform_func=None):
        """
        Compute particle gradients via minmax svgd (Fisher Neural Sampler). 
        """
        assert inputs is None, '"minmax" does not support conditional generator'

        # optimize the critic using resampled particles
        assert transform_func is None, (
            "function value based vi is not supported for minmax_grad")
        num_particles = outputs.shape[0]

        for i in range(self._critic_iter_num):

            if self._minmax_resample:
                critic_inputs, _ = self._predict(inputs,
                                                 batch_size=num_particles)
            else:
                critic_inputs = outputs.detach().clone()
                critic_inputs.requires_grad = True

            critic_loss = self._critic_train_step(critic_inputs, loss_func,
                                                  entropy_regularization)
            self._critic.update_with_gradient(LossInfo(loss=critic_loss))

        # compute amortized svgd
        loss = loss_func(outputs.detach())
        critic_outputs = self._critic.predict_step(outputs.detach()).output
        loss_propagated = torch.sum(-critic_outputs.detach() * outputs, dim=-1)

        return loss, loss_propagated
Ejemplo n.º 9
0
    def train_step(self, time_step: TimeStep, state: DynamicsState):
        """
        Args:
            time_step (TimeStep): input data for dynamics learning
            state (DynamicsState): state for dynamics learning (previous observation)
        Returns:
            AlgStep:
                output: empty tuple ()
                state (DynamicsState): state for training
                info (DynamicsInfo):
        """
        feature = time_step.observation
        dynamics_step = self.predict_step(time_step, state)
        forward_pred = dynamics_step.output
        forward_loss = (feature - forward_pred)**2
        forward_loss = 0.5 * forward_loss.mean(
            list(range(1, forward_loss.ndim)))

        # we mask out FIRST as its state is invalid
        valid_masks = (time_step.step_type != StepType.FIRST).to(torch.float32)
        forward_loss = forward_loss * valid_masks

        info = DynamicsInfo(loss=LossInfo(
            loss=forward_loss, extra=dict(forward_loss=forward_loss)))

        state = state._replace(feature=feature)

        return AlgStep(output=(), state=state, info=info)
Ejemplo n.º 10
0
    def _step(self, time_step: TimeStep, state, calc_rewards=True):
        """This step is for both `rollout_step` and `train_step`.

        Args:
            time_step (TimeStep): input time_step data for ICM
            state (Tensor): state for ICM (previous observation)
            calc_rewards (bool): whether calculate rewards

        Returns:
            AlgStep:
                output: empty tuple ()
                state: observation
                info (ICMInfo):
        """
        feature = time_step.observation
        prev_action = time_step.prev_action.detach()

        # normalize observation for easier prediction
        if self._observation_normalizer is not None:
            feature = self._observation_normalizer.normalize(feature)

        if self._encoding_net is not None:
            feature, _ = self._encoding_net(feature)
        prev_feature = state

        forward_pred, _ = self._forward_net(
            inputs=[prev_feature.detach(),
                    self._encode_action(prev_action)])
        # nn.MSELoss doesn't support reducing along a dim
        forward_loss = 0.5 * torch.mean(
            math_ops.square(forward_pred - feature.detach()), dim=-1)

        action_pred, _ = self._inverse_net([prev_feature, feature])

        if self._action_spec.is_discrete:
            inverse_loss = torch.nn.CrossEntropyLoss(reduction='none')(
                input=action_pred, target=prev_action.to(torch.int64))
        else:
            # nn.MSELoss doesn't support reducing along a dim
            inverse_loss = 0.5 * torch.mean(
                math_ops.square(action_pred - prev_action), dim=-1)

        intrinsic_reward = ()
        if calc_rewards:
            intrinsic_reward = forward_loss.detach()
            intrinsic_reward = self._reward_normalizer.normalize(
                intrinsic_reward)

        return AlgStep(
            output=(),
            state=feature,
            info=ICMInfo(
                reward=intrinsic_reward,
                loss=LossInfo(
                    loss=forward_loss + inverse_loss,
                    extra=dict(
                        forward_loss=forward_loss,
                        inverse_loss=inverse_loss))))
Ejemplo n.º 11
0
    def calc_loss(self, experience, train_info: MerlinInfo):
        """Calculate loss."""
        self.summarize_reward("reward", experience.reward)
        mbp_loss_info = self._mbp.calc_loss(experience, train_info.mbp_info)
        mba_loss_info = self._mba.calc_loss(experience, train_info.mba_info)

        return LossInfo(loss=mbp_loss_info.loss + mba_loss_info.loss,
                        extra=MerlinLossInfo(mbp=mbp_loss_info.extra,
                                             mba=mba_loss_info.extra))
Ejemplo n.º 12
0
 def calc_loss(self, training_info: TrainingInfo):
     critic_loss = self._calc_critic_loss(training_info)
     alpha_loss = training_info.info.alpha.loss
     actor_loss = training_info.info.actor.loss
     return LossInfo(loss=actor_loss.loss + critic_loss.loss +
                     alpha_loss.loss,
                     extra=SacLossInfo(actor=actor_loss.extra,
                                       critic=critic_loss.extra,
                                       alpha=alpha_loss.extra))
Ejemplo n.º 13
0
    def train_step(self, time_step: TimeStep, state: DynamicsState):
        """
        Args:
            time_step (TimeStep): time step structure. The ``prev_action`` from
                time_step will be used for predicting feature of the next step.
                It should be a Tensor of the shape [B, ...] or [B, n, ...] when
                n > 1, where n denotes the number of dynamics network replicas.
                When the input tensor has the shape of [B, ...] and n > 1, it
                will be first expanded to [B, n, ...] to match the number of
                dynamics network replicas.
            state (DynamicsState): state for dynamics learning with the
                following fields:
                - feature (Tensor): features of the previous observation of the
                    shape [B, ...] or [B, n, ...] when n > 1. When
                    ``state.feature`` has the shape of [B, ...] and n > 1, it
                    will be first expanded to [B, n, ...] to match the number
                    of dynamics network replicas.
                    It is used for predicting the feature of the next step
                    together with ``time_step.prev_action``.
                - network: the input state of the dynamics network
        Returns:
            AlgStep:
                outputs: empty tuple ()
                state (DynamicsState): with the following fields
                    - feature (Tensor): [B, ...] (or [B, n, ...] when n > 1)
                        shape tensor representing the predicted feature of
                        the next step
                    - network: the updated state of the dynamics network
                info (DynamicsInfo): with the following fields being updated:
                    - loss (LossInfo):
                    - dist (td.Distribution): the predictive distribution which
                        can be used for further calculation or summarization.
        """

        feature = time_step.observation
        feature = self._expand_to_replica(feature, self._feature_spec)
        dynamics_step = self.predict_step(time_step, state)

        dist = dynamics_step.info.dist
        forward_loss = -dist.log_prob(feature - state.feature)

        if forward_loss.ndim > 2:
            # [B, n, ...] -> [B, ...]
            forward_loss = forward_loss.sum(1)
        if forward_loss.ndim > 1:
            forward_loss = forward_loss.mean(list(range(1, forward_loss.ndim)))

        valid_masks = (time_step.step_type != StepType.FIRST).to(torch.float32)

        forward_loss = forward_loss * valid_masks

        info = DynamicsInfo(loss=LossInfo(
            loss=forward_loss, extra=dict(forward_loss=forward_loss)),
                            dist=dist)
        state = state._replace(feature=feature)

        return AlgStep(output=(), state=state, info=info)
Ejemplo n.º 14
0
    def __call__(self, training_info: TrainingInfo, value):
        """Cacluate actor critic loss

        The first dimension of all the tensors is time dimension and the second
        dimesion is the batch dimension.

        Args:
            training_info (TrainingInfo): training_info collected by
                OnPolicyDriver/OffPolicyAlgorithm. All tensors in training_info
                are time-major
            value (tf.Tensor): the time-major tensor for the value at each time
                step
        Returns:
            loss_info (LossInfo): with loss_info.extra being ActorCriticLossInfo
        """

        returns, advantages = self._calc_returns_and_advantages(
            training_info, value)

        def _summary():
            with self.name_scope:
                tf.summary.scalar("values", tf.reduce_mean(value))
                tf.summary.scalar("returns", tf.reduce_mean(returns))
                tf.summary.scalar("advantages/mean",
                                  tf.reduce_mean(advantages))
                tf.summary.histogram("advantages/value", advantages)
                tf.summary.scalar("explained_variance_of_return_by_value",
                                  common.explained_variance(value, returns))

        if self._debug_summaries:
            common.run_if(common.should_record_summaries(), _summary)

        if self._normalize_advantages:
            advantages = _normalize_advantages(advantages, axes=(0, 1))

        if self._advantage_clip:
            advantages = tf.clip_by_value(advantages, -self._advantage_clip,
                                          self._advantage_clip)

        pg_loss = self._pg_loss(training_info, tf.stop_gradient(advantages))

        td_loss = self._td_error_loss_fn(tf.stop_gradient(returns), value)

        loss = pg_loss + self._td_loss_weight * td_loss

        entropy_loss = ()
        if self._entropy_regularization is not None:
            entropy, entropy_for_gradient = dist_utils.entropy_with_fallback(
                training_info.info.action_distribution, self._action_spec)
            entropy_loss = -entropy
            loss -= self._entropy_regularization * entropy_for_gradient

        return LossInfo(loss=loss,
                        extra=ActorCriticLossInfo(td_loss=td_loss,
                                                  pg_loss=pg_loss,
                                                  neg_entropy=entropy_loss))
Ejemplo n.º 15
0
    def calc_loss(self, training_info: TrainingInfo):
        """Calculate loss."""
        self.add_reward_summary("reward", training_info.reward)
        mbp_loss_info = self._mbp.calc_loss(training_info.info.mbp_info)
        mba_loss_info = self._mba.calc_loss(
            training_info._replace(info=training_info.info.mba_info))

        return LossInfo(loss=mbp_loss_info.loss + mba_loss_info.loss,
                        extra=MerlinLossInfo(mbp=mbp_loss_info.extra,
                                             mba=mba_loss_info.extra))
Ejemplo n.º 16
0
 def _update_loss(loss_info, training_info, name, algorithm):
     if algorithm is None:
         return loss_info
     new_loss_info = algorithm.calc_loss(
         getattr(training_info.info, name))
     return LossInfo(
         loss=add_ignore_empty(loss_info.loss, new_loss_info.loss),
         scalar_loss=add_ignore_empty(loss_info.scalar_loss,
                                      new_loss_info.scalar_loss),
         extra=loss_info.extra._replace(**{name: new_loss_info.extra}))
    def calc_loss(self, target, predicted, mask=None):
        """Calculate the loss between ``target`` and ``predicted``.

        Args:
            target (Tensor): target to be predicted. Its shape is [T, B, ...]
            predicted (Tensor): predicted target. Its shape is [T, B, ...]
            mask (bool Tensor): indicating which target should be predicted.
                Its shape is [T, B].
        Returns:
            LossInfo
        """
        if self._target_normalizer:
            self._target_normalizer.update(target)
            target = self._target_normalizer.normalize(target)
            predicted = self._target_normalizer.normalize(predicted)

        loss = self._loss(predicted, target)
        if self._debug_summaries and alf.summary.should_record_summaries():
            with alf.summary.scope(self._name):

                def _summarize1(pred, tgt, loss, mask, suffix):
                    alf.summary.scalar(
                        "explained_variance" + suffix,
                        tensor_utils.explained_variance(pred, tgt, mask))
                    safe_mean_hist_summary('predict' + suffix, pred, mask)
                    safe_mean_hist_summary('target' + suffix, tgt, mask)
                    safe_mean_summary("loss" + suffix, loss, mask)

                def _summarize(pred, tgt, loss, mask, suffix):
                    _summarize1(pred[0], tgt[0], loss[0], mask[0],
                                suffix + "/current")
                    if pred.shape[0] > 1:
                        _summarize1(pred[1:], tgt[1:], loss[1:], mask[1:],
                                    suffix + "/future")

                if loss.ndim == 2:
                    _summarize(predicted, target, loss, mask, '')
                elif not self._summarize_each_dimension:
                    m = mask
                    if m is not None:
                        m = m.unsqueeze(-1).expand_as(predicted)
                    _summarize(predicted, target, loss, m, '')
                else:
                    for i in range(predicted.shape[2]):
                        suffix = '/' + str(i)
                        _summarize(predicted[..., i], target[..., i],
                                   loss[..., i], mask, suffix)

        if loss.ndim == 3:
            loss = loss.mean(dim=2)

        if mask is not None:
            loss = loss * mask

        return LossInfo(loss=loss * self._loss_weight, extra=loss)
Ejemplo n.º 18
0
    def calc_loss(self, training_info: TrainingInfo):
        critic_loss = self._critic_loss(
            training_info=training_info,
            value=training_info.info.critic.q_value,
            target_value=training_info.info.critic.target_q_value)

        actor_loss = training_info.info.actor_loss

        return LossInfo(loss=critic_loss.loss + actor_loss.loss,
                        extra=DdpgLossInfo(critic=critic_loss.extra,
                                           actor=actor_loss.extra))
Ejemplo n.º 19
0
    def calc_loss(self, experience, train_info: MdqInfo):
        alpha_loss = train_info.alpha
        critic_loss, distill_loss = self._calc_critic_loss(
            experience, train_info)

        total_loss = critic_loss.loss + distill_loss + alpha_loss.loss.squeeze(
            -1)
        return LossInfo(loss=total_loss,
                        extra=MdqLossInfo(critic=critic_loss.extra,
                                          alpha=alpha_loss.extra,
                                          distill=distill_loss))
Ejemplo n.º 20
0
    def forward(self, experience, train_info):
        """Cacluate actor critic loss. The first dimension of all the tensors is
        time dimension and the second dimesion is the batch dimension.

        Args:
            experience (nest): experience used for training. All tensors are
                time-major.
            train_info (nest): information collected for training. It is batched
                from each ``AlgStep.info`` returned by ``rollout_step()``
                (on-policy training) or ``train_step()`` (off-policy training).
                All tensors in ``train_info`` are time-major.
        Returns:
            LossInfo: with ``extra`` being ``ActorCriticLossInfo``.
        """

        value = train_info.value
        returns, advantages = self._calc_returns_and_advantages(
            experience, value)

        if self._debug_summaries and alf.summary.should_record_summaries():
            with alf.summary.scope(self._name):
                alf.summary.scalar("values", value.mean())
                alf.summary.scalar("returns", returns.mean())
                alf.summary.scalar("advantages/mean", advantages.mean())
                alf.summary.histogram("advantages/value", advantages)
                alf.summary.scalar(
                    "explained_variance_of_return_by_value",
                    tensor_utils.explained_variance(value, returns))

        if self._normalize_advantages:
            advantages = _normalize_advantages(advantages)

        if self._advantage_clip:
            advantages = torch.clamp(advantages, -self._advantage_clip,
                                     self._advantage_clip)

        pg_loss = self._pg_loss(experience, train_info, advantages.detach())

        td_loss = self._td_error_loss_fn(returns.detach(), value)

        loss = pg_loss + self._td_loss_weight * td_loss

        entropy_loss = ()
        if self._entropy_regularization is not None:
            entropy, entropy_for_gradient = dist_utils.entropy_with_fallback(
                train_info.action_distribution)
            entropy_loss = -entropy
            loss -= self._entropy_regularization * entropy_for_gradient

        return LossInfo(loss=loss,
                        extra=ActorCriticLossInfo(td_loss=td_loss,
                                                  pg_loss=pg_loss,
                                                  neg_entropy=entropy_loss))
Ejemplo n.º 21
0
 def decode_step(self, latent_vector, observations):
     """Calculate decoding loss."""
     decoders = tf.nest.flatten(self._decoders)
     observations = tf.nest.flatten(observations)
     decoder_losses = [
         decoder.train_step((latent_vector, obs)).info
         for decoder, obs in zip(decoders, observations)
     ]
     loss = tf.add_n([decoder_loss.loss for decoder_loss in decoder_losses])
     decoder_losses = tf.nest.pack_sequence_as(self._decoders,
                                               decoder_losses)
     return LossInfo(loss=loss, extra=decoder_losses)
Ejemplo n.º 22
0
    def calc_loss(self, experience, train_info: LossInfo):
        assert experience.batch_info != ()
        if (experience.batch_info != ()
                and experience.batch_info.importance_weights != ()):
            priority = (experience.rollout_info.value -
                        experience.rollout_info.target.value[..., 0])
            priority = priority.abs().sum(dim=1)
        else:
            priority = ()

        return train_info._replace(loss=(),
                                   scalar_loss=train_info.loss.mean(),
                                   priority=priority)
Ejemplo n.º 23
0
 def __call__(self, training_info: TrainingInfo, value, target_value):
     returns = value_ops.one_step_discounted_return(
         rewards=training_info.reward,
         values=target_value,
         step_types=training_info.step_type,
         discounts=training_info.discount * self._gamma)
     returns = common.tensor_extend(returns, value[-1])
     if self._debug_summaries:
         with self.name_scope:
             tf.summary.scalar("values", tf.reduce_mean(value))
             tf.summary.scalar("returns", tf.reduce_mean(returns))
     loss = self._td_error_loss_fn(tf.stop_gradient(returns), value)
     return LossInfo(loss=loss, extra=loss)
Ejemplo n.º 24
0
 def _update_loss(loss_info, algorithm, name):
     info = getattr(train_info, name)
     exp = _make_alg_experience(experience, name)
     new_loss_info = algorithm.calc_loss(exp, info)
     if loss_info is None:
         return new_loss_info._replace(
             extra={name: new_loss_info.extra})
     else:
         loss_info.extra[name] = new_loss_info.extra
         return LossInfo(loss=add_ignore_empty(loss_info.loss,
                                               new_loss_info.loss),
                         scalar_loss=add_ignore_empty(
                             loss_info.scalar_loss,
                             new_loss_info.scalar_loss),
                         extra=loss_info.extra)
Ejemplo n.º 25
0
    def _calc_critic_loss(self, training_info):
        critic_info = training_info.info.critic

        target_critic = critic_info.target_critic

        critic_loss1 = self._critic_loss(training_info=training_info,
                                         value=critic_info.critic1,
                                         target_value=target_critic)

        critic_loss2 = self._critic_loss(training_info=training_info,
                                         value=critic_info.critic2,
                                         target_value=target_critic)

        critic_loss = critic_loss1.loss + critic_loss2.loss
        return LossInfo(loss=critic_loss, extra=critic_loss)
Ejemplo n.º 26
0
    def _calc_critic_loss(self, experience, train_info: MdqInfo):
        critic_info = train_info.critic

        # [t, B, n]
        critic_free_form = critic_info.critic_free_form
        # [t, B, n, action_dim]
        critic_adv_form = critic_info.critic_adv_form
        target_critic_free_form = critic_info.target_critic_free_form
        distill_target = critic_info.distill_target

        num_critic_replicas = critic_free_form.shape[2]

        alpha = torch.exp(self._log_alpha).detach()
        kl_wrt_prior = critic_info.kl_wrt_prior

        # [t, B, n, action_dim] -> [t, B]
        # note that currently the kl_wrt_prior is independent of ensembles,
        # we therefore slice over ensemble by taking the first element;
        # for the aciton dimension, the first element is the full KL
        kl_wrt_prior = kl_wrt_prior[..., 0, 0]

        # [t, B, n] -> [t, B]
        target_critic, min_target_ind = torch.min(
            target_critic_free_form, dim=2)

        # [t, B, n] -> [t, B]
        distill_target, _ = torch.min(distill_target, dim=2)

        target_critic_corrected = target_critic - alpha * kl_wrt_prior

        critic_losses = []
        for j in range(num_critic_replicas):
            critic_losses.append(self._critic_losses[j](
                experience=experience,
                value=critic_free_form[:, :, j],
                target_value=target_critic_corrected).loss)

        critic_loss = math_ops.add_n(critic_losses)

        distill_loss = (
            critic_adv_form[..., -1] - distill_target.unsqueeze(2).detach())**2
        # mean over replica
        distill_loss = distill_loss.mean(dim=2)

        return LossInfo(
            loss=critic_loss,
            extra=critic_loss / len(critic_losses)), distill_loss
Ejemplo n.º 27
0
    def train_step(self, distribution, step_type):
        """Train step.

        Args:
            distribution (nested Distribution): action distribution from the
                policy.
            step_type (StepType): the step type for the distributions.
        Returns:
            AlgStep: ``info`` field is ``LossInfo``, other fields are empty.
        """
        entropy, entropy_for_gradient = entropy_with_fallback(distribution)
        return AlgStep(
            output=(),
            state=(),
            info=EntropyTargetInfo(loss=LossInfo(loss=-entropy_for_gradient,
                                                 extra=EntropyTargetLossInfo(
                                                     neg_entropy=-entropy))))
Ejemplo n.º 28
0
    def train_step(self, inputs, state: MBPState):
        """Train one step.

        Args:
            inputs (tuple): a tuple of (observation, action)
        """
        observation, _ = inputs
        latent_vector, kld, next_state = self.encode_step(inputs, state)

        # TODO: decoder for action
        decoder_loss = self.decode_step(latent_vector, observation)

        return AlgorithmStep(
            outputs=latent_vector,
            state=next_state,
            info=LossInfo(loss=self._loss_weight * (decoder_loss.loss + kld),
                          extra=MBPLossInfo(decoder=decoder_loss, vae=kld)))
Ejemplo n.º 29
0
    def train_step(self,
                   inputs,
                   loss_func,
                   outputs=None,
                   batch_size=None,
                   entropy_regularization=None,
                   state=None):
        """
        Args:
            inputs (nested Tensor): if None, the outputs is generated only from
                noise.
            outputs (Tensor): generator's output (possibly from previous runs) used
                for this train_step.
            loss_func (Callable): loss_func([outputs, inputs])
                (loss_func(outputs) if inputs is None) returns a Tensor or namedtuple
                of tensors with field `loss`, which is a Tensor of
                shape [batch_size] a loss term for optimizing the generator.
            batch_size (int): batch_size. Must be provided if inputs is None.
                Its is ignored if inputs is not None.
            state: not used

        Returns:
            AlgorithmStep:
                outputs: Tensor with shape (batch_size, dim)
                info: LossInfo
        """
        if outputs is None:
            outputs, gen_inputs = self._predict(inputs, batch_size=batch_size)
        if entropy_regularization is None:
            entropy_regularization = self._entropy_regularization
        loss, loss_propagated = self._grad_func(inputs, outputs, loss_func,
                                                entropy_regularization)
        mi_loss = ()
        if self._mi_estimator is not None:
            mi_step = self._mi_estimator.train_step([gen_inputs, outputs])
            mi_loss = mi_step.info.loss
            loss_propagated = loss_propagated + self._mi_weight * mi_loss

        return AlgStep(output=outputs,
                       state=(),
                       info=LossInfo(
                           loss=loss_propagated,
                           extra=GeneratorLossInfo(generator=loss,
                                                   mi_estimator=mi_loss)))
Ejemplo n.º 30
0
    def _step(self, time_step: TimeStep, state, calc_rewards=True):
        """
        Args:
            time_step (TimeStep): input time_step data
            state (tuple):  empty tuple ()
            calc_rewards (bool): whether calculate rewards

        Returns:
            AlgStep:
                output: empty tuple ()
                state: empty tuple ()
                info: ICMInfo
        """
        observation = time_step.observation

        if self._keep_stacked_frames > 0:
            # Assuming stacking in the first dim, we only keep the last frames.
            observation = observation[:, -self._keep_stacked_frames:, ...]

        if self._observation_normalizer is not None:
            observation = self._observation_normalizer.normalize(observation)

        if self._encoder_net is not None:
            with torch.no_grad():
                observation, _ = self._encoder_net(observation)

        pred_embedding, _ = self._predictor_net(observation)
        with torch.no_grad():
            target_embedding, _ = self._target_net(observation)

        loss = torch.sum(math_ops.square(pred_embedding - target_embedding),
                         dim=-1)

        intrinsic_reward = ()
        if calc_rewards:
            intrinsic_reward = loss.detach()
            if self._reward_normalizer:
                intrinsic_reward = self._reward_normalizer.normalize(
                    intrinsic_reward, clip_value=self._reward_clip_value)

        return AlgStep(output=(),
                       state=(),
                       info=ICMInfo(reward=intrinsic_reward,
                                    loss=LossInfo(loss=loss)))