Esempio n. 1
0
  def test_restarts_after_interruption(self):
    # Run some initial training first.
    self.run_single_mode(
        TEST_MODE_1, target_local_step=TEST_NUM_STEPS,
        target_global_step=TEST_NUM_STEPS
    )
    global_step = TEST_NUM_STEPS

    restarter = Restarter(
        TEST_MODE_2, self.out_dir, target_local_step=2
    )
    with self.assertRaises(RuntimeError):
      global_step += 1
      with restarter.training_loop():
        self.create_checkpoint(global_step)
        # Simulate training interruption after the first step.
        raise RuntimeError
    restarter = Restarter(
        TEST_MODE_2, self.out_dir, target_local_step=2
    )

    self.assertFalse(restarter.should_skip)
    self.assertTrue(restarter.restarting)
    # Training should resume after the first step.
    self.assertEqual(restarter.steps_to_go, 1)
Esempio n. 2
0
def train_world_model(
    env, data_dir, output_dir, hparams, world_model_steps_num, epoch
):
  """Train the world model on problem_name."""
  world_model_steps_num += world_model_step_increment(
      hparams, is_initial_epoch=(epoch == 0)
  )
  model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
  model_hparams.learning_rate = model_hparams.learning_rate_constant
  if epoch > 0:
    model_hparams.learning_rate *= hparams.learning_rate_bump
  if hparams.wm_policy_param_sharing:
    model_hparams.optimizer_zero_grads = True

  restarter = Restarter("world_model", output_dir, world_model_steps_num)
  if restarter.should_skip:
    return world_model_steps_num
  with restarter.training_loop():
    train_supervised(
        problem=env,
        model_name=hparams.generative_model,
        hparams=model_hparams,
        data_dir=data_dir,
        output_dir=output_dir,
        train_steps=restarter.target_global_step,
        eval_steps=100,
        local_eval_frequency=2000
    )

  return world_model_steps_num
Esempio n. 3
0
 def test_runs_in_single_mode(self):
     restarter = Restarter(TEST_MODE_1,
                           self.out_dir,
                           target_local_step=TEST_NUM_STEPS)
     self.assert_first_run(restarter,
                           steps_to_go=TEST_NUM_STEPS,
                           target_global_step=TEST_NUM_STEPS)
Esempio n. 4
0
    def test_runs_in_two_modes(self):
        global_step = TEST_NUM_STEPS
        local_steps = {TEST_MODE_1: TEST_NUM_STEPS, TEST_MODE_2: 0}
        self.run_single_mode(TEST_MODE_1, local_steps[TEST_MODE_1],
                             global_step)

        for mode in [TEST_MODE_2, TEST_MODE_1]:
            global_step += TEST_NUM_STEPS
            local_steps[mode] += TEST_NUM_STEPS
            restarter = Restarter(mode,
                                  self.out_dir,
                                  target_local_step=local_steps[mode])
            self.assert_first_run(restarter,
                                  steps_to_go=TEST_NUM_STEPS,
                                  target_global_step=global_step)
            with restarter.training_loop():
                self.create_checkpoint(global_step)
Esempio n. 5
0
    def test_skips_already_done(self):
        self.run_single_mode(TEST_MODE_1,
                             target_local_step=TEST_NUM_STEPS,
                             target_global_step=TEST_NUM_STEPS)

        restarter = Restarter(TEST_MODE_1,
                              self.out_dir,
                              target_local_step=TEST_NUM_STEPS)
        # We should skip the training as those steps are already completed.
        self.assertTrue(restarter.should_skip)
Esempio n. 6
0
 def run_single_mode(self, mode, target_local_step, target_global_step):
     restarter = Restarter(mode, self.out_dir, target_local_step)
     with restarter.training_loop():
         self.create_checkpoint(target_global_step)
Esempio n. 7
0
    def train(self,
              env_fn,
              hparams,
              simulated,
              save_continuously,
              epoch,
              sampling_temp=1.0,
              num_env_steps=None,
              env_step_multiplier=1,
              eval_env_fn=None,
              report_fn=None,
              model_save_fn=None):
        assert sampling_temp == 1.0 or hparams.learning_rate == 0.0, \
            "Sampling with non-1 temperature does not make sense during training."

        if not save_continuously:
            # We do not save model, as that resets frames that we need at restarts.
            # But we need to save at the last step, so we set it very high.
            hparams.save_models_every_epochs = 1000000

        if simulated:
            simulated_str = "sim"
        else:
            simulated_str = "real"
        name_scope = "ppo_{}{}".format(simulated_str, epoch + 1)
        event_dir = os.path.join(self.base_event_dir, "ppo_summaries",
                                 str(epoch) + simulated_str)

        with tf.Graph().as_default():
            with tf.name_scope(name_scope):
                with tf.variable_scope(tf.get_variable_scope(),
                                       reuse=tf.AUTO_REUSE):
                    env = env_fn(in_graph=True)
                    (train_summary_op, eval_summary_op,
                     initializers) = (_define_train(
                         env,
                         hparams,
                         eval_env_fn,
                         sampling_temp,
                         distributional_size=self._distributional_size,
                         distributional_subscale=self._distributional_subscale,
                         distributional_threshold=self.
                         _distributional_threshold,
                         epoch=epoch if simulated else -1,
                         frame_stack_size=self.frame_stack_size,
                         force_beginning_resets=simulated))

                if num_env_steps is None:
                    iteration_increment = hparams.epochs_num
                else:
                    iteration_increment = int(
                        math.ceil(num_env_steps /
                                  (env.batch_size * hparams.epoch_length)))
                iteration_increment *= env_step_multiplier

                self._num_completed_iterations += iteration_increment

                restarter = Restarter("policy", self.agent_model_dir,
                                      self._num_completed_iterations)
                if restarter.should_skip:
                    return

                if hparams.lr_decay_in_final_epoch:
                    if epoch != self.total_num_epochs - 1:
                        # Extend the warmup period to the end of this epoch.
                        hparams.learning_rate_warmup_steps = restarter.target_global_step
                    else:
                        if self._lr_decay_start is None:
                            # Stop the warmup at the beginning of this epoch.
                            self._lr_decay_start = \
                                restarter.target_global_step - iteration_increment
                        hparams.learning_rate_warmup_steps = self._lr_decay_start

                _run_train(hparams,
                           event_dir,
                           self.agent_model_dir,
                           restarter,
                           train_summary_op,
                           eval_summary_op,
                           initializers,
                           epoch,
                           report_fn=report_fn,
                           model_save_fn=model_save_fn)