def _make_schedule( self, history, start_lr=1e-3, observation_metrics=(("eval", "metrics/accuracy"),), action_multipliers=(1.0,), ): policy_and_value_model = atari_cnn.FrameStackMLP net = ppo.policy_and_value_net( n_actions=len(action_multipliers), n_controls=1, bottom_layers_fn=policy_and_value_model, two_towers=False, ) rng = jax_random.get_prng(seed=0) obs_dim = len(observation_metrics) (params, state) = net.initialize((1, 1, obs_dim), np.float32, rng) policy_dir = self.get_temp_dir() # Optimizer slots should not be used for anything. slots = None opt_state = (params, slots) ppo.save_opt_state(policy_dir, opt_state, state, epoch=0, total_opt_step=0) return learning_rate.PolicySchedule( history, observation_metrics=observation_metrics, include_lr_in_observation=False, action_multipliers=action_multipliers, start_lr=start_lr, policy_and_value_model=policy_and_value_model, policy_and_value_two_towers=False, policy_dir=policy_dir, )
def test_saves_and_restores_opt_state(self): opt_state = 123 state = 456 epoch = 7 opt_step = 89 output_dir = self.get_temp_dir() ppo.save_opt_state(output_dir, opt_state, state, epoch, opt_step) restored_data = ppo.maybe_restore_opt_state(output_dir) self.assertEqual(restored_data, (opt_state, state, epoch, opt_step))
def save(self): """Save the agent parameters.""" logging.vlog(1, "PPO epoch [% 6d]: saving model.", self._epoch) ppo.save_opt_state( self._output_dir, self._policy_and_value_opt_state, self._model_state, self._epoch, self._total_opt_step, ) # Reset this number. self._n_trajectories_done = 0 self._last_saved_at = self._epoch