예제 #1
0
 def _make_schedule(
         self,
         history,
         control_configs,
         observation_metrics=(('eval', 'metrics/accuracy'), ),
         action_multipliers=(1.0, ),
         vocab_size=None,
 ):
     policy_and_value_model = functools.partial(
         transformer.TransformerDecoder,
         d_model=2,
         d_ff=2,
         n_layers=0,
         vocab_size=vocab_size,
     )
     observation_space = gym.spaces.Box(
         shape=(len(observation_metrics), ),
         low=0.0,
         high=1.0,
     )
     action_space = gym.spaces.MultiDiscrete(
         nvec=(len(action_multipliers), ) * len(control_configs))
     (net, _) = policy_based_utils.policy_and_value_net(
         bottom_layers_fn=policy_and_value_model,
         observation_space=observation_space,
         action_space=action_space,
         vocab_size=vocab_size,
         two_towers=False,
     )
     input_signature = (
         shapes.ShapeDtype((1, 2) + observation_space.shape,
                           observation_space.dtype),
         shapes.ShapeDtype((1, 1) + action_space.shape, action_space.dtype),
     )
     (params, state) = net.init(input_signature)
     policy_dir = self.get_temp_dir()
     # Optimizer slots and parameters should not be used for anything.
     slots = None
     opt_params = None
     opt_state = (params, slots, opt_params)
     policy_based_utils.save_opt_state(policy_dir,
                                       opt_state,
                                       state,
                                       epoch=0,
                                       total_opt_step=0,
                                       history=history)
     return lr_schedules.PolicySchedule(
         history,
         observation_metrics=observation_metrics,
         include_controls_in_observation=False,
         action_multipliers=action_multipliers,
         control_configs=control_configs,
         policy_and_value_model=policy_and_value_model,
         policy_and_value_two_towers=False,
         policy_and_value_vocab_size=vocab_size,
         policy_dir=policy_dir,
     )
예제 #2
0
 def test_saves_and_restores_opt_state(self):
     opt_state = 123
     state = 456
     epoch = 7
     opt_step = 89
     history = 0
     output_dir = self.get_temp_dir()
     policy_based_utils.save_opt_state(output_dir, opt_state, state, epoch,
                                       opt_step, history)
     restored_data = policy_based_utils.maybe_restore_opt_state(output_dir)
     self.assertEqual(restored_data,
                      (opt_state, state, epoch, opt_step, history))
예제 #3
0
  def save(self):
    """Save the agent parameters."""
    if not self._should_save_checkpoints:
      return

    logging.vlog(1, 'PolicyBasedTrainer epoch [% 6d]: saving model.',
                 self.epoch)
    policy_based_utils.save_opt_state(
        self._output_dir,
        self._policy_and_value_opt_state,
        self._model_state,
        self.epoch,
        self._total_opt_step,
        self._history,
    )
    # Reset this number.
    self._n_trajectories_done_since_last_save = 0
    self._last_saved_at_epoch = self.epoch