예제 #1
0
    def from_config(cls, config, custom_objects=None):

        user_cells = []
        session_cells = []

        for cell_config in config.pop('user_cells'):
            user_cells.append(
                deserialize_layer(cell_config, custom_objects=custom_objects))
        for cell_config in config.pop('session_cells'):
            session_cells.append(
                deserialize_layer(cell_config, custom_objects=custom_objects))
        embedding = deserialize_layer(config.pop('embedding_layer'),
                                      custom_objects=custom_objects)

        return cls(user_cells, session_cells, embedding, **config)
예제 #2
0
 def from_config(cls, config, custom_objects=None):
     from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
     cells = []
     for cell_config in config.pop('cells'):
         cells.append(
             deserialize_layer(cell_config, custom_objects=custom_objects))
     return cls(cells, **config)
예제 #3
0
 def from_config(cls, config, custom_objects=None):
     from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
     # Avoid mutating the input dict
     config = copy.deepcopy(config)
     layer = deserialize_layer(config.pop('layer'),
                               custom_objects=custom_objects)
     return cls(layer, **config)
예제 #4
0
  def from_config(cls, config):
    config = config.copy()

    from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
    layer = deserialize_layer(config.pop('layer'))
    config['layer'] = layer

    return cls(**config)
예제 #5
0
 def from_config(cls, config, custom_objects=None):
     # Instead of updating the input, create a copy and use that.
     config = copy.deepcopy(config)
     num_constants = config.pop('num_constants', 0)
     # Handle forward layer instantiation (as would parent class).
     from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
     config['layer'] = deserialize_layer(config['layer'],
                                         custom_objects=custom_objects)
     # Handle (optional) backward layer instantiation.
     backward_layer_config = config.pop('backward_layer', None)
     if backward_layer_config is not None:
         backward_layer = deserialize_layer(backward_layer_config,
                                            custom_objects=custom_objects)
         config['backward_layer'] = backward_layer
     # Instantiate the wrapper, adjust it and return it.
     layer = cls(**config)
     layer._num_constants = num_constants
     return layer
예제 #6
0
    def from_config(cls, config):
        config = config.copy()

        quantize_provider = deserialize_keras_object(
            config.pop('quantize_provider'),
            module_objects=globals(),
            custom_objects=None)

        layer = deserialize_layer(config.pop('layer'))

        return cls(layer=layer, quantize_provider=quantize_provider, **config)
예제 #7
0
  def _from_config(cls_initializer, config):
    """All shared from_config logic for fused layers."""
    config = config.copy()
    # use_bias is not an argument of this class, as explained by
    # comment in __init__.
    config.pop('use_bias')
    is_advanced_activation = 'class_name' in config['post_activation']
    if is_advanced_activation:
      config['post_activation'] = deserialize_layer(config['post_activation'])
    else:
      config['post_activation'] = activations.deserialize(
          config['post_activation'])

    return cls_initializer(**config)
    def from_config(cls, config, custom_objects=None):
        config = config.copy()

        number_of_clusters = config.pop('number_of_clusters')
        cluster_centroids_init = config.pop('cluster_centroids_init')
        config['number_of_clusters'] = number_of_clusters
        config['cluster_centroids_init'] = cluster_centroids_init

        from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
        layer = deserialize_layer(config.pop('layer'),
                                  custom_objects=custom_objects)
        config['layer'] = layer

        return cls(**config)
예제 #9
0
  def from_config(cls, config, custom_objects=None):
    # Instead of updating the input, create a copy and use that.
    config = config.copy()
    num_constants = config.pop('num_constants', None)
    backward_layer_config = config.pop('backward_layer', None)
    if backward_layer_config is not None:
      from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
      backward_layer = deserialize_layer(
          backward_layer_config, custom_objects=custom_objects)
      config['backward_layer'] = backward_layer

    layer = super(Bidirectional, cls).from_config(config,
                                                  custom_objects=custom_objects)
    layer._num_constants = num_constants
    return layer
    def from_config(cls, config):
        config = config.copy()

        # QuantizeWrapper may be constructed with any QuantizeProvider and the
        # wrapper itself cannot know all the possible provider classes.
        # The deserialization code should ensure the QuantizeProvider is in keras
        # serialization scope.
        quantize_provider = deserialize_keras_object(
            config.pop('quantize_provider'),
            module_objects=globals(),
            custom_objects=None)

        layer = deserialize_layer(config.pop('layer'))

        return cls(layer=layer, quantize_provider=quantize_provider, **config)
예제 #11
0
  def from_config(cls, config, custom_objects=None):
    # Instead of updating the input, create a copy and use that.
    config = config.copy()
    num_constants = config.pop('num_constants', 0)
    backward_layer_config = config.pop('backward_layer', None)
    if backward_layer_config is not None:
      from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
      backward_layer = deserialize_layer(
          backward_layer_config, custom_objects=custom_objects)
      config['backward_layer'] = backward_layer

    layer = super(Bidirectional, cls).from_config(config,
                                                  custom_objects=custom_objects)
    layer._num_constants = num_constants
    return layer
예제 #12
0
  def testSerializationQuantizeAnnotate(self):
    input_shape = (2,)
    layer = keras.layers.Dense(3)
    wrapper = quantize_annotate.QuantizeAnnotate(
        layer=layer,
        quantize_provider=self.TestQuantizeProvider(),
        input_shape=input_shape)

    custom_objects = {
        'QuantizeAnnotate': quantize_annotate.QuantizeAnnotate,
        'TestQuantizeProvider': self.TestQuantizeProvider
    }

    serialized_wrapper = serialize_layer(wrapper)
    with keras.utils.custom_object_scope(custom_objects):
      wrapper_from_config = deserialize_layer(serialized_wrapper)

    self.assertEqual(wrapper_from_config.get_config(), wrapper.get_config())
예제 #13
0
  def testSerializationQuantizeWrapper(self):
    input_shape = (2,)
    layer = keras.layers.Dense(3)
    wrapper = QuantizeWrapper(
        layer=layer,
        quantize_provider=self.quantize_registry.get_quantize_provider(layer),
        input_shape=input_shape)

    custom_objects = {
        'QuantizeAwareActivation': QuantizeAwareActivation,
        'QuantizeWrapper': QuantizeWrapper
    }
    custom_objects.update(tflite_quantize_registry._types_dict())

    serialized_wrapper = serialize_layer(wrapper)
    with keras.utils.custom_object_scope(custom_objects):
      wrapper_from_config = deserialize_layer(serialized_wrapper)

    self.assertEqual(wrapper_from_config.get_config(), wrapper.get_config())
  def from_config(cls, config):
    config = config.copy()

    quantize_provider = config.pop('quantize_provider')
    from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object  # pylint: disable=g-import-not-at-top
    # TODO(pulkitb): Add all known `QuantizeProvider`s to custom_objects
    custom_objects = {
        'QuantizeProvider': quantize_provider_mod.QuantizeProvider
    }
    config['quantize_provider'] = deserialize_keras_object(
        quantize_provider,
        module_objects=globals(),
        custom_objects=custom_objects)

    from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
    layer = deserialize_layer(config.pop('layer'))
    config['layer'] = layer

    return cls(**config)
예제 #15
0
    def from_config(cls, config):
        config = config.copy()

        pruning_schedule = config.pop('pruning_schedule')
        from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object  # pylint: disable=g-import-not-at-top
        # TODO(pulkitb): This should ideally be fetched from pruning_schedule,
        # which should maintain a list of all the pruning_schedules.
        custom_objects = {
            'ConstantSparsity': pruning_sched.ConstantSparsity,
            'PolynomialDecay': pruning_sched.PolynomialDecay
        }
        config['pruning_schedule'] = deserialize_keras_object(
            pruning_schedule,
            module_objects=globals(),
            custom_objects=custom_objects)

        from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
        layer = deserialize_layer(config.pop('layer'))
        config['layer'] = layer

        return cls(**config)
예제 #16
0
 def from_config(cls, config, custom_objects=None):
     from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
     layer = deserialize_layer(config.pop('layer'),
                               custom_objects=custom_objects)
     return cls(layer, **config)
예제 #17
0
 def from_config(cls, config, custom_objects=None):
     from tensorflow.python.keras.layers import deserialize as deserialize_layer
     cell = deserialize_layer(config.pop('cell'))
     return cls(cell, **config)
예제 #18
0
 def from_config(cls, config, custom_objects=None):
     config = config.copy()
     from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
     cell = deserialize_layer(config.pop("cell"),
                              custom_objects=custom_objects)
     return cls(cell, **config)
예제 #19
0
 def from_config(cls, config, custom_objects=None):
   from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
   layer = deserialize_layer(
       config.pop('layer'), custom_objects=custom_objects)
   return cls(layer, **config)