Exemple #1
0
        def run(self):
            """Execute the train/eval loop."""
            with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
                    allow_soft_placement=True)) as sess:
                # Initialize the graph.
                self._initialize_graph(sess)

                # Initial collect
                self._initial_collect()

                while self._iteration_metric.result() < self._num_iterations:
                    # Train phase
                    env_steps = 0
                    for metric in self._train_phase_metrics:
                        metric.reset()
                    while env_steps < self._train_steps_per_iteration:
                        env_steps += self._run_episode(
                            sess,
                            self._train_metrics + self._train_phase_metrics,
                            train=True)
                    for metric in self._train_phase_metrics:
                        log_metric(metric, prefix='Train/Metrics')
                    py_metric.run_summaries(self._train_phase_metrics +
                                            [self._iteration_metric])

                    global_step_val = sess.run(self._global_step)

                    if self._do_eval:
                        # Eval phase
                        env_steps = 0
                        for metric in self._eval_metrics:
                            metric.reset()
                        while env_steps < self._eval_steps_per_iteration:
                            env_steps += self._run_episode(sess,
                                                           self._eval_metrics,
                                                           train=False)

                        py_metric.run_summaries(self._eval_metrics +
                                                [self._iteration_metric])
                        if self._eval_metrics_callback:
                            results = dict((metric.name, metric.result())
                                           for metric in self._eval_metrics)
                            self._eval_metrics_callback(
                                results, global_step_val)
                        for metric in self._eval_metrics:
                            log_metric(metric, prefix='Eval/Metrics')

                    self._iteration_metric()

                    self._train_checkpointer.save(global_step=global_step_val)
                    self._policy_checkpointer.save(global_step=global_step_val)
                    self._rb_checkpointer.save(global_step=global_step_val)

                    export_dir = os.path.join(
                        self._train_dir, 'saved_policy',
                        'step_' + ('%d' % global_step_val).zfill(8))
                    self._policy_exporter.save(export_dir)
                    common.save_spec(
                        self._collect_policy.trajectory_spec,
                        os.path.join(export_dir, 'trajectory_spec'))
Exemple #2
0
    def test_save_and_load(self):
        spec = {
            'spec_1':
            tensor_spec.TensorSpec((2, 3), tf.int32),
            'bounded_spec_1':
            tensor_spec.BoundedTensorSpec((2, 3), tf.float32, -10, 10),
            'bounded_spec_2':
            tensor_spec.BoundedTensorSpec((2, 3), tf.int8, -10, -10),
            'bounded_array_spec_3':
            tensor_spec.BoundedTensorSpec((2, ), tf.int32, [-10, -10],
                                          [10, 10]),
            'bounded_array_spec_4':
            tensor_spec.BoundedTensorSpec((2, ), tf.float16, [-10, -9],
                                          [10, 9]),
            'dict_spec': {
                'spec_2':
                tensor_spec.TensorSpec((2, 3), tf.float32),
                'bounded_spec_2':
                tensor_spec.BoundedTensorSpec((2, 3), tf.int16, -10, 10)
            },
            'tuple_spec': (
                tensor_spec.TensorSpec((2, 3), tf.int32),
                tensor_spec.BoundedTensorSpec((2, 3), tf.float64, -10, 10),
            ),
            'list_spec': [
                tensor_spec.TensorSpec((2, 3), tf.int64),
                (tensor_spec.TensorSpec((2, 3), tf.float32),
                 tensor_spec.BoundedTensorSpec((2, 3), tf.float32, -10, 10)),
            ],
        }

        spec_save_path = os.path.join(flags.FLAGS.test_tmpdir, 'spec.tfrecord')
        common.save_spec(spec, spec_save_path)

        loaded_spec_nest = common.load_spec(spec_save_path)

        self.assertAllEqual(sorted(spec.keys()),
                            sorted(loaded_spec_nest.keys()))

        for expected_spec, loaded_spec in zip(
                tf.nest.flatten(spec), tf.nest.flatten(loaded_spec_nest)):
            self.assertAllEqual(expected_spec.shape, loaded_spec.shape)
            self.assertEqual(expected_spec.dtype, loaded_spec.dtype)