Example #1
0
    def test_serialize_custom_class_with_custom_name(self):
        @serialization.register_softlearning_serializable(
            'TestPackage', 'CustomName')
        class OtherTestClass(object):
            def __init__(self, val):
                self._val = val

            def get_config(self):
                return {'val': self._val}

        serialized_name = 'TestPackage>CustomName'
        inst = OtherTestClass(val=5)
        class_name = serialization._GLOBAL_CUSTOM_NAMES[OtherTestClass]
        self.assertEqual(serialized_name, class_name)
        fn_class_name = serialization.get_registered_name(OtherTestClass)
        self.assertEqual(fn_class_name, class_name)

        cls = serialization.get_registered_object(fn_class_name)
        self.assertEqual(OtherTestClass, cls)

        config = serialization.serialize_softlearning_object(inst)
        self.assertEqual(class_name, config['class_name'])
        new_inst = serialization.deserialize_softlearning_object(config)
        self.assertIsNot(inst, new_inst)
        self.assertIsInstance(new_inst, OtherTestClass)
        self.assertEqual(5, new_inst._val)
Example #2
0
def deserialize(name, custom_objects=None):
    """Returns a value function or class denoted by input string.

    Arguments:
        name : String

    Returns:
        Value function function or class denoted by input string.

    For example:
    >>> softlearning.value_functions.get('double_feedforward_Q_function')
      <function double_feedforward_Q_function at 0x7f86e3691e60>
    >>> softlearning.value_functions.get('abcd')
      Traceback (most recent call last):
      ...
      ValueError: Unknown value function: abcd

    Args:
      name: The name of the value function.

    Raises:
        ValueError: `Unknown value function` if the input string does not
        denote any defined value function.
    """
    return deserialize_softlearning_object(
        name,
        module_objects=globals(),
        custom_objects=custom_objects,
        printable_module_name='value function')
Example #3
0
    def test_serialize_custom_class_with_default_name(self):
        @serialization.register_softlearning_serializable()
        class TestClass(object):
            def __init__(self, value):
                self._value = value

            def get_config(self):
                return {'value': self._value}

        serialized_name = 'Custom>TestClass'
        inst = TestClass(value=10)
        class_name = serialization._GLOBAL_CUSTOM_NAMES[TestClass]
        self.assertEqual(serialized_name, class_name)
        config = serialization.serialize_softlearning_object(inst)
        self.assertEqual(class_name, config['class_name'])
        new_inst = serialization.deserialize_softlearning_object(config)
        self.assertIsNot(inst, new_inst)
        self.assertIsInstance(new_inst, TestClass)
        self.assertEqual(10, new_inst._value)

        # Make sure registering a new class with same name will fail.
        with self.assertRaisesRegex(ValueError,
                                    ".*has already been registered.*"):

            @serialization.register_softlearning_serializable()  # pylint: disable=function-redefined
            class TestClass(object):
                def __init__(self, value):
                    self._value = value

                def get_config(self):
                    return {'value': self._value}
Example #4
0
def deserialize(name, custom_objects=None):
    """Returns a replay pool function or class denoted by input string.

    Arguments:
        name : String

    Returns:
        Replay Pool function or class denoted by input string.

    For example:
    >>> softlearning.replay_pools.get({'class_name': 'SimpleReplayPool', ...})
      <softlearning.replay_pools.simple_replay_pool.SimpleReplayPool object at 0x7fea93d6cdd0>
    >>> softlearning.replay_pools.get('abcd')
      Traceback (most recent call last):
      ...
      ValueError: Unknown replay pool: abcd

    Args:
      name: The name of the replay pool.

    Raises:
        ValueError: `Unknown replay pool` if the input string does not
        denote any defined replay pool.
    """
    return deserialize_softlearning_object(name,
                                           module_objects=globals(),
                                           custom_objects=custom_objects,
                                           printable_module_name='replay pool')
Example #5
0
def deserialize(name, custom_objects=None):
    """Returns a algorithm function or class denoted by input string.

    Arguments:
        name : String

    Returns:
        Algorithm function or class denoted by input string.

    For example:
    >>> softlearning.algorithms.get({'class_name': 'SAC', ...})
      <softlearning.algorithms.sac.SAC object at 0x7fea93d6cdd0>
    >>> softlearning.algorithms.get('abcd')
      Traceback (most recent call last):
      ...
      ValueError: Unknown algorithm: abcd

    Args:
      name: The name of the algorithm.

    Raises:
        ValueError: `Unknown algorithm` if the input string does not
        denote any defined algorithm.
    """
    return deserialize_softlearning_object(name,
                                           module_objects=globals(),
                                           custom_objects=custom_objects,
                                           printable_module_name='algorithm')
Example #6
0
def deserialize(name, custom_objects=None):
    """Returns a preprocessor function or class denoted by input string.

    Arguments:
        name : String

    Returns:
        Preprocessor function or class denoted by input string.

    For example:
    >>> softlearning.preprocessors.get('convnet_preprocessor')
      <function convnet_preprocessor at 0x7fd170125950>
    >>> softlearning.preprocessors.get('abcd')
      Traceback (most recent call last):
      ...
      ValueError: Unknown preprocessor: abcd

    Args:
      name: The name of the preprocessor.

    Raises:
        ValueError: `Unknown preprocessor` if the input string does not
        denote any defined preprocessor.
    """
    return deserialize_softlearning_object(
        name,
        module_objects=globals(),
        custom_objects=custom_objects,
        printable_module_name='preprocessor')
Example #7
0
def deserialize(name, custom_objects=None):
    """Returns a sampler function or class denoted by input string.

    Arguments:
        name : String

    Returns:
        Sampler function or class denoted by input string.

    For example:
    >>> softlearning.samplers.get({'class_name': 'SimpleSampler', ...})
      <softlearning.samplers.simple_sampler.SimpleSampler object at 0x7fea93d6cdd0>
    >>> softlearning.samplers.get('abcd')
      Traceback (most recent call last):
      ...
      ValueError: Unknown sampler: abcd

    Args:
      name: The name of the sampler.

    Raises:
        ValueError: `Unknown sampler` if the input string does not
        denote any defined sampler.
    """
    return deserialize_softlearning_object(name,
                                           module_objects=globals(),
                                           custom_objects=custom_objects,
                                           printable_module_name='sampler')
Example #8
0
def deserialize(name, custom_objects=None):
    """Returns a policy function or class denoted by input string.

    Arguments:
        name : String

    Returns:
        Policy function or class denoted by input string.

    For example:
    >>> softlearning.policies.get({
    ...     'class_name': 'ContinuousUniformPolicy',
    ...     'config': {
    ...         'action_range': [[-1], [1]],
    ...         'input_shapes': tf.TensorShape((3, )),
    ...         'output_shape': 2
    ...      }
    ... })
      <softlearning.policies.uniform_policy.ContinuousUniformPolicy object at 0x7fea93d6cdd0>
    >>> softlearning.policies.get('abcd')
      Traceback (most recent call last):
      ...
      ValueError: Unknown policy: abcd

    Args:
      name: The name of the policy.

    Raises:
        ValueError: `Unknown policy` if the input string does not
        denote any defined policy.
    """
    return deserialize_softlearning_object(name,
                                           module_objects=globals(),
                                           custom_objects=custom_objects,
                                           printable_module_name='policy')
Example #9
0
    def test_serialize_custom_function(self):
        @serialization.register_softlearning_serializable()
        def my_fn():
            return 42

        serialized_name = 'Custom>my_fn'
        class_name = serialization._GLOBAL_CUSTOM_NAMES[my_fn]
        self.assertEqual(serialized_name, class_name)
        fn_class_name = serialization.get_registered_name(my_fn)
        self.assertEqual(fn_class_name, class_name)

        config = serialization.serialize_softlearning_object(my_fn)
        self.assertEqual(class_name, config)
        fn = serialization.deserialize_softlearning_object(config)
        self.assertEqual(42, fn())

        fn_2 = serialization.get_registered_object(fn_class_name)
        self.assertEqual(42, fn_2())
Example #10
0
 def test_serialize_none(self):
     serialized = serialization.serialize_softlearning_object(None)
     self.assertEqual(serialized, None)
     deserialized = serialization.deserialize_softlearning_object(
         serialized)
     self.assertEqual(deserialized, None)