def _write_dict_summaries(dictionary, writer, name, env_steps):
        for d, key, value in iterate_recursively(dictionary):
            if isinstance(value, bool):
                value = int(value)

            if isinstance(value, (int, float)):
                writer.add_scalar(f'zz_pbt/{name}_{key}', value, env_steps)
            elif isinstance(value, (tuple, list)):
                for i, tuple_value in enumerate(value):
                    writer.add_scalar(f'zz_pbt/{name}_{key}_{i}', tuple_value, env_steps)
            else:
                log.error('Unsupported type in pbt summaries %r', type(value))
Пример #2
0
    def process_report(self, report):
        """Process stats from various types of workers."""

        if 'policy_id' in report:
            policy_id = report['policy_id']

            if 'learner_env_steps' in report:
                if policy_id in self.env_steps:
                    delta = report['learner_env_steps'] - self.env_steps[
                        policy_id]
                    self.total_env_steps_since_resume += delta
                self.env_steps[policy_id] = report['learner_env_steps']

            if 'episodic' in report:
                s = report['episodic']
                for _, key, value in iterate_recursively(s):
                    if key not in self.policy_avg_stats:
                        self.policy_avg_stats[key] = [
                            deque(maxlen=self.cfg.stats_avg)
                            for _ in range(self.cfg.num_policies)
                        ]

                    self.policy_avg_stats[key][policy_id].append(value)

                    for extra_stat_func in EXTRA_EPISODIC_STATS_PROCESSING:
                        extra_stat_func(policy_id, key, value, self.cfg)

            if 'train' in report:
                self.report_train_summaries(report['train'], policy_id)

            if 'samples' in report:
                self.samples_collected[policy_id] += report['samples']

        if 'timing' in report:
            for k, v in report['timing'].items():
                if k not in self.avg_stats:
                    self.avg_stats[k] = deque([], maxlen=50)
                self.avg_stats[k].append(v)

        if 'stats' in report:
            self.stats.update(report['stats'])
Пример #3
0
def ensure_memory_shared(*tensors):
    """To prevent programming errors, ensure all tensors are in shared memory."""
    for tensor_dict in tensors:
        for _, _, t in iterate_recursively(tensor_dict):
            assert t.is_shared()