예제 #1
0
    def from_config(cls, config, custom_objects=None):
        globs = globals()
        if custom_objects:
            globs = dict(list(globs.items()) + list(custom_objects.items()))
        function_type = config.pop('function_type')
        if function_type == 'function':
            # Simple lookup in custom objects
            function = deserialize_keras_object(
                config['function'],
                custom_objects=custom_objects,
                printable_module_name='function in Lambda layer')
        elif function_type == 'lambda':
            # Unsafe deserialization from bytecode
            function = func_load(config['function'], globs=globs)
        else:
            raise TypeError('Unknown function type:', function_type)

        # If arguments were numpy array, they have been saved as
        # list. We need to recover the ndarray
        if 'arguments' in config:
            for key in config['arguments']:
                if isinstance(config['arguments'][key], dict):
                    arg_dict = config['arguments'][key]
                    if 'type' in arg_dict and arg_dict['type'] == 'ndarray':
                        # Overwrite the argument with its numpy translation
                        config['arguments'][key] = np.array(arg_dict['value'])

        config['function'] = function
        return cls(**config)
예제 #2
0
def deserialize(config, custom_objects=None):
    """Inverse of the `serialize` function.

  Arguments:
      config: Optimizer configuration dictionary.
      custom_objects: Optional dictionary mapping
          names (strings) to custom objects
          (classes and functions)
          to be considered during deserialization.

  Returns:
      A Keras Optimizer instance.
  """
    all_classes = {
        'sgd': SGD,
        'rmsprop': RMSprop,
        'adagrad': Adagrad,
        'adadelta': Adadelta,
        'adam': Adam,
        'adamax': Adamax,
        'nadam': Nadam,
        'tfoptimizer': TFOptimizer,
    }
    # Make deserialization case-insensitive for built-in optimizers.
    if config['class_name'].lower() in all_classes:
        config['class_name'] = config['class_name'].lower()
    return deserialize_keras_object(config,
                                    module_objects=all_classes,
                                    custom_objects=custom_objects,
                                    printable_module_name='optimizer')
예제 #3
0
파일: core.py 프로젝트: chdinh/tensorflow
  def from_config(cls, config, custom_objects=None):
    globs = globals()
    if custom_objects:
      globs = dict(list(globs.items()) + list(custom_objects.items()))
    function_type = config.pop('function_type')
    if function_type == 'function':
      # Simple lookup in custom objects
      function = deserialize_keras_object(
          config['function'],
          custom_objects=custom_objects,
          printable_module_name='function in Lambda layer')
    elif function_type == 'lambda':
      # Unsafe deserialization from bytecode
      function = func_load(config['function'], globs=globs)
    else:
      raise TypeError('Unknown function type:', function_type)

    # If arguments were numpy array, they have been saved as
    # list. We need to recover the ndarray
    if 'arguments' in config:
      for key in config['arguments']:
        if isinstance(config['arguments'][key], dict):
          arg_dict = config['arguments'][key]
          if 'type' in arg_dict and arg_dict['type'] == 'ndarray':
            # Overwrite the argument with its numpy translation
            config['arguments'][key] = np.array(arg_dict['value'])

    config['function'] = function
    return cls(**config)
예제 #4
0
def deserialize(config, custom_objects=None):
  """Inverse of the `serialize` function.

  Arguments:
      config: Optimizer configuration dictionary.
      custom_objects: Optional dictionary mapping
          names (strings) to custom objects
          (classes and functions)
          to be considered during deserialization.

  Returns:
      A Keras Optimizer instance.
  """
  all_classes = {
      'sgd': SGD,
      'rmsprop': RMSprop,
      'adagrad': Adagrad,
      'adadelta': Adadelta,
      'adam': Adam,
      'adamax': Adamax,
      'nadam': Nadam,
      'tfoptimizer': TFOptimizer,
  }
  # Make deserialization case-insensitive for built-in optimizers.
  if config['class_name'].lower() in all_classes:
    config['class_name'] = config['class_name'].lower()
  return deserialize_keras_object(
      config,
      module_objects=all_classes,
      custom_objects=custom_objects,
      printable_module_name='optimizer')
예제 #5
0
  def from_config(cls, config, custom_objects=None):
    globs = globals()
    if custom_objects:
      globs = dict(list(globs.items()) + list(custom_objects.items()))
    function_type = config.pop('function_type')
    if function_type == 'function':
      # Simple lookup in custom objects
      function = deserialize_keras_object(
          config['function'],
          custom_objects=custom_objects,
          printable_module_name='function in Lambda layer')
    elif function_type == 'lambda':
      # Unsafe deserialization from bytecode
      function = func_load(config['function'], globs=globs)
    else:
      raise TypeError('Unknown function type:', function_type)

    config['function'] = function
    return cls(**config)
예제 #6
0
파일: core.py 프로젝트: maony/tensorflow
  def from_config(cls, config, custom_objects=None):
    globs = globals()
    if custom_objects:
      globs = dict(list(globs.items()) + list(custom_objects.items()))
    function_type = config.pop('function_type')
    if function_type == 'function':
      # Simple lookup in custom objects
      function = deserialize_keras_object(
          config['function'],
          custom_objects=custom_objects,
          printable_module_name='function in Lambda layer')
    elif function_type == 'lambda':
      # Unsafe deserialization from bytecode
      function = func_load(config['function'], globs=globs)
    else:
      raise TypeError('Unknown function type:', function_type)

    config['function'] = function
    return cls(**config)
def deserialize(config, custom_objects=None):
    """Instantiates a layer from a config dictionary.

  Arguments:
      config: dict of the form {'class_name': str, 'config': dict}
      custom_objects: dict mapping class names (or function names)
          of custom (non-Keras) objects to class/functions

  Returns:
      Layer instance (may be Model, Sequential, Layer...)
  """
    from tensorflow.contrib.keras.python.keras import models  # pylint: disable=g-import-not-at-top
    globs = globals()  # All layers.
    globs['Model'] = models.Model
    globs['Sequential'] = models.Sequential
    return deserialize_keras_object(config,
                                    module_objects=globs,
                                    custom_objects=custom_objects,
                                    printable_module_name='layer')
예제 #8
0
def deserialize(config, custom_objects=None):
  """Instantiates a layer from a config dictionary.

  Arguments:
      config: dict of the form {'class_name': str, 'config': dict}
      custom_objects: dict mapping class names (or function names)
          of custom (non-Keras) objects to class/functions

  Returns:
      Layer instance (may be Model, Sequential, Layer...)
  """
  from tensorflow.contrib.keras.python.keras import models  # pylint: disable=g-import-not-at-top
  globs = globals()  # All layers.
  globs['Model'] = models.Model
  globs['Sequential'] = models.Sequential
  return deserialize_keras_object(
      config,
      module_objects=globs,
      custom_objects=custom_objects,
      printable_module_name='layer')
예제 #9
0
def deserialize(name, custom_objects=None):
  return deserialize_keras_object(
      name,
      module_objects=globals(),
      custom_objects=custom_objects,
      printable_module_name='loss function')
예제 #10
0
def deserialize(config, custom_objects=None):
  return deserialize_keras_object(
      config,
      module_objects=globals(),
      custom_objects=custom_objects,
      printable_module_name='regularizer')
예제 #11
0
def deserialize(config, custom_objects=None):
  return deserialize_keras_object(
      config,
      module_objects=globals(),
      custom_objects=custom_objects,
      printable_module_name='initializer')
예제 #12
0
def deserialize(name, custom_objects=None):
  return deserialize_keras_object(
      name,
      module_objects=globals(),
      custom_objects=custom_objects,
      printable_module_name='metric function')