Example #1
0
def deserialize(config, custom_objects=None):
  """Inverse of the `serialize` function.

  Arguments:
      config: Optimizer configuration dictionary.
      custom_objects: Optional dictionary mapping
          names (strings) to custom objects
          (classes and functions)
          to be considered during deserialization.

  Returns:
      A Keras Optimizer instance.
  """
  all_classes = {
      'sgd': SGD,
      'rmsprop': RMSprop,
      'adagrad': Adagrad,
      'adadelta': Adadelta,
      'adam': Adam,
      'adamax': Adamax,
      'nadam': Nadam,
      'tfoptimizer': TFOptimizer,
  }
  # Make deserialization case-insensitive for built-in optimizers.
  if config['class_name'].lower() in all_classes:
    config['class_name'] = config['class_name'].lower()
  return deserialize_keras_object(
      config,
      module_objects=all_classes,
      custom_objects=custom_objects,
      printable_module_name='optimizer')
Example #2
0
def deserialize(config, custom_objects=None):
  """Instantiates a layer from a config dictionary.

  Arguments:
      config: dict of the form {'class_name': str, 'config': dict}
      custom_objects: dict mapping class names (or function names)
          of custom (non-Keras) objects to class/functions

  Returns:
      Layer instance (may be Model, Sequential, Network, Layer...)
  """
  from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
  globs = globals()  # All layers.
  globs['Network'] = models.Network
  globs['Model'] = models.Model
  globs['Sequential'] = models.Sequential
  layer_class_name = config['class_name']
  if layer_class_name in _DESERIALIZATION_TABLE:
    config['class_name'] = _DESERIALIZATION_TABLE[layer_class_name]

  return deserialize_keras_object(
      config,
      module_objects=globs,
      custom_objects=custom_objects,
      printable_module_name='layer')
Example #3
0
  def from_config(cls, config, custom_objects=None):
    config = config.copy()
    globs = globals()
    if custom_objects:
      globs = dict(list(globs.items()) + list(custom_objects.items()))
    function_type = config.pop('function_type')
    if function_type == 'function':
      # Simple lookup in custom objects
      function = generic_utils.deserialize_keras_object(
          config['function'],
          custom_objects=custom_objects,
          printable_module_name='function in Lambda layer')
    elif function_type == 'lambda':
      # Unsafe deserialization from bytecode
      function = generic_utils.func_load(config['function'], globs=globs)
    else:
      raise TypeError('Unknown function type:', function_type)

    output_shape_type = config.pop('output_shape_type')
    if output_shape_type == 'function':
      # Simple lookup in custom objects
      output_shape = generic_utils.deserialize_keras_object(
          config['output_shape'],
          custom_objects=custom_objects,
          printable_module_name='output_shape function in Lambda layer')
    elif output_shape_type == 'lambda':
      # Unsafe deserialization from bytecode
      output_shape = generic_utils.func_load(config['output_shape'],
                                             globs=globs)
    else:
      output_shape = config['output_shape']

    # If arguments were numpy array, they have been saved as
    # list. We need to recover the ndarray
    if 'arguments' in config:
      for key in config['arguments']:
        if isinstance(config['arguments'][key], dict):
          arg_dict = config['arguments'][key]
          if 'type' in arg_dict and arg_dict['type'] == 'ndarray':
            # Overwrite the argument with its numpy translation
            config['arguments'][key] = np.array(arg_dict['value'])

    config['function'] = function
    config['output_shape'] = output_shape
    return cls(**config)
Example #4
0
def deserialize(config, custom_objects=None):
  """Return an `Initializer` object from its config."""
  if tf2.enabled():
    # Class names are the same for V1 and V2 but the V2 classes
    # are aliased in this file so we need to grab them directly
    # from `init_ops_v2`.
    module_objects = {
        obj_name: getattr(init_ops_v2, obj_name)
        for obj_name in dir(init_ops_v2)
    }
  else:
    module_objects = globals()
  return deserialize_keras_object(
      config,
      module_objects=module_objects,
      custom_objects=custom_objects,
      printable_module_name='initializer')
Example #5
0
def recursively_deserialize_keras_object(config, module_objects=None):
    """Deserialize Keras object from a nested structure."""
    if isinstance(config, dict):
        if 'class_name' in config:
            return generic_utils.deserialize_keras_object(
                config, module_objects=module_objects)
        else:
            return {
                key:
                recursively_deserialize_keras_object(config[key],
                                                     module_objects)
                for key in config
            }
    if isinstance(config, (tuple, list)):
        return [
            recursively_deserialize_keras_object(x, module_objects)
            for x in config
        ]
    else:
        raise ValueError('Unable to decode config: {}'.format(config))
Example #6
0
def deserialize(config, custom_objects=None):
  """Instantiates a layer from a config dictionary.

  Arguments:
      config: dict of the form {'class_name': str, 'config': dict}
      custom_objects: dict mapping class names (or function names)
          of custom (non-Keras) objects to class/functions

  Returns:
      Layer instance (may be Model, Sequential, Layer...)
  """
  from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
  globs = globals()  # All layers.
  globs['Model'] = models.Model
  globs['Sequential'] = models.Sequential
  return deserialize_keras_object(
      config,
      module_objects=globs,
      custom_objects=custom_objects,
      printable_module_name='layer')
Example #7
0
def deserialize(config, custom_objects=None):
    """Inverse of the `serialize` function.

  Arguments:
      config: Optimizer configuration dictionary.
      custom_objects: Optional dictionary mapping
          names (strings) to custom objects
          (classes and functions)
          to be considered during deserialization.

  Returns:
      A Keras Optimizer instance.
  """
    if tf2.enabled():
        all_classes = {
            'adadelta': adadelta_v2.Adadelta,
            'adagrad': adagrad_v2.Adagrad,
            'adam': adam_v2.Adam,
            'adamax': adamax_v2.Adamax,
            'nadam': nadam_v2.Nadam,
            'rmsprop': rmsprop_v2.RMSprop,
            'sgd': gradient_descent_v2.SGD
        }
    else:
        all_classes = {
            'adadelta': Adadelta,
            'adagrad': Adagrad,
            'adam': Adam,
            'adamax': Adamax,
            'nadam': Nadam,
            'rmsprop': RMSprop,
            'sgd': SGD,
            'tfoptimizer': TFOptimizer
        }
    # Make deserialization case-insensitive for built-in optimizers.
    if config['class_name'].lower() in all_classes:
        config['class_name'] = config['class_name'].lower()
    return deserialize_keras_object(config,
                                    module_objects=all_classes,
                                    custom_objects=custom_objects,
                                    printable_module_name='optimizer')
    def testSerialization(self, quantizer_type):
        quantizer = quantizer_type(**self.quant_params)

        expected_config = {
            'class_name': quantizer_type.__name__,
            'config': {
                'num_bits': 8,
                'per_axis': False,
                'symmetric': False
            }
        }
        serialized_quantizer = serialize_keras_object(quantizer)

        self.assertEqual(expected_config, serialized_quantizer)

        quantizer_from_config = deserialize_keras_object(
            serialized_quantizer,
            module_objects=globals(),
            custom_objects=quantizers._types_dict())

        self.assertEqual(quantizer, quantizer_from_config)
Example #9
0
    def from_config(cls, config):
        config = config.copy()

        pruning_schedule = config.pop('pruning_schedule')
        from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object  # pylint: disable=g-import-not-at-top
        # TODO(pulkitb): This should ideally be fetched from pruning_schedule,
        # which should maintain a list of all the pruning_schedules.
        custom_objects = {
            'ConstantSparsity': pruning_sched.ConstantSparsity,
            'PolynomialDecay': pruning_sched.PolynomialDecay
        }
        config['pruning_schedule'] = deserialize_keras_object(
            pruning_schedule,
            module_objects=globals(),
            custom_objects=custom_objects)

        from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
        layer = deserialize_layer(config.pop('layer'))
        config['layer'] = layer

        return cls(**config)
    def testSerialization(self):
        expected_config = {
            'class_name': 'TFLiteQuantizeProviderRNN',
            'config': {
                'weight_attrs': [['kernel', 'recurrent_kernel'],
                                 ['kernel', 'recurrent_kernel']],
                'activation_attrs': [['activation', 'recurrent_activation'],
                                     ['activation', 'recurrent_activation']],
            }
        }
        serialized_quantize_provider = serialize_keras_object(
            self.quantize_provider)

        self.assertEqual(expected_config, serialized_quantize_provider)

        quantize_provider_from_config = deserialize_keras_object(
            serialized_quantize_provider,
            module_objects=globals(),
            custom_objects=tflite_quantize_registry._types_dict())

        self.assertEqual(self.quantize_provider, quantize_provider_from_config)
Example #11
0
    def from_config(cls, config, custom_objects=None):
        fn, fn_type, fn_module = config['fn_layer_creator']

        globs = globals()
        module = config.pop(fn_module, None)
        if module in sys.modules:
            globs.update(sys.modules[module].__dict__)
        if custom_objects:
            globs.update(custom_objects)

        if fn_type == 'function':
            # Simple lookup in custom objects
            fn = generic_utils.deserialize_keras_object(
                fn,
                custom_objects=custom_objects,
                printable_module_name='function in Lambda layer')
        elif fn_type == 'lambda':
            # Unsafe deserialization from bytecode
            fn = generic_utils.func_load(fn, globs=globs)
        config['fn_layer_creator'] = fn

        return cls(**config)
Example #12
0
def deserialize(name, custom_objects=None):
    """Returns activation function given a string identifier.

  Args:
    name: The name of the activation function.
    custom_objects: Optional `{function_name: function_obj}`
      dictionary listing user-provided activation functions.

  Returns:
      Corresponding activation function.

  For example:

  >>> tf.keras.activations.deserialize('linear')
   <function linear at 0x1239596a8>
  >>> tf.keras.activations.deserialize('sigmoid')
   <function sigmoid at 0x123959510>
  >>> tf.keras.activations.deserialize('abcd')
  Traceback (most recent call last):
  ...
  ValueError: Unknown activation function:abcd

  Raises:
      ValueError: `Unknown activation function` if the input string does not
      denote any defined Tensorflow activation function.
  """
    globs = globals()

    # only replace missing activations
    advanced_activations_globs = advanced_activations.get_globals()
    for key, val in advanced_activations_globs.items():
        if key not in globs:
            globs[key] = val

    return deserialize_keras_object(
        name,
        module_objects=globs,
        custom_objects=custom_objects,
        printable_module_name='activation function')
    def testSerializationReturnsWrappedActivation(self, activation,
                                                  activation_config):
        quantize_activation = QuantizeAwareActivation(activation,
                                                      self.quantizer, 0,
                                                      self.TestLayer())
        serialized_quantize_activation = serialize_keras_object(
            quantize_activation)

        expected_config = {
            'class_name': 'QuantizeAwareActivation',
            'config': activation_config
        }
        self.assertEqual(expected_config, serialized_quantize_activation)

        deserialized_activation = deserialize_keras_object(
            serialized_quantize_activation,
            custom_objects={
                'QuantizeAwareActivation': QuantizeAwareActivation,
                'NoOpActivation': quantize_aware_activation.NoOpActivation
            })

        self.assertEqual(activation, deserialized_activation)
Example #14
0
  def testSerialization(self):
    quantize_provider = tflite_quantize_registry.TFLiteQuantizeProvider(
        ['kernel'], ['activation'], False)

    expected_config = {
        'class_name': 'TFLiteQuantizeProvider',
        'config': {
            'weight_attrs': ['kernel'],
            'activation_attrs': ['activation'],
            'quantize_output': False
        }
    }
    serialized_quantize_provider = serialize_keras_object(quantize_provider)

    self.assertEqual(expected_config, serialized_quantize_provider)

    quantize_provider_from_config = deserialize_keras_object(
        serialized_quantize_provider,
        module_objects=globals(),
        custom_objects=tflite_quantize_registry._types_dict())

    self.assertEqual(quantize_provider, quantize_provider_from_config)
Example #15
0
def deserialize(config, custom_objects=None):
  """Instantiates a layer from a config dictionary.

  Arguments:
      config: dict of the form {'class_name': str, 'config': dict}
      custom_objects: dict mapping class names (or function names)
          of custom (non-Keras) objects to class/functions

  Returns:
      Layer instance (may be Model, Sequential, Network, Layer...)
  """
  # Prevent circular dependencies.
  from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
  from tensorflow.python.keras.premade.linear import LinearModel  # pylint: disable=g-import-not-at-top
  from tensorflow.python.keras.premade.wide_deep import WideDeepModel  # pylint: disable=g-import-not-at-top
  from tensorflow.python.feature_column import dense_features  # pylint: disable=g-import-not-at-top
  from tensorflow.python.feature_column import sequence_feature_column as sfc  # pylint: disable=g-import-not-at-top

  globs = globals()  # All layers.
  globs['Network'] = models.Network
  globs['Model'] = models.Model
  globs['Sequential'] = models.Sequential
  globs['LinearModel'] = LinearModel
  globs['WideDeepModel'] = WideDeepModel

  # Prevent circular dependencies with FeatureColumn serialization.
  globs['DenseFeatures'] = dense_features.DenseFeatures
  globs['SequenceFeatures'] = sfc.SequenceFeatures

  layer_class_name = config['class_name']
  if layer_class_name in _DESERIALIZATION_TABLE:
    config['class_name'] = _DESERIALIZATION_TABLE[layer_class_name]

  return deserialize_keras_object(
      config,
      module_objects=globs,
      custom_objects=custom_objects,
      printable_module_name='layer')
Example #16
0
def deserialize(config, custom_objects=None):
  """Inverse of the `serialize` function.

  Arguments:
      config: Optimizer configuration dictionary.
      custom_objects: Optional dictionary mapping names (strings) to custom
        objects (classes and functions) to be considered during deserialization.

  Returns:
      A Keras Optimizer instance.
  """
  # loss_scale_optimizer has a direct dependency of optimizer, import here
  # rather than top to avoid the cyclic dependency.
  from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer  # pylint: disable=g-import-not-at-top
  all_classes = {
      'adadelta': adadelta_v2.Adadelta,
      'adagrad': adagrad_v2.Adagrad,
      'adam': adam_v2.Adam,
      'adamax': adamax_v2.Adamax,
      'nadam': nadam_v2.Nadam,
      'rmsprop': rmsprop_v2.RMSprop,
      'sgd': gradient_descent_v2.SGD,
      'ftrl': ftrl.Ftrl,
      'lossscaleoptimizer': loss_scale_optimizer.LossScaleOptimizer,
      # LossScaleOptimizerV1 deserializes into LossScaleOptimizer, as
      # LossScaleOptimizerV1 will be removed soon but deserializing it will
      # still be supported.
      'lossscaleoptimizerv1': loss_scale_optimizer.LossScaleOptimizer,
  }

  # Make deserialization case-insensitive for built-in optimizers.
  if config['class_name'].lower() in all_classes:
    config['class_name'] = config['class_name'].lower()
  return deserialize_keras_object(
      config,
      module_objects=all_classes,
      custom_objects=custom_objects,
      printable_module_name='optimizer')
    def testSerializationReturnsWrappedActivation_BuiltInActivation(self):
        activation = activations.get('softmax')
        quantize_activation = QuantizeAwareActivation(activation,
                                                      self.quantizer, 0,
                                                      self.TestLayer())

        expected_config = {
            'class_name': 'QuantizeAwareActivation',
            'config': {
                'activation': 'softmax'
            }
        }
        serialized_quantize_activation = serialize_keras_object(
            quantize_activation)

        self.assertEqual(expected_config, serialized_quantize_activation)

        deserialized_activation = deserialize_keras_object(
            serialized_quantize_activation,
            custom_objects={
                'QuantizeAwareActivation': QuantizeAwareActivation
            })

        self.assertEqual(activation, deserialized_activation)
Example #18
0
def deserialize(name, custom_objects=None):
    return deserialize_keras_object(name,
                                    module_objects=globals(),
                                    custom_objects=custom_objects,
                                    printable_module_name='loss function')
Example #19
0
def deserialize(config, custom_objects=None):
  return deserialize_keras_object(
      config,
      module_objects=globals(),
      custom_objects=custom_objects,
      printable_module_name='metric function')
Example #20
0
  def from_config(cls, config, custom_objects=None):
    config = config.copy()
    globs = globals()
    module = config.pop('module', None)
    if module in sys.modules:
      globs.update(sys.modules[module].__dict__)
    elif module is not None:
      # Note: we don't know the name of the function if it's a lambda.
      warnings.warn('{} is not loaded, but a Lambda layer uses it. '
                    'It may cause errors.'.format(module)
                    , UserWarning)
    if custom_objects:
      globs.update(custom_objects)
    function_type = config.pop('function_type')
    if function_type == 'function':
      # Simple lookup in custom objects
      function = generic_utils.deserialize_keras_object(
          config['function'],
          custom_objects=custom_objects,
          printable_module_name='function in Lambda layer')
    elif function_type == 'lambda':
      # Unsafe deserialization from bytecode
      function = generic_utils.func_load(config['function'], globs=globs)
    else:
      raise TypeError('Unknown function type:', function_type)

    output_shape_module = config.pop('output_shape_module', None)
    if output_shape_module in sys.modules:
      globs.update(sys.modules[output_shape_module].__dict__)
    elif output_shape_module is not None:
      # Note: we don't know the name of the function if it's a lambda.
      warnings.warn('{} is not loaded, but a Lambda layer uses it. '
                    'It may cause errors.'.format(output_shape_module)
                    , UserWarning)
    output_shape_type = config.pop('output_shape_type')
    if output_shape_type == 'function':
      # Simple lookup in custom objects
      output_shape = generic_utils.deserialize_keras_object(
          config['output_shape'],
          custom_objects=custom_objects,
          printable_module_name='output_shape function in Lambda layer')
    elif output_shape_type == 'lambda':
      # Unsafe deserialization from bytecode
      output_shape = generic_utils.func_load(config['output_shape'],
                                             globs=globs)
    else:
      output_shape = config['output_shape']

    # If arguments were numpy array, they have been saved as
    # list. We need to recover the ndarray
    if 'arguments' in config:
      for key in config['arguments']:
        if isinstance(config['arguments'][key], dict):
          arg_dict = config['arguments'][key]
          if 'type' in arg_dict and arg_dict['type'] == 'ndarray':
            # Overwrite the argument with its numpy translation
            config['arguments'][key] = np.array(arg_dict['value'])

    config['function'] = function
    config['output_shape'] = output_shape
    return cls(**config)
Example #21
0
  def from_config(cls, config, custom_objects=None):
    config = config.copy()
    globs = globals()
    module = config.pop('module', None)
    if module in sys.modules:
      globs.update(sys.modules[module].__dict__)
    elif module is not None:
      # Note: we don't know the name of the function if it's a lambda.
      warnings.warn('{} is not loaded, but a Lambda layer uses it. '
                    'It may cause errors.'.format(module)
                    , UserWarning)
    if custom_objects:
      globs.update(custom_objects)
    function_type = config.pop('function_type')
    if function_type == 'function':
      # Simple lookup in custom objects
      function = generic_utils.deserialize_keras_object(
          config['function'],
          custom_objects=custom_objects,
          printable_module_name='function in Lambda layer')
    elif function_type == 'lambda':
      # Unsafe deserialization from bytecode
      function = generic_utils.func_load(config['function'], globs=globs)
    else:
      raise TypeError('Unknown function type:', function_type)

    output_shape_module = config.pop('output_shape_module', None)
    if output_shape_module in sys.modules:
      globs.update(sys.modules[output_shape_module].__dict__)
    elif output_shape_module is not None:
      # Note: we don't know the name of the function if it's a lambda.
      warnings.warn('{} is not loaded, but a Lambda layer uses it. '
                    'It may cause errors.'.format(output_shape_module)
                    , UserWarning)
    output_shape_type = config.pop('output_shape_type')
    if output_shape_type == 'function':
      # Simple lookup in custom objects
      output_shape = generic_utils.deserialize_keras_object(
          config['output_shape'],
          custom_objects=custom_objects,
          printable_module_name='output_shape function in Lambda layer')
    elif output_shape_type == 'lambda':
      # Unsafe deserialization from bytecode
      output_shape = generic_utils.func_load(config['output_shape'],
                                             globs=globs)
    else:
      output_shape = config['output_shape']

    # If arguments were numpy array, they have been saved as
    # list. We need to recover the ndarray
    if 'arguments' in config:
      for key in config['arguments']:
        if isinstance(config['arguments'][key], dict):
          arg_dict = config['arguments'][key]
          if 'type' in arg_dict and arg_dict['type'] == 'ndarray':
            # Overwrite the argument with its numpy translation
            config['arguments'][key] = np.array(arg_dict['value'])

    config['function'] = function
    config['output_shape'] = output_shape
    return cls(**config)
Example #22
0
def deserialize(config, custom_objects=None):
    return deserialize_keras_object(config,
                                    module_objects=globals(),
                                    custom_objects=custom_objects,
                                    printable_module_name='metric function')
Example #23
0
 def from_config(cls, config):
   return cls(generic_utils.deserialize_keras_object(
       config['inner_layer']))
Example #24
0
def deserialize(name, custom_objects=None):
  return deserialize_keras_object(
      name,
      module_objects=globals(),
      custom_objects=custom_objects,
      printable_module_name='loss function')
Example #25
0
def _deserialize(config, custom_objects={}):
    custom_objects = {**custom_objects, **object_scope}
    return deserialize_keras_object(config,
                                    module_objects=globals(),
                                    custom_objects=custom_objects,
                                    printable_module_name='gating_functions')
Example #26
0
def deserialize(config, custom_objects=None):
  return deserialize_keras_object(
      config,
      module_objects=globals(),
      custom_objects=custom_objects,
      printable_module_name='regularizer')
def deserialize(config, custom_objects=None):
    return deserialize_keras_object(config,
                                    module_objects=globals(),
                                    custom_objects=custom_objects,
                                    printable_module_name='constraint')
Example #28
0
def deserialize(config, custom_objects=None):
    return generic_utils.deserialize_keras_object(
        config,
        module_objects=globals(),
        custom_objects=custom_objects,
        printable_module_name="decay")
def deserialize(config, custom_objects=None):
  return generic_utils.deserialize_keras_object(
      config,
      module_objects=globals(),
      custom_objects=custom_objects,
      printable_module_name="decay")
Example #30
0
from tensorflow.keras.layers import Layer, Dense
import tensorflow as tf
from tensorflow.keras import regularizers

#%%
l = kl.Dense(10, kernel_regularizer="l2")
kl.serialize(l)
#%%
l1 = kl.deserialize({'class_name': 'Dense', 'config': {'units': 5}})
l1.get_config()
#%%
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
deserialize_keras_object({
    'class_name': 'Dense',
    'config': {
        'units': 5
    }
},
                         module_objects=globals())


#%%
class ConstantMultiple(Layer):
    def __init__(self, init_val: float = 1, regularizer=None, **kwargs):
        super().__init__(**kwargs)
        self.init_val = init_val
        self.regularizer = regularizers.get(regularizer)

        self.c = self.add_weight(name="c", shape=(), regularizer=regularizer)

    def call(self, input):