Exemplo n.º 1
0
 def test_runs_policy_non_serialized(self):
   n_timesteps = 5
   n_controls = 3
   n_actions = 2
   obs_shape = (2, 3)
   lengths = np.array([2, 3])
   input_observations = np.random.uniform(
       0, 1, size=((2, n_timesteps) + obs_shape)
   )
   expected_log_probs = np.random.uniform(
       0, 1, size=(2, n_controls, n_actions)
   )
   expected_values = np.random.uniform(0, 1, size=(2, n_controls))
   def mock_apply(observations, *unused_args, **unused_kwargs):
     np.testing.assert_array_equal(observations, input_observations)
     start_indices = (lengths - 1) * n_controls
     return self._make_log_prob_and_value_seqs(
         expected_log_probs, expected_values, start_indices, n_timesteps
     )
   observation_space = gym.spaces.Box(shape=obs_shape, low=0, high=1)
   action_space = gym.spaces.MultiDiscrete(nvec=((n_actions,) * n_controls))
   (log_probs, values, _, _) = ppo.run_policy(
       mock_apply,
       observations=input_observations,
       lengths=lengths,
       **self._make_run_policy_kwargs(
           observation_space, action_space, n_timesteps, vocab_size=None
       )
   )
   np.testing.assert_array_equal(log_probs, expected_log_probs)
   np.testing.assert_array_equal(values, expected_values)
Exemplo n.º 2
0
 def _policy_fun(self, observations, lengths, state, rng):
     return ppo.run_policy(
         self._policy_and_value_net_apply,
         observations,
         lengths,
         self._policy_and_value_net_weights,
         state,
         rng,
         self._policy_and_value_vocab_size,
         self.train_env.observation_space,
         self.train_env.action_space,
         self._rewards_to_actions,
     )
    def test_runs_policy_serialized(self):
        precision = 2
        gin.bind_parameter('BoxSpaceSerializer.precision', precision)
        n_timesteps = 5
        n_controls = 3
        n_actions = 2
        obs_length = 4
        obs_shape = (obs_length, )
        lengths = np.array([2, 3])
        input_observations = np.random.uniform(0,
                                               1,
                                               size=((2, n_timesteps) +
                                                     obs_shape))
        expected_log_probs = np.random.uniform(0,
                                               1,
                                               size=(2, n_controls, n_actions))
        expected_values = np.random.uniform(0, 1, size=(2, n_controls))

        def mock_apply(observations, *unused_args, **unused_kwargs):
            step_repr_length = obs_length * precision + n_controls
            n_symbols = n_timesteps * step_repr_length
            self.assertEqual(observations.shape, (2, n_symbols))
            start_indices = (lengths -
                             1) * step_repr_length + obs_length * precision
            return self._make_log_prob_and_value_seqs(expected_log_probs,
                                                      expected_values,
                                                      start_indices, n_symbols)

        observation_space = gym.spaces.Box(shape=obs_shape, low=0, high=1)
        action_space = gym.spaces.MultiDiscrete(nvec=((n_actions, ) *
                                                      n_controls))
        (log_probs, values, _,
         _) = ppo.run_policy(mock_apply,
                             observations=input_observations,
                             lengths=lengths,
                             **self._make_run_policy_kwargs(observation_space,
                                                            action_space,
                                                            n_timesteps,
                                                            vocab_size=6))
        np.testing.assert_array_equal(log_probs, expected_log_probs)
        np.testing.assert_array_equal(values, expected_values)