Example #1
0
 def test_saves_and_restores_opt_state(self):
     opt_state = 123
     state = 456
     epoch = 7
     opt_step = 89
     output_dir = self.get_temp_dir()
     ppo.save_opt_state(output_dir, opt_state, state, epoch, opt_step)
     restored_data = ppo.maybe_restore_opt_state(output_dir)
     self.assertEqual(restored_data, (opt_state, state, epoch, opt_step))
Example #2
0
 def _make_schedule(
         self,
         history,
         control_configs,
         observation_metrics=(('eval', 'metrics/accuracy'), ),
         action_multipliers=(1.0, ),
         vocab_size=None,
 ):
     policy_and_value_model = functools.partial(
         transformer.TransformerDecoder,
         d_model=2,
         d_ff=2,
         n_layers=0,
         vocab_size=vocab_size,
     )
     net = ppo.policy_and_value_net(
         n_actions=len(action_multipliers),
         n_controls=len(control_configs),
         vocab_size=None,
         bottom_layers_fn=policy_and_value_model,
         two_towers=False,
     )
     obs_dim = len(observation_metrics)
     if vocab_size is None:
         shape = (1, 1, obs_dim)
         dtype = np.float32
     else:
         shape = (1, 1)
         dtype = np.int32
     input_signature = ShapeDtype(shape, dtype)
     (params, state) = net.init(input_signature)
     policy_dir = self.get_temp_dir()
     # Optimizer slots and parameters should not be used for anything.
     slots = None
     opt_params = None
     opt_state = (params, slots, opt_params)
     ppo.save_opt_state(policy_dir,
                        opt_state,
                        state,
                        epoch=0,
                        total_opt_step=0,
                        history=history)
     return lr_schedules.PolicySchedule(
         history,
         observation_metrics=observation_metrics,
         include_controls_in_observation=False,
         action_multipliers=action_multipliers,
         control_configs=control_configs,
         policy_and_value_model=policy_and_value_model,
         policy_and_value_two_towers=False,
         policy_and_value_vocab_size=vocab_size,
         policy_dir=policy_dir,
     )
Example #3
0
 def save(self):
     """Save the agent parameters."""
     logging.vlog(1, "PPO epoch [% 6d]: saving model.", self._epoch)
     ppo.save_opt_state(
         self._output_dir,
         self._policy_and_value_opt_state,
         self._model_state,
         self._epoch,
         self._total_opt_step,
     )
     # Reset this number.
     self._n_trajectories_done = 0
     self._last_saved_at = self._epoch
Example #4
0
    def save(self):
        """Save the agent parameters."""
        if not self._should_save_checkpoints:
            return

        logging.vlog(1, 'PPO epoch [% 6d]: saving model.', self._epoch)
        ppo.save_opt_state(
            self._output_dir,
            self._policy_and_value_opt_state,
            self._model_state,
            self._epoch,
            self._total_opt_step,
            self._history,
        )
        # Reset this number.
        self._n_trajectories_done = 0
        self._last_saved_at = self._epoch
Example #5
0
 def _make_schedule(
         self,
         history,
         control_configs,
         observation_metrics=(("eval", "metrics/accuracy"), ),
         action_multipliers=(1.0, ),
 ):
     policy_and_value_model = atari_cnn.FrameStackMLP
     net = ppo.policy_and_value_net(
         n_actions=len(action_multipliers),
         n_controls=len(control_configs),
         vocab_size=None,
         bottom_layers_fn=policy_and_value_model,
         two_towers=False,
     )
     rng = jax_random.get_prng(seed=0)
     obs_dim = len(observation_metrics)
     (params, state) = net.initialize_once((1, 1, obs_dim), np.float32, rng)
     policy_dir = self.get_temp_dir()
     # Optimizer slots should not be used for anything.
     slots = None
     opt_state = (params, slots)
     ppo.save_opt_state(policy_dir,
                        opt_state,
                        state,
                        epoch=0,
                        total_opt_step=0)
     return learning_rate.PolicySchedule(
         history,
         observation_metrics=observation_metrics,
         include_controls_in_observation=False,
         action_multipliers=action_multipliers,
         control_configs=control_configs,
         policy_and_value_model=policy_and_value_model,
         policy_and_value_two_towers=False,
         policy_dir=policy_dir,
     )