def __init__(self, regularizer=None, activity_regularizer=None, use_operator=False, var_name='v', **kwargs): """Initializes the MultiplyLayer. Args: regularizer: The weight regularizer on the scalar variable. activity_regularizer: The activity regularizer. use_operator: If True, add using the * operator. If False, add using tf.multiply. var_name: The name of the variable. It can be useful to pass a name other than 'v', to test having the attribute name (self.v) being different from the variable name. **kwargs: Passed to AssertTypeLayer constructor. """ self._regularizer = regularizer if isinstance(regularizer, dict): self._regularizer = regularizers.deserialize( regularizer, custom_objects=globals()) self._activity_regularizer = activity_regularizer if isinstance(activity_regularizer, dict): self._activity_regularizer = regularizers.deserialize( activity_regularizer, custom_objects=globals()) self._use_operator = use_operator self._var_name = var_name super(MultiplyLayer, self).__init__(activity_regularizer=self._activity_regularizer, **kwargs)
def _init_from_metadata(cls, metadata): """Create revived layer from metadata stored in the SavedModel proto.""" init_args = dict( name=metadata['name'], trainable=metadata['trainable']) if metadata.get('dtype') is not None: init_args['dtype'] = metadata['dtype'] if metadata.get('batch_input_shape') is not None: init_args['batch_input_shape'] = metadata['batch_input_shape'] revived_obj = cls(**init_args) with trackable.no_automatic_dependency_tracking_scope(revived_obj): # pylint:disable=protected-access revived_obj._expects_training_arg = metadata['expects_training_arg'] config = metadata.get('config') if generic_utils.validate_config(config): revived_obj._config = config if metadata.get('input_spec') is not None: revived_obj.input_spec = recursively_deserialize_keras_object( metadata['input_spec'], module_objects={'InputSpec': input_spec.InputSpec}) if metadata.get('activity_regularizer') is not None: revived_obj.activity_regularizer = regularizers.deserialize( metadata['activity_regularizer']) if metadata.get('_is_feature_layer') is not None: revived_obj._is_feature_layer = metadata['_is_feature_layer'] if metadata.get('stateful') is not None: revived_obj.stateful = metadata['stateful'] # pylint:enable=protected-access return revived_obj, _revive_setter
def _init_from_metadata(cls, metadata): """Create revived layer from metadata stored in the SavedModel proto.""" init_args = dict(name=metadata['name'], trainable=metadata['trainable']) if metadata.get('dtype') is not None: init_args['dtype'] = metadata['dtype'] if metadata.get('batch_input_shape') is not None: init_args['batch_input_shape'] = metadata['batch_input_shape'] revived_obj = cls(**init_args) with trackable.no_automatic_dependency_tracking_scope(revived_obj): # pylint:disable=protected-access revived_obj._expects_training_arg = metadata[ 'expects_training_arg'] if metadata.get('config') is not None: revived_obj._config = metadata['config'] if metadata.get('input_spec') is not None: revived_obj.input_spec = recursively_deserialize_keras_object( metadata['input_spec'], module_objects={'InputSpec': input_spec.InputSpec}) if metadata.get('activity_regularizer') is not None: revived_obj.activity_regularizer = regularizers.deserialize( metadata['activity_regularizer']) if metadata.get('_is_feature_layer') is not None: revived_obj._is_feature_layer = metadata['_is_feature_layer'] # Store attributes revived from SerializedAttributes in a un-tracked # dictionary. The attributes are the ones listed in CommonEndpoints or # "keras_api" for keras-specific attributes. revived_obj._serialized_attributes = {'metadata': metadata} # pylint:enable=protected-access return revived_obj, _revive_setter
def _init_from_metadata(cls, metadata): """Create revived network from metadata stored in the SavedModel proto.""" revived_obj = cls(name=metadata['name']) with trackable.no_automatic_dependency_tracking_scope(revived_obj): # pylint:disable=protected-access if metadata.get('dtype') is not None: revived_obj._dtype = metadata['dtype'] revived_obj.trainable = metadata['trainable'] revived_obj._expects_training_arg = metadata[ 'expects_training_arg'] if metadata.get('config') is not None: revived_obj._config = metadata['config'] if metadata.get('activity_regularizer') is not None: revived_obj.activity_regularizer = regularizers.deserialize( metadata['activity_regularizer']) # Store attributes revived from SerializedAttributes in a un-tracked # dictionary. The attributes are the ones listed in CommonEndpoints or # "keras_api" for keras-specific attributes. revived_obj._serialized_attributes = {} # pylint:enable=protected-access return revived_obj
def _set_network_attributes_from_metadata(revived_obj): """Sets attributes recorded in the metadata.""" with trackable.no_automatic_dependency_tracking_scope(revived_obj): # pylint:disable=protected-access metadata = revived_obj._serialized_attributes['metadata'] if metadata.get('dtype') is not None: revived_obj._dtype = metadata['dtype'] revived_obj.trainable = metadata['trainable'] revived_obj._expects_training_arg = metadata['expects_training_arg'] if metadata.get('config') is not None: revived_obj._config = metadata['config'] if metadata.get('activity_regularizer') is not None: revived_obj.activity_regularizer = regularizers.deserialize( metadata['activity_regularizer'])
def get(config): return regularizers.deserialize(config=config, custom_objects={ "ctx_cond_neg_ent": ctx_cond_neg_ent, "l1_l2_tv": l1_l2_tv, })