def test_saves_and_restores_opt_state(self): opt_state = 123 state = 456 epoch = 7 opt_step = 89 output_dir = self.get_temp_dir() ppo.save_opt_state(output_dir, opt_state, state, epoch, opt_step) restored_data = ppo.maybe_restore_opt_state(output_dir) self.assertEqual(restored_data, (opt_state, state, epoch, opt_step))
def _make_schedule( self, history, control_configs, observation_metrics=(('eval', 'metrics/accuracy'), ), action_multipliers=(1.0, ), vocab_size=None, ): policy_and_value_model = functools.partial( transformer.TransformerDecoder, d_model=2, d_ff=2, n_layers=0, vocab_size=vocab_size, ) net = ppo.policy_and_value_net( n_actions=len(action_multipliers), n_controls=len(control_configs), vocab_size=None, bottom_layers_fn=policy_and_value_model, two_towers=False, ) obs_dim = len(observation_metrics) if vocab_size is None: shape = (1, 1, obs_dim) dtype = np.float32 else: shape = (1, 1) dtype = np.int32 input_signature = ShapeDtype(shape, dtype) (params, state) = net.init(input_signature) policy_dir = self.get_temp_dir() # Optimizer slots and parameters should not be used for anything. slots = None opt_params = None opt_state = (params, slots, opt_params) ppo.save_opt_state(policy_dir, opt_state, state, epoch=0, total_opt_step=0, history=history) return lr_schedules.PolicySchedule( history, observation_metrics=observation_metrics, include_controls_in_observation=False, action_multipliers=action_multipliers, control_configs=control_configs, policy_and_value_model=policy_and_value_model, policy_and_value_two_towers=False, policy_and_value_vocab_size=vocab_size, policy_dir=policy_dir, )
def save(self): """Save the agent parameters.""" logging.vlog(1, "PPO epoch [% 6d]: saving model.", self._epoch) ppo.save_opt_state( self._output_dir, self._policy_and_value_opt_state, self._model_state, self._epoch, self._total_opt_step, ) # Reset this number. self._n_trajectories_done = 0 self._last_saved_at = self._epoch
def save(self): """Save the agent parameters.""" if not self._should_save_checkpoints: return logging.vlog(1, 'PPO epoch [% 6d]: saving model.', self._epoch) ppo.save_opt_state( self._output_dir, self._policy_and_value_opt_state, self._model_state, self._epoch, self._total_opt_step, self._history, ) # Reset this number. self._n_trajectories_done = 0 self._last_saved_at = self._epoch
def _make_schedule( self, history, control_configs, observation_metrics=(("eval", "metrics/accuracy"), ), action_multipliers=(1.0, ), ): policy_and_value_model = atari_cnn.FrameStackMLP net = ppo.policy_and_value_net( n_actions=len(action_multipliers), n_controls=len(control_configs), vocab_size=None, bottom_layers_fn=policy_and_value_model, two_towers=False, ) rng = jax_random.get_prng(seed=0) obs_dim = len(observation_metrics) (params, state) = net.initialize_once((1, 1, obs_dim), np.float32, rng) policy_dir = self.get_temp_dir() # Optimizer slots should not be used for anything. slots = None opt_state = (params, slots) ppo.save_opt_state(policy_dir, opt_state, state, epoch=0, total_opt_step=0) return learning_rate.PolicySchedule( history, observation_metrics=observation_metrics, include_controls_in_observation=False, action_multipliers=action_multipliers, control_configs=control_configs, policy_and_value_model=policy_and_value_model, policy_and_value_two_towers=False, policy_dir=policy_dir, )