def deserialize_keras_object(identifier, module_objects=None, custom_objects=None, printable_module_name='object'): """Turns the serialized form of a Keras object back into an actual object.""" if identifier is None: return None if isinstance(identifier, dict): # In this case we are dealing with a Keras config dictionary. config = identifier (cls, cls_config) = class_and_config_for_serialized_keras_object( config, module_objects, custom_objects, printable_module_name) if hasattr(cls, 'from_config'): arg_spec = tf_inspect.getfullargspec(cls.from_config) custom_objects = custom_objects or {} if 'custom_objects' in arg_spec.args: return cls.from_config( cls_config, custom_objects=dict( list(_GLOBAL_CUSTOM_OBJECTS.items()) + list(custom_objects.items()))) with CustomObjectScope(custom_objects): return cls.from_config(cls_config) else: # Then `cls` may be a function returning a class. # in this case by convention `config` holds # the kwargs of the function. custom_objects = custom_objects or {} with CustomObjectScope(custom_objects): return cls(**cls_config) elif isinstance(identifier, six.string_types): object_name = identifier if custom_objects and object_name in custom_objects: obj = custom_objects.get(object_name) elif object_name in _GLOBAL_CUSTOM_OBJECTS: obj = _GLOBAL_CUSTOM_OBJECTS[object_name] else: obj = module_objects.get(object_name) if obj is None: raise ValueError( 'Unknown {}: {}. Please ensure this object is ' 'passed to the `custom_objects` argument. See ' 'https://www.tensorflow.org/guide/keras/save_and_serialize' '#registering_the_custom_object for details.'.format( printable_module_name, object_name)) # Classes passed by name are instantiated with no args, functions are # returned as-is. if tf_inspect.isclass(obj): return obj() return obj elif tf_inspect.isfunction(identifier): # If a function has already been deserialized, return as is. return identifier else: raise ValueError('Could not interpret serialized %s: %s' % (printable_module_name, identifier))
def decorator(f): if tf_inspect.isclass(f): return unittest.skipIf(condition=condition, reason=reason)(obj) def decorated(self, *args, **kwargs): if condition: self.skipTest(reason) return f(self, *args, **kwargs) return decorated
def decorator(f): if tf_inspect.isclass(f): raise ValueError('`run_v2_only` only supports test methods.') def decorated(self, *args, **kwargs): if not tf.__internal__.tf2.enabled(): self.skipTest('Test is only compatible with v2') return f(self, *args, **kwargs) return decorated
def get(identifier): if identifier is None: return None if isinstance(identifier, dict): return deserialize(identifier) elif isinstance(identifier, six.string_types): identifier = str(identifier) return deserialize(identifier) elif callable(identifier): if inspect.isclass(identifier): identifier = identifier() return identifier else: raise ValueError('Could not interpret initializer identifier: ' + str(identifier))
def get(identifier): """Retrieve a Keras initializer by the identifier. The `identifier` may be the string name of a initializers function or class ( case-sensitively). >>> identifier = 'Ones' >>> tf.keras.initializers.deserialize(identifier) <...keras.initializers.initializers_v2.Ones...> You can also specify `config` of the initializer to this function by passing dict containing `class_name` and `config` as an identifier. Also note that the `class_name` must map to a `Initializer` class. >>> cfg = {'class_name': 'Ones', 'config': {}} >>> tf.keras.initializers.deserialize(cfg) <...keras.initializers.initializers_v2.Ones...> In the case that the `identifier` is a class, this method will return a new instance of the class by its constructor. Args: identifier: String or dict that contains the initializer name or configurations. Returns: Initializer instance base on the input identifier. Raises: ValueError: If the input identifier is not a supported type or in a bad format. """ if identifier is None: return None if isinstance(identifier, dict): return deserialize(identifier) elif isinstance(identifier, str): identifier = str(identifier) return deserialize(identifier) elif callable(identifier): if inspect.isclass(identifier): identifier = identifier() return identifier else: raise ValueError("Could not interpret initializer identifier: " + str(identifier))
def make_variable(name, shape=None, dtype=tf.float32, initializer=None, layout=None, trainable=None, caching_device=None, validate_shape=True, constraint=None, use_resource=None, collections=None, synchronization=tf.VariableSynchronization.AUTO, aggregation=tf.VariableAggregation.NONE, partitioner=None): # Note that this function is copied from keras.engine.base_layer_utils. # The only part that is changed are the usage of tf.Variable. The original # version was using tf.compat.v1.Variable for backward compat for estimator. initializing_from_value = False if initializer is not None and not callable(initializer): initializing_from_value = True if initializing_from_value: init_val = initializer variable_dtype = None else: # Instantiate initializer if provided initializer is a type object. if tf_inspect.isclass(initializer): initializer = initializer() init_val = functools.partial(initializer, shape, dtype=dtype, layout=layout) variable_dtype = dtype.base_dtype variable_shape = tf.TensorShape(shape) return dtensor.DVariable(initial_value=init_val, name=name, trainable=trainable, caching_device=caching_device, dtype=variable_dtype, validate_shape=validate_shape, constraint=constraint, synchronization=synchronization, aggregation=aggregation, shape=variable_shape if variable_shape else None)
def get(identifier): """Retrieve an initializer by the identifier.""" # This function is copied from keras, and we only want to inject the logic for # `deserialize()`. if identifier is None: return None if isinstance(identifier, dict): return deserialize(identifier) elif isinstance(identifier, str): identifier = str(identifier) return deserialize(identifier) elif callable(identifier): if tf_inspect.isclass(identifier): identifier = identifier() return identifier else: raise ValueError('Could not interpret initializer identifier: ' + str(identifier))
def test_single_element(self, layer): # Instantiate the Layer subclasses if tf_inspect.isclass(layer) and issubclass(layer, keras.layers.Layer): layer = layer() # Processing a single element list should behave as identity. i1 = keras.layers.Input(shape=(4, 5)) o = layer([i1]) self.assertListEqual(o.shape.as_list(), [None, 4, 5]) model = keras.models.Model(i1, o) model.run_eagerly = test_utils.should_run_eagerly() x1 = np.random.random((2, 4, 5)) out = model.predict(x1) self.assertEqual(out.shape, (2, 4, 5)) self.assertAllClose(out, x1) # A single element must be passed as a list, not by itself. with self.assertRaisesRegex(ValueError, "called on a list"): layer(i1)
def decorator(arg): """Registers a class with the Keras serialization framework.""" class_name = name if name is not None else arg.__name__ registered_name = package + '>' + class_name if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'): raise ValueError( 'Cannot register a class that does not have a get_config() method.') if registered_name in _GLOBAL_CUSTOM_OBJECTS: raise ValueError( '%s has already been registered to %s' % (registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name])) if arg in _GLOBAL_CUSTOM_NAMES: raise ValueError('%s has already been registered to %s' % (arg, _GLOBAL_CUSTOM_NAMES[arg])) _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg _GLOBAL_CUSTOM_NAMES[arg] = registered_name return arg
def decorator(arg): """Registers a class with the Keras serialization framework.""" class_name = name if name is not None else arg.__name__ registered_name = package + ">" + class_name if tf_inspect.isclass(arg) and not hasattr(arg, "get_config"): raise ValueError("Cannot register a class that does not have a " "get_config() method.") if registered_name in _GLOBAL_CUSTOM_OBJECTS: raise ValueError( f"{registered_name} has already been registered to " f"{_GLOBAL_CUSTOM_OBJECTS[registered_name]}") if arg in _GLOBAL_CUSTOM_NAMES: raise ValueError(f"{arg} has already been registered to " f"{_GLOBAL_CUSTOM_NAMES[arg]}") _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg _GLOBAL_CUSTOM_NAMES[arg] = registered_name return arg
def make_variable(name, shape=None, dtype=tf.float32, initializer=None, trainable=None, caching_device=None, validate_shape=True, constraint=None, use_resource=None, collections=None, synchronization=tf.VariableSynchronization.AUTO, aggregation=tf.compat.v1.VariableAggregation.NONE, partitioner=None): # pylint: disable=unused-argument """Temporary util to create a variable (relies on `variable_scope.variable`). Some reuse-related technicalities prevent us from using `variable_scope.get_variable()` directly, so we use a subcomponent that has fewer constraints (`variable_scope.variable()`). In the longer term, it seems like a similar "default variable creator" method should exist in `Trackable` instead. When this happens, we can get rid of this temporary solution. TODO(fchollet): remove this method when no longer needed. Args: name: Variable name. shape: Variable shape. dtype: The type of the variable. Defaults to `self.dtype` or `float32`. initializer: Initializer instance (callable). trainable: Whether the variable should be part of the layer's "trainable_variables" (e.g. variables, biases) or "non_trainable_variables" (e.g. BatchNorm mean, stddev). Note, if the current variable scope is marked as non-trainable then this parameter is ignored and any added variables are also marked as non-trainable. `trainable` defaults to `True` unless `synchronization` is set to `ON_READ`. caching_device: Passed to `tf.Variable`. validate_shape: Passed to `tf.Variable`. constraint: Constraint instance (callable). use_resource: Whether to use a `ResourceVariable`. collections: List of graph collections keys. The new variable is added to these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. synchronization: Indicates when a distributed a variable will be aggregated. Accepted values are constants defined in the class `tf.VariableSynchronization`. By default the synchronization is set to `AUTO` and the current `DistributionStrategy` chooses when to synchronize. If `synchronization` is set to `ON_READ`, `trainable` must not be set to `True`. aggregation: Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class `tf.VariableAggregation`. partitioner: Not handled at this time. Returns: Variable instance. """ initializing_from_value = False if initializer is not None and not callable(initializer): initializing_from_value = True if initializing_from_value: init_val = initializer variable_dtype = None else: # Instantiate initializer if provided initializer is a type object. if tf_inspect.isclass(initializer): initializer = initializer() init_val = functools.partial(initializer, shape, dtype=dtype) variable_dtype = dtype.base_dtype if use_resource is None: use_resource = True # TODO(apassos,rohanj) figure out how to remove collections from here so we # can remove the V1. variable_shape = tf.TensorShape(shape) return tf.compat.v1.Variable( initial_value=init_val, name=name, trainable=trainable, caching_device=caching_device, dtype=variable_dtype, validate_shape=validate_shape, constraint=constraint, use_resource=use_resource, collections=collections, synchronization=synchronization, aggregation=aggregation, shape=variable_shape if variable_shape else None)
def populate_deserializable_objects(): """Populates dict ALL_OBJECTS with every built-in layer. """ global LOCAL if not hasattr(LOCAL, 'ALL_OBJECTS'): LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = None if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled( ): # Objects dict is already generated for the proper TF version: # do nothing. return LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled() base_cls = base_layer.Layer generic_utils.populate_dict_with_module_objects( LOCAL.ALL_OBJECTS, ALL_MODULES, obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) # Overwrite certain V1 objects with V2 versions if tf.__internal__.tf2.enabled(): generic_utils.populate_dict_with_module_objects( LOCAL.ALL_OBJECTS, ALL_V2_MODULES, obj_filter=lambda x: inspect.isclass(x) and issubclass( x, base_cls)) # These deserialization aliases are added for backward compatibility, # as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2" # were used as class name for v1 and v2 version of BatchNormalization, # respectively. Here we explicitly convert them to their canonical names. LOCAL.ALL_OBJECTS[ 'BatchNormalizationV1'] = normalization.BatchNormalization LOCAL.ALL_OBJECTS[ 'BatchNormalizationV2'] = normalization_v2.BatchNormalization # Prevent circular dependencies. from keras import models # pylint: disable=g-import-not-at-top from keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top from keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top from keras.feature_column.sequence_feature_column import SequenceFeatures # pylint: disable=g-import-not-at-top LOCAL.ALL_OBJECTS['Input'] = input_layer.Input LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec LOCAL.ALL_OBJECTS['Functional'] = models.Functional LOCAL.ALL_OBJECTS['Model'] = models.Model LOCAL.ALL_OBJECTS['SequenceFeatures'] = SequenceFeatures LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential LOCAL.ALL_OBJECTS['LinearModel'] = LinearModel LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel if tf.__internal__.tf2.enabled(): from keras.feature_column.dense_features_v2 import DenseFeatures # pylint: disable=g-import-not-at-top LOCAL.ALL_OBJECTS['DenseFeatures'] = DenseFeatures else: from keras.feature_column.dense_features import DenseFeatures # pylint: disable=g-import-not-at-top LOCAL.ALL_OBJECTS['DenseFeatures'] = DenseFeatures # Merge layers, function versions. LOCAL.ALL_OBJECTS['add'] = merge.add LOCAL.ALL_OBJECTS['subtract'] = merge.subtract LOCAL.ALL_OBJECTS['multiply'] = merge.multiply LOCAL.ALL_OBJECTS['average'] = merge.average LOCAL.ALL_OBJECTS['maximum'] = merge.maximum LOCAL.ALL_OBJECTS['minimum'] = merge.minimum LOCAL.ALL_OBJECTS['concatenate'] = merge.concatenate LOCAL.ALL_OBJECTS['dot'] = merge.dot
def deserialize_keras_object(identifier, module_objects=None, custom_objects=None, printable_module_name='object'): """Turns the serialized form of a Keras object back into an actual object. This function is for mid-level library implementers rather than end users. Importantly, this utility requires you to provide the dict of `module_objects` to use for looking up the object config; this is not populated by default. If you need a deserialization utility that has preexisting knowledge of built-in Keras objects, use e.g. `keras.layers.deserialize(config)`, `keras.metrics.deserialize(config)`, etc. Calling `deserialize_keras_object` while underneath the `SharedObjectLoadingScope` context manager will cause any already-seen shared objects to be returned as-is rather than creating a new object. Args: identifier: the serialized form of the object. module_objects: A dictionary of built-in objects to look the name up in. Generally, `module_objects` is provided by midlevel library implementers. custom_objects: A dictionary of custom objects to look the name up in. Generally, `custom_objects` is provided by the end user. printable_module_name: A human-readable string representing the type of the object. Printed in case of exception. Returns: The deserialized object. Example: A mid-level library implementer might want to implement a utility for retrieving an object from its config, as such: ```python def deserialize(config, custom_objects=None): return deserialize_keras_object( identifier, module_objects=globals(), custom_objects=custom_objects, name="MyObjectType", ) ``` This is how e.g. `keras.layers.deserialize()` is implemented. """ if identifier is None: return None if isinstance(identifier, dict): # In this case we are dealing with a Keras config dictionary. config = identifier (cls, cls_config) = class_and_config_for_serialized_keras_object( config, module_objects, custom_objects, printable_module_name) # If this object has already been loaded (i.e. it's shared between multiple # objects), return the already-loaded object. shared_object_id = config.get(SHARED_OBJECT_KEY) shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none if shared_object is not None: return shared_object if hasattr(cls, 'from_config'): arg_spec = tf_inspect.getfullargspec(cls.from_config) custom_objects = custom_objects or {} if 'custom_objects' in arg_spec.args: deserialized_obj = cls.from_config( cls_config, custom_objects=dict( list(_GLOBAL_CUSTOM_OBJECTS.items()) + list(custom_objects.items()))) else: with CustomObjectScope(custom_objects): deserialized_obj = cls.from_config(cls_config) else: # Then `cls` may be a function returning a class. # in this case by convention `config` holds # the kwargs of the function. custom_objects = custom_objects or {} with CustomObjectScope(custom_objects): deserialized_obj = cls(**cls_config) # Add object to shared objects, in case we find it referenced again. _shared_object_loading_scope().set(shared_object_id, deserialized_obj) return deserialized_obj elif isinstance(identifier, str): object_name = identifier if custom_objects and object_name in custom_objects: obj = custom_objects.get(object_name) elif object_name in _GLOBAL_CUSTOM_OBJECTS: obj = _GLOBAL_CUSTOM_OBJECTS[object_name] else: obj = module_objects.get(object_name) if obj is None: raise ValueError( 'Unknown {}: {}. Please ensure this object is ' 'passed to the `custom_objects` argument. See ' 'https://www.tensorflow.org/guide/keras/save_and_serialize' '#registering_the_custom_object for details.'.format( printable_module_name, object_name)) # Classes passed by name are instantiated with no args, functions are # returned as-is. if tf_inspect.isclass(obj): return obj() return obj elif tf_inspect.isfunction(identifier): # If a function has already been deserialized, return as is. return identifier else: raise ValueError('Could not interpret serialized %s: %s' % (printable_module_name, identifier))
def populate_deserializable_objects(): """Populates dict ALL_OBJECTS with every built-in layer.""" global LOCAL if not hasattr(LOCAL, "ALL_OBJECTS"): LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = None if ( LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled() ): # Objects dict is already generated for the proper TF version: # do nothing. return LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled() base_cls = base_layer.Layer generic_utils.populate_dict_with_module_objects( LOCAL.ALL_OBJECTS, ALL_MODULES, obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls), ) # Overwrite certain V1 objects with V2 versions if tf.__internal__.tf2.enabled(): generic_utils.populate_dict_with_module_objects( LOCAL.ALL_OBJECTS, ALL_V2_MODULES, obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls), ) # These deserialization aliases are added for backward compatibility, # as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2" # were used as class name for v1 and v2 version of BatchNormalization, # respectively. Here we explicitly convert them to their canonical names. LOCAL.ALL_OBJECTS[ "BatchNormalizationV1" ] = batch_normalization_v1.BatchNormalization LOCAL.ALL_OBJECTS[ "BatchNormalizationV2" ] = batch_normalization.BatchNormalization # Prevent circular dependencies. from keras import models from keras.feature_column.sequence_feature_column import ( SequenceFeatures, ) from keras.premade_models.linear import ( LinearModel, ) from keras.premade_models.wide_deep import ( WideDeepModel, ) LOCAL.ALL_OBJECTS["Input"] = input_layer.Input LOCAL.ALL_OBJECTS["InputSpec"] = input_spec.InputSpec LOCAL.ALL_OBJECTS["Functional"] = models.Functional LOCAL.ALL_OBJECTS["Model"] = models.Model LOCAL.ALL_OBJECTS["SequenceFeatures"] = SequenceFeatures LOCAL.ALL_OBJECTS["Sequential"] = models.Sequential LOCAL.ALL_OBJECTS["LinearModel"] = LinearModel LOCAL.ALL_OBJECTS["WideDeepModel"] = WideDeepModel if tf.__internal__.tf2.enabled(): from keras.feature_column.dense_features_v2 import ( DenseFeatures, ) LOCAL.ALL_OBJECTS["DenseFeatures"] = DenseFeatures else: from keras.feature_column.dense_features import ( DenseFeatures, ) LOCAL.ALL_OBJECTS["DenseFeatures"] = DenseFeatures # Merging layers, function versions. LOCAL.ALL_OBJECTS["add"] = merging.add LOCAL.ALL_OBJECTS["subtract"] = merging.subtract LOCAL.ALL_OBJECTS["multiply"] = merging.multiply LOCAL.ALL_OBJECTS["average"] = merging.average LOCAL.ALL_OBJECTS["maximum"] = merging.maximum LOCAL.ALL_OBJECTS["minimum"] = merging.minimum LOCAL.ALL_OBJECTS["concatenate"] = merging.concatenate LOCAL.ALL_OBJECTS["dot"] = merging.dot
def populate_deserializable_objects(): """Populates dict ALL_OBJECTS with every built-in initializer.""" global LOCAL if not hasattr(LOCAL, "ALL_OBJECTS"): LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = None if (LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled()): # Objects dict is already generated for the proper TF version: # do nothing. return LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled() # Compatibility aliases (need to exist in both V1 and V2). LOCAL.ALL_OBJECTS["ConstantV2"] = initializers_v2.Constant LOCAL.ALL_OBJECTS["GlorotNormalV2"] = initializers_v2.GlorotNormal LOCAL.ALL_OBJECTS["GlorotUniformV2"] = initializers_v2.GlorotUniform LOCAL.ALL_OBJECTS["HeNormalV2"] = initializers_v2.HeNormal LOCAL.ALL_OBJECTS["HeUniformV2"] = initializers_v2.HeUniform LOCAL.ALL_OBJECTS["IdentityV2"] = initializers_v2.Identity LOCAL.ALL_OBJECTS["LecunNormalV2"] = initializers_v2.LecunNormal LOCAL.ALL_OBJECTS["LecunUniformV2"] = initializers_v2.LecunUniform LOCAL.ALL_OBJECTS["OnesV2"] = initializers_v2.Ones LOCAL.ALL_OBJECTS["OrthogonalV2"] = initializers_v2.Orthogonal LOCAL.ALL_OBJECTS["RandomNormalV2"] = initializers_v2.RandomNormal LOCAL.ALL_OBJECTS["RandomUniformV2"] = initializers_v2.RandomUniform LOCAL.ALL_OBJECTS["TruncatedNormalV2"] = initializers_v2.TruncatedNormal LOCAL.ALL_OBJECTS["VarianceScalingV2"] = initializers_v2.VarianceScaling LOCAL.ALL_OBJECTS["ZerosV2"] = initializers_v2.Zeros # Out of an abundance of caution we also include these aliases that have # a non-zero probability of having been included in saved configs in the past. LOCAL.ALL_OBJECTS["glorot_normalV2"] = initializers_v2.GlorotNormal LOCAL.ALL_OBJECTS["glorot_uniformV2"] = initializers_v2.GlorotUniform LOCAL.ALL_OBJECTS["he_normalV2"] = initializers_v2.HeNormal LOCAL.ALL_OBJECTS["he_uniformV2"] = initializers_v2.HeUniform LOCAL.ALL_OBJECTS["lecun_normalV2"] = initializers_v2.LecunNormal LOCAL.ALL_OBJECTS["lecun_uniformV2"] = initializers_v2.LecunUniform if tf.__internal__.tf2.enabled(): # For V2, entries are generated automatically based on the content of # initializers_v2.py. v2_objs = {} base_cls = initializers_v2.Initializer generic_utils.populate_dict_with_module_objects( v2_objs, [initializers_v2], obj_filter=lambda x: inspect.isclass(x) and issubclass( x, base_cls), ) for key, value in v2_objs.items(): LOCAL.ALL_OBJECTS[key] = value # Functional aliases. LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value else: # V1 initializers. v1_objs = { "Constant": tf.compat.v1.constant_initializer, "GlorotNormal": tf.compat.v1.glorot_normal_initializer, "GlorotUniform": tf.compat.v1.glorot_uniform_initializer, "Identity": tf.compat.v1.initializers.identity, "Ones": tf.compat.v1.ones_initializer, "Orthogonal": tf.compat.v1.orthogonal_initializer, "VarianceScaling": tf.compat.v1.variance_scaling_initializer, "Zeros": tf.compat.v1.zeros_initializer, "HeNormal": initializers_v1.HeNormal, "HeUniform": initializers_v1.HeUniform, "LecunNormal": initializers_v1.LecunNormal, "LecunUniform": initializers_v1.LecunUniform, "RandomNormal": initializers_v1.RandomNormal, "RandomUniform": initializers_v1.RandomUniform, "TruncatedNormal": initializers_v1.TruncatedNormal, } for key, value in v1_objs.items(): LOCAL.ALL_OBJECTS[key] = value # Functional aliases. LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value # More compatibility aliases. LOCAL.ALL_OBJECTS["normal"] = LOCAL.ALL_OBJECTS["random_normal"] LOCAL.ALL_OBJECTS["uniform"] = LOCAL.ALL_OBJECTS["random_uniform"] LOCAL.ALL_OBJECTS["one"] = LOCAL.ALL_OBJECTS["ones"] LOCAL.ALL_OBJECTS["zero"] = LOCAL.ALL_OBJECTS["zeros"]
def _get_single_variable( self, name, shape=None, dtype=tf.float32, initializer=None, regularizer=None, partition_info=None, reuse=None, trainable=None, caching_device=None, validate_shape=True, constraint=None, synchronization=tf.VariableSynchronization.AUTO, aggregation=tf.compat.v1.VariableAggregation.NONE, ): """Get or create a single Variable (e.g. a shard or entire variable). See the documentation of get_variable above (ignore partitioning components) for details. Args: name: see get_variable. shape: see get_variable. dtype: see get_variable. initializer: see get_variable. regularizer: see get_variable. partition_info: _PartitionInfo object. reuse: see get_variable. trainable: see get_variable. caching_device: see get_variable. validate_shape: see get_variable. constraint: see get_variable. synchronization: see get_variable. aggregation: see get_variable. Returns: A Variable. See documentation of get_variable above. Raises: ValueError: See documentation of get_variable above. """ # Set to true if initializer is a constant. initializing_from_value = False if initializer is not None and not callable(initializer): initializing_from_value = True if shape is not None and initializing_from_value: raise ValueError( "If initializer is a constant, do not specify shape.") dtype = tf.as_dtype(dtype) shape = as_shape(shape) if name in self._vars: # Here we handle the case when returning an existing variable. found_var = self._vars[name] if not shape.is_compatible_with(found_var.get_shape()): raise ValueError( "Trying to share variable %s, but specified shape %s" " and found shape %s." % (name, shape, found_var.get_shape())) if not dtype.is_compatible_with(found_var.dtype): dtype_str = dtype.name found_type_str = found_var.dtype.name raise ValueError( "Trying to share variable %s, but specified dtype %s" " and found dtype %s." % (name, dtype_str, found_type_str)) return found_var # The code below handles only the case of creating a new variable. if reuse is True: # pylint: disable=g-bool-id-comparison raise ValueError( "Variable %s does not exist, or was not created with " "tf.get_variable(). Did you mean to set " "reuse=tf.AUTO_REUSE in VarScope?" % name) # Create the tensor to initialize the variable with default value. if initializer is None: ( initializer, initializing_from_value, ) = self._get_default_initializer(name=name, shape=shape, dtype=dtype) # Enter an init scope when creating the initializer. with tf.init_scope(): if initializing_from_value: init_val = initializer variable_dtype = None else: # Instantiate initializer if provided initializer is a type object. if tf_inspect.isclass(initializer): initializer = initializer() if shape.is_fully_defined(): if ("partition_info" in tf_inspect.getargspec(initializer).args): init_val = functools.partial( initializer, shape.as_list(), dtype=dtype, partition_info=partition_info, ) else: init_val = functools.partial(initializer, shape.as_list(), dtype=dtype) variable_dtype = dtype.base_dtype else: init_val = initializer variable_dtype = None # Create the variable (Always eagerly as a workaround for a strange # tpu / funcgraph / keras functional model interaction ) with tf.init_scope(): v = tf.Variable( initial_value=init_val, name=name, trainable=trainable, caching_device=caching_device, dtype=variable_dtype, validate_shape=validate_shape, constraint=constraint, synchronization=synchronization, aggregation=aggregation, ) self._vars[name] = v logging.vlog( 1, "Created variable %s with shape %s and init %s", v.name, format(shape), initializer, ) # Run the regularizer if requested and save the resulting loss. if regularizer: self.add_regularizer(v, regularizer) return v