def run(self): """Execute the train/eval loop.""" with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto( allow_soft_placement=True)) as sess: # Initialize the graph. self._initialize_graph(sess) # Initial collect self._initial_collect() while self._iteration_metric.result() < self._num_iterations: # Train phase env_steps = 0 for metric in self._train_phase_metrics: metric.reset() while env_steps < self._train_steps_per_iteration: env_steps += self._run_episode( sess, self._train_metrics + self._train_phase_metrics, train=True) for metric in self._train_phase_metrics: log_metric(metric, prefix='Train/Metrics') py_metric.run_summaries(self._train_phase_metrics + [self._iteration_metric]) global_step_val = sess.run(self._global_step) if self._do_eval: # Eval phase env_steps = 0 for metric in self._eval_metrics: metric.reset() while env_steps < self._eval_steps_per_iteration: env_steps += self._run_episode(sess, self._eval_metrics, train=False) py_metric.run_summaries(self._eval_metrics + [self._iteration_metric]) if self._eval_metrics_callback: results = dict((metric.name, metric.result()) for metric in self._eval_metrics) self._eval_metrics_callback( results, global_step_val) for metric in self._eval_metrics: log_metric(metric, prefix='Eval/Metrics') self._iteration_metric() self._train_checkpointer.save(global_step=global_step_val) self._policy_checkpointer.save(global_step=global_step_val) self._rb_checkpointer.save(global_step=global_step_val) export_dir = os.path.join( self._train_dir, 'saved_policy', 'step_' + ('%d' % global_step_val).zfill(8)) self._policy_exporter.save(export_dir) common.save_spec( self._collect_policy.trajectory_spec, os.path.join(export_dir, 'trajectory_spec'))
def test_save_and_load(self): spec = { 'spec_1': tensor_spec.TensorSpec((2, 3), tf.int32), 'bounded_spec_1': tensor_spec.BoundedTensorSpec((2, 3), tf.float32, -10, 10), 'bounded_spec_2': tensor_spec.BoundedTensorSpec((2, 3), tf.int8, -10, -10), 'bounded_array_spec_3': tensor_spec.BoundedTensorSpec((2, ), tf.int32, [-10, -10], [10, 10]), 'bounded_array_spec_4': tensor_spec.BoundedTensorSpec((2, ), tf.float16, [-10, -9], [10, 9]), 'dict_spec': { 'spec_2': tensor_spec.TensorSpec((2, 3), tf.float32), 'bounded_spec_2': tensor_spec.BoundedTensorSpec((2, 3), tf.int16, -10, 10) }, 'tuple_spec': ( tensor_spec.TensorSpec((2, 3), tf.int32), tensor_spec.BoundedTensorSpec((2, 3), tf.float64, -10, 10), ), 'list_spec': [ tensor_spec.TensorSpec((2, 3), tf.int64), (tensor_spec.TensorSpec((2, 3), tf.float32), tensor_spec.BoundedTensorSpec((2, 3), tf.float32, -10, 10)), ], } spec_save_path = os.path.join(flags.FLAGS.test_tmpdir, 'spec.tfrecord') common.save_spec(spec, spec_save_path) loaded_spec_nest = common.load_spec(spec_save_path) self.assertAllEqual(sorted(spec.keys()), sorted(loaded_spec_nest.keys())) for expected_spec, loaded_spec in zip( tf.nest.flatten(spec), tf.nest.flatten(loaded_spec_nest)): self.assertAllEqual(expected_spec.shape, loaded_spec.shape) self.assertEqual(expected_spec.dtype, loaded_spec.dtype)