def populate_deserializable_objects(): """Populates dict ALL_OBJECTS with every built-in layer. """ global LOCAL if not hasattr(LOCAL, 'ALL_OBJECTS'): LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = None if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled(): # Objects dict is already generated for the proper TF version: # do nothing. return LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = tf2.enabled() base_cls = base_layer.Layer generic_utils.populate_dict_with_module_objects( LOCAL.ALL_OBJECTS, ALL_MODULES, obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) # Overwrite certain V1 objects with V2 versions if tf2.enabled(): generic_utils.populate_dict_with_module_objects( LOCAL.ALL_OBJECTS, ALL_V2_MODULES, obj_filter=lambda x: inspect.isclass(x) and issubclass( x, base_cls)) # These deserialization aliases are added for backward compatibility, # as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2" # were used as class name for v1 and v2 version of BatchNormalization, # respectively. Here we explicitly convert them to their canonical names. LOCAL.ALL_OBJECTS[ 'BatchNormalizationV1'] = batch_normalization_v1.BatchNormalization LOCAL.ALL_OBJECTS[ 'BatchNormalizationV2'] = batch_normalization.BatchNormalization # Prevent circular dependencies. from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top LOCAL.ALL_OBJECTS['Input'] = input_layer.Input LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec LOCAL.ALL_OBJECTS['Functional'] = models.Functional LOCAL.ALL_OBJECTS['Model'] = models.Model LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential # Merge layers, function versions. LOCAL.ALL_OBJECTS['add'] = merge.add LOCAL.ALL_OBJECTS['subtract'] = merge.subtract LOCAL.ALL_OBJECTS['multiply'] = merge.multiply LOCAL.ALL_OBJECTS['average'] = merge.average LOCAL.ALL_OBJECTS['maximum'] = merge.maximum LOCAL.ALL_OBJECTS['minimum'] = merge.minimum LOCAL.ALL_OBJECTS['concatenate'] = merge.concatenate LOCAL.ALL_OBJECTS['dot'] = merge.dot
def populate_deserializable_objects(): """Populates dict ALL_OBJECTS with every built-in layer. """ global LOCAL if not hasattr(LOCAL, 'ALL_OBJECTS'): LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = None if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled(): # Objects dict is already generated for the proper TF version: # do nothing. return LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = tf2.enabled() base_cls = base_layer.Layer generic_utils.populate_dict_with_module_objects( LOCAL.ALL_OBJECTS, ALL_MODULES, obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) # Overwrite certain V1 objects with V2 versions if tf2.enabled(): generic_utils.populate_dict_with_module_objects( LOCAL.ALL_OBJECTS, ALL_V2_MODULES, obj_filter=lambda x: inspect.isclass(x) and issubclass( x, base_cls)) # Prevent circular dependencies. from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top LOCAL.ALL_OBJECTS['Input'] = input_layer.Input LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec LOCAL.ALL_OBJECTS['Functional'] = models.Functional LOCAL.ALL_OBJECTS['Model'] = models.Model LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential # Merge layers, function versions. LOCAL.ALL_OBJECTS['add'] = merge.add LOCAL.ALL_OBJECTS['subtract'] = merge.subtract LOCAL.ALL_OBJECTS['multiply'] = merge.multiply LOCAL.ALL_OBJECTS['average'] = merge.average LOCAL.ALL_OBJECTS['maximum'] = merge.maximum LOCAL.ALL_OBJECTS['minimum'] = merge.minimum LOCAL.ALL_OBJECTS['concatenate'] = merge.concatenate LOCAL.ALL_OBJECTS['dot'] = merge.dot
def deserialize_keras_object(identifier, module_objects=None, custom_objects=None, printable_module_name='object'): """Turns the serialized form of a Keras object back into an actual object.""" if identifier is None: return None if isinstance(identifier, dict): # In this case we are dealing with a Keras config dictionary. config = identifier (cls, cls_config) = class_and_config_for_serialized_keras_object( config, module_objects, custom_objects, printable_module_name) if hasattr(cls, 'from_config'): arg_spec = tf_inspect.getfullargspec(cls.from_config) custom_objects = custom_objects or {} if 'custom_objects' in arg_spec.args: return cls.from_config( cls_config, custom_objects=dict( list(_GLOBAL_CUSTOM_OBJECTS.items()) + list(custom_objects.items()))) with CustomObjectScope(custom_objects): return cls.from_config(cls_config) else: # Then `cls` may be a function returning a class. # in this case by convention `config` holds # the kwargs of the function. custom_objects = custom_objects or {} with CustomObjectScope(custom_objects): return cls(**cls_config) elif isinstance(identifier, six.string_types): object_name = identifier if custom_objects and object_name in custom_objects: obj = custom_objects.get(object_name) elif object_name in _GLOBAL_CUSTOM_OBJECTS: obj = _GLOBAL_CUSTOM_OBJECTS[object_name] else: obj = module_objects.get(object_name) if obj is None: raise ValueError( 'Unknown {}: {}. Please ensure this object is ' 'passed to the `custom_objects` argument. See ' 'https://www.tensorflow.org/guide/keras/save_and_serialize' '#registering_the_custom_object for details.'.format( printable_module_name, object_name)) # Classes passed by name are instantiated with no args, functions are # returned as-is. if tf_inspect.isclass(obj): return obj() return obj elif tf_inspect.isfunction(identifier): # If a function has already been deserialized, return as is. return identifier else: raise ValueError('Could not interpret serialized %s: %s' % (printable_module_name, identifier))
def decorator(f): if tf_inspect.isclass(f): raise ValueError('`run_v2_only` only supports test methods.') def decorated(self, *args, **kwargs): if not tf2.enabled(): self.skipTest('Test is only compatible with v2') return f(self, *args, **kwargs) return decorated
def get(identifier): if identifier is None: return None if isinstance(identifier, dict): return deserialize(identifier) elif isinstance(identifier, six.string_types): identifier = str(identifier) return deserialize(identifier) elif callable(identifier): if inspect.isclass(identifier): identifier = identifier() return identifier else: raise ValueError('Could not interpret initializer identifier: ' + str(identifier))
def get(identifier): """Retrieve a Keras initializer by the identifier. The `identifier` may be the string name of a initializers function or class ( case-sensitively). >>> identifier = 'Ones' >>> tf.keras.initializers.deserialize(identifier) <...tensorflow.python.keras.initializers.initializers_v2.Ones...> You can also specify `config` of the initializer to this function by passing dict containing `class_name` and `config` as an identifier. Also note that the `class_name` must map to a `Initializer` class. >>> cfg = {'class_name': 'Ones', 'config': {}} >>> tf.keras.initializers.deserialize(cfg) <...tensorflow.python.keras.initializers.initializers_v2.Ones...> In the case that the `identifier` is a class, this method will return a new instance of the class by its constructor. Args: identifier: String or dict that contains the initializer name or configurations. Returns: Initializer instance base on the input identifier. Raises: ValueError: If the input identifier is not a supported type or in a bad format. """ if identifier is None: return None if isinstance(identifier, dict): return deserialize(identifier) elif isinstance(identifier, str): identifier = str(identifier) return deserialize(identifier) elif callable(identifier): if inspect.isclass(identifier): identifier = identifier() return identifier else: raise ValueError('Could not interpret initializer identifier: ' + str(identifier))
def decorator(arg): """Registers a class with the Keras serialization framework.""" class_name = name if name is not None else arg.__name__ registered_name = package + '>' + class_name if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'): raise ValueError( 'Cannot register a class that does not have a get_config() method.') if registered_name in _GLOBAL_CUSTOM_OBJECTS: raise ValueError( '%s has already been registered to %s' % (registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name])) if arg in _GLOBAL_CUSTOM_NAMES: raise ValueError('%s has already been registered to %s' % (arg, _GLOBAL_CUSTOM_NAMES[arg])) _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg _GLOBAL_CUSTOM_NAMES[arg] = registered_name return arg
def make_variable(name, shape=None, dtype=dtypes.float32, initializer=None, trainable=None, caching_device=None, validate_shape=True, constraint=None, use_resource=None, collections=None, synchronization=tf_variables.VariableSynchronization.AUTO, aggregation=tf_variables.VariableAggregation.NONE, partitioner=None): # pylint: disable=unused-argument """Temporary util to create a variable (relies on `variable_scope.variable`). Some reuse-related technicalities prevent us from using `variable_scope.get_variable()` directly, so we use a subcomponent that has fewer constraints (`variable_scope.variable()`). In the longer term, it seems like a similar "default variable creator" method should exist in `Trackable` instead. When this happens, we can get rid of this temporary solution. TODO(fchollet): remove this method when no longer needed. Arguments: name: Variable name. shape: Variable shape. dtype: The type of the variable. Defaults to `self.dtype` or `float32`. initializer: Initializer instance (callable). trainable: Whether the variable should be part of the layer's "trainable_variables" (e.g. variables, biases) or "non_trainable_variables" (e.g. BatchNorm mean, stddev). Note, if the current variable scope is marked as non-trainable then this parameter is ignored and any added variables are also marked as non-trainable. `trainable` defaults to `True` unless `synchronization` is set to `ON_READ`. caching_device: Passed to `tf.Variable`. validate_shape: Passed to `tf.Variable`. constraint: Constraint instance (callable). use_resource: Whether to use a `ResourceVariable`. collections: List of graph collections keys. The new variable is added to these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. synchronization: Indicates when a distributed a variable will be aggregated. Accepted values are constants defined in the class `tf.VariableSynchronization`. By default the synchronization is set to `AUTO` and the current `DistributionStrategy` chooses when to synchronize. If `synchronization` is set to `ON_READ`, `trainable` must not be set to `True`. aggregation: Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class `tf.VariableAggregation`. partitioner: Not handled at this time. Returns: Variable instance. """ initializing_from_value = False if initializer is not None and not callable(initializer): initializing_from_value = True if initializing_from_value: init_val = initializer variable_dtype = None else: # Instantiate initializer if provided initializer is a type object. if tf_inspect.isclass(initializer): initializer = initializer() init_val = functools.partial(initializer, shape, dtype=dtype) variable_dtype = dtype.base_dtype if use_resource is None: use_resource = True # TODO(apassos,rohanj) figure out how to remove collections from here so we # can remove the V1. variable_shape = tensor_shape.TensorShape(shape) return tf_variables.VariableV1( initial_value=init_val, name=name, trainable=trainable, caching_device=caching_device, dtype=variable_dtype, validate_shape=validate_shape, constraint=constraint, use_resource=use_resource, collections=collections, synchronization=synchronization, aggregation=aggregation, shape=variable_shape if variable_shape else None)
def populate_deserializable_objects(): """Populates dict ALL_OBJECTS with every built-in initializer. """ global LOCAL if not hasattr(LOCAL, 'ALL_OBJECTS'): LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = None if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled(): # Objects dict is already generated for the proper TF version: # do nothing. return LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = tf2.enabled() # Compatibility aliases (need to exist in both V1 and V2). LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros # Out of an abundance of caution we also include these aliases that have # a non-zero probability of having been included in saved configs in the past. LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform if tf2.enabled(): # For V2, entries are generated automatically based on the content of # initializers_v2.py. v2_objs = {} base_cls = initializers_v2.Initializer generic_utils.populate_dict_with_module_objects( v2_objs, [initializers_v2], obj_filter=lambda x: inspect.isclass(x) and issubclass( x, base_cls)) for key, value in v2_objs.items(): LOCAL.ALL_OBJECTS[key] = value # Functional aliases. LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value else: # V1 initializers. v1_objs = { 'Constant': init_ops.Constant, 'GlorotNormal': init_ops.GlorotNormal, 'GlorotUniform': init_ops.GlorotUniform, 'Identity': init_ops.Identity, 'Ones': init_ops.Ones, 'Orthogonal': init_ops.Orthogonal, 'VarianceScaling': init_ops.VarianceScaling, 'Zeros': init_ops.Zeros, 'HeNormal': initializers_v1.HeNormal, 'HeUniform': initializers_v1.HeUniform, 'LecunNormal': initializers_v1.LecunNormal, 'LecunUniform': initializers_v1.LecunUniform, 'RandomNormal': initializers_v1.RandomNormal, 'RandomUniform': initializers_v1.RandomUniform, 'TruncatedNormal': initializers_v1.TruncatedNormal, } for key, value in v1_objs.items(): LOCAL.ALL_OBJECTS[key] = value # Functional aliases. LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value # More compatibility aliases. LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal'] LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform'] LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones'] LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros']
def deserialize_keras_object(identifier, module_objects=None, custom_objects=None, printable_module_name='object'): """Turns the serialized form of a Keras object back into an actual object. This function is for mid-level library implementers rather than end users. Importantly, this utility requires you to provide the dict of `module_objects` to use for looking up the object config; this is not populated by default. If you need a deserialization utility that has preexisting knowledge of built-in Keras objects, use e.g. `keras.layers.deserialize(config)`, `keras.metrics.deserialize(config)`, etc. Calling `deserialize_keras_object` while underneath the `SharedObjectLoadingScope` context manager will cause any already-seen shared objects to be returned as-is rather than creating a new object. Args: identifier: the serialized form of the object. module_objects: A dictionary of built-in objects to look the name up in. Generally, `module_objects` is provided by midlevel library implementers. custom_objects: A dictionary of custom objects to look the name up in. Generally, `custom_objects` is provided by the end user. printable_module_name: A human-readable string representing the type of the object. Printed in case of exception. Returns: The deserialized object. Example: A mid-level library implementer might want to implement a utility for retrieving an object from its config, as such: ```python def deserialize(config, custom_objects=None): return deserialize_keras_object( identifier, module_objects=globals(), custom_objects=custom_objects, name="MyObjectType", ) ``` This is how e.g. `keras.layers.deserialize()` is implemented. """ if identifier is None: return None if isinstance(identifier, dict): # In this case we are dealing with a Keras config dictionary. config = identifier (cls, cls_config) = class_and_config_for_serialized_keras_object( config, module_objects, custom_objects, printable_module_name) # If this object has already been loaded (i.e. it's shared between multiple # objects), return the already-loaded object. shared_object_id = config.get(SHARED_OBJECT_KEY) shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none if shared_object is not None: return shared_object if hasattr(cls, 'from_config'): arg_spec = tf_inspect.getfullargspec(cls.from_config) custom_objects = custom_objects or {} if 'custom_objects' in arg_spec.args: deserialized_obj = cls.from_config( cls_config, custom_objects=dict( list(_GLOBAL_CUSTOM_OBJECTS.items()) + list(custom_objects.items()))) else: with CustomObjectScope(custom_objects): deserialized_obj = cls.from_config(cls_config) else: # Then `cls` may be a function returning a class. # in this case by convention `config` holds # the kwargs of the function. custom_objects = custom_objects or {} with CustomObjectScope(custom_objects): deserialized_obj = cls(**cls_config) # Add object to shared objects, in case we find it referenced again. _shared_object_loading_scope().set(shared_object_id, deserialized_obj) return deserialized_obj elif isinstance(identifier, str): object_name = identifier if custom_objects and object_name in custom_objects: obj = custom_objects.get(object_name) elif object_name in _GLOBAL_CUSTOM_OBJECTS: obj = _GLOBAL_CUSTOM_OBJECTS[object_name] else: obj = module_objects.get(object_name) if obj is None: raise ValueError( 'Unknown {}: {}. Please ensure this object is ' 'passed to the `custom_objects` argument. See ' 'https://www.tensorflow.org/guide/keras/save_and_serialize' '#registering_the_custom_object for details.'.format( printable_module_name, object_name)) # Classes passed by name are instantiated with no args, functions are # returned as-is. if tf_inspect.isclass(obj): return obj() return obj elif tf_inspect.isfunction(identifier): # If a function has already been deserialized, return as is. return identifier else: raise ValueError('Could not interpret serialized %s: %s' % (printable_module_name, identifier))
def _get_single_variable(self, name, shape=None, dtype=dtypes.float32, initializer=None, regularizer=None, partition_info=None, reuse=None, trainable=None, caching_device=None, validate_shape=True, constraint=None, synchronization=vs.VariableSynchronization.AUTO, aggregation=vs.VariableAggregation.NONE): """Get or create a single Variable (e.g. a shard or entire variable). See the documentation of get_variable above (ignore partitioning components) for details. Args: name: see get_variable. shape: see get_variable. dtype: see get_variable. initializer: see get_variable. regularizer: see get_variable. partition_info: _PartitionInfo object. reuse: see get_variable. trainable: see get_variable. caching_device: see get_variable. validate_shape: see get_variable. constraint: see get_variable. synchronization: see get_variable. aggregation: see get_variable. Returns: A Variable. See documentation of get_variable above. Raises: ValueError: See documentation of get_variable above. """ # Set to true if initializer is a constant. initializing_from_value = False if initializer is not None and not callable(initializer): initializing_from_value = True if shape is not None and initializing_from_value: raise ValueError( "If initializer is a constant, do not specify shape.") dtype = dtypes.as_dtype(dtype) shape = as_shape(shape) if name in self._vars: # Here we handle the case when returning an existing variable. if reuse is False: # pylint: disable=g-bool-id-comparison err_msg = ("Variable %s already exists, disallowed." " Did you mean to set reuse=True or " "reuse=tf.AUTO_REUSE in VarScope?" % name) # ResourceVariables don't have an op associated with so no traceback raise ValueError(err_msg) found_var = self._vars[name] if not shape.is_compatible_with(found_var.get_shape()): raise ValueError( "Trying to share variable %s, but specified shape %s" " and found shape %s." % (name, shape, found_var.get_shape())) if not dtype.is_compatible_with(found_var.dtype): dtype_str = dtype.name found_type_str = found_var.dtype.name raise ValueError( "Trying to share variable %s, but specified dtype %s" " and found dtype %s." % (name, dtype_str, found_type_str)) return found_var # The code below handles only the case of creating a new variable. if reuse is True: # pylint: disable=g-bool-id-comparison raise ValueError( "Variable %s does not exist, or was not created with " "tf.get_variable(). Did you mean to set " "reuse=tf.AUTO_REUSE in VarScope?" % name) # Create the tensor to initialize the variable with default value. if initializer is None: initializer, initializing_from_value = self._get_default_initializer( name=name, shape=shape, dtype=dtype) # Enter an init scope when creating the initializer. with ops.init_scope(): if initializing_from_value: init_val = initializer variable_dtype = None else: # Instantiate initializer if provided initializer is a type object. if tf_inspect.isclass(initializer): initializer = initializer() if shape.is_fully_defined(): if "partition_info" in tf_inspect.getargspec( initializer).args: init_val = functools.partial( initializer, shape.as_list(), dtype=dtype, partition_info=partition_info) else: init_val = functools.partial(initializer, shape.as_list(), dtype=dtype) variable_dtype = dtype.base_dtype else: init_val = initializer variable_dtype = None # Create the variable (Always eagerly as a workaround for a strange # tpu / funcgraph / keras functional model interaction ) with ops.init_scope(): v = variables.Variable(initial_value=init_val, name=name, trainable=trainable, caching_device=caching_device, dtype=variable_dtype, validate_shape=validate_shape, constraint=constraint, synchronization=synchronization, aggregation=aggregation) self._vars[name] = v logging.vlog(1, "Created variable %s with shape %s and init %s", v.name, format(shape), initializer) # Run the regularizer if requested and save the resulting loss. if regularizer: self.add_regularizer(v, regularizer) return v