def wrapped_policy_fn(): return serialization_utils.wrap_policy( bottom_layers_fn(**kwargs), observation_space, action_space, vocab_size, )
def test_wrapped_policy_continuous(self, vocab_size): precision = 3 n_controls = 2 n_actions = 4 gin.bind_parameter('BoxSpaceSerializer.precision', precision) obs = onp.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0, 0]]]) act = onp.array([[[0, 1], [2, 0], [1, 3]]]) wrapped_policy = serialization_utils.wrap_policy( TestModel(extra_dim=vocab_size), # pylint: disable=no-value-for-parameter observation_space=gym.spaces.Box(shape=(2, ), low=-2, high=2), action_space=gym.spaces.MultiDiscrete([n_actions] * n_controls), vocab_size=vocab_size, ) example = (obs, act) wrapped_policy.init(shapes.signature(example)) (act_logits, values) = wrapped_policy(example) self.assertEqual(act_logits.shape, act.shape + (n_actions, )) self.assertEqual(values.shape, obs.shape[:2])