def test_serialize_custom_class_with_custom_name(self): @serialization.register_softlearning_serializable( 'TestPackage', 'CustomName') class OtherTestClass(object): def __init__(self, val): self._val = val def get_config(self): return {'val': self._val} serialized_name = 'TestPackage>CustomName' inst = OtherTestClass(val=5) class_name = serialization._GLOBAL_CUSTOM_NAMES[OtherTestClass] self.assertEqual(serialized_name, class_name) fn_class_name = serialization.get_registered_name(OtherTestClass) self.assertEqual(fn_class_name, class_name) cls = serialization.get_registered_object(fn_class_name) self.assertEqual(OtherTestClass, cls) config = serialization.serialize_softlearning_object(inst) self.assertEqual(class_name, config['class_name']) new_inst = serialization.deserialize_softlearning_object(config) self.assertIsNot(inst, new_inst) self.assertIsInstance(new_inst, OtherTestClass) self.assertEqual(5, new_inst._val)
def deserialize(name, custom_objects=None): """Returns a value function or class denoted by input string. Arguments: name : String Returns: Value function function or class denoted by input string. For example: >>> softlearning.value_functions.get('double_feedforward_Q_function') <function double_feedforward_Q_function at 0x7f86e3691e60> >>> softlearning.value_functions.get('abcd') Traceback (most recent call last): ... ValueError: Unknown value function: abcd Args: name: The name of the value function. Raises: ValueError: `Unknown value function` if the input string does not denote any defined value function. """ return deserialize_softlearning_object( name, module_objects=globals(), custom_objects=custom_objects, printable_module_name='value function')
def test_serialize_custom_class_with_default_name(self): @serialization.register_softlearning_serializable() class TestClass(object): def __init__(self, value): self._value = value def get_config(self): return {'value': self._value} serialized_name = 'Custom>TestClass' inst = TestClass(value=10) class_name = serialization._GLOBAL_CUSTOM_NAMES[TestClass] self.assertEqual(serialized_name, class_name) config = serialization.serialize_softlearning_object(inst) self.assertEqual(class_name, config['class_name']) new_inst = serialization.deserialize_softlearning_object(config) self.assertIsNot(inst, new_inst) self.assertIsInstance(new_inst, TestClass) self.assertEqual(10, new_inst._value) # Make sure registering a new class with same name will fail. with self.assertRaisesRegex(ValueError, ".*has already been registered.*"): @serialization.register_softlearning_serializable() # pylint: disable=function-redefined class TestClass(object): def __init__(self, value): self._value = value def get_config(self): return {'value': self._value}
def deserialize(name, custom_objects=None): """Returns a replay pool function or class denoted by input string. Arguments: name : String Returns: Replay Pool function or class denoted by input string. For example: >>> softlearning.replay_pools.get({'class_name': 'SimpleReplayPool', ...}) <softlearning.replay_pools.simple_replay_pool.SimpleReplayPool object at 0x7fea93d6cdd0> >>> softlearning.replay_pools.get('abcd') Traceback (most recent call last): ... ValueError: Unknown replay pool: abcd Args: name: The name of the replay pool. Raises: ValueError: `Unknown replay pool` if the input string does not denote any defined replay pool. """ return deserialize_softlearning_object(name, module_objects=globals(), custom_objects=custom_objects, printable_module_name='replay pool')
def deserialize(name, custom_objects=None): """Returns a algorithm function or class denoted by input string. Arguments: name : String Returns: Algorithm function or class denoted by input string. For example: >>> softlearning.algorithms.get({'class_name': 'SAC', ...}) <softlearning.algorithms.sac.SAC object at 0x7fea93d6cdd0> >>> softlearning.algorithms.get('abcd') Traceback (most recent call last): ... ValueError: Unknown algorithm: abcd Args: name: The name of the algorithm. Raises: ValueError: `Unknown algorithm` if the input string does not denote any defined algorithm. """ return deserialize_softlearning_object(name, module_objects=globals(), custom_objects=custom_objects, printable_module_name='algorithm')
def deserialize(name, custom_objects=None): """Returns a preprocessor function or class denoted by input string. Arguments: name : String Returns: Preprocessor function or class denoted by input string. For example: >>> softlearning.preprocessors.get('convnet_preprocessor') <function convnet_preprocessor at 0x7fd170125950> >>> softlearning.preprocessors.get('abcd') Traceback (most recent call last): ... ValueError: Unknown preprocessor: abcd Args: name: The name of the preprocessor. Raises: ValueError: `Unknown preprocessor` if the input string does not denote any defined preprocessor. """ return deserialize_softlearning_object( name, module_objects=globals(), custom_objects=custom_objects, printable_module_name='preprocessor')
def deserialize(name, custom_objects=None): """Returns a sampler function or class denoted by input string. Arguments: name : String Returns: Sampler function or class denoted by input string. For example: >>> softlearning.samplers.get({'class_name': 'SimpleSampler', ...}) <softlearning.samplers.simple_sampler.SimpleSampler object at 0x7fea93d6cdd0> >>> softlearning.samplers.get('abcd') Traceback (most recent call last): ... ValueError: Unknown sampler: abcd Args: name: The name of the sampler. Raises: ValueError: `Unknown sampler` if the input string does not denote any defined sampler. """ return deserialize_softlearning_object(name, module_objects=globals(), custom_objects=custom_objects, printable_module_name='sampler')
def deserialize(name, custom_objects=None): """Returns a policy function or class denoted by input string. Arguments: name : String Returns: Policy function or class denoted by input string. For example: >>> softlearning.policies.get({ ... 'class_name': 'ContinuousUniformPolicy', ... 'config': { ... 'action_range': [[-1], [1]], ... 'input_shapes': tf.TensorShape((3, )), ... 'output_shape': 2 ... } ... }) <softlearning.policies.uniform_policy.ContinuousUniformPolicy object at 0x7fea93d6cdd0> >>> softlearning.policies.get('abcd') Traceback (most recent call last): ... ValueError: Unknown policy: abcd Args: name: The name of the policy. Raises: ValueError: `Unknown policy` if the input string does not denote any defined policy. """ return deserialize_softlearning_object(name, module_objects=globals(), custom_objects=custom_objects, printable_module_name='policy')
def test_serialize_custom_function(self): @serialization.register_softlearning_serializable() def my_fn(): return 42 serialized_name = 'Custom>my_fn' class_name = serialization._GLOBAL_CUSTOM_NAMES[my_fn] self.assertEqual(serialized_name, class_name) fn_class_name = serialization.get_registered_name(my_fn) self.assertEqual(fn_class_name, class_name) config = serialization.serialize_softlearning_object(my_fn) self.assertEqual(class_name, config) fn = serialization.deserialize_softlearning_object(config) self.assertEqual(42, fn()) fn_2 = serialization.get_registered_object(fn_class_name) self.assertEqual(42, fn_2())
def test_serialize_none(self): serialized = serialization.serialize_softlearning_object(None) self.assertEqual(serialized, None) deserialized = serialization.deserialize_softlearning_object( serialized) self.assertEqual(deserialized, None)