Exemple #1
0
class WindowAverager(tf.Module):
    def __init__(self,
                 tensor_spec: tf.TensorSpec,
                 window_size,
                 name="WindowAverager"):
        """Create a WindowAverager.

        WindowAverager calculate the average of the past `window_size` samples.
        Args:
            tensor_spec (TensorSpec): the TensorSpec for the value to be
                averaged
            window_size (int): the size of the window
            name (str): name of this averager
        """
        super().__init__(name=name)
        self._buf = DataBuffer(tensor_spec, window_size)

    def update(self, tensor):
        """Update the average.

        Args:
            tensor (Tensor): a value for updating the average
        Returns:
            None
        """
        self._buf.add_batch(tf.expand_dims(tensor, axis=0))

    def get(self):
        """Get the current average.

        Returns:
            Tensor: the current average
        """
        n = tf.cast(tf.maximum(self._buf.current_size, 1), tf.float32)
        return tf.reduce_sum(self._buf.get_all(), axis=0) * (1. / n)

    def average(self, tensor):
        """Combines self.update and self.get in one step. Can be handy in practice.

        Args:
            tensor (Tensor): a value for updating the average
        Returns:
            Tensor: the current average
        """
        self.update(tensor)
        return self.get()
Exemple #2
0
    def test_data_buffer(self):
        dim = 20
        capacity = 256
        data_spec = (TensorSpec(shape=()), TensorSpec(shape=(dim // 3 - 1, )),
                     TensorSpec(shape=(dim - dim // 3, )))

        data_buffer = DataBuffer(data_spec=data_spec, capacity=capacity)

        def _get_batch(batch_size):
            x = torch.randn(batch_size, dim, requires_grad=True)
            x = (x[:, 0], x[:, 1:dim // 3], x[..., dim // 3:])
            return x

        data_buffer.add_batch(_get_batch(100))
        self.assertEqual(int(data_buffer.current_size), 100)
        batch = _get_batch(1000)
        # test that the created batch has gradients
        self.assertTrue(batch[0].requires_grad)
        data_buffer.add_batch(batch)
        ret = data_buffer.get_batch(2)
        # test that DataBuffer detaches gradients of inputs
        self.assertFalse(ret[0].requires_grad)
        self.assertEqual(int(data_buffer.current_size), capacity)
        ret = data_buffer.get_batch_by_indices(torch.arange(capacity))
        self.assertEqual(ret[0], batch[0][-capacity:])
        self.assertEqual(ret[1], batch[1][-capacity:])
        self.assertEqual(ret[2], batch[2][-capacity:])
        batch = _get_batch(100)
        data_buffer.add_batch(batch)
        ret = data_buffer.get_batch_by_indices(
            torch.arange(data_buffer.current_size - 100,
                         data_buffer.current_size))
        self.assertEqual(ret[0], batch[0])
        self.assertEqual(ret[1], batch[1])
        self.assertEqual(ret[2], batch[2][-capacity:])

        # Test checkpoint working
        with tempfile.TemporaryDirectory() as checkpoint_directory:
            checkpoint = Checkpointer(checkpoint_directory,
                                      data_buffer=data_buffer)
            checkpoint.save(10)
            data_buffer = DataBuffer(data_spec=data_spec, capacity=capacity)
            checkpoint = Checkpointer(checkpoint_directory,
                                      data_buffer=data_buffer)
            global_step = checkpoint.load()
            self.assertEqual(global_step, 10)

        ret = data_buffer.get_batch_by_indices(
            torch.arange(data_buffer.current_size - 100,
                         data_buffer.current_size))
        self.assertEqual(ret[0], batch[0])
        self.assertEqual(ret[1], batch[1])
        self.assertEqual(ret[2], batch[2][-capacity:])

        data_buffer.clear()
        self.assertEqual(int(data_buffer.current_size), 0)
Exemple #3
0
    def test_data_buffer(self):
        dim = 20
        capacity = 256
        data_spec = (tf.TensorSpec(shape=(), dtype=tf.float32),
                     tf.TensorSpec(shape=(dim // 3 - 1, ), dtype=tf.float32),
                     tf.TensorSpec(shape=(dim - dim // 3, ), dtype=tf.float32))

        data_buffer = DataBuffer(data_spec=data_spec, capacity=capacity)

        def _get_batch(batch_size):
            x = tf.random.normal(shape=(batch_size, dim))
            x = (x[:, 0], x[:, 1:dim // 3], x[..., dim // 3:])
            return x

        data_buffer.add_batch(_get_batch(100))
        self.assertEqual(int(data_buffer.current_size), 100)
        batch = _get_batch(1000)
        data_buffer.add_batch(batch)
        self.assertEqual(int(data_buffer.current_size), capacity)
        ret = data_buffer.get_batch_by_indices(tf.range(capacity))
        self.assertArrayEqual(ret[0], batch[0][-capacity:])
        self.assertArrayEqual(ret[1], batch[1][-capacity:])
        self.assertArrayEqual(ret[2], batch[2][-capacity:])
        batch = _get_batch(100)
        data_buffer.add_batch(batch)
        ret = data_buffer.get_batch_by_indices(
            tf.range(data_buffer.current_size - 100, data_buffer.current_size))
        self.assertArrayEqual(ret[0], batch[0])
        self.assertArrayEqual(ret[1], batch[1])
        self.assertArrayEqual(ret[2], batch[2][-capacity:])

        # Test checkpoint working
        with tempfile.TemporaryDirectory() as checkpoint_directory:
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            checkpoint = tf.train.Checkpoint(data_buffer=data_buffer)
            checkpoint.save(file_prefix=checkpoint_prefix)

            data_buffer = DataBuffer(data_spec=data_spec, capacity=capacity)
            checkpoint = tf.train.Checkpoint(data_buffer=data_buffer)
            status = checkpoint.restore(
                tf.train.latest_checkpoint(checkpoint_directory))
            status.assert_consumed()

        ret = data_buffer.get_batch_by_indices(
            tf.range(data_buffer.current_size - 100, data_buffer.current_size))
        self.assertArrayEqual(ret[0], batch[0])
        self.assertArrayEqual(ret[1], batch[1])
        self.assertArrayEqual(ret[2], batch[2][-capacity:])
Exemple #4
0
class MISCAlgorithm(Algorithm):
    """Mutual Information-based State Control (MISC)
    Author: Rui Zhao
    Work done during a research internship at Horizon Robotics.
    The paper is currently under review in a conference.

    This algorithm generates the intrinsic reward based on the mutual information
    estimation between the goal states and the controllable states.

    See Zhao et al "Mutual Information-based State-Control for Intrinsically Motivated Reinforcement Learning",
    https://arxiv.org/abs/2002.01963
    """
    def __init__(self,
                 batch_size,
                 observation_spec,
                 action_spec,
                 soi_spec,
                 soc_spec,
                 split_observation_fn: Callable,
                 network: Network = None,
                 mi_r_scale=5000.0,
                 hidden_size=128,
                 buffer_size=100,
                 n_objects=1,
                 name="MISCAlgorithm"):
        """Create an MISCAlgorithm.

        Args:
            batch_size (int): batch size
            observation_spec (tf.TensorSpec): observation size
            action_spec (tf.TensorSpec): action size
            soi_spec (tf.TensorSpec): state of interest size
            soc_spec (tf.TensorSpec): state of context size
            split_observation_fn (Callable): split observation function.
                The input is observation and action concatenated.
                The outputs are the context states and states of interest
            network (Network): network for estimating mutual information (MI)
            mi_r_scale (float): scale factor of MI estimation
            hidden_size (int): number of hidden units in neural nets
            buffer_size (int): buffer size for the data buffer storing the trajectories
                for training the Mutual Information Neural Estimator
            n_objects: number of objects for estimating the mutual information reward
            name (str): the algorithm name, "MISCAlgorithm"
        """

        super(MISCAlgorithm,
              self).__init__(train_state_spec=[observation_spec, action_spec],
                             name=name)

        assert isinstance(observation_spec, tf.TensorSpec), \
            "does not support nested observation_spec"
        assert isinstance(action_spec, tf.TensorSpec), \
            "does not support nested action_spec"

        if network is None:
            network = EncodingNetwork(input_tensor_spec=[soc_spec, soi_spec],
                                      fc_layer_params=(hidden_size, ),
                                      activation_fn='relu',
                                      last_layer_size=1,
                                      last_activation_fn='tanh')

        self._network = network

        self._traj_spec = tf.TensorSpec(shape=[batch_size] + [
            observation_spec.shape.as_list()[0] +
            action_spec.shape.as_list()[0]
        ],
                                        dtype=observation_spec.dtype)
        self._buffer_size = buffer_size
        self._buffer = DataBuffer(self._traj_spec, capacity=self._buffer_size)
        self._mi_r_scale = mi_r_scale
        self._n_objects = n_objects
        self._split_observation_fn = split_observation_fn
        self._batch_size = batch_size

    def _mine(self, x_in, y_in):
        """Mutual Infomation Neural Estimator.

        Implement mutual information neural estimator from
        Belghazi et al "Mutual Information Neural Estimation"
        http://proceedings.mlr.press/v80/belghazi18a/belghazi18a.pdf
        'DV':  sup_T E_P(T) - log E_Q(exp(T))
        where P is the joint distribution of X and Y, and Q is the product
         marginal distribution of P. DV is a lower bound for
         KLD(P||Q)=MI(X, Y).

        """
        y_in_tran = transpose2(y_in, 1, 0)
        y_shuffle_tran = math_ops.shuffle(y_in_tran)
        y_shuffle = transpose2(y_shuffle_tran, 1, 0)

        # propagate the forward pass
        T_xy, _ = self._network([x_in, y_in])
        T_x_y, _ = self._network([x_in, y_shuffle])

        # compute the negative loss (maximize loss == minimize -loss)
        mean_exp_T_x_y = tf.reduce_mean(tf.math.exp(T_x_y), axis=1)
        loss = tf.reduce_mean(T_xy, axis=1) - tf.math.log(mean_exp_T_x_y)
        loss = tf.squeeze(loss, axis=-1)  # Mutual Information

        return loss

    def train_step(self,
                   time_step: ActionTimeStep,
                   state,
                   calc_intrinsic_reward=True):
        """
        Args:
            time_step (ActionTimeStep): input time_step data
            state (tuple): state for MISC (previous observation,
                previous previous action)
            calc_intrinsic_reward (bool): if False, only return the losses
        Returns:
            TrainStep:
                outputs: empty tuple ()
                state: tuple of observation and previous action
                info: (MISCInfo):
        """
        feature_state = time_step.observation
        prev_action = time_step.prev_action
        feature = tf.concat([feature_state, prev_action], axis=-1)
        prev_feature = tf.concat(state, axis=-1)

        feature_reshaped = tf.expand_dims(feature, axis=1)
        prev_feature_reshaped = tf.expand_dims(prev_feature, axis=1)
        feature_pair = tf.concat([prev_feature_reshaped, feature_reshaped], 1)
        feature_reshaped_tran = transpose2(feature_reshaped, 1, 0)

        def add_batch():
            self._buffer.add_batch(feature_reshaped_tran)

        # The reason that the batch_size of time_step can be different from
        # self._batch_size is that this is invoked in PREPARE_SPEC mode.
        # TODO: handle PREPARE_SPEC properly
        if (calc_intrinsic_reward
                and time_step.step_type.shape[0] == self._batch_size):
            add_batch()

        if self._n_objects < 2:
            obs_tau_excludes_goal, obs_tau_achieved_goal = (
                self._split_observation_fn(feature_pair))
            loss = self._mine(obs_tau_excludes_goal, obs_tau_achieved_goal)
        elif self._n_objects == 2:
            (obs_tau_excludes_goal, obs_tau_achieved_goal_1,
             obs_tau_achieved_goal_2
             ) = self._split_observation_fn(feature_pair)
            loss_1 = self._mine(obs_tau_excludes_goal, obs_tau_achieved_goal_1)
            loss_2 = self._mine(obs_tau_excludes_goal, obs_tau_achieved_goal_2)
            loss = loss_1 + loss_2

        intrinsic_reward = ()
        if calc_intrinsic_reward:
            # scale/normalize the MISC intrinsic reward
            if self._n_objects < 2:
                intrinsic_reward = tf.clip_by_value(self._mi_r_scale * loss, 0,
                                                    1)
            elif self._n_objects == 2:
                intrinsic_reward = tf.clip_by_value(
                    self._mi_r_scale * loss_1, 0,
                    1) + 1 * tf.clip_by_value(self._mi_r_scale * loss_2, 0, 1)

        return AlgorithmStep(outputs=(),
                             state=[feature_state, prev_action],
                             info=MISCInfo(reward=intrinsic_reward))

    def calc_loss(self, info: MISCInfo):
        feature_tau_sampled = self._buffer.get_batch(
            batch_size=self._buffer_size)
        feature_tau_sampled_tran = transpose2(feature_tau_sampled, 1, 0)
        if self._n_objects < 2:
            obs_tau_excludes_goal, obs_tau_achieved_goal = (
                self._split_observation_fn(feature_tau_sampled_tran))
            loss = self._mine(obs_tau_excludes_goal, obs_tau_achieved_goal)
        elif self._n_objects == 2:
            (obs_tau_excludes_goal, obs_tau_achieved_goal_1,
             obs_tau_achieved_goal_2
             ) = self._split_observation_fn(feature_tau_sampled_tran)
            loss_1 = self._mine(obs_tau_excludes_goal, obs_tau_achieved_goal_1)
            loss_2 = self._mine(obs_tau_excludes_goal, obs_tau_achieved_goal_2)
            loss = loss_1 + loss_2

        neg_loss = -loss
        neg_loss_scalar = tf.reduce_mean(neg_loss)
        return LossInfo(scalar_loss=neg_loss_scalar)
Exemple #5
0
class MIEstimator(Algorithm):
    """Mutual Infomation Estimator.

    Implements several mutual information estimator from
    Belghazi et al "Mutual Information Neural Estimation"
    http://proceedings.mlr.press/v80/belghazi18a/belghazi18a.pdf
    Hjelm et al "Learning Deep Representations by Mutual Information Estimation
    and Maximization" https://arxiv.org/pdf/1808.06670.pdf

    Currently 3 types of estimator are implemented, which are based on the
    following variational lower bounds:
    * 'DV':  sup_T E_P(T) - log E_Q(exp(T))
    * 'KLD': sup_T E_P(T) - E_Q(exp(T)) + 1
    * 'JSD': sup_T -E_P(softplus(-T))) - E_Q(solftplus(T)) + log(4)

    where P is the joint distribution of X and Y, and Q is the product marginal
    distribution of P. Both DV and KLD are lower bounds for KLD(P||Q)=MI(X, Y).
    However, JSD is not a lower bound for mutual information, it is a lower
    bound for JSD(P||Q), which is closely correlated with MI as pointed out in
    Hjelm et al.

    Assumming the function class of T is rich enough to represent any function,
    for KLD and JSD, T will converge to log(P/Q) and hence E_P(T) can also be
    used as an estimator of KLD(P||Q)=MI(X,Y). For DV, T will converge to
    log(P/Q) + c, where c=log E_Q(exp(T)).

    Among these 3 estimators, 'DV' and 'KLD' seems to give a better estimation
    of PMI than 'JSD'. But 'JSD' might be numerically more stable than 'DV' and
    'KLD' because of the use of softplus instead of exp. And 'DV' is more stable
    than 'KLD' because of the logarithm.

    Several strategies are implemented in order to estimate E_Q(.):
    * 'buffer': store y to a buffer and randomly retrieve samples from the
       buffer.
    * 'double_buffer': stroe both x and y to buffers and randomly retrieve
       samples from the two buffers.
    * 'shuffle': randomly shuffle batch y
    * 'shift': shift batch y by one sample, i.e.
      tf.concat([y[-1:, ...], y[0:-1, ...]], axis=0)

    Among these, 'buffer' and 'shift' seem to perform better and 'shuffle'
    performs worst. 'buffer' incurs additional storage cost. 'shift' has the
    assumption that y samples from one batch are independent. If the additional
    memory is not a concern, we recommend 'buffer' sampler so that there is no
    need to worry about the assumption of independence.
    """
    def __init__(self,
                 x_spec: tf.TensorSpec,
                 y_spec: tf.TensorSpec,
                 model=None,
                 fc_layers=(256, ),
                 sampler='buffer',
                 buffer_size=65536,
                 optimizer: tf.optimizers.Optimizer = None,
                 estimator_type='DV',
                 averager=ScalarAdaptiveAverager(),
                 name="MIEstimator"):
        """Create a MIEstimator.

        Args:
            x_spec (TensorSpec): spec of x
            y_spec (TensorSpec): spec of y
            model (Network): can be called as model([x, y]) and return a Tensor
                with shape=[batch_size, 1]. If None, a default MLP with
                fc_layers will be created.
            fc_layers (tuple[int]): size of hidden layers. Only used if model is
                None.
            sampler (str): type of sampler used to get samples from marginal
                distribution, should be one of ['buffer', 'double_buffer',
                'shuffle', 'shift']
            buffer_size (int): capacity of buffer for storing y for sampler
                'buffer' and 'double_buffer'
            optimzer (tf.optimizers.Optimzer): optimizer
            estimator_type (str): one of 'DV', 'KLD' or 'JSD'
            averager (EMAverager): averager used to maintain a moving average
                of exp(T). Only used for 'DV' estimator
            name (str): name of this estimator
        """
        assert estimator_type in ['DV', 'KLD', 'JSD'
                                  ], "Wrong estimator_type %s" % estimator_type
        super().__init__(train_state_spec=(), optimizer=optimizer, name=name)
        self._x_spec = x_spec
        self._y_spec = y_spec
        if model is None:
            model = EncodingNetwork(name="MIEstimator",
                                    input_tensor_spec=[x_spec, y_spec],
                                    fc_layer_params=fc_layers,
                                    last_layer_size=1)
        self._model = model
        self._type = estimator_type
        if sampler == 'buffer':
            self._y_buffer = DataBuffer(y_spec, capacity=buffer_size)
            self._sampler = self._buffer_sampler
        elif sampler == 'double_buffer':
            self._x_buffer = DataBuffer(x_spec, capacity=buffer_size)
            self._y_buffer = DataBuffer(y_spec, capacity=buffer_size)
            self._sampler = self._double_buffer_sampler
        elif sampler == 'shuffle':
            self._sampler = self._shuffle_sampler
        elif sampler == 'shift':
            self._sampler = self._shift_sampler
        else:
            raise TypeError("Wrong type for sampler %s" % sampler)

        if estimator_type == 'DV':
            self._mean_averager = averager

    def _buffer_sampler(self, x, y):
        batch_size = tf.cast(tf.shape(y)[0], tf.int64)
        if self._y_buffer.current_size >= batch_size:
            y1 = self._y_buffer.get_batch(batch_size)
            self._y_buffer.add_batch(y)
        else:
            self._y_buffer.add_batch(y)
            y1 = self._y_buffer.get_batch(batch_size)
        return x, y1

    def _double_buffer_sampler(self, x, y):
        batch_size = tf.shape(y)[0]
        self._x_buffer.add_batch(x)
        x1 = self._x_buffer.get_batch(batch_size)
        self._y_buffer.add_batch(y)
        y1 = self._y_buffer.get_batch(batch_size)
        return x1, y1

    def _shuffle_sampler(self, x, y):
        return x, tf.random.shuffle(y)

    def _shift_sampler(self, x, y):
        return x, tf.concat([y[-1:, ...], y[0:-1, ...]], axis=0)

    def train_step(self, inputs, state=None):
        """Perform training on one batch of inputs.

        Args:
            inputs (tuple(Tensor, Tensor)): tuple of x and y
            state: not used
        Returns:
            AlgorithmStep
                outputs (Tensor): shape=[batch_size], its mean is the estimated
                    MI
                state: not used
                info (LossInfo): info.loss is the loss
        """
        x, y = inputs
        num_outer_dims = get_outer_rank(x, self._x_spec)
        batch_squash = BatchSquash(num_outer_dims)
        x = batch_squash.flatten(x)
        y = batch_squash.flatten(y)
        x1, y1 = self._sampler(x, y)

        log_ratio = self._model([x, y])[0]
        t1 = self._model([x1, y1])[0]

        if self._type == 'DV':
            ratio = tf.math.exp(tf.minimum(t1, 20))
            mean = tf.stop_gradient(tf.reduce_mean(ratio))
            if self._mean_averager:
                self._mean_averager.update(mean)
                unbiased_mean = tf.stop_gradient(self._mean_averager.get())
            else:
                unbiased_mean = mean
            # estimated MI = reduce_mean(mi)
            # ratio/mean-1 does not contribute to the final estimated MI, since
            # mean(ratio/mean-1) = 0. We add it so that we can have an estimation
            # of the variance of the MI estimator
            mi = log_ratio - (tf.math.log(mean) + ratio / mean - 1)
            loss = ratio / unbiased_mean - log_ratio
        elif self._type == 'KLD':
            ratio = tf.math.exp(tf.minimum(t1, 20))
            mi = log_ratio - ratio + 1
            loss = -mi
        elif self._type == 'JSD':
            mi = -tf.nn.softplus(-log_ratio) - tf.nn.softplus(t1) + math.log(4)
            loss = -mi

        mi = batch_squash.unflatten(mi)
        loss = batch_squash.unflatten(loss)

        return AlgorithmStep(outputs=mi,
                             state=(),
                             info=LossInfo(loss, extra=()))

    def calc_pmi(self, x, y):
        """Return estimated pointwise mutual information.

        The pointwise mutual information is defined as:
            log P(x|y)/P(x) = log P(y|x)/P(y)

        Args:
            x (tf.Tensor): x
            y (tf.Tensor): y
        Returns:
            tf.Tensor: pointwise mutual information between x and y
        """
        log_ratio = self._model([x, y])[0]
        if self._type == 'DV':
            log_ratio -= tf.math.log(self._mean_averager.get())
        return log_ratio
Exemple #6
0
class MIEstimator(Algorithm):
    r"""Mutual Infomation Estimator.

    Implements several mutual information estimator from
    Belghazi et al `Mutual Information Neural Estimation
    <http://proceedings.mlr.press/v80/belghazi18a/belghazi18a.pdf>`_
    Hjelm et al `Learning Deep Representations by Mutual Information Estimation
    and Maximization <https://arxiv.org/pdf/1808.06670.pdf>`_

    Currently, 3 types of estimator are implemented, which are based on the
    following variational lower bounds:

    * *DV*:  :math:`\sup_T E_P(T) - \log E_Q(\exp(T))`
    * *KLD*: :math:`\sup_T E_P(T) - E_Q(\exp(T)) + 1`
    * *JSD*: :math:`\sup_T -E_P(softplus(-T))) - E_Q(solftplus(T)) + \log(4)`
    * *ML*: :math:`\sup_q E_P(\log(q(y|x)) - \log(P(y)))`

    where P is the joint distribution of X and Y, and Q is the product marginal
    distribution of P. Both DV and KLD are lower bounds for :math:`KLD(P||Q)=MI(X, Y)`.
    However, *JSD* is not a lower bound for mutual information, it is a lower
    bound for :math:`JSD(P||Q)`, which is closely correlated with MI as pointed out in
    Hjelm et al.

    For *ML*, :math:`P(y)` is the margianl distribution of y, and it needs to be provided.
    The current implementation uses a normal distribution with diagonal variance
    for :math:`q(y|x)`. So it only support continous `y`. If :math:`P(y|x)` can be reasonably
    approximated as an diagonal normal distribution and :math:`P(y)` is known,
    then 'ML' may give better estimation for the mutual information.

    Assumming the function class of T is rich enough to represent any function,
    for *KLD* and *JSD*, T will converge to :math:`\log(\frac{P}{Q})` and hence
    :math:`E_P(T)` can also be used as an estimator of :math:`KLD(P||Q)=MI(X,Y)`.
    For *DV*, :math:`T` will converge to :math:`\log(\frac{P}{Q}) + c`, where
    :math:`c=\log E_Q(\exp(T))`.

    Among *DV*, *KLD* and *JSD*,  *DV* and *KLD* seem to give a better estimation
    of PMI than *JSD*. But *JSD* might be numerically more stable than *DV* and
    *KLD* because of the use of softplus instead of exp. And *DV* is more stable
    than *KLD* because of the logarithm.

    Several strategies are implemented in order to estimate :math:`E_Q(\cdot)`:

    * 'buffer': store :math:`y` to a buffer and randomly retrieve samples from
      the buffer.
    * 'double_buffer': stroe both :math:`x` and :math:`y` to buffers and randomly
      retrieve samples from the two buffers.
    * 'shuffle': randomly shuffle batch :math:`y`
    * 'shift': shift batch :math:`y` by one sample, i.e.
      ``torch.cat([y[-1:, ...], y[0:-1, ...]], dim=0)``
    * direct sampling: You can also provide the marginal distribution of :math:`y`
      to ``train_step()``. In this case, sampler is ignored and samples of :math:`y`
      for estimating :math:`E_Q(.)` are sampled from ``y_distribution``.

    If you need the gradient of :math:`y`, you should use sampler 'shift' and
    'shuffle'.

    Among these, 'buffer' and 'shift' seem to perform better and 'shuffle'
    performs worst. 'buffer' incurs additional storage cost. 'shift' has the
    assumption that y samples from one batch are independent. If the additional
    memory is not a concern, we recommend 'buffer' sampler so that there is no
    need to worry about the assumption of independence.

    ``MIEstimator`` can be also used to estimate conditional mutual information
    :math:`MI(X,Y|Z)` using *KLD*, *JSD* or *ML*. In this case, you should let
    ``x`` to represent :math:`X` and :math:`Z`, and ``y`` to represent :math:`Y`.
    And when calling ``train_step()``, you need to provide ``y_distribution``
    which is the distribution :math:`P(Y|z)`. Note that *DV* cannot be used for
    estimating conditional mutual information. See ``mi_estimator_test.py`` for
    an example.
    """
    def __init__(self,
                 x_spec,
                 y_spec,
                 model=None,
                 fc_layers=(256, ),
                 sampler='buffer',
                 buffer_size=65536,
                 optimizer: torch.optim.Optimizer = None,
                 estimator_type='DV',
                 averager: EMAverager = None,
                 name="MIEstimator"):
        """

        Args:
            x_spec (nested TensorSpec): spec of ``x``
            y_spec (nested TensorSpec): spec of ``y``
            model (Network): can be called as ``model([x, y])`` and return a Tensor
                with ``shape=[batch_size, 1]``. If None, a default MLP with
                ``fc_layers`` will be created.
            fc_layers (tuple[int]): size of hidden layers. Only used if model is
                None.
            sampler (str): type of sampler used to get samples from marginal
                distribution, should be one of ``['buffer', 'double_buffer',
                'shuffle', 'shift']``.
            buffer_size (int): capacity of buffer for storing y for sampler
                'buffer' and 'double_buffer'.
            optimzer (torch.optim.Optimzer): optimizer
            estimator_type (str): one of 'DV', 'KLD' or 'JSD'
            averager (EMAverager): averager used to maintain a moving average
                of :math:`exp(T)`. Only used for 'DV' estimator. If None, 
                a ScalarAdaptiveAverager will be created.
            name (str): name of this estimator
        """
        assert estimator_type in ['ML', 'DV', 'KLD', 'JSD'
                                  ], "Wrong estimator_type %s" % estimator_type
        super().__init__(train_state_spec=(), optimizer=optimizer, name=name)
        self._x_spec = x_spec
        self._y_spec = y_spec
        if model is None:
            if estimator_type == 'ML':
                model = EncodingNetwork(
                    name="MIEstimator",
                    input_tensor_spec=x_spec,
                    fc_layer_params=fc_layers,
                    preprocessing_combiner=NestConcat(dim=-1))
            else:
                model = EncodingNetwork(
                    name="MIEstimator",
                    input_tensor_spec=[x_spec, y_spec],
                    preprocessing_combiner=NestConcat(dim=-1),
                    fc_layer_params=fc_layers,
                    last_layer_size=1,
                    last_activation=math_ops.identity)
        self._model = model
        self._type = estimator_type
        if sampler == 'buffer':
            self._y_buffer = DataBuffer(y_spec, capacity=buffer_size)
            self._sampler = self._buffer_sampler
        elif sampler == 'double_buffer':
            self._x_buffer = DataBuffer(x_spec, capacity=buffer_size)
            self._y_buffer = DataBuffer(y_spec, capacity=buffer_size)
            self._sampler = self._double_buffer_sampler
        elif sampler == 'shuffle':
            self._sampler = self._shuffle_sampler
        elif sampler == 'shift':
            self._sampler = self._shift_sampler
        else:
            raise TypeError("Wrong type for sampler %s" % sampler)

        if estimator_type == 'DV':
            if averager is None:
                averager = ScalarAdaptiveAverager()
            self._mean_averager = averager
        if estimator_type == 'ML':
            assert isinstance(
                y_spec,
                alf.TensorSpec), ("Currently, 'ML' does "
                                  "not support nested y_spec: %s" % y_spec)
            assert y_spec.is_continuous, ("Currently, 'ML' does "
                                          "not support discreted y_spec: %s" %
                                          y_spec)
            hidden_size = self._model.output_spec.shape[-1]
            self._delta_loc_layer = alf.layers.FC(
                hidden_size,
                y_spec.shape[-1],
                kernel_initializer=torch.nn.init.zeros_,
                bias_init_value=0.0)
            self._delta_scale_layer = alf.layers.FC(
                hidden_size,
                y_spec.shape[-1],
                kernel_initializer=torch.nn.init.zeros_,
                bias_init_value=math.log(math.e - 1))

    def _buffer_sampler(self, x, y):
        batch_size = get_nest_batch_size(y)
        if self._y_buffer.current_size >= batch_size:
            y1 = self._y_buffer.get_batch(batch_size)
            self._y_buffer.add_batch(y)
        else:
            self._y_buffer.add_batch(y)
            y1 = self._y_buffer.get_batch(batch_size)
        return x, common.detach(y1)

    def _double_buffer_sampler(self, x, y):
        batch_size = get_nest_batch_size(y)
        self._x_buffer.add_batch(x)
        x1 = self._x_buffer.get_batch(batch_size)
        self._y_buffer.add_batch(y)
        y1 = self._y_buffer.get_batch(batch_size)
        return x1, y1

    def _shuffle_sampler(self, x, y):
        return x, math_ops.shuffle(y)

    def _shift_sampler(self, x, y):
        def _shift(y):
            return torch.cat([y[-1:, ...], y[0:-1, ...]], dim=0)

        return x, alf.nest.map_structure(_shift, y)

    def train_step(self, inputs, y_distribution=None, state=None):
        """Perform training on one batch of inputs.

        Args:
            inputs (tuple(nested Tensor, nested Tensor)): tuple of ``x`` and ``y``
            y_distribution (nested td.Distribution): distribution
                for the marginal distribution of ``y``. If None, will use the
                sampling method ``sampler`` provided at constructor to generate
                the samples for the marginal distribution of :math:`Y`.
            state: not used
        Returns:
            AlgStep:
            - outputs (Tensor): shape is ``[batch_size]``, its mean is the
              estimated MI for estimator 'KL', 'DV' and 'KLD', and
              Jensen-Shannon divergence for estimator 'JSD'
            - state: not used
            - info (LossInfo): ``info.loss`` is the loss
        """
        x, y = inputs

        if self._type == 'ML':
            return self._ml_step(x, y, y_distribution)

        num_outer_dims = get_outer_rank(x, self._x_spec)
        batch_squash = BatchSquash(num_outer_dims)
        x = batch_squash.flatten(x)
        y = batch_squash.flatten(y)
        if y_distribution is None:
            x1, y1 = self._sampler(x, y)
        else:
            x1 = x
            y1 = y_distribution.sample()
            y1 = batch_squash.flatten(y1)

        log_ratio = self._model([x, y])[0]
        t1 = self._model([x1, y1])[0]

        if self._type == 'DV':
            ratio = torch.min(t1, torch.tensor(20.)).exp()
            mean = ratio.mean().detach()
            if self._mean_averager:
                self._mean_averager.update(mean)
                unbiased_mean = self._mean_averager.get().detach()
            else:
                unbiased_mean = mean
            # estimated MI = reduce_mean(mi)
            # ratio/mean-1 does not contribute to the final estimated MI, since
            # mean(ratio/mean-1) = 0. We add it so that we can have an estimation
            # of the variance of the MI estimator
            mi = log_ratio - (mean.log() + ratio / mean - 1)
            loss = ratio / unbiased_mean - log_ratio
        elif self._type == 'KLD':
            ratio = torch.min(t1, torch.tensor(20.)).exp()
            mi = log_ratio - ratio + 1
            loss = -mi
        elif self._type == 'JSD':
            mi = -F.softplus(-log_ratio) - F.softplus(t1) + math.log(4)
            loss = -mi
        mi = batch_squash.unflatten(mi)
        loss = batch_squash.unflatten(loss)

        return AlgStep(output=mi, state=(), info=LossInfo(loss, extra=()))

    def _ml_pmi(self, x, y, y_distribution):
        num_outer_dims = get_outer_rank(x, self._x_spec)
        hidden = self._model(x)[0]
        batch_squash = BatchSquash(num_outer_dims)
        hidden = batch_squash.flatten(hidden)
        delta_loc = self._delta_loc_layer(hidden)
        delta_scale = F.softplus(self._delta_scale_layer(hidden))
        delta_loc = batch_squash.unflatten(delta_loc)
        delta_scale = batch_squash.unflatten(delta_scale)
        y_given_x_dist = DiagMultivariateNormal(
            loc=y_distribution.mean + delta_loc,
            scale=y_distribution.stddev * delta_scale)

        pmi = y_given_x_dist.log_prob(y) - y_distribution.log_prob(y).detach()
        return pmi

    def _ml_step(self, x, y, y_distribution):
        pmi = self._ml_pmi(x, y, y_distribution)
        return AlgStep(output=pmi, state=(), info=LossInfo(loss=-pmi))

    def calc_pmi(self, x, y, y_distribution=None):
        r"""Return estimated pointwise mutual information.

        The pointwise mutual information is defined as:

        .. math::

            \log \frac{P(x|y)}{P(x)} = \log \frac{P(y|x)}{P(y)}

        Args:
            x (Tensor): x
            y (Tensor): y
            y_distribution (DiagMultivariateNormal): needs to be provided for
                'ML' estimator.
        Returns:
            Tensor: pointwise mutual information between ``x`` and ``y``.
        """
        if self._type == 'ML':
            assert isinstance(y_distribution, DiagMultivariateNormal), (
                "y_distribution should be a DiagMultivariateNormal")
            return self._ml_pmi(x, y, y_distribution)
        log_ratio = self._model([x, y])[0]
        log_ratio = torch.squeeze(log_ratio, dim=-1)
        if self._type == 'DV':
            log_ratio -= self._mean_averager.get().log()
        return log_ratio
Exemple #7
0
class MIEstimator(Algorithm):
    """Mutual Infomation Estimator.

    Implements several mutual information estimator from
    Belghazi et al "Mutual Information Neural Estimation"
    http://proceedings.mlr.press/v80/belghazi18a/belghazi18a.pdf
    Hjelm et al "Learning Deep Representations by Mutual Information Estimation
    and Maximization" https://arxiv.org/pdf/1808.06670.pdf

    Currently 3 types of estimator are implemented, which are based on the
    following variational lower bounds:
    * 'DV':  sup_T E_P(T) - log E_Q(exp(T))
    * 'KLD': sup_T E_P(T) - E_Q(exp(T)) + 1
    * 'JSD': sup_T -E_P(softplus(-T))) - E_Q(solftplus(T)) + log(4)
    * 'ML': sup_q E_P(log(q(y|x)) - log(P(y)))

    where P is the joint distribution of X and Y, and Q is the product marginal
    distribution of P. Both DV and KLD are lower bounds for KLD(P||Q)=MI(X, Y).
    However, JSD is not a lower bound for mutual information, it is a lower
    bound for JSD(P||Q), which is closely correlated with MI as pointed out in
    Hjelm et al.

    For ML, P(y) is the margianl distribution of y, and it needs to be provided.
    The current implementation uses a normal distribution with diagonal variance
    for q(y|x). So it only support continous `y`. If P(y|x) can be reasonably
    approximated as an diagonal normal distribution and P(y) is known, then 'ML'
    may give better estimation for the mutual information.

    Assumming the function class of T is rich enough to represent any function,
    for KLD and JSD, T will converge to log(P/Q) and hence E_P(T) can also be
    used as an estimator of KLD(P||Q)=MI(X,Y). For DV, T will converge to
    log(P/Q) + c, where c=log E_Q(exp(T)).

    Among 'DV', 'KLD' and 'JSD',  'DV' and 'KLD' seem to give a better estimation
    of PMI than 'JSD'. But 'JSD' might be numerically more stable than 'DV' and
    'KLD' because of the use of softplus instead of exp. And 'DV' is more stable
    than 'KLD' because of the logarithm.

    Several strategies are implemented in order to estimate E_Q(.):
    * 'buffer': store y to a buffer and randomly retrieve samples from the
       buffer.
    * 'double_buffer': stroe both x and y to buffers and randomly retrieve
       samples from the two buffers.
    * 'shuffle': randomly shuffle batch y
    * 'shift': shift batch y by one sample, i.e.
      tf.concat([y[-1:, ...], y[0:-1, ...]], axis=0)
    * direct sampling: You can also provide the marginal distribution of y to
      train_step(). In this case, sampler is ignored and samples of y for
      estimating E_Q(.) are sampled from y_distribution.

    If you need the gradient of y, you should use sampler 'shift' and 'shuffle'.

    Among these, 'buffer' and 'shift' seem to perform better and 'shuffle'
    performs worst. 'buffer' incurs additional storage cost. 'shift' has the
    assumption that y samples from one batch are independent. If the additional
    memory is not a concern, we recommend 'buffer' sampler so that there is no
    need to worry about the assumption of independence.

    MIEstimator can be also used to estimate conditional mutual information
    MI(X,Y|Z) using 'KLD', 'JSD' or 'ML'. In this case, you should let `x` to
    represent X and Z, and `y` to represent Y. And when calling train_step(),
    you need to provide `y_distribution` which is the distribution P(Y|z).
    Note that 'DV' cannot be used for estimating conditional mutual information.
    See mi_estimator_test.py for example.
    """

    def __init__(self,
                 x_spec,
                 y_spec,
                 model=None,
                 fc_layers=(256, ),
                 sampler='buffer',
                 buffer_size=65536,
                 optimizer: tf.optimizers.Optimizer = None,
                 estimator_type='DV',
                 averager=ScalarAdaptiveAverager(),
                 name="MIEstimator"):
        """Create a MIEstimator.

        Args:
            x_spec (nested TensorSpec): spec of x
            y_spec (nested TensorSpec): spec of y
            model (Network): can be called as model([x, y]) and return a Tensor
                with shape=[batch_size, 1]. If None, a default MLP with
                fc_layers will be created.
            fc_layers (tuple[int]): size of hidden layers. Only used if model is
                None.
            sampler (str): type of sampler used to get samples from marginal
                distribution, should be one of ['buffer', 'double_buffer',
                'shuffle', 'shift']
            buffer_size (int): capacity of buffer for storing y for sampler
                'buffer' and 'double_buffer'
            optimzer (tf.optimizers.Optimzer): optimizer
            estimator_type (str): one of 'DV', 'KLD' or 'JSD'
            averager (EMAverager): averager used to maintain a moving average
                of exp(T). Only used for 'DV' estimator
            name (str): name of this estimator
        """
        assert estimator_type in ['ML', 'DV', 'KLD', 'JSD'
                                  ], "Wrong estimator_type %s" % estimator_type
        super().__init__(train_state_spec=(), optimizer=optimizer, name=name)
        self._x_spec = x_spec
        self._y_spec = y_spec
        if model is None:
            if estimator_type == 'ML':
                model = TFAEncodingNetwork(
                    name="MIEstimator",
                    input_tensor_spec=x_spec,
                    fc_layer_params=fc_layers,
                    preprocessing_combiner=NestConcatenate(axis=-1))
            else:
                model = EncodingNetwork(
                    name="MIEstimator",
                    input_tensor_spec=[x_spec, y_spec],
                    fc_layer_params=fc_layers,
                    last_layer_size=1)
        self._model = model
        self._type = estimator_type
        if sampler == 'buffer':
            self._y_buffer = DataBuffer(y_spec, capacity=buffer_size)
            self._sampler = self._buffer_sampler
        elif sampler == 'double_buffer':
            self._x_buffer = DataBuffer(x_spec, capacity=buffer_size)
            self._y_buffer = DataBuffer(y_spec, capacity=buffer_size)
            self._sampler = self._double_buffer_sampler
        elif sampler == 'shuffle':
            self._sampler = self._shuffle_sampler
        elif sampler == 'shift':
            self._sampler = self._shift_sampler
        else:
            raise TypeError("Wrong type for sampler %s" % sampler)

        if estimator_type == 'DV':
            self._mean_averager = averager
        if estimator_type == 'ML':
            assert isinstance(
                y_spec,
                tf.TensorSpec), ("Currently, 'ML' does "
                                 "not support nested y_spec: %s" % y_spec)
            assert tensor_spec.is_continuous(y_spec), (
                "Currently, 'ML' does "
                "not support discreted y_spec: %s" % y_spec)
            self._delta_loc_layer = tf.keras.layers.Dense(
                y_spec.shape[-1],
                kernel_initializer=tf.initializers.Zeros(),
                bias_initializer=tf.initializers.Zeros(),
                name='delta_loc_layer')
            self._delta_scale_layer = tf.keras.layers.Dense(
                y_spec.shape[-1],
                kernel_initializer=tf.initializers.Zeros(),
                bias_initializer=tf.keras.initializers.Constant(
                    value=math.log(math.e - 1)),
                name='delta_scale_layer')

    def _buffer_sampler(self, x, y):
        batch_size = get_nest_batch_size(y)
        if self._y_buffer.current_size >= batch_size:
            y1 = self._y_buffer.get_batch(batch_size)
            self._y_buffer.add_batch(y)
        else:
            self._y_buffer.add_batch(y)
            y1 = self._y_buffer.get_batch(batch_size)
        # It seems that tf.stop_gradient() should be unnesessary. But somehow
        # TF will crash without this stop_gradient
        return x, tf.nest.map_structure(tf.stop_gradient, y1)

    def _double_buffer_sampler(self, x, y):
        batch_size = get_nest_batch_size(y)
        self._x_buffer.add_batch(x)
        x1 = self._x_buffer.get_batch(batch_size)
        self._y_buffer.add_batch(y)
        y1 = self._y_buffer.get_batch(batch_size)
        return x1, y1

    def _shuffle_sampler(self, x, y):
        return x, tf.nest.map_structure(tf.random.shuffle, y)

    def _shift_sampler(self, x, y):
        def _shift(y):
            return tf.concat([y[-1:, ...], y[0:-1, ...]], axis=0)

        return x, tf.nest.map_structure(_shift, y)

    def train_step(self, inputs, y_distribution=None, state=None):
        """Perform training on one batch of inputs.

        Args:
            inputs (tuple(nested Tensor, nested Tensor)): tuple of x and y
            y_distribution (nested tfp.distributions.Distribution): distribution
                for the marginal distribution of y. If None, will use the
                sampling method `sampler` provided at constructor to generate
                the samples for the marginal distribution of Y.
            state: not used
        Returns:
            AlgorithmStep
                outputs (Tensor): shape=[batch_size], its mean is the estimated
                    MI for estimator 'KL', 'DV' and 'KLD', and Jensen-Shannon
                    divergence for estimator 'JSD'
                state: not used
                info (LossInfo): info.loss is the loss
        """
        x, y = inputs

        if self._type == 'ML':
            return self._ml_step(x, y, y_distribution)

        num_outer_dims = get_outer_rank(x, self._x_spec)
        batch_squash = BatchSquash(num_outer_dims)
        x = batch_squash.flatten(x)
        y = batch_squash.flatten(y)
        if y_distribution is None:
            x1, y1 = self._sampler(x, y)
        else:
            x1 = x
            y1 = y_distribution.sample()
            y1 = batch_squash.flatten(y1)

        log_ratio = self._model([x, y])[0]
        t1 = self._model([x1, y1])[0]

        if self._type == 'DV':
            ratio = tf.math.exp(tf.minimum(t1, 20))
            mean = tf.stop_gradient(tf.reduce_mean(ratio))
            if self._mean_averager:
                self._mean_averager.update(mean)
                unbiased_mean = tf.stop_gradient(self._mean_averager.get())
            else:
                unbiased_mean = mean
            # estimated MI = reduce_mean(mi)
            # ratio/mean-1 does not contribute to the final estimated MI, since
            # mean(ratio/mean-1) = 0. We add it so that we can have an estimation
            # of the variance of the MI estimator
            mi = log_ratio - (tf.math.log(mean) + ratio / mean - 1)
            loss = ratio / unbiased_mean - log_ratio
        elif self._type == 'KLD':
            ratio = tf.math.exp(tf.minimum(t1, 20))
            mi = log_ratio - ratio + 1
            loss = -mi
        elif self._type == 'JSD':
            mi = -tf.nn.softplus(-log_ratio) - tf.nn.softplus(t1) + math.log(4)
            loss = -mi
        mi = batch_squash.unflatten(mi)
        loss = batch_squash.unflatten(loss)

        return AlgorithmStep(
            outputs=mi, state=(), info=LossInfo(loss, extra=()))

    def _ml_pmi(self, x, y, y_distribution):
        num_outer_dims = get_outer_rank(x, self._x_spec)
        hidden = self._model(x)[0]
        batch_squash = BatchSquash(num_outer_dims)
        hidden = batch_squash.flatten(hidden)
        delta_loc = self._delta_loc_layer(hidden)
        delta_scale = tf.nn.softplus(self._delta_scale_layer(hidden))
        delta_loc = batch_squash.unflatten(delta_loc)
        delta_scale = batch_squash.unflatten(delta_scale)
        y_given_x_dist = tfp.distributions.Normal(
            loc=y_distribution.loc + delta_loc,
            scale=y_distribution.scale * delta_scale)

        # Because Normal.event_shape is [], the result of Normal.log_prob() is
        # the probabilities of individual dimensions. So we need to use
        # tfa_common.log_probability() instead.
        # TODO: implement a normal distribution with non-scalar event shape.
        pmi = tfa_common.log_probability(y_given_x_dist, y, self._y_spec)
        pmi -= tf.stop_gradient(
            tfa_common.log_probability(y_distribution, y, self._y_spec))
        return pmi

    def _ml_step(self, x, y, y_distribution):
        pmi = self._ml_pmi(x, y, y_distribution)
        return AlgorithmStep(outputs=pmi, state=(), info=LossInfo(loss=-pmi))

    def calc_pmi(self, x, y, y_distribution=None):
        """Return estimated pointwise mutual information.

        The pointwise mutual information is defined as:
            log P(x|y)/P(x) = log P(y|x)/P(y)

        Args:
            x (tf.Tensor): x
            y (tf.Tensor): y
            y_distribution (tfp.distributions.Normal): needs to be provided for
                'ML' estimator.
        Returns:
            tf.Tensor: pointwise mutual information between x and y
        """
        if self._type == 'ML':
            assert y_distribution is not None, "y_distribution needs to be provided"
            return self._ml_pmi(x, y, y_distribution)
        log_ratio = self._model([x, y])[0]
        log_ratio = tf.squeeze(log_ratio, axis=-1)
        if self._type == 'DV':
            log_ratio -= tf.math.log(self._mean_averager.get())
        return log_ratio
Exemple #8
0
class MUSICAlgorithm(Algorithm):
    """Mutual Information State Intrinsic Control (MUSIC)
    Author: Rui Zhao
    Work done during a research internship at Horizon Robotics.
    The paper is accepted by International Conference on Learning Representations (ICLR) 2021 as a spotlight.

    This algorithm generates the intrinsic reward based on the mutual information
    estimation between the agent states and the surrounding states.

    See Zhao et al "Mutual Information State Intrinsic Control",
    https://openreview.net/forum?id=OthEq8I5v1
    """
    def __init__(self,
                 batch_size,
                 observation_spec,
                 action_spec,
                 sos_spec,
                 soa_spec,
                 split_observation_fn: Callable,
                 network: Network = None,
                 mi_r_scale=5000.0,
                 hidden_size=128,
                 buffer_size=100,
                 n_objects=1,
                 name="MUSICAlgorithm"):
        """Create an MUSICAlgorithm.
        
        Args:
            batch_size (int): batch size
            observation_spec (tf.TensorSpec): observation size
            action_spec (tf.TensorSpec): action size
            sos_spec (tf.TensorSpec): surrounding state size
            soa_spec (tf.TensorSpec): agent state size
            split_observation_fn (Callable): split observation function. 
                The input is observation and action concatenated.
                The outputs are the agent states and the surrounding states
            network (Network): network for estimating mutual information (MI)
            mi_r_scale (float): scale factor of MI estimation
            hidden_size (int): number of hidden units in neural nets
            buffer_size (int): buffer size for the data buffer storing the trajectories 
                for training the Mutual Information Neural Estimator
            n_objects: number of objects for estimating the mutual information reward
            name (str): the algorithm name, "MUSICAlgorithm"
        """

        super(MUSICAlgorithm,
              self).__init__(train_state_spec=[observation_spec, action_spec],
                             name=name)

        assert isinstance(observation_spec, tf.TensorSpec), \
            "does not support nested observation_spec"
        assert isinstance(action_spec, tf.TensorSpec), \
            "does not support nested action_spec"

        if network is None:
            network = EncodingNetwork(input_tensor_spec=[soa_spec, sos_spec],
                                      fc_layer_params=(hidden_size, ),
                                      activation_fn='relu',
                                      last_layer_size=1,
                                      last_activation_fn='tanh')

        self._network = network

        self._traj_spec = tf.TensorSpec(shape=[batch_size] + [
            observation_spec.shape.as_list()[0] +
            action_spec.shape.as_list()[0]
        ],
                                        dtype=observation_spec.dtype)
        self._buffer_size = buffer_size
        self._buffer = DataBuffer(self._traj_spec, capacity=self._buffer_size)
        self._mi_r_scale = mi_r_scale
        self._n_objects = n_objects
        self._split_observation_fn = split_observation_fn

    def _mine(self, x_in, y_in):
        """Mutual Infomation Neural Estimator.

        Implement mutual information neural estimator from
        Belghazi et al "Mutual Information Neural Estimation"
        http://proceedings.mlr.press/v80/belghazi18a/belghazi18a.pdf
        'DV':  sup_T E_P(T) - log E_Q(exp(T))
        where P is the joint distribution of X and Y, and Q is the product
         marginal distribution of P. DV is a lower bound for 
         KLD(P||Q)=MI(X, Y).

        """
        y_in_tran = transpose2(y_in, 1, 0)
        # tf.random.shuffle() has no gradient defined, so use tf.gather()
        y_shuffle_tran = tf.gather(
            y_in_tran, tf.random.shuffle(tf.range(tf.shape(y_in_tran)[0])))
        y_shuffle = transpose2(y_shuffle_tran, 1, 0)

        # propagate the forward pass
        T_xy, _ = self._network([x_in, y_in])
        T_x_y, _ = self._network([x_in, y_shuffle])

        # compute the negative loss (maximize loss == minimize -loss)
        mean_exp_T_x_y = tf.reduce_mean(tf.math.exp(T_x_y), axis=1)
        loss = tf.reduce_mean(T_xy, axis=1) - tf.math.log(mean_exp_T_x_y)
        loss = tf.squeeze(loss, axis=-1)  # Mutual Information

        return loss

    def train_step(self, inputs, state, calc_intrinsic_reward=True):
        """
        Args:
            inputs (tuple): observation
            state (Tensor):  state for MUSIC (previous feature)
            calc_intrinsic_reward (bool): if False, only return the losses
        Returns:
            TrainStep:
                outputs: empty tuple ()
                state:  empty tuple ()
                info: (MUSICInfo):
        """
        feature_state, prev_action = inputs
        feature = tf.concat([feature_state, prev_action], axis=-1)
        prev_feature = tf.concat(state, axis=-1)

        feature_reshaped = tf.expand_dims(feature, axis=1)
        prev_feature_reshaped = tf.expand_dims(prev_feature, axis=1)
        feature_pair = tf.concat([prev_feature_reshaped, feature_reshaped], 1)
        feature_reshaped_tran = transpose2(feature_reshaped, 1, 0)

        def add_batch():
            self._buffer.add_batch(feature_reshaped_tran)

        if calc_intrinsic_reward:
            add_batch()

        if self._n_objects < 2:
            obs_tau_excludes_goal, obs_tau_achieved_goal = \
                self._split_observation_fn(feature_pair)
            loss = self._mine(obs_tau_excludes_goal, obs_tau_achieved_goal)
        elif self._n_objects == 2:
            obs_tau_excludes_goal, obs_tau_achieved_goal_1, obs_tau_achieved_goal_2 \
            = self._split_observation_fn(
                feature_pair)
            loss_1 = self._mine(obs_tau_excludes_goal, obs_tau_achieved_goal_1)
            loss_2 = self._mine(obs_tau_excludes_goal, obs_tau_achieved_goal_2)
            loss = loss_1 + loss_2

        intrinsic_reward = ()
        if calc_intrinsic_reward:
            # scale/normalize the MUSIC intrinsic reward
            if self._n_objects < 2:
                intrinsic_reward = tf.clip_by_value(self._mi_r_scale * loss, 0,
                                                    1)
            elif self._n_objects == 2:
                intrinsic_reward = tf.clip_by_value(
                    self._mi_r_scale * loss_1, 0,
                    1) + 1 * tf.clip_by_value(self._mi_r_scale * loss_2, 0, 1)

        return AlgorithmStep(
            outputs=(), state=[feature_state, prev_action], \
            info=MUSICInfo(reward=intrinsic_reward))

    def calc_loss(self, info: MUSICInfo):
        feature_tau_sampled = self._buffer.get_batch(
            batch_size=self._buffer_size)
        feature_tau_sampled_tran = transpose2(feature_tau_sampled, 1, 0)
        if self._n_objects < 2:
            obs_tau_excludes_goal, obs_tau_achieved_goal = self._split_observation_fn(
                feature_tau_sampled_tran)
            loss = self._mine(obs_tau_excludes_goal, obs_tau_achieved_goal)
        elif self._n_objects == 2:
            obs_tau_excludes_goal, obs_tau_achieved_goal_1, obs_tau_achieved_goal_2 = \
            self._split_observation_fn(
                feature_tau_sampled_tran)
            loss_1 = self._mine(obs_tau_excludes_goal, obs_tau_achieved_goal_1)
            loss_2 = self._mine(obs_tau_excludes_goal, obs_tau_achieved_goal_2)
            loss = loss_1 + loss_2

        neg_loss = -loss
        neg_loss_scalar = tf.reduce_mean(neg_loss)
        return LossInfo(scalar_loss=neg_loss_scalar)