def test_runs_policy_non_serialized(self): n_timesteps = 5 n_controls = 3 n_actions = 2 obs_shape = (2, 3) lengths = np.array([2, 3]) input_observations = np.random.uniform( 0, 1, size=((2, n_timesteps) + obs_shape) ) expected_log_probs = np.random.uniform( 0, 1, size=(2, n_controls, n_actions) ) expected_values = np.random.uniform(0, 1, size=(2, n_controls)) def mock_apply(observations, *unused_args, **unused_kwargs): np.testing.assert_array_equal(observations, input_observations) start_indices = (lengths - 1) * n_controls return self._make_log_prob_and_value_seqs( expected_log_probs, expected_values, start_indices, n_timesteps ) observation_space = gym.spaces.Box(shape=obs_shape, low=0, high=1) action_space = gym.spaces.MultiDiscrete(nvec=((n_actions,) * n_controls)) (log_probs, values, _, _) = ppo.run_policy( mock_apply, observations=input_observations, lengths=lengths, **self._make_run_policy_kwargs( observation_space, action_space, n_timesteps, vocab_size=None ) ) np.testing.assert_array_equal(log_probs, expected_log_probs) np.testing.assert_array_equal(values, expected_values)
def _policy_fun(self, observations, lengths, state, rng): return ppo.run_policy( self._policy_and_value_net_apply, observations, lengths, self._policy_and_value_net_weights, state, rng, self._policy_and_value_vocab_size, self.train_env.observation_space, self.train_env.action_space, self._rewards_to_actions, )
def test_runs_policy_serialized(self): precision = 2 gin.bind_parameter('BoxSpaceSerializer.precision', precision) n_timesteps = 5 n_controls = 3 n_actions = 2 obs_length = 4 obs_shape = (obs_length, ) lengths = np.array([2, 3]) input_observations = np.random.uniform(0, 1, size=((2, n_timesteps) + obs_shape)) expected_log_probs = np.random.uniform(0, 1, size=(2, n_controls, n_actions)) expected_values = np.random.uniform(0, 1, size=(2, n_controls)) def mock_apply(observations, *unused_args, **unused_kwargs): step_repr_length = obs_length * precision + n_controls n_symbols = n_timesteps * step_repr_length self.assertEqual(observations.shape, (2, n_symbols)) start_indices = (lengths - 1) * step_repr_length + obs_length * precision return self._make_log_prob_and_value_seqs(expected_log_probs, expected_values, start_indices, n_symbols) observation_space = gym.spaces.Box(shape=obs_shape, low=0, high=1) action_space = gym.spaces.MultiDiscrete(nvec=((n_actions, ) * n_controls)) (log_probs, values, _, _) = ppo.run_policy(mock_apply, observations=input_observations, lengths=lengths, **self._make_run_policy_kwargs(observation_space, action_space, n_timesteps, vocab_size=6)) np.testing.assert_array_equal(log_probs, expected_log_probs) np.testing.assert_array_equal(values, expected_values)