Esempio n. 1
0
 def _make_schedule(
     self,
     history,
     start_lr=1e-3,
     observation_metrics=(("eval", "metrics/accuracy"),),
     action_multipliers=(1.0,),
 ):
   policy_and_value_model = atari_cnn.FrameStackMLP
   net = ppo.policy_and_value_net(
       n_actions=len(action_multipliers),
       n_controls=1,
       bottom_layers_fn=policy_and_value_model,
       two_towers=False,
   )
   rng = jax_random.get_prng(seed=0)
   obs_dim = len(observation_metrics)
   (params, state) = net.initialize((1, 1, obs_dim), np.float32, rng)
   policy_dir = self.get_temp_dir()
   # Optimizer slots should not be used for anything.
   slots = None
   opt_state = (params, slots)
   ppo.save_opt_state(policy_dir, opt_state, state, epoch=0, total_opt_step=0)
   return learning_rate.PolicySchedule(
       history,
       observation_metrics=observation_metrics,
       include_lr_in_observation=False,
       action_multipliers=action_multipliers,
       start_lr=start_lr,
       policy_and_value_model=policy_and_value_model,
       policy_and_value_two_towers=False,
       policy_dir=policy_dir,
   )
Esempio n. 2
0
 def test_saves_and_restores_opt_state(self):
     opt_state = 123
     state = 456
     epoch = 7
     opt_step = 89
     output_dir = self.get_temp_dir()
     ppo.save_opt_state(output_dir, opt_state, state, epoch, opt_step)
     restored_data = ppo.maybe_restore_opt_state(output_dir)
     self.assertEqual(restored_data, (opt_state, state, epoch, opt_step))
 def save(self):
     """Save the agent parameters."""
     logging.vlog(1, "PPO epoch [% 6d]: saving model.", self._epoch)
     ppo.save_opt_state(
         self._output_dir,
         self._policy_and_value_opt_state,
         self._model_state,
         self._epoch,
         self._total_opt_step,
     )
     # Reset this number.
     self._n_trajectories_done = 0
     self._last_saved_at = self._epoch