Пример #1
0
def train_world_model(
    env, data_dir, output_dir, hparams, world_model_steps_num, epoch
):
  """Train the world model on problem_name."""
  world_model_steps_num += world_model_step_increment(
      hparams, is_initial_epoch=(epoch == 0)
  )
  model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
  model_hparams.learning_rate = model_hparams.learning_rate_constant
  if epoch > 0:
    model_hparams.learning_rate *= hparams.learning_rate_bump
  if hparams.wm_policy_param_sharing:
    model_hparams.optimizer_zero_grads = True

  restarter = Restarter("world_model", output_dir, world_model_steps_num)
  if restarter.should_skip:
    return world_model_steps_num
  with restarter.training_loop():
    train_supervised(
        problem=env,
        model_name=hparams.generative_model,
        hparams=model_hparams,
        data_dir=data_dir,
        output_dir=output_dir,
        train_steps=restarter.target_global_step,
        eval_steps=100,
        local_eval_frequency=2000
    )

  return world_model_steps_num
Пример #2
0
  def test_restarts_after_interruption(self):
    # Run some initial training first.
    self.run_single_mode(
        TEST_MODE_1, target_local_step=TEST_NUM_STEPS,
        target_global_step=TEST_NUM_STEPS
    )
    global_step = TEST_NUM_STEPS

    restarter = Restarter(
        TEST_MODE_2, self.out_dir, target_local_step=2
    )
    with self.assertRaises(RuntimeError):
      global_step += 1
      with restarter.training_loop():
        self.create_checkpoint(global_step)
        # Simulate training interruption after the first step.
        raise RuntimeError
    restarter = Restarter(
        TEST_MODE_2, self.out_dir, target_local_step=2
    )

    self.assertFalse(restarter.should_skip)
    self.assertTrue(restarter.restarting)
    # Training should resume after the first step.
    self.assertEqual(restarter.steps_to_go, 1)
Пример #3
0
    def test_runs_in_two_modes(self):
        global_step = TEST_NUM_STEPS
        local_steps = {TEST_MODE_1: TEST_NUM_STEPS, TEST_MODE_2: 0}
        self.run_single_mode(TEST_MODE_1, local_steps[TEST_MODE_1],
                             global_step)

        for mode in [TEST_MODE_2, TEST_MODE_1]:
            global_step += TEST_NUM_STEPS
            local_steps[mode] += TEST_NUM_STEPS
            restarter = Restarter(mode,
                                  self.out_dir,
                                  target_local_step=local_steps[mode])
            self.assert_first_run(restarter,
                                  steps_to_go=TEST_NUM_STEPS,
                                  target_global_step=global_step)
            with restarter.training_loop():
                self.create_checkpoint(global_step)
Пример #4
0
 def run_single_mode(self, mode, target_local_step, target_global_step):
     restarter = Restarter(mode, self.out_dir, target_local_step)
     with restarter.training_loop():
         self.create_checkpoint(target_global_step)