示例#1
0
    def log_metrics(self,
                    data_logger: DataLogger,
                    step: int,
                    full_log: bool = False):
        data_logger.log_value("loss/actor", self._loss_actor, step)
        data_logger.log_value("loss/critic", self._loss_critic, step)
        policy_params = {
            str(i): v
            for i, v in enumerate(
                itertools.chain.from_iterable(self.policy.parameters()))
        }
        data_logger.log_values_dict("policy/param", policy_params, step)

        data_logger.create_histogram('metric/batch_errors',
                                     self._metric_batch_error, step)
        data_logger.create_histogram('metric/batch_value_dist',
                                     self._metric_batch_value_dist, step)

        if full_log:
            dist = self._display_dist
            z_atoms = self.critic.z_atoms
            z_delta = self.critic.z_delta
            data_logger.add_histogram('dist/dist_value',
                                      min=z_atoms[0],
                                      max=z_atoms[-1],
                                      num=self.num_atoms,
                                      sum=dist.sum(),
                                      sum_squares=dist.pow(2).sum(),
                                      bucket_limits=z_atoms + z_delta,
                                      bucket_counts=dist,
                                      global_step=step)
示例#2
0
    def log_metrics(self, data_logger: DataLogger, step: int, full_log: bool=False):
        data_logger.log_value("loss/agent", self._loss, step)

        if full_log and self.dist_probs is not None:
            for action_idx in range(self.action_size):
                dist = self.dist_probs[0, action_idx]
                data_logger.log_value(f'dist/expected_{action_idx}', (dist*self.z_atoms).sum().item(), step)
                data_logger.add_histogram(
                    f'dist/Q_{action_idx}', min=self.z_atoms[0], max=self.z_atoms[-1], num=len(self.z_atoms),
                    sum=dist.sum(), sum_squares=dist.pow(2).sum(), bucket_limits=self.z_atoms+self.z_delta,
                    bucket_counts=dist, global_step=step
                )

        # This method, `log_metrics`, isn't executed on every iteration but just in case we delay plotting weights.
        # It simply might be quite costly. Thread wisely.
        if full_log:
            for idx, layer in enumerate(self.net.value_net.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"value_net/layer_weights_{idx}", layer.weight.cpu(), step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"value_net/layer_bias_{idx}", layer.bias.cpu(), step)
            for idx, layer in enumerate(self.net.advantage_net.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"advantage_net/layer_{idx}", layer.weight.cpu(), step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"advantage_net/layer_bias_{idx}", layer.bias.cpu(), step)