Exemple #1
0
    def __init__(self,
                 regularizer=None,
                 activity_regularizer=None,
                 use_operator=False,
                 var_name="v",
                 **kwargs):
        """Initializes the MultiplyLayer.

        Args:
          regularizer: The weight regularizer on the scalar variable.
          activity_regularizer: The activity regularizer.
          use_operator: If True, add using the * operator. If False, add using
            tf.multiply.
          var_name: The name of the variable. It can be useful to pass a name other
            than 'v', to test having the attribute name (self.v) being different
            from the variable name.
          **kwargs: Passed to AssertTypeLayer constructor.
        """
        self._regularizer = regularizer
        if isinstance(regularizer, dict):
            self._regularizer = regularizers.deserialize(
                regularizer, custom_objects=globals())
        self._activity_regularizer = activity_regularizer
        if isinstance(activity_regularizer, dict):
            self._activity_regularizer = regularizers.deserialize(
                activity_regularizer, custom_objects=globals())

        self._use_operator = use_operator
        self._var_name = var_name
        super().__init__(activity_regularizer=self._activity_regularizer,
                         **kwargs)
Exemple #2
0
  def _init_from_metadata(cls, metadata):
    """Create revived layer from metadata stored in the SavedModel proto."""
    init_args = dict(
        name=metadata['name'],
        trainable=metadata['trainable'])
    if metadata.get('dtype') is not None:
      init_args['dtype'] = metadata['dtype']
    if metadata.get('batch_input_shape') is not None:
      init_args['batch_input_shape'] = metadata['batch_input_shape']

    revived_obj = cls(**init_args)

    with trackable.no_automatic_dependency_tracking_scope(revived_obj):
      # pylint:disable=protected-access
      revived_obj._expects_training_arg = metadata['expects_training_arg']
      config = metadata.get('config')
      if generic_utils.validate_config(config):
        revived_obj._config = config
      if metadata.get('input_spec') is not None:
        revived_obj.input_spec = recursively_deserialize_keras_object(
            metadata['input_spec'],
            module_objects={'InputSpec': input_spec.InputSpec})
      if metadata.get('activity_regularizer') is not None:
        revived_obj.activity_regularizer = regularizers.deserialize(
            metadata['activity_regularizer'])
      if metadata.get('_is_feature_layer') is not None:
        revived_obj._is_feature_layer = metadata['_is_feature_layer']
      if metadata.get('stateful') is not None:
        revived_obj.stateful = metadata['stateful']
      # pylint:enable=protected-access

    return revived_obj, _revive_setter
Exemple #3
0
  def _init_from_metadata(cls, metadata):
    """Create revived network from metadata stored in the SavedModel proto."""
    revived_obj = cls(name=metadata['name'])

    # Store attributes revived from SerializedAttributes in a un-tracked
    # dictionary. The attributes are the ones listed in CommonEndpoints or
    # "keras_api" for keras-specific attributes.
    with trackable.no_automatic_dependency_tracking_scope(revived_obj):
      # pylint:disable=protected-access
      revived_obj._expects_training_arg = metadata['expects_training_arg']
      config = metadata.get('config')
      if generic_utils.validate_config(config):
        revived_obj._config = config

      if metadata.get('activity_regularizer') is not None:
        revived_obj.activity_regularizer = regularizers.deserialize(
            metadata['activity_regularizer'])
      # pylint:enable=protected-access

    return revived_obj, _revive_setter  # pylint:disable=protected-access