예제 #1
0
    def unroll(self, observations: tf.Tensor, actions: tf.Tensor) -> \
            typing.List[typing.Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]]:
        """
        Overrides super function to additionally predict state-observations using a decoder model.

        The magnitude of gradient that passes through the decoder model back to the MuZero dynamics model
        is governed by the dynamics penalty.

        :param observations: tf.Tensor in R^(batch_size x width x height x (depth * time))
        :param actions: tf.Tensor consisting of one-hot-encoded actions in {0, 1}^(batch_size x K x |action_space|)
        :return: List of tuples containing the hidden state, value predictions and loss-scale for each unrolled step.
        """
        # Root inference. Collect predictions of the form: [w_i / K, o_k, v, r, pi] for each forward step k = 0...K
        s, pi_0, v_0 = self.neural_net.forward(observations)

        # Decouple latent state from default unrolling graph to accordingly distribute (scaled) gradients.
        s_decoupled = scale_gradient(s, self.net_args.dynamics_penalty)
        o_t = self.neural_net.decoder(s_decoupled)

        # Note: Root can be a terminal state. Loss scale for the root head is 1.0 instead of 1 / K.
        predictions = [(1.0, o_t, v_0, 0, pi_0)]
        for k in range(actions.shape[1]):
            r, s, pi, v = self.neural_net.recurrent([s, actions[:, k, :]])

            # Decouple latent state from default unrolling graph to accordingly distribute (scaled) gradients.
            s_decoupled = scale_gradient(s, self.net_args.dynamics_penalty)
            o_k = self.neural_net.decoder(s_decoupled)

            predictions.append((1.0 / actions.shape[1], o_k, v, r, pi))

            # Scale the gradient at the start of the dynamics function by 1/2
            s = scale_gradient(s, 0.5)

        return predictions
예제 #2
0
    def loss_function(self, observations, actions, target_vs, target_rs, target_pis, target_observations,
                      sample_weights) -> typing.Tuple[tf.Tensor, typing.List]:
        """
        Overrides super function to compute the loss for decoding the unrolled latent-states back to true future
        observations.

        :param observations: tf.Tensor in R^(batch_size x width x height x (depth * time)). Stacked state observations.
        :param actions: tf.Tensor in {0, 1}^(batch_size x K x |action_space|). One-hot encoded actions for unrolling.
        :param target_vs: tf.Tensor either in [0,1] or R with dimensions (K x batch_size x support_size)
        :param target_rs: tf.Tensor either in [0,1] or R with dimensions (K x batch_size x support_size)
        :param target_pis: tf.Tensor either in [0,1] or R with dimensions (K x batch_size x |action_space|)
        :param target_observations: tf.Tensor of same dimensions of observations for each unroll step in axis 1.
        :param sample_weights: tf.Tensor in [0, 1]^(batch_size). Of the form (batch_size * priority) ^ (-beta)
        :return: tuple of a tf.Tensor and a list of tf.Tensors containing the total loss and piecewise losses.
        :see: MuNeuralNet.unroll
        """
        loss_monitor = []  # Collect losses for logging.

        # Sum over target probabilities. Absorbing states should have a zero sum --> leaf node.
        absorb_k = 1.0 - tf.reduce_sum(target_pis, axis=-1)

        # Root inference. Collect predictions of the form: [w_i / K, s, v, r, pi] for each forward step k = 0...K
        predictions = self.unroll(observations, actions)

        # Perform loss computation for each unrolling step.
        total_loss = tf.constant(0.0, dtype=tf.float32)
        for k in range(len(predictions)):  # Length = 1 + K (root + hypothetical forward steps)
            loss_scale, p_obs, vs, rs, pis = predictions[k]
            t_vs, t_rs, t_pis = target_vs[k, ...], target_rs[k, ...], target_pis[k, ...]
            absorb = absorb_k[k, :]

            # Decoder target observations
            t_obs = observations if k == 0 else target_observations[:, (k - 1), ...]

            # Calculate losses per head. Cancel gradients in prior for absorbing states, keep gradients for r and v.
            r_loss = scalar_loss(rs, t_rs) if (k > 0 and self.fit_rewards) else tf.constant(0, dtype=tf.float32)
            v_loss = scalar_loss(vs, t_vs)
            pi_loss = scalar_loss(pis, t_pis) * (1.0 - absorb)
            o_loss = tf.reduce_mean(tf.keras.losses.mean_squared_error(t_obs, p_obs), axis=(1, 2))

            step_loss = scale_gradient(r_loss + v_loss + pi_loss + o_loss, loss_scale * sample_weights)
            total_loss += tf.reduce_sum(step_loss)  # Actually averages over batch : see sample_weights.

            # Logging
            loss_monitor.append((v_loss, r_loss, pi_loss, absorb, o_loss))

        # Penalize magnitude of weights using l2 norm
        l2_norm = tf.reduce_sum([safe_l2norm(x) for x in self.get_variables()])
        total_loss += self.net_args.l2 * l2_norm

        return total_loss, loss_monitor
예제 #3
0
    def unroll(self, observations: tf.Tensor, actions: tf.Tensor) -> \
            typing.List[typing.Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]]:
        """
        Build up a computation graph that collects output tensors from recurrently unrolling the MuZero model.

        After each recurrent unrolling, the graph contains a layer that halves reverse differentiated gradients.

        :param observations: tf.Tensor in R^(batch_size x width x height x (depth * time))
        :param actions: tf.Tensor consisting of one-hot-encoded actions in {0, 1}^(batch_size x K x |action_space|)
        :return: List of tuples containing the hidden state, value predictions and loss-scale for each unrolled step.
        """
        # Root inference. Collect predictions of the form: [w_i / K, s, v, r, pi] for each forward step k = 0...K
        s, pi_0, v_0 = self.neural_net.forward(observations)

        # Note: Root can be a terminal state. Loss scale for the root head is 1.0 instead of 1 / K.
        predictions = [(1.0, s, v_0, 0, pi_0)]
        for k in range(actions.shape[1]):
            r, s, pi, v = self.neural_net.recurrent([s, actions[:, k, :]])
            predictions.append((1.0 / actions.shape[1], s, v, r, pi))

            # Scale the gradient at the start of the dynamics function by 1/2
            s = scale_gradient(s, 0.5)

        return predictions
예제 #4
0
    def loss_function(self, observations, actions, target_vs, target_rs, target_pis,
                      target_observations, sample_weights) -> typing.Tuple[tf.Tensor, typing.List]:
        """
        Defines the computation graph for computing the loss of a MuZero model given data.

        The function recurrently unrolls the MuZero neural network based on data trajectories.
        From the collected output tensors, this function aggregates the loss for each prediction head for
        each unrolled time step k = 0, 1, ..., K.

        We expect target_pis/ MCTS probability vectors extrapolated beyond terminal states (i.e., no valid search
        statistics) to be a zero-vector. This is important as we infer the unrolling beyond terminal states by
        summing the MCTS probability vectors assuming that they should define proper distributions.

        For unrolled states beyond terminal environment states, we cancel the gradient for the probability vector.
        We keep gradients for the value and reward prediction, so that they learn to recognize terminal states
        during MCTS search. Note that the root state could be provided as a terminal state, this would mean that
        the probability vector head would receive zero gradient for the entire unrolling.

        If specified, the dynamics function will receive a slight differentiable penalty based on the
        target_observations and the predicted latent state by the encoder network.

        :param observations: tf.Tensor in R^(batch_size x width x height x (depth * time)). Stacked state observations.
        :param actions: tf.Tensor in {0, 1}^(batch_size x K x |action_space|). One-hot encoded actions for unrolling.
        :param target_vs: tf.Tensor either in [0,1] or R with dimensions (K x batch_size x support_size)
        :param target_rs: tf.Tensor either in [0,1] or R with dimensions (K x batch_size x support_size)
        :param target_pis: tf.Tensor either in [0,1] or R with dimensions (K x batch_size x |action_space|)
        :param target_observations: tf.Tensor of same dimensions of observations for each unroll step in axis 1.
        :param sample_weights: tf.Tensor in [0, 1]^(batch_size). Of the form (batch_size * priority) ^ (-beta)
        :return: tuple of a tf.Tensor and a list of tf.Tensors containing the total loss and piecewise losses.
        :see: MuNeuralNet.unroll
        """
        loss_monitor = []  # Collect losses for logging.

        # Sum over target probabilities. Absorbing states should have a zero sum --> leaf node.
        absorb_k = 1.0 - tf.reduce_sum(target_pis, axis=-1)

        # Root inference. Collect predictions of the form: [w_i / K, s, v, r, pi] for each forward step k = 0...K
        predictions = self.unroll(observations, actions)

        # Perform loss computation for each unrolling step.
        total_loss = tf.constant(0.0, dtype=tf.float32)
        for k in range(len(predictions)):  # Length = 1 + K (root + hypothetical forward steps)
            loss_scale, states, vs, rs, pis = predictions[k]
            t_vs, t_rs, t_pis = target_vs[k, ...], target_rs[k, ...], target_pis[k, ...]
            absorb = absorb_k[k, :]

            # Calculate losses per head. Cancel gradients in prior for absorbing states, keep gradients for r and v.
            r_loss = scalar_loss(rs, t_rs) if (k > 0 and self.fit_rewards) else tf.constant(0, dtype=tf.float32)
            v_loss = scalar_loss(vs, t_vs)
            pi_loss = scalar_loss(pis, t_pis) * (1.0 - absorb)

            step_loss = scale_gradient(r_loss + v_loss + pi_loss, loss_scale * sample_weights)
            total_loss += tf.reduce_sum(step_loss)  # Actually averages over batch : see sample_weights.

            # If specified, slightly regularize the dynamics model using the discrepancy between the abstract state
            # predicted by the dynamics model with the encoder. This penalty should be low to emphasize
            # value prediction, but may aid stability of learning.
            if self.net_args.dynamics_penalty > 0 and k > 0:
                # Infer latent states as predicted by the encoder and cancel the gradients for the encoder
                encoded_states = self.neural_net.encoder(target_observations[:, (k - 1), ...])
                encoded_states = tf.stop_gradient(encoded_states)

                contrastive_loss = tf.reduce_mean(tf.keras.losses.mean_squared_error(states, encoded_states),
                                                  axis=-1)
                contrastive_loss = scale_gradient(contrastive_loss, loss_scale * sample_weights)

                total_loss += self.net_args.dynamics_penalty * tf.reduce_sum(contrastive_loss)

            # Logging
            loss_monitor.append((v_loss, r_loss, pi_loss, absorb))

        # Penalize magnitude of weights using l2 norm
        l2_norm = tf.reduce_sum([safe_l2norm(x) for x in self.get_variables()])
        total_loss += self.net_args.l2 * l2_norm

        return total_loss, loss_monitor