def __init__(self,
                 input_tensor_spec,
                 target_field,
                 decoder_net_ctor,
                 loss_ctor=partial(torch.nn.SmoothL1Loss, reduction='none'),
                 loss_weight=1.0,
                 summarize_each_dimension=False,
                 optimizer=None,
                 normalize_target=False,
                 append_target_field_to_name=True,
                 debug_summaries=False,
                 name="SimpleDecoder"):
        """
        Args:
            input_tensor_spec (TensorSpec): describing the input tensor.
            target_field (str): name of the field in the experience to be used
                as the decoding target.
            decoder_net_ctor (Callable): called as ``decoder_net_ctor(input_tensor_spec=input_tensor_spec)``
                to construct an instance of ``Network`` for decoding. The network
                should take the latent representation as input and output the
                predicted value of the target.
            loss_ctor (Callable): loss function with signature
                ``loss(y_pred, y_true)``. Note that it should not reduce to a
                scalar. It should at least keep the batch dimension in the
                returned loss.
            loss_weight (float): weight for the loss.
            optimizer (Optimzer|None): if provided, it will be used to optimize
                the parameter of decoder_net
            normalize_target (bool): whether to normalize target.
                Note that the effect of this is to change the loss. The predicted
                value itself is not normalized.
            append_target_field_to_name (bool): whether append target field to
                the name of the decoder. If True, the actual name used will be
                ``name.target_field``
            debug_summaries (bool): whether to generate debug summaries
            name (str): name of this instance
        """
        if append_target_field_to_name:
            name = name + "." + target_field

        super().__init__(optimizer=optimizer,
                         debug_summaries=debug_summaries,
                         name=name)
        self._decoder_net = decoder_net_ctor(
            input_tensor_spec=input_tensor_spec)
        assert self._decoder_net.state_spec == (
        ), "RNN decoder is not suppported"
        self._summarize_each_dimension = summarize_each_dimension
        self._target_field = target_field
        self._loss = loss_ctor()
        self._loss_weight = loss_weight
        if normalize_target:
            self._target_normalizer = AdaptiveNormalizer(
                self._decoder_net.output_spec,
                auto_update=False,
                name=name + ".target_normalizer")
        else:
            self._target_normalizer = None
Example #2
0
    def __init__(self,
                 skill_spec,
                 encoding_net: EncodingNetwork,
                 reward_adapt_speed=8.0,
                 observation_spec=None,
                 hidden_size=(),
                 hidden_activation=torch.relu_,
                 name="DIAYNAlgorithm"):
        """Create a DIAYNAlgorithm.

        Args:
            skill_spec (TensorSpec): supports both discrete and continuous skills.
                In the discrete case, the algorithm will predict 1-of-K skills
                using the cross entropy loss; in the continuous case, the
                algorithm will predict the skill vector itself using the mean
                square error loss.
            encoding_net (EncodingNetwork): network for encoding observation into
                a latent feature.
            reward_adapt_speed (float): how fast to adapt the reward normalizer.
                rouphly speaking, the statistics for the normalization is
                calculated mostly based on the most recent `T/speed` samples,
                where `T` is the total number of samples.
            observation_spec (TensorSpec): If not None, this spec is to be used
                by a observation normalizer to normalize incoming observations.
                In some cases, the normalized observation can be easier for
                training the discriminator.
            hidden_size (tuple[int]): a tuple of hidden layer sizes used by the
                discriminator.
            hidden_activation (torch.nn.functional): activation for the hidden
                layers.
            name (str): module's name
        """
        assert isinstance(skill_spec, TensorSpec)

        self._skill_spec = skill_spec
        if skill_spec.is_discrete:
            assert isinstance(skill_spec, BoundedTensorSpec)
            skill_dim = skill_spec.maximum - skill_spec.minimum + 1
        else:
            assert len(
                skill_spec.shape) == 1, "Only 1D skill vector is supported"
            skill_dim = skill_spec.shape[0]

        super().__init__(
            train_state_spec=TensorSpec((skill_dim, )),
            predict_state_spec=(),  # won't be needed for predict_step
            name=name)

        self._encoding_net = encoding_net

        self._discriminator_net = EncodingNetwork(
            input_tensor_spec=encoding_net.output_spec,
            fc_layer_params=hidden_size,
            activation=hidden_activation,
            last_layer_size=skill_dim,
            last_activation=math_ops.identity)

        self._reward_normalizer = ScalarAdaptiveNormalizer(
            speed=reward_adapt_speed)

        self._observation_normalizer = None
        if observation_spec is not None:
            self._observation_normalizer = AdaptiveNormalizer(
                tensor_spec=observation_spec)
Example #3
0
class DIAYNAlgorithm(Algorithm):
    """Diversity is All You Need Module

    This module learns a set of skill-conditional policies in an unsupervised
    way. See Eysenbach et al "Diversity is All You Need: Learning Diverse Skills
    without a Reward Function" for more details.
    """

    def __init__(self,
                 skill_spec,
                 encoding_net: EncodingNetwork,
                 reward_adapt_speed=8.0,
                 observation_spec=None,
                 hidden_size=(),
                 hidden_activation=torch.relu_,
                 name="DIAYNAlgorithm"):
        """Create a DIAYNAlgorithm.

        Args:
            skill_spec (TensorSpec): supports both discrete and continuous skills.
                In the discrete case, the algorithm will predict 1-of-K skills
                using the cross entropy loss; in the continuous case, the
                algorithm will predict the skill vector itself using the mean
                square error loss.
            encoding_net (EncodingNetwork): network for encoding observation into
                a latent feature.
            reward_adapt_speed (float): how fast to adapt the reward normalizer.
                rouphly speaking, the statistics for the normalization is
                calculated mostly based on the most recent `T/speed` samples,
                where `T` is the total number of samples.
            observation_spec (TensorSpec): If not None, this spec is to be used
                by a observation normalizer to normalize incoming observations.
                In some cases, the normalized observation can be easier for
                training the discriminator.
            hidden_size (tuple[int]): a tuple of hidden layer sizes used by the
                discriminator.
            hidden_activation (torch.nn.functional): activation for the hidden
                layers.
            name (str): module's name
        """
        assert isinstance(skill_spec, TensorSpec)

        self._skill_spec = skill_spec
        if skill_spec.is_discrete:
            assert isinstance(skill_spec, BoundedTensorSpec)
            skill_dim = skill_spec.maximum - skill_spec.minimum + 1
        else:
            assert len(
                skill_spec.shape) == 1, "Only 1D skill vector is supported"
            skill_dim = skill_spec.shape[0]

        super().__init__(
            train_state_spec=TensorSpec((skill_dim, )),
            predict_state_spec=(),  # won't be needed for predict_step
            name=name)

        self._encoding_net = encoding_net

        self._discriminator_net = EncodingNetwork(
            input_tensor_spec=encoding_net.output_spec,
            fc_layer_params=hidden_size,
            activation=hidden_activation,
            last_layer_size=skill_dim,
            last_activation=math_ops.identity)

        self._reward_normalizer = ScalarAdaptiveNormalizer(
            speed=reward_adapt_speed)

        self._observation_normalizer = None
        if observation_spec is not None:
            self._observation_normalizer = AdaptiveNormalizer(
                tensor_spec=observation_spec)

    def _step(self, time_step: TimeStep, state, calc_rewards=True):
        """
        Args:
            time_step (TimeStep): input time step data, where the
                observation is skill-augmened observation. The skill should be
                a one-hot vector.
            state (Tensor): state for DIAYN (previous skill) which should be
                a one-hot vector.
            calc_rewards (bool): if False, only return the losses.

        Returns:
            AlgStep:
                output: empty tuple ()
                state: skill
                info (DIAYNInfo):
        """
        observations_aug = time_step.observation
        step_type = time_step.step_type
        observation, skill = observations_aug
        prev_skill = state.detach()

        # normalize observation for easier prediction
        if self._observation_normalizer is not None:
            observation = self._observation_normalizer.normalize(observation)

        if self._encoding_net is not None:
            feature, _ = self._encoding_net(observation)

        skill_pred, _ = self._discriminator_net(feature)

        if self._skill_spec.is_discrete:
            loss = torch.nn.CrossEntropyLoss(reduction='none')(
                input=skill_pred, target=torch.argmax(prev_skill, dim=-1))
        else:
            # nn.MSELoss doesn't support reducing along a dim
            loss = torch.sum(math_ops.square(skill_pred - prev_skill), dim=-1)

        valid_masks = (step_type != to_tensor(StepType.FIRST)).to(
            torch.float32)
        loss *= valid_masks

        intrinsic_reward = ()
        if calc_rewards:
            intrinsic_reward = -loss.detach()
            intrinsic_reward = self._reward_normalizer.normalize(
                intrinsic_reward)

        return AlgStep(
            output=(),
            state=skill,
            info=DIAYNInfo(reward=intrinsic_reward, loss=loss))

    def rollout_step(self, time_step, state):
        return self._step(time_step, state)

    def train_step(self, time_step, state):
        return self._step(time_step, state, calc_rewards=False)

    def calc_loss(self, experience, info: DIAYNInfo):
        loss = torch.mean(info.loss)
        return LossInfo(
            scalar_loss=loss, extra=dict(skill_discriminate_loss=info.loss))
class SimpleDecoder(Algorithm):
    """A simple decoder with elementwise loss between the target and the predicted value.

    It is used to predict the target value from the given representation. Its
    loss can be used to train the representation.
    """
    def __init__(self,
                 input_tensor_spec,
                 target_field,
                 decoder_net_ctor,
                 loss_ctor=partial(torch.nn.SmoothL1Loss, reduction='none'),
                 loss_weight=1.0,
                 summarize_each_dimension=False,
                 optimizer=None,
                 normalize_target=False,
                 append_target_field_to_name=True,
                 debug_summaries=False,
                 name="SimpleDecoder"):
        """
        Args:
            input_tensor_spec (TensorSpec): describing the input tensor.
            target_field (str): name of the field in the experience to be used
                as the decoding target.
            decoder_net_ctor (Callable): called as ``decoder_net_ctor(input_tensor_spec=input_tensor_spec)``
                to construct an instance of ``Network`` for decoding. The network
                should take the latent representation as input and output the
                predicted value of the target.
            loss_ctor (Callable): loss function with signature
                ``loss(y_pred, y_true)``. Note that it should not reduce to a
                scalar. It should at least keep the batch dimension in the
                returned loss.
            loss_weight (float): weight for the loss.
            optimizer (Optimzer|None): if provided, it will be used to optimize
                the parameter of decoder_net
            normalize_target (bool): whether to normalize target.
                Note that the effect of this is to change the loss. The predicted
                value itself is not normalized.
            append_target_field_to_name (bool): whether append target field to
                the name of the decoder. If True, the actual name used will be
                ``name.target_field``
            debug_summaries (bool): whether to generate debug summaries
            name (str): name of this instance
        """
        if append_target_field_to_name:
            name = name + "." + target_field

        super().__init__(optimizer=optimizer,
                         debug_summaries=debug_summaries,
                         name=name)
        self._decoder_net = decoder_net_ctor(
            input_tensor_spec=input_tensor_spec)
        assert self._decoder_net.state_spec == (
        ), "RNN decoder is not suppported"
        self._summarize_each_dimension = summarize_each_dimension
        self._target_field = target_field
        self._loss = loss_ctor()
        self._loss_weight = loss_weight
        if normalize_target:
            self._target_normalizer = AdaptiveNormalizer(
                self._decoder_net.output_spec,
                auto_update=False,
                name=name + ".target_normalizer")
        else:
            self._target_normalizer = None

    def get_target_fields(self):
        return self._target_field

    def train_step(self, repr, state=()):
        predicted_target = self._decoder_net(repr)[0]
        return AlgStep(output=predicted_target,
                       state=state,
                       info=predicted_target)

    def predict_step(self, repr, state=()):
        predicted_target = self._decoder_net(repr)[0]
        return AlgStep(output=predicted_target,
                       state=state,
                       info=predicted_target)

    def calc_loss(self, target, predicted, mask=None):
        """Calculate the loss between ``target`` and ``predicted``.

        Args:
            target (Tensor): target to be predicted. Its shape is [T, B, ...]
            predicted (Tensor): predicted target. Its shape is [T, B, ...]
            mask (bool Tensor): indicating which target should be predicted.
                Its shape is [T, B].
        Returns:
            LossInfo
        """
        if self._target_normalizer:
            self._target_normalizer.update(target)
            target = self._target_normalizer.normalize(target)
            predicted = self._target_normalizer.normalize(predicted)

        loss = self._loss(predicted, target)
        if self._debug_summaries and alf.summary.should_record_summaries():
            with alf.summary.scope(self._name):

                def _summarize1(pred, tgt, loss, mask, suffix):
                    alf.summary.scalar(
                        "explained_variance" + suffix,
                        tensor_utils.explained_variance(pred, tgt, mask))
                    safe_mean_hist_summary('predict' + suffix, pred, mask)
                    safe_mean_hist_summary('target' + suffix, tgt, mask)
                    safe_mean_summary("loss" + suffix, loss, mask)

                def _summarize(pred, tgt, loss, mask, suffix):
                    _summarize1(pred[0], tgt[0], loss[0], mask[0],
                                suffix + "/current")
                    if pred.shape[0] > 1:
                        _summarize1(pred[1:], tgt[1:], loss[1:], mask[1:],
                                    suffix + "/future")

                if loss.ndim == 2:
                    _summarize(predicted, target, loss, mask, '')
                elif not self._summarize_each_dimension:
                    m = mask
                    if m is not None:
                        m = m.unsqueeze(-1).expand_as(predicted)
                    _summarize(predicted, target, loss, m, '')
                else:
                    for i in range(predicted.shape[2]):
                        suffix = '/' + str(i)
                        _summarize(predicted[..., i], target[..., i],
                                   loss[..., i], mask, suffix)

        if loss.ndim == 3:
            loss = loss.mean(dim=2)

        if mask is not None:
            loss = loss * mask

        return LossInfo(loss=loss * self._loss_weight, extra=loss)
Example #5
0
class RNDAlgorithm(Algorithm):
    """Exploration by Random Network Distillation, Burda et al. 2019.

    This module generates the intrinsic reward based on the prediction errors of
    randomly generated state embeddings.

    Suppose we have a fixed randomly initialized target network g: s -> e_t and
    a trainable predictor network h: s -> e_p, then the intrinsic reward is

    r = |e_t - e_p|^2

    The reward is expected to be higher for novel states.
    """
    def __init__(self,
                 target_net: EncodingNetwork,
                 predictor_net: EncodingNetwork,
                 encoder_net: EncodingNetwork = None,
                 reward_adapt_speed=None,
                 observation_adapt_speed=None,
                 observation_spec=None,
                 optimizer=None,
                 clip_value=-1.0,
                 keep_stacked_frames=1,
                 name="RNDAlgorithm"):
        """
        Args:
            encoder_net (EncodingNetwork): a shared network that encodes
                observation to embeddings before being input to `target_net` or
                `predictor_net`; its parameters are not trainable.
            target_net (EncodingNetwork): the random fixed network that generates
                target state embeddings to be fitted.
            predictor_net (EncodingNetwork): the trainable network that predicts
                target embeddings. If fully trained given enough data,
                `predictor_net` will become target_net eventually.
            reward_adapt_speed (float): speed for adaptively normalizing intrinsic
                rewards; if None, no normalizer is used.
            observation_adapt_speed (float): speed for adaptively normalizing
                observations. Only useful if `observation_spec` is not None.
            observation_spec (TensorSpec): the observation tensor spec; used
                for creating an adaptive observation normalizer.
            optimizer (torch.optim.Optimizer): The optimizer for training
            clip_value (float): if positive, the rewards will be clipped to
                [-clip_value, clip_value]; only used for reward normalization.
            keep_stacked_frames (int): a non-negative integer indicating how many
                stacked frames we want to keep as the observation. If >0, we only
                keep the last so many frames for RND to make predictions on,
                as suggested by the original paper Burda et al. 2019. For Atari
                games, this argument is usually 1 (with `frame_stacking==4`). If
                it's 0, the observation is unchanged. For other games, the user
                is responsible for setting this value correctly depending on
                how many channels an observation has at each time step.
            name (str):
        """
        super(RNDAlgorithm, self).__init__(train_state_spec=(),
                                           optimizer=optimizer,
                                           name=name)
        self._encoder_net = encoder_net
        self._target_net = target_net  # fixed
        self._predictor_net = predictor_net  # trainable
        if reward_adapt_speed is not None:
            self._reward_normalizer = ScalarAdaptiveNormalizer(
                speed=reward_adapt_speed)
            self._reward_clip_value = clip_value
        else:
            self._reward_normalizer = None

        self._keep_stacked_frames = keep_stacked_frames
        if keep_stacked_frames > 0 and (observation_spec is not None):
            # Assuming stacking in the first dim, we only keep the last frames.
            shape = observation_spec.shape
            assert keep_stacked_frames <= shape[0]
            new_shape = (keep_stacked_frames, ) + tuple(shape[1:])
            observation_spec = TensorSpec(shape=new_shape,
                                          dtype=observation_spec.dtype)

        # The paper suggests to also normalize observations, because the
        # original observation subspace might be small and the target network will
        # yield random embeddings that are indistinguishable
        self._observation_normalizer = None
        if observation_adapt_speed is not None:
            assert observation_spec is not None, \
                "Observation normalizer requires its input tensor spec!"
            self._observation_normalizer = AdaptiveNormalizer(
                tensor_spec=observation_spec, speed=observation_adapt_speed)

    def _step(self, time_step: TimeStep, state, calc_rewards=True):
        """
        Args:
            time_step (TimeStep): input time_step data
            state (tuple):  empty tuple ()
            calc_rewards (bool): whether calculate rewards

        Returns:
            AlgStep:
                output: empty tuple ()
                state: empty tuple ()
                info: ICMInfo
        """
        observation = time_step.observation

        if self._keep_stacked_frames > 0:
            # Assuming stacking in the first dim, we only keep the last frames.
            observation = observation[:, -self._keep_stacked_frames:, ...]

        if self._observation_normalizer is not None:
            observation = self._observation_normalizer.normalize(observation)

        if self._encoder_net is not None:
            with torch.no_grad():
                observation, _ = self._encoder_net(observation)

        pred_embedding, _ = self._predictor_net(observation)
        with torch.no_grad():
            target_embedding, _ = self._target_net(observation)

        loss = torch.sum(math_ops.square(pred_embedding - target_embedding),
                         dim=-1)

        intrinsic_reward = ()
        if calc_rewards:
            intrinsic_reward = loss.detach()
            if self._reward_normalizer:
                intrinsic_reward = self._reward_normalizer.normalize(
                    intrinsic_reward, clip_value=self._reward_clip_value)

        return AlgStep(output=(),
                       state=(),
                       info=ICMInfo(reward=intrinsic_reward,
                                    loss=LossInfo(loss=loss)))

    def rollout_step(self, time_step: TimeStep, state):
        return self._step(time_step, state)

    def train_step(self, time_step: TimeStep, state):
        return self._step(time_step, state, calc_rewards=False)

    def calc_loss(self, experience, info: ICMInfo):
        return LossInfo(scalar_loss=torch.mean(info.loss.loss))
Example #6
0
    def __init__(self,
                 target_net: EncodingNetwork,
                 predictor_net: EncodingNetwork,
                 encoder_net: EncodingNetwork = None,
                 reward_adapt_speed=None,
                 observation_adapt_speed=None,
                 observation_spec=None,
                 optimizer=None,
                 clip_value=-1.0,
                 keep_stacked_frames=1,
                 name="RNDAlgorithm"):
        """
        Args:
            encoder_net (EncodingNetwork): a shared network that encodes
                observation to embeddings before being input to `target_net` or
                `predictor_net`; its parameters are not trainable.
            target_net (EncodingNetwork): the random fixed network that generates
                target state embeddings to be fitted.
            predictor_net (EncodingNetwork): the trainable network that predicts
                target embeddings. If fully trained given enough data,
                `predictor_net` will become target_net eventually.
            reward_adapt_speed (float): speed for adaptively normalizing intrinsic
                rewards; if None, no normalizer is used.
            observation_adapt_speed (float): speed for adaptively normalizing
                observations. Only useful if `observation_spec` is not None.
            observation_spec (TensorSpec): the observation tensor spec; used
                for creating an adaptive observation normalizer.
            optimizer (torch.optim.Optimizer): The optimizer for training
            clip_value (float): if positive, the rewards will be clipped to
                [-clip_value, clip_value]; only used for reward normalization.
            keep_stacked_frames (int): a non-negative integer indicating how many
                stacked frames we want to keep as the observation. If >0, we only
                keep the last so many frames for RND to make predictions on,
                as suggested by the original paper Burda et al. 2019. For Atari
                games, this argument is usually 1 (with `frame_stacking==4`). If
                it's 0, the observation is unchanged. For other games, the user
                is responsible for setting this value correctly depending on
                how many channels an observation has at each time step.
            name (str):
        """
        super(RNDAlgorithm, self).__init__(train_state_spec=(),
                                           optimizer=optimizer,
                                           name=name)
        self._encoder_net = encoder_net
        self._target_net = target_net  # fixed
        self._predictor_net = predictor_net  # trainable
        if reward_adapt_speed is not None:
            self._reward_normalizer = ScalarAdaptiveNormalizer(
                speed=reward_adapt_speed)
            self._reward_clip_value = clip_value
        else:
            self._reward_normalizer = None

        self._keep_stacked_frames = keep_stacked_frames
        if keep_stacked_frames > 0 and (observation_spec is not None):
            # Assuming stacking in the first dim, we only keep the last frames.
            shape = observation_spec.shape
            assert keep_stacked_frames <= shape[0]
            new_shape = (keep_stacked_frames, ) + tuple(shape[1:])
            observation_spec = TensorSpec(shape=new_shape,
                                          dtype=observation_spec.dtype)

        # The paper suggests to also normalize observations, because the
        # original observation subspace might be small and the target network will
        # yield random embeddings that are indistinguishable
        self._observation_normalizer = None
        if observation_adapt_speed is not None:
            assert observation_spec is not None, \
                "Observation normalizer requires its input tensor spec!"
            self._observation_normalizer = AdaptiveNormalizer(
                tensor_spec=observation_spec, speed=observation_adapt_speed)
Example #7
0
class TDLoss(nn.Module):
    def __init__(self,
                 gamma=0.99,
                 td_error_loss_fn=element_wise_squared_loss,
                 td_lambda=0.95,
                 normalize_target=False,
                 debug_summaries=False,
                 name="TDLoss"):
        r"""Create a TDLoss object.

        Let :math:`G_{t:T}` be the bootstaped return from t to T:
            :math:`G_{t:T} = \sum_{i=t+1}^T \gamma^{t-i-1}R_i + \gamma^{T-t} V(s_T)`
        If ``td_lambda`` = 1, the target for step t is :math:`G_{t:T}`.
        If ``td_lambda`` = 0, the target for step t is :math:`G_{t:t+1}`
        If 0 < ``td_lambda`` < 1, the target for step t is the :math:`\lambda`-return:
            :math:`G_t^\lambda = (1 - \lambda) \sum_{i=t+1}^{T-1} \lambda^{i-t}G_{t:i} + \lambda^{T-t-1} G_{t:T}`
        There is a simple relationship between :math:`\lambda`-return and
        the generalized advantage estimation :math:`\hat{A}^{GAE}_t`:
            :math:`G_t^\lambda = \hat{A}^{GAE}_t + V(s_t)`
        where the generalized advantage estimation is defined as:
            :math:`\hat{A}^{GAE}_t = \sum_{i=t}^{T-1}(\gamma\lambda)^{i-t}(R_{i+1} + \gamma V(s_{i+1}) - V(s_i))`

        References:

        Schulman et al. `High-Dimensional Continuous Control Using Generalized Advantage Estimation
        <https://arxiv.org/abs/1506.02438>`_

        Sutton et al. `Reinforcement Learning: An Introduction
        <http://incompleteideas.net/book/the-book.html>`_, Chapter 12, 2018

        Args:
            gamma (float): A discount factor for future rewards.
            td_errors_loss_fn (Callable): A function for computing the TD errors
                loss. This function takes as input the target and the estimated
                Q values and returns the loss for each element of the batch.
            td_lambda (float): Lambda parameter for TD-lambda computation.
            normalize_target (bool): whether to normalize target.
                Note that the effect of this is to change the loss. The critic
                value itself is not normalized.
            debug_summaries (bool): True if debug summaries should be created.
            name (str): The name of this loss.
        """
        super().__init__()

        self._name = name
        self._gamma = gamma
        self._td_error_loss_fn = td_error_loss_fn
        self._lambda = td_lambda
        self._debug_summaries = debug_summaries
        self._normalize_target = normalize_target
        self._target_normalizer = None

    def forward(self, experience, value, target_value):
        """Cacluate the loss.

        The first dimension of all the tensors is time dimension and the second
        dimesion is the batch dimension.

        Args:
            experience (Experience): experience collected from ``unroll()`` or
                a replay buffer. All tensors are time-major.
            value (torch.Tensor): the time-major tensor for the value at each time
                step. The loss is between this and the calculated return.
            target_value (torch.Tensor): the time-major tensor for the value at
                each time step. This is used to calculate return. ``target_value``
                can be same as ``value``.
        Returns:
            LossInfo: with the ``extra`` field same as ``loss``.
        """
        if self._lambda == 1.0:
            returns = value_ops.discounted_return(
                rewards=experience.reward,
                values=target_value,
                step_types=experience.step_type,
                discounts=experience.discount * self._gamma)
        elif self._lambda == 0.0:
            returns = value_ops.one_step_discounted_return(
                rewards=experience.reward,
                values=target_value,
                step_types=experience.step_type,
                discounts=experience.discount * self._gamma)
        else:
            advantages = value_ops.generalized_advantage_estimation(
                rewards=experience.reward,
                values=target_value,
                step_types=experience.step_type,
                discounts=experience.discount * self._gamma,
                td_lambda=self._lambda)
            returns = advantages + target_value[:-1]

        value = value[:-1]
        if self._normalize_target:
            if self._target_normalizer is None:
                self._target_normalizer = AdaptiveNormalizer(
                    alf.TensorSpec(value.shape[2:]),
                    auto_update=False,
                    debug_summaries=self._debug_summaries,
                    name=self._name + ".target_normalizer")

            self._target_normalizer.update(returns)
            returns = self._target_normalizer.normalize(returns)
            value = self._target_normalizer.normalize(value)

        if self._debug_summaries and alf.summary.should_record_summaries():
            mask = experience.step_type[:-1] != StepType.LAST
            with alf.summary.scope(self._name):

                def _summarize(v, r, td, suffix):
                    alf.summary.scalar(
                        "explained_variance_of_return_by_value" + suffix,
                        tensor_utils.explained_variance(v, r, mask))
                    safe_mean_hist_summary('values' + suffix, v, mask)
                    safe_mean_hist_summary('returns' + suffix, r, mask)
                    safe_mean_hist_summary("td_error" + suffix, td, mask)

                if value.ndim == 2:
                    _summarize(value, returns, returns - value, '')
                else:
                    td = returns - value
                    for i in range(value.shape[2]):
                        suffix = '/' + str(i)
                        _summarize(value[..., i], returns[..., i], td[..., i],
                                   suffix)

        loss = self._td_error_loss_fn(returns.detach(), value)

        if loss.ndim == 3:
            # Multidimensional reward. Average over the critic loss for all dimensions
            loss = loss.mean(dim=2)

        # The shape of the loss expected by Algorith.update_with_gradient is
        # [T, B], so we need to augment it with additional zeros.
        loss = tensor_utils.tensor_extend_zero(loss)
        return LossInfo(loss=loss, extra=loss)
Example #8
0
    def forward(self, experience, value, target_value):
        """Cacluate the loss.

        The first dimension of all the tensors is time dimension and the second
        dimesion is the batch dimension.

        Args:
            experience (Experience): experience collected from ``unroll()`` or
                a replay buffer. All tensors are time-major.
            value (torch.Tensor): the time-major tensor for the value at each time
                step. The loss is between this and the calculated return.
            target_value (torch.Tensor): the time-major tensor for the value at
                each time step. This is used to calculate return. ``target_value``
                can be same as ``value``.
        Returns:
            LossInfo: with the ``extra`` field same as ``loss``.
        """
        if self._lambda == 1.0:
            returns = value_ops.discounted_return(
                rewards=experience.reward,
                values=target_value,
                step_types=experience.step_type,
                discounts=experience.discount * self._gamma)
        elif self._lambda == 0.0:
            returns = value_ops.one_step_discounted_return(
                rewards=experience.reward,
                values=target_value,
                step_types=experience.step_type,
                discounts=experience.discount * self._gamma)
        else:
            advantages = value_ops.generalized_advantage_estimation(
                rewards=experience.reward,
                values=target_value,
                step_types=experience.step_type,
                discounts=experience.discount * self._gamma,
                td_lambda=self._lambda)
            returns = advantages + target_value[:-1]

        value = value[:-1]
        if self._normalize_target:
            if self._target_normalizer is None:
                self._target_normalizer = AdaptiveNormalizer(
                    alf.TensorSpec(value.shape[2:]),
                    auto_update=False,
                    debug_summaries=self._debug_summaries,
                    name=self._name + ".target_normalizer")

            self._target_normalizer.update(returns)
            returns = self._target_normalizer.normalize(returns)
            value = self._target_normalizer.normalize(value)

        if self._debug_summaries and alf.summary.should_record_summaries():
            mask = experience.step_type[:-1] != StepType.LAST
            with alf.summary.scope(self._name):

                def _summarize(v, r, td, suffix):
                    alf.summary.scalar(
                        "explained_variance_of_return_by_value" + suffix,
                        tensor_utils.explained_variance(v, r, mask))
                    safe_mean_hist_summary('values' + suffix, v, mask)
                    safe_mean_hist_summary('returns' + suffix, r, mask)
                    safe_mean_hist_summary("td_error" + suffix, td, mask)

                if value.ndim == 2:
                    _summarize(value, returns, returns - value, '')
                else:
                    td = returns - value
                    for i in range(value.shape[2]):
                        suffix = '/' + str(i)
                        _summarize(value[..., i], returns[..., i], td[..., i],
                                   suffix)

        loss = self._td_error_loss_fn(returns.detach(), value)

        if loss.ndim == 3:
            # Multidimensional reward. Average over the critic loss for all dimensions
            loss = loss.mean(dim=2)

        # The shape of the loss expected by Algorith.update_with_gradient is
        # [T, B], so we need to augment it with additional zeros.
        loss = tensor_utils.tensor_extend_zero(loss)
        return LossInfo(loss=loss, extra=loss)
Example #9
0
    def __init__(self,
                 action_spec,
                 observation_spec=None,
                 hidden_size=256,
                 reward_adapt_speed=8.0,
                 encoding_net: EncodingNetwork = None,
                 forward_net: EncodingNetwork = None,
                 inverse_net: EncodingNetwork = None,
                 activation=torch.relu_,
                 optimizer=None,
                 name="ICMAlgorithm"):
        """Create an ICMAlgorithm.

        Args
            action_spec (nested TensorSpec): agent's action spec
            observation_spec (nested TensorSpec): agent's observation spec. If
                not None, then a normalizer will be used to normalize the
                observation.
            hidden_size (int or tuple[int]): size of hidden layer(s)
            reward_adapt_speed (float): how fast to adapt the reward normalizer.
                rouphly speaking, the statistics for the normalization is
                calculated mostly based on the most recent T/speed samples,
                where T is the total number of samples.
            encoding_net (Network): network for encoding observation into a
                latent feature. Its input is same as the input of this algorithm.
            forward_net (Network): network for predicting next feature based on
                previous feature and action. It should accept input with spec
                [feature_spec, encoded_action_spec] and output a tensor of shape
                feature_spec. For discrete action, encoded_action is an one-hot
                representation of the action. For continuous action, encoded
                action is same as the original action.
            inverse_net (Network): network for predicting previous action given
                the previous feature and current feature. It should accept input
                with spec [feature_spec, feature_spec] and output tensor of
                shape (num_actions,).
            activation (torch.nn.functional): activation used for constructing
                any of the forward net and inverse net, if not provided.
            optimizer (torch.optim.Optimizer): The optimizer for training
            name (str):
        """
        if encoding_net is not None:
            feature_spec = encoding_net.output_spec
        else:
            feature_spec = observation_spec

        super(ICMAlgorithm, self).__init__(
            train_state_spec=feature_spec,
            predict_state_spec=(),
            optimizer=optimizer,
            name=name)

        flat_action_spec = alf.nest.flatten(action_spec)
        assert len(
            flat_action_spec) == 1, "ICM doesn't suport nested action_spec"

        flat_feature_spec = alf.nest.flatten(feature_spec)
        assert len(
            flat_feature_spec) == 1, "ICM doesn't support nested feature_spec"

        action_spec = flat_action_spec[0]

        if action_spec.is_discrete:
            self._num_actions = int(action_spec.maximum - action_spec.minimum +
                                    1)
        else:
            self._num_actions = action_spec.shape[-1]

        self._action_spec = action_spec
        self._observation_normalizer = None
        if observation_spec is not None:
            self._observation_normalizer = AdaptiveNormalizer(
                tensor_spec=observation_spec)

        feature_dim = flat_feature_spec[0].shape[-1]

        self._encoding_net = encoding_net

        if isinstance(hidden_size, int):
            hidden_size = (hidden_size, )

        if forward_net is None:
            encoded_action_spec = TensorSpec((self._num_actions, ),
                                             dtype=torch.float32)
            forward_net = EncodingNetwork(
                name="forward_net",
                input_tensor_spec=[feature_spec, encoded_action_spec],
                preprocessing_combiner=NestConcat(),
                fc_layer_params=hidden_size,
                activation=activation,
                last_layer_size=feature_dim,
                last_activation=math_ops.identity)

        self._forward_net = forward_net

        if inverse_net is None:
            inverse_net = EncodingNetwork(
                name="inverse_net",
                input_tensor_spec=[feature_spec, feature_spec],
                preprocessing_combiner=NestConcat(),
                fc_layer_params=hidden_size,
                activation=activation,
                last_layer_size=self._num_actions,
                last_activation=math_ops.identity,
                last_kernel_initializer=torch.nn.init.zeros_)

        self._inverse_net = inverse_net

        self._reward_normalizer = ScalarAdaptiveNormalizer(
            speed=reward_adapt_speed)
Example #10
0
class ICMAlgorithm(Algorithm):
    """Intrinsic Curiosity Module

    This module generate the intrinsic reward based on predition error of
    observation.

    See Pathak et al "Curiosity-driven Exploration by Self-supervised Prediction"
    """

    def __init__(self,
                 action_spec,
                 observation_spec=None,
                 hidden_size=256,
                 reward_adapt_speed=8.0,
                 encoding_net: EncodingNetwork = None,
                 forward_net: EncodingNetwork = None,
                 inverse_net: EncodingNetwork = None,
                 activation=torch.relu_,
                 optimizer=None,
                 name="ICMAlgorithm"):
        """Create an ICMAlgorithm.

        Args
            action_spec (nested TensorSpec): agent's action spec
            observation_spec (nested TensorSpec): agent's observation spec. If
                not None, then a normalizer will be used to normalize the
                observation.
            hidden_size (int or tuple[int]): size of hidden layer(s)
            reward_adapt_speed (float): how fast to adapt the reward normalizer.
                rouphly speaking, the statistics for the normalization is
                calculated mostly based on the most recent T/speed samples,
                where T is the total number of samples.
            encoding_net (Network): network for encoding observation into a
                latent feature. Its input is same as the input of this algorithm.
            forward_net (Network): network for predicting next feature based on
                previous feature and action. It should accept input with spec
                [feature_spec, encoded_action_spec] and output a tensor of shape
                feature_spec. For discrete action, encoded_action is an one-hot
                representation of the action. For continuous action, encoded
                action is same as the original action.
            inverse_net (Network): network for predicting previous action given
                the previous feature and current feature. It should accept input
                with spec [feature_spec, feature_spec] and output tensor of
                shape (num_actions,).
            activation (torch.nn.functional): activation used for constructing
                any of the forward net and inverse net, if not provided.
            optimizer (torch.optim.Optimizer): The optimizer for training
            name (str):
        """
        if encoding_net is not None:
            feature_spec = encoding_net.output_spec
        else:
            feature_spec = observation_spec

        super(ICMAlgorithm, self).__init__(
            train_state_spec=feature_spec,
            predict_state_spec=(),
            optimizer=optimizer,
            name=name)

        flat_action_spec = alf.nest.flatten(action_spec)
        assert len(
            flat_action_spec) == 1, "ICM doesn't suport nested action_spec"

        flat_feature_spec = alf.nest.flatten(feature_spec)
        assert len(
            flat_feature_spec) == 1, "ICM doesn't support nested feature_spec"

        action_spec = flat_action_spec[0]

        if action_spec.is_discrete:
            self._num_actions = int(action_spec.maximum - action_spec.minimum +
                                    1)
        else:
            self._num_actions = action_spec.shape[-1]

        self._action_spec = action_spec
        self._observation_normalizer = None
        if observation_spec is not None:
            self._observation_normalizer = AdaptiveNormalizer(
                tensor_spec=observation_spec)

        feature_dim = flat_feature_spec[0].shape[-1]

        self._encoding_net = encoding_net

        if isinstance(hidden_size, int):
            hidden_size = (hidden_size, )

        if forward_net is None:
            encoded_action_spec = TensorSpec((self._num_actions, ),
                                             dtype=torch.float32)
            forward_net = EncodingNetwork(
                name="forward_net",
                input_tensor_spec=[feature_spec, encoded_action_spec],
                preprocessing_combiner=NestConcat(),
                fc_layer_params=hidden_size,
                activation=activation,
                last_layer_size=feature_dim,
                last_activation=math_ops.identity)

        self._forward_net = forward_net

        if inverse_net is None:
            inverse_net = EncodingNetwork(
                name="inverse_net",
                input_tensor_spec=[feature_spec, feature_spec],
                preprocessing_combiner=NestConcat(),
                fc_layer_params=hidden_size,
                activation=activation,
                last_layer_size=self._num_actions,
                last_activation=math_ops.identity,
                last_kernel_initializer=torch.nn.init.zeros_)

        self._inverse_net = inverse_net

        self._reward_normalizer = ScalarAdaptiveNormalizer(
            speed=reward_adapt_speed)

    def _encode_action(self, action):
        if self._action_spec.is_discrete:
            return torch.nn.functional.one_hot(action, self._num_actions).to(
                torch.float32)
        else:
            return action

    def _step(self, time_step: TimeStep, state, calc_rewards=True):
        """This step is for both `rollout_step` and `train_step`.

        Args:
            time_step (TimeStep): input time_step data for ICM
            state (Tensor): state for ICM (previous observation)
            calc_rewards (bool): whether calculate rewards

        Returns:
            AlgStep:
                output: empty tuple ()
                state: observation
                info (ICMInfo):
        """
        feature = time_step.observation
        prev_action = time_step.prev_action.detach()

        # normalize observation for easier prediction
        if self._observation_normalizer is not None:
            feature = self._observation_normalizer.normalize(feature)

        if self._encoding_net is not None:
            feature, _ = self._encoding_net(feature)
        prev_feature = state

        forward_pred, _ = self._forward_net(
            inputs=[prev_feature.detach(),
                    self._encode_action(prev_action)])
        # nn.MSELoss doesn't support reducing along a dim
        forward_loss = 0.5 * torch.mean(
            math_ops.square(forward_pred - feature.detach()), dim=-1)

        action_pred, _ = self._inverse_net([prev_feature, feature])

        if self._action_spec.is_discrete:
            inverse_loss = torch.nn.CrossEntropyLoss(reduction='none')(
                input=action_pred, target=prev_action.to(torch.int64))
        else:
            # nn.MSELoss doesn't support reducing along a dim
            inverse_loss = 0.5 * torch.mean(
                math_ops.square(action_pred - prev_action), dim=-1)

        intrinsic_reward = ()
        if calc_rewards:
            intrinsic_reward = forward_loss.detach()
            intrinsic_reward = self._reward_normalizer.normalize(
                intrinsic_reward)

        return AlgStep(
            output=(),
            state=feature,
            info=ICMInfo(
                reward=intrinsic_reward,
                loss=LossInfo(
                    loss=forward_loss + inverse_loss,
                    extra=dict(
                        forward_loss=forward_loss,
                        inverse_loss=inverse_loss))))

    def rollout_step(self, time_step: TimeStep, state):
        return self._step(time_step, state)

    def train_step(self, time_step: TimeStep, state):
        return self._step(time_step, state, calc_rewards=False)

    def calc_loss(self, experience, info: ICMInfo):
        loss = alf.nest.map_structure(torch.mean, info.loss)
        return LossInfo(scalar_loss=loss.loss, extra=loss.extra)
Example #11
0
    def __init__(self,
                 observation_spec,
                 fields=None,
                 clipping=0.,
                 window_size=10000,
                 update_rate=1e-4,
                 speed=8.0,
                 zero_mean=True,
                 update_mode="replay",
                 mode="adaptive"):
        """Create an observation normalizer with optional value clipping to be
        used as the ``data_transformer`` of an algorithm. It will be called
        before both ``rollout_step()`` and ``train_step()``.

        The normalizer by default doesn't automatically update the mean and std.
        Instead, it will check when ``self.forward()`` is called, whether an
        algorithm is unrolling or training. It only updates the mean and std
        during unroll. This is the suggested way of using an observation
        normalizer (i.e., update the stats when encountering new data for the
        first time). This same strategy has been used by OpenAI's baselines for
        training their Robotics environments.

        Args:
            observation_spec (nested TensorSpec): describing the observation in timestep
            fields (None|list[str]): If None, normalize all fields. Otherwise,
                only normalized the specified fields. Each string in ``fields``
                is a a multi-step path denoted by "A.B.C".
            clipping (float): a floating value for clipping the normalized
                observation into ``[-clipping, clipping]``. Only valid if it's
                greater than 0.
            window_size (int): the window size of ``WindowNormalizer``.
            update_rate (float): the update rate of ``EMNormalizer``.
            speed (float): the speed of updating for ``AdaptiveNormalizer``.
            zero_mean (bool): whether to make the normalized value be zero-mean
            update_mode (str): update stats during either "replay" or "rollout".
            mode (str): a value in ["adaptive", "window", "em"] indicates which
                normalizer to use.
        """
        super().__init__(observation_spec)
        self._update_mode = update_mode
        self._clipping = float(clipping)
        self._fields = fields
        if fields is not None:
            observation_spec = dict([(field,
                                      alf.nest.get_field(
                                          observation_spec, field))
                                     for field in fields])
        if mode == "adaptive":
            self._normalizer = AdaptiveNormalizer(
                tensor_spec=observation_spec,
                speed=float(speed),
                auto_update=False,
                zero_mean=zero_mean,
                name="observations/adaptive_normalizer")
        elif mode == "window":
            self._normalzier = WindowNormalizer(
                tensor_spec=observation_spec,
                window_size=int(window_size),
                zero_mean=zero_mean,
                auto_update=False)
        elif mode == "em":
            self._normalizer = EMNormalizer(
                tensor_spec=observation_spec,
                update_rate=float(update_rate),
                zero_mean=zero_mean,
                auto_update=False)
        else:
            raise ValueError("Unsupported mode: " + mode)
Example #12
0
    def __init__(self,
                 target_net: Network,
                 predictor_net: Network,
                 encoder_net: Network = None,
                 reward_adapt_speed=None,
                 observation_adapt_speed=None,
                 observation_spec=None,
                 learning_rate=None,
                 clip_value=-1.0,
                 stacked_frames=True,
                 name="RNDAlgorithm"):
        """
        Args:
            encoder_net (Network): a shared network that encodes observation to
                embeddings before being input to `target_net` or `predictor_net`;
                its parameters are not trainable
            target_net (Network): the random fixed network that generates target
                state embeddings to be fitted
            predictor_net (Network): the trainable network that predicts target
                embeddings. If fully trained given enough data, predictor_net
                will become target_net eventually.
            reward_adapt_speed (float): speed for adaptively normalizing intrinsic
                rewards; if None, no normalizer is used
            observation_adapt_speed (float): speed for adaptively normalizing
                observations. Only useful if `observation_spec` is not None.
            observation_spec (TensorSpec): the observation tensor spec; used
                for creating an adaptive observation normalizer
            learning_rate (float): the learning rate for prediction cost; if None,
                a global learning rate will be used
            clip_value (float): if positive, the rewards will be clipped to
                [-clip_value, clip_value]; only used for reward normalization
            stacked_frames (bool): a boolean flag indicating whether the input
                observation has stacked frames. If True, then we only keep the
                last frame for RND to make predictions on, as suggested by the
                original paper Burda et al. 2019. For Atari games, this flag is
                usually True (`frame_stacking==4`).
            name (str):
        """
        optimizer = None
        if learning_rate is not None:
            optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
        super(RNDAlgorithm, self).__init__(
            train_state_spec=(), optimizer=optimizer, name=name)
        self._encoder_net = encoder_net
        self._target_net = target_net  # fixed
        self._predictor_net = predictor_net  # trainable
        if reward_adapt_speed is not None:
            self._reward_normalizer = ScalarAdaptiveNormalizer(
                speed=reward_adapt_speed)
            self._reward_clip_value = clip_value
        else:
            self._reward_normalizer = None

        self._stacked_frames = stacked_frames
        if stacked_frames and (observation_spec is not None):
            # Assuming stacking in the last dim, we only keep the last frame.
            shape = observation_spec.shape
            new_shape = shape[:-1] + (1, )
            observation_spec = tf.TensorSpec(
                shape=new_shape, dtype=observation_spec.dtype)

        # The paper suggests to also normalize observations, because the
        # original observation subspace might be small and the target network will
        # yield random embeddings that are indistinguishable
        self._observation_normalizer = None
        if observation_adapt_speed is not None:
            assert observation_spec is not None, \
                "Observation normalizer requires its input tensor spec!"
            self._observation_normalizer = AdaptiveNormalizer(
                tensor_spec=observation_spec, speed=observation_adapt_speed)
Example #13
0
class RNDAlgorithm(Algorithm):
    """Exploration by Random Network Distillation, Burda et al. 2019.

    This module generates the intrinsic reward based on the prediction errors of
    randomly generated state embeddings.

    Suppose we have a fixed randomly initialized target network g: s -> e_t and
    a trainable predictor network h: s -> e_p, then the intrinsic reward is

    r = |e_t - e_p|^2

    The reward is expected to be higher for novel states.
    """

    def __init__(self,
                 target_net: Network,
                 predictor_net: Network,
                 encoder_net: Network = None,
                 reward_adapt_speed=None,
                 observation_adapt_speed=None,
                 observation_spec=None,
                 learning_rate=None,
                 clip_value=-1.0,
                 stacked_frames=True,
                 name="RNDAlgorithm"):
        """
        Args:
            encoder_net (Network): a shared network that encodes observation to
                embeddings before being input to `target_net` or `predictor_net`;
                its parameters are not trainable
            target_net (Network): the random fixed network that generates target
                state embeddings to be fitted
            predictor_net (Network): the trainable network that predicts target
                embeddings. If fully trained given enough data, predictor_net
                will become target_net eventually.
            reward_adapt_speed (float): speed for adaptively normalizing intrinsic
                rewards; if None, no normalizer is used
            observation_adapt_speed (float): speed for adaptively normalizing
                observations. Only useful if `observation_spec` is not None.
            observation_spec (TensorSpec): the observation tensor spec; used
                for creating an adaptive observation normalizer
            learning_rate (float): the learning rate for prediction cost; if None,
                a global learning rate will be used
            clip_value (float): if positive, the rewards will be clipped to
                [-clip_value, clip_value]; only used for reward normalization
            stacked_frames (bool): a boolean flag indicating whether the input
                observation has stacked frames. If True, then we only keep the
                last frame for RND to make predictions on, as suggested by the
                original paper Burda et al. 2019. For Atari games, this flag is
                usually True (`frame_stacking==4`).
            name (str):
        """
        optimizer = None
        if learning_rate is not None:
            optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
        super(RNDAlgorithm, self).__init__(
            train_state_spec=(), optimizer=optimizer, name=name)
        self._encoder_net = encoder_net
        self._target_net = target_net  # fixed
        self._predictor_net = predictor_net  # trainable
        if reward_adapt_speed is not None:
            self._reward_normalizer = ScalarAdaptiveNormalizer(
                speed=reward_adapt_speed)
            self._reward_clip_value = clip_value
        else:
            self._reward_normalizer = None

        self._stacked_frames = stacked_frames
        if stacked_frames and (observation_spec is not None):
            # Assuming stacking in the last dim, we only keep the last frame.
            shape = observation_spec.shape
            new_shape = shape[:-1] + (1, )
            observation_spec = tf.TensorSpec(
                shape=new_shape, dtype=observation_spec.dtype)

        # The paper suggests to also normalize observations, because the
        # original observation subspace might be small and the target network will
        # yield random embeddings that are indistinguishable
        self._observation_normalizer = None
        if observation_adapt_speed is not None:
            assert observation_spec is not None, \
                "Observation normalizer requires its input tensor spec!"
            self._observation_normalizer = AdaptiveNormalizer(
                tensor_spec=observation_spec, speed=observation_adapt_speed)

    def train_step(self,
                   time_step: ActionTimeStep,
                   state,
                   calc_intrinsic_reward=True):
        """
        Args:
            time_step (ActionTimeStep): input time_step data
            state (tuple):  empty tuple ()
            calc_intrinsic_reward (bool): if False, only return the losses
        Returns:
            TrainStep:
                outputs: empty tuple ()
                state: empty tuple ()
                info: ICMInfo
        """
        observation = time_step.observation

        if self._stacked_frames:
            # Assuming stacking in the last dim, we only keep the last frame.
            observation = observation[..., -1:]

        if self._observation_normalizer is not None:
            observation = self._observation_normalizer.normalize(observation)

        if self._encoder_net is not None:
            observation = tf.stop_gradient(self._encoder_net(observation)[0])

        pred_embedding, _ = self._predictor_net(observation)
        target_embedding, _ = self._target_net(observation)

        loss = tf.reduce_sum(
            tf.square(pred_embedding - tf.stop_gradient(target_embedding)),
            axis=-1)

        intrinsic_reward = ()
        if calc_intrinsic_reward:
            intrinsic_reward = tf.stop_gradient(loss)
            if self._reward_normalizer:
                intrinsic_reward = self._reward_normalizer.normalize(
                    intrinsic_reward, clip_value=self._reward_clip_value)

        return AlgorithmStep(
            outputs=(),
            state=(),
            info=ICMInfo(reward=intrinsic_reward, loss=LossInfo(loss=loss)))

    def calc_loss(self, info: ICMInfo):
        return LossInfo(scalar_loss=tf.reduce_mean(info.loss.loss))