Example #1
0
    def test_serialize_custom_class_with_custom_name(self):
        @serialization.register_softlearning_serializable(
            'TestPackage', 'CustomName')
        class OtherTestClass(object):
            def __init__(self, val):
                self._val = val

            def get_config(self):
                return {'val': self._val}

        serialized_name = 'TestPackage>CustomName'
        inst = OtherTestClass(val=5)
        class_name = serialization._GLOBAL_CUSTOM_NAMES[OtherTestClass]
        self.assertEqual(serialized_name, class_name)
        fn_class_name = serialization.get_registered_name(OtherTestClass)
        self.assertEqual(fn_class_name, class_name)

        cls = serialization.get_registered_object(fn_class_name)
        self.assertEqual(OtherTestClass, cls)

        config = serialization.serialize_softlearning_object(inst)
        self.assertEqual(class_name, config['class_name'])
        new_inst = serialization.deserialize_softlearning_object(config)
        self.assertIsNot(inst, new_inst)
        self.assertIsInstance(new_inst, OtherTestClass)
        self.assertEqual(5, new_inst._val)
Example #2
0
    def test_serialize_custom_class_with_default_name(self):
        @serialization.register_softlearning_serializable()
        class TestClass(object):
            def __init__(self, value):
                self._value = value

            def get_config(self):
                return {'value': self._value}

        serialized_name = 'Custom>TestClass'
        inst = TestClass(value=10)
        class_name = serialization._GLOBAL_CUSTOM_NAMES[TestClass]
        self.assertEqual(serialized_name, class_name)
        config = serialization.serialize_softlearning_object(inst)
        self.assertEqual(class_name, config['class_name'])
        new_inst = serialization.deserialize_softlearning_object(config)
        self.assertIsNot(inst, new_inst)
        self.assertIsInstance(new_inst, TestClass)
        self.assertEqual(10, new_inst._value)

        # Make sure registering a new class with same name will fail.
        with self.assertRaisesRegex(ValueError,
                                    ".*has already been registered.*"):

            @serialization.register_softlearning_serializable()  # pylint: disable=function-redefined
            class TestClass(object):
                def __init__(self, value):
                    self._value = value

                def get_config(self):
                    return {'value': self._value}
Example #3
0
    def test_serialize_custom_function(self):
        @serialization.register_softlearning_serializable()
        def my_fn():
            return 42

        serialized_name = 'Custom>my_fn'
        class_name = serialization._GLOBAL_CUSTOM_NAMES[my_fn]
        self.assertEqual(serialized_name, class_name)
        fn_class_name = serialization.get_registered_name(my_fn)
        self.assertEqual(fn_class_name, class_name)

        config = serialization.serialize_softlearning_object(my_fn)
        self.assertEqual(class_name, config)
        fn = serialization.deserialize_softlearning_object(config)
        self.assertEqual(42, fn())

        fn_2 = serialization.get_registered_object(fn_class_name)
        self.assertEqual(42, fn_2())
Example #4
0
 def test_serialize_none(self):
     serialized = serialization.serialize_softlearning_object(None)
     self.assertEqual(serialized, None)
     deserialized = serialization.deserialize_softlearning_object(
         serialized)
     self.assertEqual(deserialized, None)
Example #5
0
def serialize(value_function):
    return serialize_softlearning_object(value_function)
Example #6
0
def serialize(replay_pool):
    return serialize_softlearning_object(replay_pool)
Example #7
0
def serialize(algorithm):
    return serialize_softlearning_object(algorithm)
Example #8
0
def serialize(preprocessor):
    return serialize_softlearning_object(preprocessor)
Example #9
0
def serialize(sampler):
    return serialize_softlearning_object(sampler)
Example #10
0
def serialize(policy):
    return serialize_softlearning_object(policy)