Esempio n. 1
0
def _parse_config_to_function(config, custom_objects, func_attr_name,
                              func_type_attr_name, module_attr_name):
    """Reconstruct the function from the config."""
    globs = globals()
    module = config.pop(module_attr_name, None)
    if module in sys.modules:
        globs.update(sys.modules[module].__dict__)
    elif module is not None:
        # Note: we don't know the name of the function if it's a lambda.
        warnings.warn(
            "{} is not loaded, but a layer uses it. "
            "It may cause errors.".format(module), UserWarning)
    if custom_objects:
        globs.update(custom_objects)
    function_type = config.pop(func_type_attr_name)
    if function_type == "function":
        # Simple lookup in custom objects
        function = generic_utils.deserialize_keras_object(
            config[func_attr_name],
            custom_objects=custom_objects,
            printable_module_name="function in wrapper")
    elif function_type == "lambda":
        # Unsafe deserialization from bytecode
        function = generic_utils.func_load(config[func_attr_name], globs=globs)
    else:
        raise TypeError("Unknown function type:", function_type)
    return function
Esempio n. 2
0
 def _parse_function_from_config(cls, config, custom_objects,
                                 func_attr_name, module_attr_name,
                                 func_type_attr_name):
     globs = globals()
     module = config.pop(module_attr_name, None)
     if module in sys.modules:
         globs.update(sys.modules[module].__dict__)
     elif module is not None:
         # Note: we don't know the name of the function if it's a lambda.
         warnings.warn(
             '{} is not loaded, but a Lambda layer uses it. '
             'It may cause errors.'.format(module), UserWarning)
     if custom_objects:
         globs.update(custom_objects)
     function_type = config.pop(func_type_attr_name)
     if function_type == 'function':
         # Simple lookup in custom objects
         function = generic_utils.deserialize_keras_object(
             config[func_attr_name],
             custom_objects=custom_objects,
             printable_module_name='function in Lambda layer')
     elif function_type == 'lambda':
         # Unsafe deserialization from bytecode
         function = generic_utils.func_load(config[func_attr_name],
                                            globs=globs)
     elif function_type == 'raw':
         function = config[func_attr_name]
     else:
         raise TypeError('Unknown function type:', function_type)
     return function
Esempio n. 3
0
  def from_config(cls, config, custom_objects=None):
    config = config.copy()
    globs = globals()
    if custom_objects:
      globs = dict(list(globs.items()) + list(custom_objects.items()))
    function_type = config.pop('function_type')
    if function_type == 'function':
      # Simple lookup in custom objects
      function = generic_utils.deserialize_keras_object(
          config['function'],
          custom_objects=custom_objects,
          printable_module_name='function in Lambda layer')
    elif function_type == 'lambda':
      # Unsafe deserialization from bytecode
      function = generic_utils.func_load(config['function'], globs=globs)
    else:
      raise TypeError('Unknown function type:', function_type)

    output_shape_type = config.pop('output_shape_type')
    if output_shape_type == 'function':
      # Simple lookup in custom objects
      output_shape = generic_utils.deserialize_keras_object(
          config['output_shape'],
          custom_objects=custom_objects,
          printable_module_name='output_shape function in Lambda layer')
    elif output_shape_type == 'lambda':
      # Unsafe deserialization from bytecode
      output_shape = generic_utils.func_load(config['output_shape'],
                                             globs=globs)
    else:
      output_shape = config['output_shape']

    # If arguments were numpy array, they have been saved as
    # list. We need to recover the ndarray
    if 'arguments' in config:
      for key in config['arguments']:
        if isinstance(config['arguments'][key], dict):
          arg_dict = config['arguments'][key]
          if 'type' in arg_dict and arg_dict['type'] == 'ndarray':
            # Overwrite the argument with its numpy translation
            config['arguments'][key] = np.array(arg_dict['value'])

    config['function'] = function
    config['output_shape'] = output_shape
    return cls(**config)
Esempio n. 4
0
    def from_config(cls, config, custom_objects=None):
        config = config.copy()
        globs = globals()
        if custom_objects:
            globs = dict(list(globs.items()) + list(custom_objects.items()))
        function_type = config.pop('function_type')
        if function_type == 'function':
            # Simple lookup in custom objects
            function = generic_utils.deserialize_keras_object(
                config['function'],
                custom_objects=custom_objects,
                printable_module_name='function in Lambda layer')
        elif function_type == 'lambda':
            # Unsafe deserialization from bytecode
            function = generic_utils.func_load(config['function'], globs=globs)
        else:
            raise TypeError('Unknown function type:', function_type)

        output_shape_type = config.pop('output_shape_type')
        if output_shape_type == 'function':
            # Simple lookup in custom objects
            output_shape = generic_utils.deserialize_keras_object(
                config['output_shape'],
                custom_objects=custom_objects,
                printable_module_name='output_shape function in Lambda layer')
        elif output_shape_type == 'lambda':
            # Unsafe deserialization from bytecode
            output_shape = generic_utils.func_load(config['output_shape'],
                                                   globs=globs)
        else:
            output_shape = config['output_shape']

        # If arguments were numpy array, they have been saved as
        # list. We need to recover the ndarray
        if 'arguments' in config:
            for key in config['arguments']:
                if isinstance(config['arguments'][key], dict):
                    arg_dict = config['arguments'][key]
                    if 'type' in arg_dict and arg_dict['type'] == 'ndarray':
                        # Overwrite the argument with its numpy translation
                        config['arguments'][key] = np.array(arg_dict['value'])

        config['function'] = function
        config['output_shape'] = output_shape
        return cls(**config)
Esempio n. 5
0
def deserialize_function(serial, function_type):
    """Deserializes the Keras-serialized function.

  (De)serializing Python functions from/to bytecode is unsafe. Therefore we
  also use the function's type as an anonymous function ('lambda') or named
  function in the Python environment ('function'). In the latter case, this lets
  us use the Python scope to obtain the function rather than reload it from
  bytecode. (Note that both cases are brittle!)

  Keras-deserialized functions do not perform lexical scoping. Any modules that
  the function requires must be imported within the function itself.

  This serialization mimicks the implementation in `tf.keras.layers.Lambda`.

  Args:
    serial: Serialized Keras object: typically a dict, string, or bytecode.
    function_type: Python string denoting 'function' or 'lambda'.

  Returns:
    function: Function the serialized Keras object represents.

  #### Examples

  ```python
  serial, function_type = serialize_function(lambda x: x)
  function = deserialize_function(serial, function_type)
  assert function(2.3) == 2.3  # function is identity
  ```

  """
    if function_type == 'function':
        # Simple lookup in custom objects
        function = tf.keras.utils.deserialize_keras_object(serial)
    elif function_type == 'lambda':
        # Unsafe deserialization from bytecode
        function = generic_utils.func_load(serial)
    else:
        raise TypeError('Unknown function type:', function_type)
    return function
Esempio n. 6
0
    def from_config(cls, config, custom_objects=None):
        fn, fn_type, fn_module = config['fn_layer_creator']

        globs = globals()
        module = config.pop(fn_module, None)
        if module in sys.modules:
            globs.update(sys.modules[module].__dict__)
        if custom_objects:
            globs.update(custom_objects)

        if fn_type == 'function':
            # Simple lookup in custom objects
            fn = generic_utils.deserialize_keras_object(
                fn,
                custom_objects=custom_objects,
                printable_module_name='function in Lambda layer')
        elif fn_type == 'lambda':
            # Unsafe deserialization from bytecode
            fn = generic_utils.func_load(fn, globs=globs)
        config['fn_layer_creator'] = fn

        return cls(**config)
Esempio n. 7
0
def deserialize_function(serial, function_type):
  """Deserializes the Keras-serialized function.

  (De)serializing Python functions from/to bytecode is unsafe. Therefore we
  also use the function's type as an anonymous function ('lambda') or named
  function in the Python environment ('function'). In the latter case, this lets
  us use the Python scope to obtain the function rather than reload it from
  bytecode. (Note that both cases are brittle!)

  Keras-deserialized functions do not perform lexical scoping. Any modules that
  the function requires must be imported within the function itself.

  This serialization mimicks the implementation in `tf.keras.layers.Lambda`.

  Args:
    serial: Serialized Keras object: typically a dict, string, or bytecode.
    function_type: Python string denoting 'function' or 'lambda'.

  Returns:
    function: Function the serialized Keras object represents.

  #### Examples

  ```python
  serial, function_type = serialize_function(lambda x: x)
  function = deserialize_function(serial, function_type)
  assert function(2.3) == 2.3  # function is identity
  ```

  """
  if function_type == 'function':
    # Simple lookup in custom objects
    function = tf.keras.utils.deserialize_keras_object(serial)
  elif function_type == 'lambda':
    # Unsafe deserialization from bytecode
    function = generic_utils.func_load(serial)
  else:
    raise TypeError('Unknown function type:', function_type)
  return function
Esempio n. 8
0
  def from_config(cls, config, custom_objects=None):
    config = config.copy()
    globs = globals()
    module = config.pop('module', None)
    if module in sys.modules:
      globs.update(sys.modules[module].__dict__)
    elif module is not None:
      # Note: we don't know the name of the function if it's a lambda.
      warnings.warn('{} is not loaded, but a Lambda layer uses it. '
                    'It may cause errors.'.format(module)
                    , UserWarning)
    if custom_objects:
      globs.update(custom_objects)
    function_type = config.pop('function_type')
    if function_type == 'function':
      # Simple lookup in custom objects
      function = generic_utils.deserialize_keras_object(
          config['function'],
          custom_objects=custom_objects,
          printable_module_name='function in Lambda layer')
    elif function_type == 'lambda':
      # Unsafe deserialization from bytecode
      function = generic_utils.func_load(config['function'], globs=globs)
    else:
      raise TypeError('Unknown function type:', function_type)

    output_shape_module = config.pop('output_shape_module', None)
    if output_shape_module in sys.modules:
      globs.update(sys.modules[output_shape_module].__dict__)
    elif output_shape_module is not None:
      # Note: we don't know the name of the function if it's a lambda.
      warnings.warn('{} is not loaded, but a Lambda layer uses it. '
                    'It may cause errors.'.format(output_shape_module)
                    , UserWarning)
    output_shape_type = config.pop('output_shape_type')
    if output_shape_type == 'function':
      # Simple lookup in custom objects
      output_shape = generic_utils.deserialize_keras_object(
          config['output_shape'],
          custom_objects=custom_objects,
          printable_module_name='output_shape function in Lambda layer')
    elif output_shape_type == 'lambda':
      # Unsafe deserialization from bytecode
      output_shape = generic_utils.func_load(config['output_shape'],
                                             globs=globs)
    else:
      output_shape = config['output_shape']

    # If arguments were numpy array, they have been saved as
    # list. We need to recover the ndarray
    if 'arguments' in config:
      for key in config['arguments']:
        if isinstance(config['arguments'][key], dict):
          arg_dict = config['arguments'][key]
          if 'type' in arg_dict and arg_dict['type'] == 'ndarray':
            # Overwrite the argument with its numpy translation
            config['arguments'][key] = np.array(arg_dict['value'])

    config['function'] = function
    config['output_shape'] = output_shape
    return cls(**config)
Esempio n. 9
0
  def from_config(cls, config, custom_objects=None):
    config = config.copy()
    globs = globals()
    module = config.pop('module', None)
    if module in sys.modules:
      globs.update(sys.modules[module].__dict__)
    elif module is not None:
      # Note: we don't know the name of the function if it's a lambda.
      warnings.warn('{} is not loaded, but a Lambda layer uses it. '
                    'It may cause errors.'.format(module)
                    , UserWarning)
    if custom_objects:
      globs.update(custom_objects)
    function_type = config.pop('function_type')
    if function_type == 'function':
      # Simple lookup in custom objects
      function = generic_utils.deserialize_keras_object(
          config['function'],
          custom_objects=custom_objects,
          printable_module_name='function in Lambda layer')
    elif function_type == 'lambda':
      # Unsafe deserialization from bytecode
      function = generic_utils.func_load(config['function'], globs=globs)
    else:
      raise TypeError('Unknown function type:', function_type)

    output_shape_module = config.pop('output_shape_module', None)
    if output_shape_module in sys.modules:
      globs.update(sys.modules[output_shape_module].__dict__)
    elif output_shape_module is not None:
      # Note: we don't know the name of the function if it's a lambda.
      warnings.warn('{} is not loaded, but a Lambda layer uses it. '
                    'It may cause errors.'.format(output_shape_module)
                    , UserWarning)
    output_shape_type = config.pop('output_shape_type')
    if output_shape_type == 'function':
      # Simple lookup in custom objects
      output_shape = generic_utils.deserialize_keras_object(
          config['output_shape'],
          custom_objects=custom_objects,
          printable_module_name='output_shape function in Lambda layer')
    elif output_shape_type == 'lambda':
      # Unsafe deserialization from bytecode
      output_shape = generic_utils.func_load(config['output_shape'],
                                             globs=globs)
    else:
      output_shape = config['output_shape']

    # If arguments were numpy array, they have been saved as
    # list. We need to recover the ndarray
    if 'arguments' in config:
      for key in config['arguments']:
        if isinstance(config['arguments'][key], dict):
          arg_dict = config['arguments'][key]
          if 'type' in arg_dict and arg_dict['type'] == 'ndarray':
            # Overwrite the argument with its numpy translation
            config['arguments'][key] = np.array(arg_dict['value'])

    config['function'] = function
    config['output_shape'] = output_shape
    return cls(**config)