def train_world_model( env, data_dir, output_dir, hparams, world_model_steps_num, epoch ): """Train the world model on problem_name.""" world_model_steps_num += world_model_step_increment( hparams, is_initial_epoch=(epoch == 0) ) model_hparams = trainer_lib.create_hparams(hparams.generative_model_params) model_hparams.learning_rate = model_hparams.learning_rate_constant if epoch > 0: model_hparams.learning_rate *= hparams.learning_rate_bump if hparams.wm_policy_param_sharing: model_hparams.optimizer_zero_grads = True restarter = Restarter("world_model", output_dir, world_model_steps_num) if restarter.should_skip: return world_model_steps_num with restarter.training_loop(): train_supervised( problem=env, model_name=hparams.generative_model, hparams=model_hparams, data_dir=data_dir, output_dir=output_dir, train_steps=restarter.target_global_step, eval_steps=100, local_eval_frequency=2000 ) return world_model_steps_num
def test_restarts_after_interruption(self): # Run some initial training first. self.run_single_mode( TEST_MODE_1, target_local_step=TEST_NUM_STEPS, target_global_step=TEST_NUM_STEPS ) global_step = TEST_NUM_STEPS restarter = Restarter( TEST_MODE_2, self.out_dir, target_local_step=2 ) with self.assertRaises(RuntimeError): global_step += 1 with restarter.training_loop(): self.create_checkpoint(global_step) # Simulate training interruption after the first step. raise RuntimeError restarter = Restarter( TEST_MODE_2, self.out_dir, target_local_step=2 ) self.assertFalse(restarter.should_skip) self.assertTrue(restarter.restarting) # Training should resume after the first step. self.assertEqual(restarter.steps_to_go, 1)
def test_runs_in_two_modes(self): global_step = TEST_NUM_STEPS local_steps = {TEST_MODE_1: TEST_NUM_STEPS, TEST_MODE_2: 0} self.run_single_mode(TEST_MODE_1, local_steps[TEST_MODE_1], global_step) for mode in [TEST_MODE_2, TEST_MODE_1]: global_step += TEST_NUM_STEPS local_steps[mode] += TEST_NUM_STEPS restarter = Restarter(mode, self.out_dir, target_local_step=local_steps[mode]) self.assert_first_run(restarter, steps_to_go=TEST_NUM_STEPS, target_global_step=global_step) with restarter.training_loop(): self.create_checkpoint(global_step)
def run_single_mode(self, mode, target_local_step, target_global_step): restarter = Restarter(mode, self.out_dir, target_local_step) with restarter.training_loop(): self.create_checkpoint(target_global_step)