def test_serialize_custom_class_with_custom_name(self): @serialization.register_softlearning_serializable( 'TestPackage', 'CustomName') class OtherTestClass(object): def __init__(self, val): self._val = val def get_config(self): return {'val': self._val} serialized_name = 'TestPackage>CustomName' inst = OtherTestClass(val=5) class_name = serialization._GLOBAL_CUSTOM_NAMES[OtherTestClass] self.assertEqual(serialized_name, class_name) fn_class_name = serialization.get_registered_name(OtherTestClass) self.assertEqual(fn_class_name, class_name) cls = serialization.get_registered_object(fn_class_name) self.assertEqual(OtherTestClass, cls) config = serialization.serialize_softlearning_object(inst) self.assertEqual(class_name, config['class_name']) new_inst = serialization.deserialize_softlearning_object(config) self.assertIsNot(inst, new_inst) self.assertIsInstance(new_inst, OtherTestClass) self.assertEqual(5, new_inst._val)
def test_serialize_custom_class_with_default_name(self): @serialization.register_softlearning_serializable() class TestClass(object): def __init__(self, value): self._value = value def get_config(self): return {'value': self._value} serialized_name = 'Custom>TestClass' inst = TestClass(value=10) class_name = serialization._GLOBAL_CUSTOM_NAMES[TestClass] self.assertEqual(serialized_name, class_name) config = serialization.serialize_softlearning_object(inst) self.assertEqual(class_name, config['class_name']) new_inst = serialization.deserialize_softlearning_object(config) self.assertIsNot(inst, new_inst) self.assertIsInstance(new_inst, TestClass) self.assertEqual(10, new_inst._value) # Make sure registering a new class with same name will fail. with self.assertRaisesRegex(ValueError, ".*has already been registered.*"): @serialization.register_softlearning_serializable() # pylint: disable=function-redefined class TestClass(object): def __init__(self, value): self._value = value def get_config(self): return {'value': self._value}
def test_serialize_custom_function(self): @serialization.register_softlearning_serializable() def my_fn(): return 42 serialized_name = 'Custom>my_fn' class_name = serialization._GLOBAL_CUSTOM_NAMES[my_fn] self.assertEqual(serialized_name, class_name) fn_class_name = serialization.get_registered_name(my_fn) self.assertEqual(fn_class_name, class_name) config = serialization.serialize_softlearning_object(my_fn) self.assertEqual(class_name, config) fn = serialization.deserialize_softlearning_object(config) self.assertEqual(42, fn()) fn_2 = serialization.get_registered_object(fn_class_name) self.assertEqual(42, fn_2())
def test_serialize_none(self): serialized = serialization.serialize_softlearning_object(None) self.assertEqual(serialized, None) deserialized = serialization.deserialize_softlearning_object( serialized) self.assertEqual(deserialized, None)
def serialize(value_function): return serialize_softlearning_object(value_function)
def serialize(replay_pool): return serialize_softlearning_object(replay_pool)
def serialize(algorithm): return serialize_softlearning_object(algorithm)
def serialize(preprocessor): return serialize_softlearning_object(preprocessor)
def serialize(sampler): return serialize_softlearning_object(sampler)
def serialize(policy): return serialize_softlearning_object(policy)