コード例 #1
0
ファイル: AEMuZero.py プロジェクト: Tubbz-alt/muzero-2
    def loss_function(self, observations, actions, target_vs, target_rs, target_pis, target_observations,
                      sample_weights) -> typing.Tuple[tf.Tensor, typing.List]:
        """
        Overrides super function to compute the loss for decoding the unrolled latent-states back to true future
        observations.

        :param observations: tf.Tensor in R^(batch_size x width x height x (depth * time)). Stacked state observations.
        :param actions: tf.Tensor in {0, 1}^(batch_size x K x |action_space|). One-hot encoded actions for unrolling.
        :param target_vs: tf.Tensor either in [0,1] or R with dimensions (K x batch_size x support_size)
        :param target_rs: tf.Tensor either in [0,1] or R with dimensions (K x batch_size x support_size)
        :param target_pis: tf.Tensor either in [0,1] or R with dimensions (K x batch_size x |action_space|)
        :param target_observations: tf.Tensor of same dimensions of observations for each unroll step in axis 1.
        :param sample_weights: tf.Tensor in [0, 1]^(batch_size). Of the form (batch_size * priority) ^ (-beta)
        :return: tuple of a tf.Tensor and a list of tf.Tensors containing the total loss and piecewise losses.
        :see: MuNeuralNet.unroll
        """
        loss_monitor = []  # Collect losses for logging.

        # Sum over target probabilities. Absorbing states should have a zero sum --> leaf node.
        absorb_k = 1.0 - tf.reduce_sum(target_pis, axis=-1)

        # Root inference. Collect predictions of the form: [w_i / K, s, v, r, pi] for each forward step k = 0...K
        predictions = self.unroll(observations, actions)

        # Perform loss computation for each unrolling step.
        total_loss = tf.constant(0.0, dtype=tf.float32)
        for k in range(len(predictions)):  # Length = 1 + K (root + hypothetical forward steps)
            loss_scale, p_obs, vs, rs, pis = predictions[k]
            t_vs, t_rs, t_pis = target_vs[k, ...], target_rs[k, ...], target_pis[k, ...]
            absorb = absorb_k[k, :]

            # Decoder target observations
            t_obs = observations if k == 0 else target_observations[:, (k - 1), ...]

            # Calculate losses per head. Cancel gradients in prior for absorbing states, keep gradients for r and v.
            r_loss = scalar_loss(rs, t_rs) if (k > 0 and self.fit_rewards) else tf.constant(0, dtype=tf.float32)
            v_loss = scalar_loss(vs, t_vs)
            pi_loss = scalar_loss(pis, t_pis) * (1.0 - absorb)
            o_loss = tf.reduce_mean(tf.keras.losses.mean_squared_error(t_obs, p_obs), axis=(1, 2))

            step_loss = scale_gradient(r_loss + v_loss + pi_loss + o_loss, loss_scale * sample_weights)
            total_loss += tf.reduce_sum(step_loss)  # Actually averages over batch : see sample_weights.

            # Logging
            loss_monitor.append((v_loss, r_loss, pi_loss, absorb, o_loss))

        # Penalize magnitude of weights using l2 norm
        l2_norm = tf.reduce_sum([safe_l2norm(x) for x in self.get_variables()])
        total_loss += self.net_args.l2 * l2_norm

        return total_loss, loss_monitor
コード例 #2
0
    def log_batch(self, data_batch: typing.List) -> None:
        """
        Log a large amount of neural network statistics based on the given batch.
        Functionality can be toggled on by specifying '--debug' as a console argument to Main.py.
        Note: toggling this functionality on will produce significantly larger tensorboard event files!

        Statistics include:
         - Priority sampling sample probabilities.
         - Loss of each recurrent head per sample as a distribution.
         - Loss discrepancy between cross-entropy and MSE for the reward/ value predictions.
         - Norm of the neural network's weights.
         - Divergence between the dynamics and encoder functions.
         - Squared error of the decoding function.
        """
        if DEBUG_MODE and self.reference.steps % LOG_RATE == 0:
            observations, actions, targets, forward_observations, sample_weight = list(zip(*data_batch))
            actions, sample_weight = np.asarray(actions), np.asarray(sample_weight)
            target_vs, target_rs, target_pis = list(map(np.asarray, zip(*targets)))

            priority = sample_weight * len(data_batch)  # Undo 1/n scaling to get priority
            tf.summary.histogram(f"sample probability", data=priority, step=self.reference.steps)

            s, pi, v = self.reference.neural_net.forward.predict_on_batch(np.asarray(observations))

            v_real = support_to_scalar(v, self.reference.net_args.support_size).ravel()

            tf.summary.histogram(f"v_predict_{0}", data=v_real, step=self.reference.steps)
            tf.summary.histogram(f"v_target_{0}", data=target_vs[:, 0], step=self.reference.steps)
            tf.summary.scalar(f"v_mse_{0}", data=np.mean((v_real - target_vs[:, 0]) ** 2), step=self.reference.steps)

            # Sum over one-hot-encoded actions. If this sum is zero, then there is no action --> leaf node.
            absorb_k = 1.0 - tf.reduce_sum(target_pis, axis=-1)

            collect = list()
            for k in range(actions.shape[1]):
                r, s, pi, v = self.reference.neural_net.recurrent.predict_on_batch([s, actions[:, k, :]])

                collect.append((s, v, r, pi, absorb_k[k, :]))

            for t, (s, v, r, pi, absorb) in enumerate(collect):
                k = t + 1

                pi_loss = -np.sum(target_pis[:, k] * np.log(pi + 1e-8), axis=-1)
                self.log_distribution(pi_loss, f"pi_dist_{k}")

                v_real = support_to_scalar(v, self.reference.net_args.support_size).ravel()
                r_real = support_to_scalar(r, self.reference.net_args.support_size).ravel()

                self.log_distribution(r_real, f"r_predict_{k}")
                self.log_distribution(v_real, f"v_predict_{k}")

                self.log_distribution(target_rs[:, k], f"r_target_{k}")
                self.log_distribution(target_vs[:, k], f"v_target_{k}")

                self.log(np.mean((r_real - target_rs[:, k]) ** 2), f"r_mse_{k}")
                self.log(np.mean((v_real - target_vs[:, k]) ** 2), f"v_mse_{k}")

            l2_norm = tf.reduce_sum([safe_l2norm(x) for x in self.reference.get_variables()])
            self.log(l2_norm, "l2 norm")

            # Option to track statistical properties of the dynamics model.
            if self.reference.net_args.dynamics_penalty > 0:
                forward_observations = np.asarray(forward_observations)
                # Compute statistics related to auto-encoding state dynamics:
                for t, (s, v, r, pi, absorb) in enumerate(collect):
                    k = t + 1
                    stacked_obs = forward_observations[:, t, ...]

                    s_enc = self.reference.neural_net.encoder.predict_on_batch(stacked_obs)
                    kl_divergence = tf.keras.losses.kullback_leibler_divergence(s_enc, s)

                    # Relative entropy of dynamics model and encoder.
                    # Lower values indicate that the prediction model receives more stable input.
                    self.log_distribution(kl_divergence, f"KL_Divergence_{k}")
                    self.log(np.mean(kl_divergence), f"Mean_KLDivergence_{k}")

                    # Internal entropy of the dynamics model
                    s_entropy = tf.keras.losses.categorical_crossentropy(s, s)
                    self.log(np.mean(s_entropy), f"mean_dynamics_entropy_{k}")

                    if hasattr(self.reference.neural_net, "decoder"):
                        # If available, track the performance of a neural decoder from latent to real state.
                        stacked_obs_predict = self.reference.neural_net.decoder.predict_on_batch(s)
                        se = (stacked_obs - stacked_obs_predict) ** 2

                        self.log(np.mean(se), f"decoder_error_{k}")
コード例 #3
0
    def loss_function(self, observations, actions, target_vs, target_rs, target_pis,
                      target_observations, sample_weights) -> typing.Tuple[tf.Tensor, typing.List]:
        """
        Defines the computation graph for computing the loss of a MuZero model given data.

        The function recurrently unrolls the MuZero neural network based on data trajectories.
        From the collected output tensors, this function aggregates the loss for each prediction head for
        each unrolled time step k = 0, 1, ..., K.

        We expect target_pis/ MCTS probability vectors extrapolated beyond terminal states (i.e., no valid search
        statistics) to be a zero-vector. This is important as we infer the unrolling beyond terminal states by
        summing the MCTS probability vectors assuming that they should define proper distributions.

        For unrolled states beyond terminal environment states, we cancel the gradient for the probability vector.
        We keep gradients for the value and reward prediction, so that they learn to recognize terminal states
        during MCTS search. Note that the root state could be provided as a terminal state, this would mean that
        the probability vector head would receive zero gradient for the entire unrolling.

        If specified, the dynamics function will receive a slight differentiable penalty based on the
        target_observations and the predicted latent state by the encoder network.

        :param observations: tf.Tensor in R^(batch_size x width x height x (depth * time)). Stacked state observations.
        :param actions: tf.Tensor in {0, 1}^(batch_size x K x |action_space|). One-hot encoded actions for unrolling.
        :param target_vs: tf.Tensor either in [0,1] or R with dimensions (K x batch_size x support_size)
        :param target_rs: tf.Tensor either in [0,1] or R with dimensions (K x batch_size x support_size)
        :param target_pis: tf.Tensor either in [0,1] or R with dimensions (K x batch_size x |action_space|)
        :param target_observations: tf.Tensor of same dimensions of observations for each unroll step in axis 1.
        :param sample_weights: tf.Tensor in [0, 1]^(batch_size). Of the form (batch_size * priority) ^ (-beta)
        :return: tuple of a tf.Tensor and a list of tf.Tensors containing the total loss and piecewise losses.
        :see: MuNeuralNet.unroll
        """
        loss_monitor = []  # Collect losses for logging.

        # Sum over target probabilities. Absorbing states should have a zero sum --> leaf node.
        absorb_k = 1.0 - tf.reduce_sum(target_pis, axis=-1)

        # Root inference. Collect predictions of the form: [w_i / K, s, v, r, pi] for each forward step k = 0...K
        predictions = self.unroll(observations, actions)

        # Perform loss computation for each unrolling step.
        total_loss = tf.constant(0.0, dtype=tf.float32)
        for k in range(len(predictions)):  # Length = 1 + K (root + hypothetical forward steps)
            loss_scale, states, vs, rs, pis = predictions[k]
            t_vs, t_rs, t_pis = target_vs[k, ...], target_rs[k, ...], target_pis[k, ...]
            absorb = absorb_k[k, :]

            # Calculate losses per head. Cancel gradients in prior for absorbing states, keep gradients for r and v.
            r_loss = scalar_loss(rs, t_rs) if (k > 0 and self.fit_rewards) else tf.constant(0, dtype=tf.float32)
            v_loss = scalar_loss(vs, t_vs)
            pi_loss = scalar_loss(pis, t_pis) * (1.0 - absorb)

            step_loss = scale_gradient(r_loss + v_loss + pi_loss, loss_scale * sample_weights)
            total_loss += tf.reduce_sum(step_loss)  # Actually averages over batch : see sample_weights.

            # If specified, slightly regularize the dynamics model using the discrepancy between the abstract state
            # predicted by the dynamics model with the encoder. This penalty should be low to emphasize
            # value prediction, but may aid stability of learning.
            if self.net_args.dynamics_penalty > 0 and k > 0:
                # Infer latent states as predicted by the encoder and cancel the gradients for the encoder
                encoded_states = self.neural_net.encoder(target_observations[:, (k - 1), ...])
                encoded_states = tf.stop_gradient(encoded_states)

                contrastive_loss = tf.reduce_mean(tf.keras.losses.mean_squared_error(states, encoded_states),
                                                  axis=-1)
                contrastive_loss = scale_gradient(contrastive_loss, loss_scale * sample_weights)

                total_loss += self.net_args.dynamics_penalty * tf.reduce_sum(contrastive_loss)

            # Logging
            loss_monitor.append((v_loss, r_loss, pi_loss, absorb))

        # Penalize magnitude of weights using l2 norm
        l2_norm = tf.reduce_sum([safe_l2norm(x) for x in self.get_variables()])
        total_loss += self.net_args.l2 * l2_norm

        return total_loss, loss_monitor