def test_serialize_deserialize(self): policy_1 = ContinuousUniformPolicy( action_range=( self.env.action_space.low, self.env.action_space.high, ), input_shapes=self.env.observation_shape, output_shape=self.env.action_shape, observation_keys=self.env.observation_keys) self.assertFalse(policy_1.trainable_weights) config = policies.serialize(policy_1) policy_2 = policies.deserialize(config) self.assertEqual(policy_2._action_range, policy_1._action_range) self.assertEqual(policy_2._input_shapes, policy_1._input_shapes) self.assertEqual(policy_2._output_shape, policy_1._output_shape) self.assertEqual(policy_2._observation_keys, policy_1._observation_keys) path = sampler_utils.rollout(self.env, policy_2, path_length=10, break_on_terminal=False) observations = path['observations'] np.testing.assert_equal( policy_1.actions(observations).numpy().shape, policy_2.actions(observations).numpy().shape)
def test_serialize_deserialize(self): policy_1 = FeedforwardGaussianPolicy( input_shapes=self.env.observation_shape, output_shape=self.env.action_space.shape, action_range=( self.env.action_space.low, self.env.action_space.high, ), hidden_layer_sizes=self.hidden_layer_sizes, observation_keys=self.env.observation_keys) path = sampler_utils.rollout(self.env, policy_1, path_length=10, break_on_terminal=False) observations = path['observations'] weights_1 = policy_1.get_weights() actions_1 = policy_1.actions(observations) log_pis_1 = policy_1.log_probs(observations, actions_1) config = policies.serialize(policy_1) policy_2 = policies.deserialize(config) policy_2.set_weights(policy_1.get_weights()) weights_2 = policy_2.get_weights() log_pis_2 = policy_2.log_probs(observations, actions_1) for weight_1, weight_2 in zip(weights_1, weights_2): np.testing.assert_array_equal(weight_1, weight_2) np.testing.assert_array_equal(log_pis_1, log_pis_2) np.testing.assert_equal(actions_1.shape, policy_2.actions(observations).shape)
def test_nested_serializable_fn(self): def serializable_fn(x): """A serializable function to pass out of a test layer's config.""" return x class SerializableNestedInt(int): """A serializable object containing a serializable function.""" def __new__(cls, value, fn): obj = int.__new__(cls, value) obj.fn = fn return obj def get_config(self): return {'value': int(self), 'fn': self.fn} @classmethod def from_config(cls, config): return cls(**config) policy = policies.ContinuousUniformPolicy( action_range=( [SerializableNestedInt(-1, serializable_fn)], [SerializableNestedInt(1, serializable_fn)], ), input_shapes={'what': tf.TensorShape((3, ))}, output_shape=(1, ), observation_keys=None) config = policies.serialize(policy) new_policy = policies.deserialize(config, custom_objects={ 'serializable_fn': serializable_fn, 'SerializableNestedInt': SerializableNestedInt }) self.assertEqual(new_policy._action_range, policy._action_range) self.assertEqual(new_policy._input_shapes, policy._input_shapes) self.assertIsInstance(new_policy._input_shapes['what'], tf.TensorShape) self.assertEqual(new_policy._output_shape, policy._output_shape) self.assertEqual(new_policy._observation_keys, policy._observation_keys) for action_bound in new_policy._action_range: for element in action_bound: self.assertIsInstance(element, SerializableNestedInt) self.assertIs(element.fn, serializable_fn)
def test_serializable_object(self): class SerializableInt(int): """A serializable object to pass out of a test layer's config.""" def __new__(cls, value): return int.__new__(cls, value) def get_config(self): return {'value': int(self)} @classmethod def from_config(cls, config): return cls(**config) policy = policies.ContinuousUniformPolicy( action_range=([SerializableInt(-1)], [SerializableInt(1)]), input_shapes={'what': tf.TensorShape((3, ))}, output_shape=(1, ), observation_keys=None, name='SerializableNestedInt') config = policies.serialize(policy) new_policy = policies.deserialize(config, custom_objects={ 'SerializableInt': SerializableInt, }) self.assertEqual(new_policy._action_range, policy._action_range) self.assertEqual(new_policy._input_shapes, policy._input_shapes) self.assertIsInstance(new_policy._input_shapes['what'], tf.TensorShape) self.assertEqual(new_policy._output_shape, policy._output_shape) self.assertEqual(new_policy._observation_keys, policy._observation_keys) for action_bound in new_policy._action_range: for element in action_bound: self.assertIsInstance(element, SerializableInt)
def test_nested_serializable_object(self): class SerializableInt(int): """A serializable object to pass out of a test layer's config.""" def __new__(cls, value): return int.__new__(cls, value) def get_config(self): return {'value': int(self)} @classmethod def from_config(cls, config): return cls(**config) class SerializableNestedInt(int): """A serializable object containing another serializable object.""" def __new__(cls, value, int_obj): obj = int.__new__(cls, value) obj.int_obj = int_obj return obj def get_config(self): return {'value': int(self), 'int_obj': self.int_obj} @classmethod def from_config(cls, config): return cls(**config) nested_int = SerializableInt(4) policy = policies.ContinuousUniformPolicy( action_range=( [SerializableNestedInt(-1, nested_int)], [SerializableNestedInt(1, nested_int)], ), input_shapes={'what': tf.TensorShape((3, ))}, output_shape=(1, ), observation_keys=None, name='SerializableNestedInt') config = policies.serialize(policy) new_policy = policies.deserialize(config, custom_objects={ 'SerializableInt': SerializableInt, 'SerializableNestedInt': SerializableNestedInt }) # Make sure the string field doesn't get convert to custom object, even # they have same value. self.assertEqual(new_policy.name, 'SerializableNestedInt') self.assertEqual(new_policy._action_range, policy._action_range) self.assertEqual(new_policy._input_shapes, policy._input_shapes) self.assertIsInstance(new_policy._input_shapes['what'], tf.TensorShape) self.assertEqual(new_policy._output_shape, policy._output_shape) self.assertEqual(new_policy._observation_keys, policy._observation_keys) for action_bound in new_policy._action_range: for element in action_bound: self.assertIsInstance(element, SerializableNestedInt) self.assertIsInstance(element.int_obj, SerializableInt) self.assertEqual(element.int_obj, 4)