예제 #1
0
    def __init__(self,
                 regularizer=None,
                 activity_regularizer=None,
                 use_operator=False,
                 var_name='v',
                 **kwargs):
        """Initializes the MultiplyLayer.

    Args:
      regularizer: The weight regularizer on the scalar variable.
      activity_regularizer: The activity regularizer.
      use_operator: If True, add using the * operator. If False, add using
        tf.multiply.
      var_name: The name of the variable. It can be useful to pass a name other
        than 'v', to test having the attribute name (self.v) being different
        from the variable name.
      **kwargs: Passed to AssertTypeLayer constructor.
    """
        self._regularizer = regularizer
        if isinstance(regularizer, dict):
            self._regularizer = regularizers.deserialize(
                regularizer, custom_objects=globals())
        self._activity_regularizer = activity_regularizer
        if isinstance(activity_regularizer, dict):
            self._activity_regularizer = regularizers.deserialize(
                activity_regularizer, custom_objects=globals())

        self._use_operator = use_operator
        self._var_name = var_name
        super(MultiplyLayer,
              self).__init__(activity_regularizer=self._activity_regularizer,
                             **kwargs)
예제 #2
0
  def _init_from_metadata(cls, metadata):
    """Create revived layer from metadata stored in the SavedModel proto."""
    init_args = dict(
        name=metadata['name'],
        trainable=metadata['trainable'])
    if metadata.get('dtype') is not None:
      init_args['dtype'] = metadata['dtype']
    if metadata.get('batch_input_shape') is not None:
      init_args['batch_input_shape'] = metadata['batch_input_shape']

    revived_obj = cls(**init_args)

    with trackable.no_automatic_dependency_tracking_scope(revived_obj):
      # pylint:disable=protected-access
      revived_obj._expects_training_arg = metadata['expects_training_arg']
      config = metadata.get('config')
      if generic_utils.validate_config(config):
        revived_obj._config = config
      if metadata.get('input_spec') is not None:
        revived_obj.input_spec = recursively_deserialize_keras_object(
            metadata['input_spec'],
            module_objects={'InputSpec': input_spec.InputSpec})
      if metadata.get('activity_regularizer') is not None:
        revived_obj.activity_regularizer = regularizers.deserialize(
            metadata['activity_regularizer'])
      if metadata.get('_is_feature_layer') is not None:
        revived_obj._is_feature_layer = metadata['_is_feature_layer']
      if metadata.get('stateful') is not None:
        revived_obj.stateful = metadata['stateful']
      # pylint:enable=protected-access

    return revived_obj, _revive_setter
예제 #3
0
    def _init_from_metadata(cls, metadata):
        """Create revived layer from metadata stored in the SavedModel proto."""
        init_args = dict(name=metadata['name'],
                         trainable=metadata['trainable'])
        if metadata.get('dtype') is not None:
            init_args['dtype'] = metadata['dtype']
        if metadata.get('batch_input_shape') is not None:
            init_args['batch_input_shape'] = metadata['batch_input_shape']

        revived_obj = cls(**init_args)

        with trackable.no_automatic_dependency_tracking_scope(revived_obj):
            # pylint:disable=protected-access
            revived_obj._expects_training_arg = metadata[
                'expects_training_arg']
            if metadata.get('config') is not None:
                revived_obj._config = metadata['config']
            if metadata.get('input_spec') is not None:
                revived_obj.input_spec = recursively_deserialize_keras_object(
                    metadata['input_spec'],
                    module_objects={'InputSpec': input_spec.InputSpec})
            if metadata.get('activity_regularizer') is not None:
                revived_obj.activity_regularizer = regularizers.deserialize(
                    metadata['activity_regularizer'])
            if metadata.get('_is_feature_layer') is not None:
                revived_obj._is_feature_layer = metadata['_is_feature_layer']

            # Store attributes revived from SerializedAttributes in a un-tracked
            # dictionary. The attributes are the ones listed in CommonEndpoints or
            # "keras_api" for keras-specific attributes.
            revived_obj._serialized_attributes = {'metadata': metadata}
            # pylint:enable=protected-access

        return revived_obj, _revive_setter
예제 #4
0
    def _init_from_metadata(cls, metadata):
        """Create revived network from metadata stored in the SavedModel proto."""
        revived_obj = cls(name=metadata['name'])

        with trackable.no_automatic_dependency_tracking_scope(revived_obj):
            # pylint:disable=protected-access
            if metadata.get('dtype') is not None:
                revived_obj._dtype = metadata['dtype']
            revived_obj.trainable = metadata['trainable']

            revived_obj._expects_training_arg = metadata[
                'expects_training_arg']
            if metadata.get('config') is not None:
                revived_obj._config = metadata['config']

            if metadata.get('activity_regularizer') is not None:
                revived_obj.activity_regularizer = regularizers.deserialize(
                    metadata['activity_regularizer'])

            # Store attributes revived from SerializedAttributes in a un-tracked
            # dictionary. The attributes are the ones listed in CommonEndpoints or
            # "keras_api" for keras-specific attributes.
            revived_obj._serialized_attributes = {}
            # pylint:enable=protected-access

        return revived_obj
예제 #5
0
def _set_network_attributes_from_metadata(revived_obj):
    """Sets attributes recorded in the metadata."""
    with trackable.no_automatic_dependency_tracking_scope(revived_obj):
        # pylint:disable=protected-access
        metadata = revived_obj._serialized_attributes['metadata']
        if metadata.get('dtype') is not None:
            revived_obj._dtype = metadata['dtype']
        revived_obj.trainable = metadata['trainable']

        revived_obj._expects_training_arg = metadata['expects_training_arg']
        if metadata.get('config') is not None:
            revived_obj._config = metadata['config']

        if metadata.get('activity_regularizer') is not None:
            revived_obj.activity_regularizer = regularizers.deserialize(
                metadata['activity_regularizer'])
예제 #6
0
def get(config):
    return regularizers.deserialize(config=config,
                                    custom_objects={
                                        "ctx_cond_neg_ent": ctx_cond_neg_ent,
                                        "l1_l2_tv": l1_l2_tv,
                                    })