def _make_schedule( self, history, control_configs, observation_metrics=(('eval', 'metrics/accuracy'), ), action_multipliers=(1.0, ), vocab_size=None, ): policy_and_value_model = functools.partial( transformer.TransformerDecoder, d_model=2, d_ff=2, n_layers=0, vocab_size=vocab_size, ) observation_space = gym.spaces.Box( shape=(len(observation_metrics), ), low=0.0, high=1.0, ) action_space = gym.spaces.MultiDiscrete( nvec=(len(action_multipliers), ) * len(control_configs)) (net, _) = policy_based_utils.policy_and_value_net( bottom_layers_fn=policy_and_value_model, observation_space=observation_space, action_space=action_space, vocab_size=vocab_size, two_towers=False, ) input_signature = ( shapes.ShapeDtype((1, 2) + observation_space.shape, observation_space.dtype), shapes.ShapeDtype((1, 1) + action_space.shape, action_space.dtype), ) (params, state) = net.init(input_signature) policy_dir = self.get_temp_dir() # Optimizer slots and parameters should not be used for anything. slots = None opt_params = None opt_state = (params, slots, opt_params) policy_based_utils.save_opt_state(policy_dir, opt_state, state, epoch=0, total_opt_step=0, history=history) return lr_schedules.PolicySchedule( history, observation_metrics=observation_metrics, include_controls_in_observation=False, action_multipliers=action_multipliers, control_configs=control_configs, policy_and_value_model=policy_and_value_model, policy_and_value_two_towers=False, policy_and_value_vocab_size=vocab_size, policy_dir=policy_dir, )
def test_saves_and_restores_opt_state(self): opt_state = 123 state = 456 epoch = 7 opt_step = 89 history = 0 output_dir = self.get_temp_dir() policy_based_utils.save_opt_state(output_dir, opt_state, state, epoch, opt_step, history) restored_data = policy_based_utils.maybe_restore_opt_state(output_dir) self.assertEqual(restored_data, (opt_state, state, epoch, opt_step, history))
def save(self): """Save the agent parameters.""" if not self._should_save_checkpoints: return logging.vlog(1, 'PolicyBasedTrainer epoch [% 6d]: saving model.', self.epoch) policy_based_utils.save_opt_state( self._output_dir, self._policy_and_value_opt_state, self._model_state, self.epoch, self._total_opt_step, self._history, ) # Reset this number. self._n_trajectories_done_since_last_save = 0 self._last_saved_at_epoch = self.epoch