def test_restarts_after_interruption(self): # Run some initial training first. self.run_single_mode( TEST_MODE_1, target_local_step=TEST_NUM_STEPS, target_global_step=TEST_NUM_STEPS ) global_step = TEST_NUM_STEPS restarter = Restarter( TEST_MODE_2, self.out_dir, target_local_step=2 ) with self.assertRaises(RuntimeError): global_step += 1 with restarter.training_loop(): self.create_checkpoint(global_step) # Simulate training interruption after the first step. raise RuntimeError restarter = Restarter( TEST_MODE_2, self.out_dir, target_local_step=2 ) self.assertFalse(restarter.should_skip) self.assertTrue(restarter.restarting) # Training should resume after the first step. self.assertEqual(restarter.steps_to_go, 1)
def train_world_model( env, data_dir, output_dir, hparams, world_model_steps_num, epoch ): """Train the world model on problem_name.""" world_model_steps_num += world_model_step_increment( hparams, is_initial_epoch=(epoch == 0) ) model_hparams = trainer_lib.create_hparams(hparams.generative_model_params) model_hparams.learning_rate = model_hparams.learning_rate_constant if epoch > 0: model_hparams.learning_rate *= hparams.learning_rate_bump if hparams.wm_policy_param_sharing: model_hparams.optimizer_zero_grads = True restarter = Restarter("world_model", output_dir, world_model_steps_num) if restarter.should_skip: return world_model_steps_num with restarter.training_loop(): train_supervised( problem=env, model_name=hparams.generative_model, hparams=model_hparams, data_dir=data_dir, output_dir=output_dir, train_steps=restarter.target_global_step, eval_steps=100, local_eval_frequency=2000 ) return world_model_steps_num
def test_runs_in_single_mode(self): restarter = Restarter(TEST_MODE_1, self.out_dir, target_local_step=TEST_NUM_STEPS) self.assert_first_run(restarter, steps_to_go=TEST_NUM_STEPS, target_global_step=TEST_NUM_STEPS)
def test_runs_in_two_modes(self): global_step = TEST_NUM_STEPS local_steps = {TEST_MODE_1: TEST_NUM_STEPS, TEST_MODE_2: 0} self.run_single_mode(TEST_MODE_1, local_steps[TEST_MODE_1], global_step) for mode in [TEST_MODE_2, TEST_MODE_1]: global_step += TEST_NUM_STEPS local_steps[mode] += TEST_NUM_STEPS restarter = Restarter(mode, self.out_dir, target_local_step=local_steps[mode]) self.assert_first_run(restarter, steps_to_go=TEST_NUM_STEPS, target_global_step=global_step) with restarter.training_loop(): self.create_checkpoint(global_step)
def test_skips_already_done(self): self.run_single_mode(TEST_MODE_1, target_local_step=TEST_NUM_STEPS, target_global_step=TEST_NUM_STEPS) restarter = Restarter(TEST_MODE_1, self.out_dir, target_local_step=TEST_NUM_STEPS) # We should skip the training as those steps are already completed. self.assertTrue(restarter.should_skip)
def run_single_mode(self, mode, target_local_step, target_global_step): restarter = Restarter(mode, self.out_dir, target_local_step) with restarter.training_loop(): self.create_checkpoint(target_global_step)
def train(self, env_fn, hparams, simulated, save_continuously, epoch, sampling_temp=1.0, num_env_steps=None, env_step_multiplier=1, eval_env_fn=None, report_fn=None, model_save_fn=None): assert sampling_temp == 1.0 or hparams.learning_rate == 0.0, \ "Sampling with non-1 temperature does not make sense during training." if not save_continuously: # We do not save model, as that resets frames that we need at restarts. # But we need to save at the last step, so we set it very high. hparams.save_models_every_epochs = 1000000 if simulated: simulated_str = "sim" else: simulated_str = "real" name_scope = "ppo_{}{}".format(simulated_str, epoch + 1) event_dir = os.path.join(self.base_event_dir, "ppo_summaries", str(epoch) + simulated_str) with tf.Graph().as_default(): with tf.name_scope(name_scope): with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): env = env_fn(in_graph=True) (train_summary_op, eval_summary_op, initializers) = (_define_train( env, hparams, eval_env_fn, sampling_temp, distributional_size=self._distributional_size, distributional_subscale=self._distributional_subscale, distributional_threshold=self. _distributional_threshold, epoch=epoch if simulated else -1, frame_stack_size=self.frame_stack_size, force_beginning_resets=simulated)) if num_env_steps is None: iteration_increment = hparams.epochs_num else: iteration_increment = int( math.ceil(num_env_steps / (env.batch_size * hparams.epoch_length))) iteration_increment *= env_step_multiplier self._num_completed_iterations += iteration_increment restarter = Restarter("policy", self.agent_model_dir, self._num_completed_iterations) if restarter.should_skip: return if hparams.lr_decay_in_final_epoch: if epoch != self.total_num_epochs - 1: # Extend the warmup period to the end of this epoch. hparams.learning_rate_warmup_steps = restarter.target_global_step else: if self._lr_decay_start is None: # Stop the warmup at the beginning of this epoch. self._lr_decay_start = \ restarter.target_global_step - iteration_increment hparams.learning_rate_warmup_steps = self._lr_decay_start _run_train(hparams, event_dir, self.agent_model_dir, restarter, train_summary_op, eval_summary_op, initializers, epoch, report_fn=report_fn, model_save_fn=model_save_fn)