Exemplo n.º 1
0
  def from_config(cls, config, custom_objects=None):
    config = config.copy()
    globs = globals()
    if custom_objects:
      globs = dict(list(globs.items()) + list(custom_objects.items()))
    function_type = config.pop('function_type')
    if function_type == 'function':
      # Simple lookup in custom objects
      function = deserialize_keras_object(
          config['function'],
          custom_objects=custom_objects,
          printable_module_name='function in Lambda layer')
    elif function_type == 'lambda':
      # Unsafe deserialization from bytecode
      function = func_load(config['function'], globs=globs)
    else:
      raise TypeError('Unknown function type:', function_type)

    # If arguments were numpy array, they have been saved as
    # list. We need to recover the ndarray
    if 'arguments' in config:
      for key in config['arguments']:
        if isinstance(config['arguments'][key], dict):
          arg_dict = config['arguments'][key]
          if 'type' in arg_dict and arg_dict['type'] == 'ndarray':
            # Overwrite the argument with its numpy translation
            config['arguments'][key] = np.array(arg_dict['value'])

    config['function'] = function
    return cls(**config)
Exemplo n.º 2
0
def deserialize_function(serial, function_type):
    """Deserializes the Keras-serialized function.

  (De)serializing Python functions from/to bytecode is unsafe. Therefore we
  also use the function's type as an anonymous function ('lambda') or named
  function in the Python environment ('function'). In the latter case, this lets
  us use the Python scope to obtain the function rather than reload it from
  bytecode. (Note that both cases are brittle!)

  Keras-deserialized functions do not perform lexical scoping. Any modules that
  the function requires must be imported within the function itself.

  This serialization mimicks the implementation in `tf.keras.layers.Lambda`.

  Args:
    serial: Serialized Keras object: typically a dict, string, or bytecode.
    function_type: Python string denoting 'function' or 'lambda'.

  Returns:
    function: Function the serialized Keras object represents.

  #### Examples

  ```python
  serial, function_type = serialize_function(lambda x: x)
  function = deserialize_function(serial, function_type)
  assert function(2.3) == 2.3  # function is identity
  ```

  """
    if function_type == 'function':
        # Simple lookup in custom objects
        function = tf.keras.utils.deserialize_keras_object(serial)
    elif function_type == 'lambda':
        # Unsafe deserialization from bytecode
        function = generic_utils.func_load(serial)
    else:
        raise TypeError('Unknown function type:', function_type)
    return function