Beispiel #1
0
    def from_config(cls, config):
        config = config.copy()
        mode_type = config.pop('mode_type')
        if mode_type == 'function':
            mode = globals()[config['mode']]
        elif mode_type == 'lambda':
            mode = func_load(config['mode'], globs=globals())
        else:
            mode = config['mode']

        output_shape_type = config.pop('output_shape_type', None)
        if output_shape_type == 'function':
            output_shape = globals()[config['output_shape']]
        elif output_shape_type == 'lambda':
            output_shape = func_load(config['output_shape'], globs=globals())
        else:
            output_shape = config.get('output_shape')

        output_mask_type = config.pop('output_mask_type', None)
        if output_mask_type == 'function':
            output_mask = globals()[config['output_mask']]
        elif output_mask_type == 'lambda':
            output_mask = func_load(config['output_mask'], globs=globals())
        else:
            output_mask = config.get('output_mask')

        config['mode'] = mode
        config['output_shape'] = output_shape
        config['output_mask'] = output_mask
        return super(Merge, cls).from_config(config)
Beispiel #2
0
    def serialize_one(activation):
        if isinstance(activation, six.string_types):
            return activation

        if isinstance(activation, keras.engine.Layer):  # Advanced activation
            return serialize_keras_object(activation)

        # The order matters here, since Layers are also callable.
        if callable(activation):  # A function
            return func_dump(activation)

        # Keras serialized config
        if isinstance(activation, dict) \
                and "class_name" in activation \
                and "config" in activation:
            return activation

        # Could be a marshalled function
        if isinstance(activation, (list, tuple)) \
                and len(activation) == 3 \
                and isinstance(activation[0], six.string_types):
            try:
                # TODO: Better way to check if it is a marshalled function!
                func_load(activation)  # Try to unmarshal it

                return activation

            except ValueError:
                pass

        return None
Beispiel #3
0
def _parse_config_to_function(config, custom_objects, func_attr_name,
                              func_type_attr_name, module_attr_name):
  """Reconstruct the function from the config."""
  globs = globals()
  module = config.pop(module_attr_name, None)
  if module in sys.modules:
    globs.update(sys.modules[module].__dict__)
  elif module is not None:
    # Note: we don't know the name of the function if it's a lambda.
    warnings.warn("{} is not loaded, but a layer uses it. "
                  "It may cause errors.".format(module), UserWarning)
  if custom_objects:
    globs.update(custom_objects)
  function_type = config.pop(func_type_attr_name)
  if function_type == "function":
    # Simple lookup in custom objects
    function = generic_utils.deserialize_keras_object(
        config[func_attr_name],
        custom_objects=custom_objects,
        printable_module_name="function in wrapper")
  elif function_type == "lambda":
    # Unsafe deserialization from bytecode
    function = generic_utils.func_load(
        config[func_attr_name], globs=globs)
  else:
    raise TypeError("Unknown function type:", function_type)
  return function
def test_func_dump_and_load(test_function_type):

    if test_function_type == 'simple function':

        def test_func():
            return r'\u'

    elif test_function_type == 'closured function':

        def get_test_func():
            x = r'\u'

            def test_func():
                return x

            return test_func

        test_func = get_test_func()
    else:
        raise Exception('Unknown test case for test_func_dump_and_load')

    serialized = func_dump(test_func)
    deserialized = func_load(serialized)
    assert deserialized.__code__ == test_func.__code__
    assert deserialized.__defaults__ == test_func.__defaults__
    assert deserialized.__closure__ == test_func.__closure__
Beispiel #5
0
 def _parse_function_from_config(cls, config, custom_objects,
                                 func_attr_name, module_attr_name,
                                 func_type_attr_name):
     globs = globals().copy()
     module = config.pop(module_attr_name, None)
     if module in sys.modules:
         globs.update(sys.modules[module].__dict__)
     elif module is not None:
         # Note: we don't know the name of the function if it's a lambda.
         warnings.warn('{} is not loaded, but a Lambda layer uses it. '
                       'It may cause errors.'.format(module),
                       UserWarning,
                       stacklevel=2)
     if custom_objects:
         globs.update(custom_objects)
     function_type = config.pop(func_type_attr_name)
     if function_type == 'function':
         # Simple lookup in custom objects
         function = generic_utils.deserialize_keras_object(
             config[func_attr_name],
             custom_objects=custom_objects,
             printable_module_name='function in Lambda layer')
     elif function_type == 'lambda':
         # Unsafe deserialization from bytecode
         function = generic_utils.func_load(config[func_attr_name],
                                            globs=globs)
     elif function_type == 'raw':
         function = config[func_attr_name]
     else:
         supported_types = ['function', 'lambda', 'raw']
         raise TypeError(
             f'Unsupported value for `function_type` argument. Received: '
             f'function_type={function_type}. Expected one of {supported_types}'
         )
     return function
def test_func_dump_and_load_closure():
    y = 0
    test_func = lambda x: x + y
    serialized, _, closure = func_dump(test_func)
    deserialized = func_load(serialized, closure=closure)
    assert deserialized.__code__ == test_func.__code__
    assert deserialized.__defaults__ == test_func.__defaults__
    assert deserialized.__closure__ == test_func.__closure__
Beispiel #7
0
def test_func_dump_and_load_closure():
    y = 0
    test_func = lambda x: x + y
    serialized, _, closure = func_dump(test_func)
    deserialized = func_load(serialized, closure=closure)
    assert deserialized.__code__ == test_func.__code__
    assert deserialized.__defaults__ == test_func.__defaults__
    assert deserialized.__closure__ == test_func.__closure__
Beispiel #8
0
def test_func_dump_and_load():
    def test_func():
        return r'\u'
    serialized = func_dump(test_func)
    deserialized = func_load(serialized)
    assert deserialized.__code__ == test_func.__code__
    assert deserialized.__defaults__ == test_func.__defaults__
    assert deserialized.__closure__ == test_func.__closure__
def test_func_dump_and_load():
    def test_func():
        return r'\u'

    serialized = func_dump(test_func)
    deserialized = func_load(serialized)
    assert deserialized.__code__ == test_func.__code__
    assert deserialized.__defaults__ == test_func.__defaults__
    assert deserialized.__closure__ == test_func.__closure__
Beispiel #10
0
def test_func_dump_and_load_backwards_compat(test_func):
    # this test ensures that models serialized prior to version 2.1.2 can still be
    # deserialized

    # see https://github.com/evhub/keras/blob/2.1.1/keras/utils/generic_utils.py#L166
    serialized = marshal.dumps(test_func.__code__).decode('raw_unicode_escape')

    deserialized = func_load(serialized, defaults=test_func.__defaults__)
    assert deserialized.__code__ == test_func.__code__
    assert deserialized.__defaults__ == test_func.__defaults__
    assert deserialized.__closure__ == test_func.__closure__
Beispiel #11
0
def test_func_dump_and_load_backwards_compat(test_func):
    # this test ensures that models serialized prior to version 2.1.2 can still be
    # deserialized

    # see https://github.com/evhub/keras/blob/2.1.1/keras/utils/generic_utils.py#L166
    serialized = marshal.dumps(test_func.__code__).decode('raw_unicode_escape')

    deserialized = func_load(serialized, defaults=test_func.__defaults__)
    assert deserialized.__code__ == test_func.__code__
    assert deserialized.__defaults__ == test_func.__defaults__
    assert deserialized.__closure__ == test_func.__closure__
Beispiel #12
0
    def deserialize_one(activation):

        # Simple activation
        if (activation is None) or isinstance(activation, six.string_types):
            return with_device(device, Activation, activation)

        # Advanced activation (it has already been created, nothing we can do)
        if isinstance(activation, keras.engine.Layer):
            return activation

        # Function (it has already been created, nothing we can do)
        if callable(activation):
            return activation

        # Keras serialized config
        if isinstance(activation, dict) \
                and "class_name" in activation \
                and "config" in activation:

            # Make advanced activation functions available per default
            if activation["class_name"] in dir(keras_advanced_activations):
                custom_objects = {}
                class_name = activation["class_name"]
                for attr in dir(keras_advanced_activations):
                    if class_name == attr:
                        layer = keras_advanced_activations.__dict__[class_name]
                        custom_objects[class_name] = layer
                        break

            return with_device(device,
                               keras_activations.deserialize,
                               activation,
                               custom_objects=custom_objects)

        # Could be a marshalled function
        if isinstance(activation, (list, tuple)) \
                and len(activation) == 3 \
                and isinstance(activation[0], six.string_types):
            try:
                # TODO: Better way to check if it is a marshalled function!
                return func_load(activation)  # Try to unmarshal it

            except EOFError:
                pass  # "marshal data too short" => Not a marshalled function

            except ValueError:
                pass  # ??

        return None
Beispiel #13
0
def test_func_dump_and_load(test_function_type):

    if test_function_type == 'simple function':
        def test_func():
            return r'\u'

    elif test_function_type == 'closured function':
        def get_test_func():
            x = r'\u'

            def test_func():
                return x
            return test_func
        test_func = get_test_func()
    else:
        raise Exception('Unknown test case for test_func_dump_and_load')

    serialized = func_dump(test_func)
    deserialized = func_load(serialized)
    assert deserialized.__code__ == test_func.__code__
    assert deserialized.__defaults__ == test_func.__defaults__
    assert deserialized.__closure__ == test_func.__closure__