コード例 #1
0
def deserialize_keras_object(identifier,
                             module_objects=None,
                             custom_objects=None,
                             printable_module_name='object'):
    """Turns the serialized form of a Keras object back into an actual object."""
    if identifier is None:
        return None

    if isinstance(identifier, dict):
        # In this case we are dealing with a Keras config dictionary.
        config = identifier
        (cls, cls_config) = class_and_config_for_serialized_keras_object(
            config, module_objects, custom_objects, printable_module_name)

        if hasattr(cls, 'from_config'):
            arg_spec = tf_inspect.getfullargspec(cls.from_config)
            custom_objects = custom_objects or {}

            if 'custom_objects' in arg_spec.args:
                return cls.from_config(
                    cls_config,
                    custom_objects=dict(
                        list(_GLOBAL_CUSTOM_OBJECTS.items()) +
                        list(custom_objects.items())))
            with CustomObjectScope(custom_objects):
                return cls.from_config(cls_config)
        else:
            # Then `cls` may be a function returning a class.
            # in this case by convention `config` holds
            # the kwargs of the function.
            custom_objects = custom_objects or {}
            with CustomObjectScope(custom_objects):
                return cls(**cls_config)
    elif isinstance(identifier, six.string_types):
        object_name = identifier
        if custom_objects and object_name in custom_objects:
            obj = custom_objects.get(object_name)
        elif object_name in _GLOBAL_CUSTOM_OBJECTS:
            obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
        else:
            obj = module_objects.get(object_name)
            if obj is None:
                raise ValueError(
                    'Unknown {}: {}. Please ensure this object is '
                    'passed to the `custom_objects` argument. See '
                    'https://www.tensorflow.org/guide/keras/save_and_serialize'
                    '#registering_the_custom_object for details.'.format(
                        printable_module_name, object_name))

        # Classes passed by name are instantiated with no args, functions are
        # returned as-is.
        if tf_inspect.isclass(obj):
            return obj()
        return obj
    elif tf_inspect.isfunction(identifier):
        # If a function has already been deserialized, return as is.
        return identifier
    else:
        raise ValueError('Could not interpret serialized %s: %s' %
                         (printable_module_name, identifier))
コード例 #2
0
ファイル: test_utils.py プロジェクト: bhardwajRahul/keras
    def decorator(f):
        if tf_inspect.isclass(f):
            return unittest.skipIf(condition=condition, reason=reason)(obj)

        def decorated(self, *args, **kwargs):
            if condition:
                self.skipTest(reason)
            return f(self, *args, **kwargs)

        return decorated
コード例 #3
0
    def decorator(f):
        if tf_inspect.isclass(f):
            raise ValueError('`run_v2_only` only supports test methods.')

        def decorated(self, *args, **kwargs):
            if not tf.__internal__.tf2.enabled():
                self.skipTest('Test is only compatible with v2')

            return f(self, *args, **kwargs)

        return decorated
コード例 #4
0
def get(identifier):
    if identifier is None:
        return None
    if isinstance(identifier, dict):
        return deserialize(identifier)
    elif isinstance(identifier, six.string_types):
        identifier = str(identifier)
        return deserialize(identifier)
    elif callable(identifier):
        if inspect.isclass(identifier):
            identifier = identifier()
        return identifier
    else:
        raise ValueError('Could not interpret initializer identifier: ' +
                         str(identifier))
コード例 #5
0
ファイル: __init__.py プロジェクト: bhardwajRahul/keras
def get(identifier):
    """Retrieve a Keras initializer by the identifier.

    The `identifier` may be the string name of a initializers function or class (
    case-sensitively).

    >>> identifier = 'Ones'
    >>> tf.keras.initializers.deserialize(identifier)
    <...keras.initializers.initializers_v2.Ones...>

    You can also specify `config` of the initializer to this function by passing
    dict containing `class_name` and `config` as an identifier. Also note that the
    `class_name` must map to a `Initializer` class.

    >>> cfg = {'class_name': 'Ones', 'config': {}}
    >>> tf.keras.initializers.deserialize(cfg)
    <...keras.initializers.initializers_v2.Ones...>

    In the case that the `identifier` is a class, this method will return a new
    instance of the class by its constructor.

    Args:
      identifier: String or dict that contains the initializer name or
        configurations.

    Returns:
      Initializer instance base on the input identifier.

    Raises:
      ValueError: If the input identifier is not a supported type or in a bad
        format.
    """

    if identifier is None:
        return None
    if isinstance(identifier, dict):
        return deserialize(identifier)
    elif isinstance(identifier, str):
        identifier = str(identifier)
        return deserialize(identifier)
    elif callable(identifier):
        if inspect.isclass(identifier):
            identifier = identifier()
        return identifier
    else:
        raise ValueError("Could not interpret initializer identifier: " +
                         str(identifier))
コード例 #6
0
ファイル: layers.py プロジェクト: ttigong/keras
def make_variable(name,
                  shape=None,
                  dtype=tf.float32,
                  initializer=None,
                  layout=None,
                  trainable=None,
                  caching_device=None,
                  validate_shape=True,
                  constraint=None,
                  use_resource=None,
                  collections=None,
                  synchronization=tf.VariableSynchronization.AUTO,
                  aggregation=tf.VariableAggregation.NONE,
                  partitioner=None):
    # Note that this function is copied from keras.engine.base_layer_utils.
    # The only part that is changed are the usage of tf.Variable. The original
    # version was using tf.compat.v1.Variable for backward compat for estimator.
    initializing_from_value = False
    if initializer is not None and not callable(initializer):
        initializing_from_value = True

    if initializing_from_value:
        init_val = initializer
        variable_dtype = None
    else:
        # Instantiate initializer if provided initializer is a type object.
        if tf_inspect.isclass(initializer):
            initializer = initializer()
        init_val = functools.partial(initializer,
                                     shape,
                                     dtype=dtype,
                                     layout=layout)
        variable_dtype = dtype.base_dtype

    variable_shape = tf.TensorShape(shape)

    return dtensor.DVariable(initial_value=init_val,
                             name=name,
                             trainable=trainable,
                             caching_device=caching_device,
                             dtype=variable_dtype,
                             validate_shape=validate_shape,
                             constraint=constraint,
                             synchronization=synchronization,
                             aggregation=aggregation,
                             shape=variable_shape if variable_shape else None)
コード例 #7
0
ファイル: initializers.py プロジェクト: ttigong/keras
def get(identifier):
    """Retrieve an initializer by the identifier."""
    # This function is copied from keras, and we only want to inject the logic for
    # `deserialize()`.
    if identifier is None:
        return None
    if isinstance(identifier, dict):
        return deserialize(identifier)
    elif isinstance(identifier, str):
        identifier = str(identifier)
        return deserialize(identifier)
    elif callable(identifier):
        if tf_inspect.isclass(identifier):
            identifier = identifier()
        return identifier
    else:
        raise ValueError('Could not interpret initializer identifier: ' +
                         str(identifier))
コード例 #8
0
ファイル: merging_test.py プロジェクト: bhardwajRahul/keras
    def test_single_element(self, layer):
        # Instantiate the Layer subclasses
        if tf_inspect.isclass(layer) and issubclass(layer, keras.layers.Layer):
            layer = layer()

        # Processing a single element list should behave as identity.
        i1 = keras.layers.Input(shape=(4, 5))
        o = layer([i1])
        self.assertListEqual(o.shape.as_list(), [None, 4, 5])
        model = keras.models.Model(i1, o)
        model.run_eagerly = test_utils.should_run_eagerly()

        x1 = np.random.random((2, 4, 5))
        out = model.predict(x1)
        self.assertEqual(out.shape, (2, 4, 5))
        self.assertAllClose(out, x1)

        # A single element must be passed as a list, not by itself.
        with self.assertRaisesRegex(ValueError, "called on a list"):
            layer(i1)
コード例 #9
0
ファイル: generic_utils.py プロジェクト: ohsdba/keras
  def decorator(arg):
    """Registers a class with the Keras serialization framework."""
    class_name = name if name is not None else arg.__name__
    registered_name = package + '>' + class_name

    if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'):
      raise ValueError(
          'Cannot register a class that does not have a get_config() method.')

    if registered_name in _GLOBAL_CUSTOM_OBJECTS:
      raise ValueError(
          '%s has already been registered to %s' %
          (registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name]))

    if arg in _GLOBAL_CUSTOM_NAMES:
      raise ValueError('%s has already been registered to %s' %
                       (arg, _GLOBAL_CUSTOM_NAMES[arg]))
    _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg
    _GLOBAL_CUSTOM_NAMES[arg] = registered_name

    return arg
コード例 #10
0
ファイル: generic_utils.py プロジェクト: paolodedios/keras
    def decorator(arg):
        """Registers a class with the Keras serialization framework."""
        class_name = name if name is not None else arg.__name__
        registered_name = package + ">" + class_name

        if tf_inspect.isclass(arg) and not hasattr(arg, "get_config"):
            raise ValueError("Cannot register a class that does not have a "
                             "get_config() method.")

        if registered_name in _GLOBAL_CUSTOM_OBJECTS:
            raise ValueError(
                f"{registered_name} has already been registered to "
                f"{_GLOBAL_CUSTOM_OBJECTS[registered_name]}")

        if arg in _GLOBAL_CUSTOM_NAMES:
            raise ValueError(f"{arg} has already been registered to "
                             f"{_GLOBAL_CUSTOM_NAMES[arg]}")
        _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg
        _GLOBAL_CUSTOM_NAMES[arg] = registered_name

        return arg
コード例 #11
0
ファイル: base_layer_utils.py プロジェクト: yi212212/keras
def make_variable(name,
                  shape=None,
                  dtype=tf.float32,
                  initializer=None,
                  trainable=None,
                  caching_device=None,
                  validate_shape=True,
                  constraint=None,
                  use_resource=None,
                  collections=None,
                  synchronization=tf.VariableSynchronization.AUTO,
                  aggregation=tf.compat.v1.VariableAggregation.NONE,
                  partitioner=None):  # pylint: disable=unused-argument
    """Temporary util to create a variable (relies on `variable_scope.variable`).

  Some reuse-related technicalities prevent us from using
  `variable_scope.get_variable()` directly, so we use a subcomponent
  that has fewer constraints (`variable_scope.variable()`).

  In the longer term, it seems like a similar "default variable creator" method
  should exist in `Trackable` instead. When this happens, we can get
  rid of this temporary solution.

  TODO(fchollet): remove this method when no longer needed.

  Args:
    name: Variable name.
    shape: Variable shape.
    dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
    initializer: Initializer instance (callable).
    trainable: Whether the variable should be part of the layer's
      "trainable_variables" (e.g. variables, biases)
      or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
      Note, if the current variable scope is marked as non-trainable
      then this parameter is ignored and any added variables are also
      marked as non-trainable. `trainable` defaults to `True` unless
      `synchronization` is set to `ON_READ`.
    caching_device: Passed to `tf.Variable`.
    validate_shape: Passed to `tf.Variable`.
    constraint: Constraint instance (callable).
    use_resource: Whether to use a `ResourceVariable`.
    collections: List of graph collections keys. The new variable is added to
      these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
    synchronization: Indicates when a distributed a variable will be
      aggregated. Accepted values are constants defined in the class
      `tf.VariableSynchronization`. By default the synchronization is set to
      `AUTO` and the current `DistributionStrategy` chooses
      when to synchronize. If `synchronization` is set to `ON_READ`,
      `trainable` must not be set to `True`.
    aggregation: Indicates how a distributed variable will be aggregated.
      Accepted values are constants defined in the class
      `tf.VariableAggregation`.
    partitioner: Not handled at this time.

  Returns:
    Variable instance.
  """
    initializing_from_value = False
    if initializer is not None and not callable(initializer):
        initializing_from_value = True

    if initializing_from_value:
        init_val = initializer
        variable_dtype = None
    else:
        # Instantiate initializer if provided initializer is a type object.
        if tf_inspect.isclass(initializer):
            initializer = initializer()
        init_val = functools.partial(initializer, shape, dtype=dtype)
        variable_dtype = dtype.base_dtype
    if use_resource is None:
        use_resource = True

    # TODO(apassos,rohanj) figure out how to remove collections from here so we
    # can remove the V1.
    variable_shape = tf.TensorShape(shape)
    return tf.compat.v1.Variable(
        initial_value=init_val,
        name=name,
        trainable=trainable,
        caching_device=caching_device,
        dtype=variable_dtype,
        validate_shape=validate_shape,
        constraint=constraint,
        use_resource=use_resource,
        collections=collections,
        synchronization=synchronization,
        aggregation=aggregation,
        shape=variable_shape if variable_shape else None)
コード例 #12
0
def populate_deserializable_objects():
    """Populates dict ALL_OBJECTS with every built-in layer.
  """
    global LOCAL
    if not hasattr(LOCAL, 'ALL_OBJECTS'):
        LOCAL.ALL_OBJECTS = {}
        LOCAL.GENERATED_WITH_V2 = None

    if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled(
    ):
        # Objects dict is already generated for the proper TF version:
        # do nothing.
        return

    LOCAL.ALL_OBJECTS = {}
    LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled()

    base_cls = base_layer.Layer
    generic_utils.populate_dict_with_module_objects(
        LOCAL.ALL_OBJECTS,
        ALL_MODULES,
        obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))

    # Overwrite certain V1 objects with V2 versions
    if tf.__internal__.tf2.enabled():
        generic_utils.populate_dict_with_module_objects(
            LOCAL.ALL_OBJECTS,
            ALL_V2_MODULES,
            obj_filter=lambda x: inspect.isclass(x) and issubclass(
                x, base_cls))

    # These deserialization aliases are added for backward compatibility,
    # as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2"
    # were used as class name for v1 and v2 version of BatchNormalization,
    # respectively. Here we explicitly convert them to their canonical names.
    LOCAL.ALL_OBJECTS[
        'BatchNormalizationV1'] = normalization.BatchNormalization
    LOCAL.ALL_OBJECTS[
        'BatchNormalizationV2'] = normalization_v2.BatchNormalization

    # Prevent circular dependencies.
    from keras import models  # pylint: disable=g-import-not-at-top
    from keras.premade.linear import LinearModel  # pylint: disable=g-import-not-at-top
    from keras.premade.wide_deep import WideDeepModel  # pylint: disable=g-import-not-at-top
    from keras.feature_column.sequence_feature_column import SequenceFeatures  # pylint: disable=g-import-not-at-top

    LOCAL.ALL_OBJECTS['Input'] = input_layer.Input
    LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec
    LOCAL.ALL_OBJECTS['Functional'] = models.Functional
    LOCAL.ALL_OBJECTS['Model'] = models.Model
    LOCAL.ALL_OBJECTS['SequenceFeatures'] = SequenceFeatures
    LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential
    LOCAL.ALL_OBJECTS['LinearModel'] = LinearModel
    LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel

    if tf.__internal__.tf2.enabled():
        from keras.feature_column.dense_features_v2 import DenseFeatures  # pylint: disable=g-import-not-at-top
        LOCAL.ALL_OBJECTS['DenseFeatures'] = DenseFeatures
    else:
        from keras.feature_column.dense_features import DenseFeatures  # pylint: disable=g-import-not-at-top
        LOCAL.ALL_OBJECTS['DenseFeatures'] = DenseFeatures

    # Merge layers, function versions.
    LOCAL.ALL_OBJECTS['add'] = merge.add
    LOCAL.ALL_OBJECTS['subtract'] = merge.subtract
    LOCAL.ALL_OBJECTS['multiply'] = merge.multiply
    LOCAL.ALL_OBJECTS['average'] = merge.average
    LOCAL.ALL_OBJECTS['maximum'] = merge.maximum
    LOCAL.ALL_OBJECTS['minimum'] = merge.minimum
    LOCAL.ALL_OBJECTS['concatenate'] = merge.concatenate
    LOCAL.ALL_OBJECTS['dot'] = merge.dot
コード例 #13
0
ファイル: generic_utils.py プロジェクト: yi212212/keras
def deserialize_keras_object(identifier,
                             module_objects=None,
                             custom_objects=None,
                             printable_module_name='object'):
    """Turns the serialized form of a Keras object back into an actual object.

  This function is for mid-level library implementers rather than end users.

  Importantly, this utility requires you to provide the dict of `module_objects`
  to use for looking up the object config; this is not populated by default.
  If you need a deserialization utility that has preexisting knowledge of
  built-in Keras objects, use e.g. `keras.layers.deserialize(config)`,
  `keras.metrics.deserialize(config)`, etc.

  Calling `deserialize_keras_object` while underneath the
  `SharedObjectLoadingScope` context manager will cause any already-seen shared
  objects to be returned as-is rather than creating a new object.

  Args:
    identifier: the serialized form of the object.
    module_objects: A dictionary of built-in objects to look the name up in.
      Generally, `module_objects` is provided by midlevel library implementers.
    custom_objects: A dictionary of custom objects to look the name up in.
      Generally, `custom_objects` is provided by the end user.
    printable_module_name: A human-readable string representing the type of the
      object. Printed in case of exception.

  Returns:
    The deserialized object.

  Example:

  A mid-level library implementer might want to implement a utility for
  retrieving an object from its config, as such:

  ```python
  def deserialize(config, custom_objects=None):
     return deserialize_keras_object(
       identifier,
       module_objects=globals(),
       custom_objects=custom_objects,
       name="MyObjectType",
     )
  ```

  This is how e.g. `keras.layers.deserialize()` is implemented.
  """
    if identifier is None:
        return None

    if isinstance(identifier, dict):
        # In this case we are dealing with a Keras config dictionary.
        config = identifier
        (cls, cls_config) = class_and_config_for_serialized_keras_object(
            config, module_objects, custom_objects, printable_module_name)

        # If this object has already been loaded (i.e. it's shared between multiple
        # objects), return the already-loaded object.
        shared_object_id = config.get(SHARED_OBJECT_KEY)
        shared_object = _shared_object_loading_scope().get(shared_object_id)  # pylint: disable=assignment-from-none
        if shared_object is not None:
            return shared_object

        if hasattr(cls, 'from_config'):
            arg_spec = tf_inspect.getfullargspec(cls.from_config)
            custom_objects = custom_objects or {}

            if 'custom_objects' in arg_spec.args:
                deserialized_obj = cls.from_config(
                    cls_config,
                    custom_objects=dict(
                        list(_GLOBAL_CUSTOM_OBJECTS.items()) +
                        list(custom_objects.items())))
            else:
                with CustomObjectScope(custom_objects):
                    deserialized_obj = cls.from_config(cls_config)
        else:
            # Then `cls` may be a function returning a class.
            # in this case by convention `config` holds
            # the kwargs of the function.
            custom_objects = custom_objects or {}
            with CustomObjectScope(custom_objects):
                deserialized_obj = cls(**cls_config)

        # Add object to shared objects, in case we find it referenced again.
        _shared_object_loading_scope().set(shared_object_id, deserialized_obj)

        return deserialized_obj

    elif isinstance(identifier, str):
        object_name = identifier
        if custom_objects and object_name in custom_objects:
            obj = custom_objects.get(object_name)
        elif object_name in _GLOBAL_CUSTOM_OBJECTS:
            obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
        else:
            obj = module_objects.get(object_name)
            if obj is None:
                raise ValueError(
                    'Unknown {}: {}. Please ensure this object is '
                    'passed to the `custom_objects` argument. See '
                    'https://www.tensorflow.org/guide/keras/save_and_serialize'
                    '#registering_the_custom_object for details.'.format(
                        printable_module_name, object_name))

        # Classes passed by name are instantiated with no args, functions are
        # returned as-is.
        if tf_inspect.isclass(obj):
            return obj()
        return obj
    elif tf_inspect.isfunction(identifier):
        # If a function has already been deserialized, return as is.
        return identifier
    else:
        raise ValueError('Could not interpret serialized %s: %s' %
                         (printable_module_name, identifier))
コード例 #14
0
def populate_deserializable_objects():
    """Populates dict ALL_OBJECTS with every built-in layer."""
    global LOCAL
    if not hasattr(LOCAL, "ALL_OBJECTS"):
        LOCAL.ALL_OBJECTS = {}
        LOCAL.GENERATED_WITH_V2 = None

    if (
        LOCAL.ALL_OBJECTS
        and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled()
    ):
        # Objects dict is already generated for the proper TF version:
        # do nothing.
        return

    LOCAL.ALL_OBJECTS = {}
    LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled()

    base_cls = base_layer.Layer
    generic_utils.populate_dict_with_module_objects(
        LOCAL.ALL_OBJECTS,
        ALL_MODULES,
        obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls),
    )

    # Overwrite certain V1 objects with V2 versions
    if tf.__internal__.tf2.enabled():
        generic_utils.populate_dict_with_module_objects(
            LOCAL.ALL_OBJECTS,
            ALL_V2_MODULES,
            obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls),
        )

    # These deserialization aliases are added for backward compatibility,
    # as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2"
    # were used as class name for v1 and v2 version of BatchNormalization,
    # respectively. Here we explicitly convert them to their canonical names.
    LOCAL.ALL_OBJECTS[
        "BatchNormalizationV1"
    ] = batch_normalization_v1.BatchNormalization
    LOCAL.ALL_OBJECTS[
        "BatchNormalizationV2"
    ] = batch_normalization.BatchNormalization

    # Prevent circular dependencies.
    from keras import models
    from keras.feature_column.sequence_feature_column import (
        SequenceFeatures,
    )
    from keras.premade_models.linear import (
        LinearModel,
    )
    from keras.premade_models.wide_deep import (
        WideDeepModel,
    )

    LOCAL.ALL_OBJECTS["Input"] = input_layer.Input
    LOCAL.ALL_OBJECTS["InputSpec"] = input_spec.InputSpec
    LOCAL.ALL_OBJECTS["Functional"] = models.Functional
    LOCAL.ALL_OBJECTS["Model"] = models.Model
    LOCAL.ALL_OBJECTS["SequenceFeatures"] = SequenceFeatures
    LOCAL.ALL_OBJECTS["Sequential"] = models.Sequential
    LOCAL.ALL_OBJECTS["LinearModel"] = LinearModel
    LOCAL.ALL_OBJECTS["WideDeepModel"] = WideDeepModel

    if tf.__internal__.tf2.enabled():
        from keras.feature_column.dense_features_v2 import (
            DenseFeatures,
        )

        LOCAL.ALL_OBJECTS["DenseFeatures"] = DenseFeatures
    else:
        from keras.feature_column.dense_features import (
            DenseFeatures,
        )

        LOCAL.ALL_OBJECTS["DenseFeatures"] = DenseFeatures

    # Merging layers, function versions.
    LOCAL.ALL_OBJECTS["add"] = merging.add
    LOCAL.ALL_OBJECTS["subtract"] = merging.subtract
    LOCAL.ALL_OBJECTS["multiply"] = merging.multiply
    LOCAL.ALL_OBJECTS["average"] = merging.average
    LOCAL.ALL_OBJECTS["maximum"] = merging.maximum
    LOCAL.ALL_OBJECTS["minimum"] = merging.minimum
    LOCAL.ALL_OBJECTS["concatenate"] = merging.concatenate
    LOCAL.ALL_OBJECTS["dot"] = merging.dot
コード例 #15
0
ファイル: __init__.py プロジェクト: bhardwajRahul/keras
def populate_deserializable_objects():
    """Populates dict ALL_OBJECTS with every built-in initializer."""
    global LOCAL
    if not hasattr(LOCAL, "ALL_OBJECTS"):
        LOCAL.ALL_OBJECTS = {}
        LOCAL.GENERATED_WITH_V2 = None

    if (LOCAL.ALL_OBJECTS
            and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled()):
        # Objects dict is already generated for the proper TF version:
        # do nothing.
        return

    LOCAL.ALL_OBJECTS = {}
    LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled()

    # Compatibility aliases (need to exist in both V1 and V2).
    LOCAL.ALL_OBJECTS["ConstantV2"] = initializers_v2.Constant
    LOCAL.ALL_OBJECTS["GlorotNormalV2"] = initializers_v2.GlorotNormal
    LOCAL.ALL_OBJECTS["GlorotUniformV2"] = initializers_v2.GlorotUniform
    LOCAL.ALL_OBJECTS["HeNormalV2"] = initializers_v2.HeNormal
    LOCAL.ALL_OBJECTS["HeUniformV2"] = initializers_v2.HeUniform
    LOCAL.ALL_OBJECTS["IdentityV2"] = initializers_v2.Identity
    LOCAL.ALL_OBJECTS["LecunNormalV2"] = initializers_v2.LecunNormal
    LOCAL.ALL_OBJECTS["LecunUniformV2"] = initializers_v2.LecunUniform
    LOCAL.ALL_OBJECTS["OnesV2"] = initializers_v2.Ones
    LOCAL.ALL_OBJECTS["OrthogonalV2"] = initializers_v2.Orthogonal
    LOCAL.ALL_OBJECTS["RandomNormalV2"] = initializers_v2.RandomNormal
    LOCAL.ALL_OBJECTS["RandomUniformV2"] = initializers_v2.RandomUniform
    LOCAL.ALL_OBJECTS["TruncatedNormalV2"] = initializers_v2.TruncatedNormal
    LOCAL.ALL_OBJECTS["VarianceScalingV2"] = initializers_v2.VarianceScaling
    LOCAL.ALL_OBJECTS["ZerosV2"] = initializers_v2.Zeros

    # Out of an abundance of caution we also include these aliases that have
    # a non-zero probability of having been included in saved configs in the past.
    LOCAL.ALL_OBJECTS["glorot_normalV2"] = initializers_v2.GlorotNormal
    LOCAL.ALL_OBJECTS["glorot_uniformV2"] = initializers_v2.GlorotUniform
    LOCAL.ALL_OBJECTS["he_normalV2"] = initializers_v2.HeNormal
    LOCAL.ALL_OBJECTS["he_uniformV2"] = initializers_v2.HeUniform
    LOCAL.ALL_OBJECTS["lecun_normalV2"] = initializers_v2.LecunNormal
    LOCAL.ALL_OBJECTS["lecun_uniformV2"] = initializers_v2.LecunUniform

    if tf.__internal__.tf2.enabled():
        # For V2, entries are generated automatically based on the content of
        # initializers_v2.py.
        v2_objs = {}
        base_cls = initializers_v2.Initializer
        generic_utils.populate_dict_with_module_objects(
            v2_objs,
            [initializers_v2],
            obj_filter=lambda x: inspect.isclass(x) and issubclass(
                x, base_cls),
        )
        for key, value in v2_objs.items():
            LOCAL.ALL_OBJECTS[key] = value
            # Functional aliases.
            LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
    else:
        # V1 initializers.
        v1_objs = {
            "Constant": tf.compat.v1.constant_initializer,
            "GlorotNormal": tf.compat.v1.glorot_normal_initializer,
            "GlorotUniform": tf.compat.v1.glorot_uniform_initializer,
            "Identity": tf.compat.v1.initializers.identity,
            "Ones": tf.compat.v1.ones_initializer,
            "Orthogonal": tf.compat.v1.orthogonal_initializer,
            "VarianceScaling": tf.compat.v1.variance_scaling_initializer,
            "Zeros": tf.compat.v1.zeros_initializer,
            "HeNormal": initializers_v1.HeNormal,
            "HeUniform": initializers_v1.HeUniform,
            "LecunNormal": initializers_v1.LecunNormal,
            "LecunUniform": initializers_v1.LecunUniform,
            "RandomNormal": initializers_v1.RandomNormal,
            "RandomUniform": initializers_v1.RandomUniform,
            "TruncatedNormal": initializers_v1.TruncatedNormal,
        }
        for key, value in v1_objs.items():
            LOCAL.ALL_OBJECTS[key] = value
            # Functional aliases.
            LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value

    # More compatibility aliases.
    LOCAL.ALL_OBJECTS["normal"] = LOCAL.ALL_OBJECTS["random_normal"]
    LOCAL.ALL_OBJECTS["uniform"] = LOCAL.ALL_OBJECTS["random_uniform"]
    LOCAL.ALL_OBJECTS["one"] = LOCAL.ALL_OBJECTS["ones"]
    LOCAL.ALL_OBJECTS["zero"] = LOCAL.ALL_OBJECTS["zeros"]
コード例 #16
0
    def _get_single_variable(
        self,
        name,
        shape=None,
        dtype=tf.float32,
        initializer=None,
        regularizer=None,
        partition_info=None,
        reuse=None,
        trainable=None,
        caching_device=None,
        validate_shape=True,
        constraint=None,
        synchronization=tf.VariableSynchronization.AUTO,
        aggregation=tf.compat.v1.VariableAggregation.NONE,
    ):
        """Get or create a single Variable (e.g.

        a shard or entire variable).

        See the documentation of get_variable above (ignore partitioning components)
        for details.

        Args:
          name: see get_variable.
          shape: see get_variable.
          dtype: see get_variable.
          initializer: see get_variable.
          regularizer: see get_variable.
          partition_info: _PartitionInfo object.
          reuse: see get_variable.
          trainable: see get_variable.
          caching_device: see get_variable.
          validate_shape: see get_variable.
          constraint: see get_variable.
          synchronization: see get_variable.
          aggregation: see get_variable.

        Returns:
          A Variable.  See documentation of get_variable above.

        Raises:
          ValueError: See documentation of get_variable above.
        """
        # Set to true if initializer is a constant.
        initializing_from_value = False
        if initializer is not None and not callable(initializer):
            initializing_from_value = True
        if shape is not None and initializing_from_value:
            raise ValueError(
                "If initializer is a constant, do not specify shape.")

        dtype = tf.as_dtype(dtype)
        shape = as_shape(shape)

        if name in self._vars:
            # Here we handle the case when returning an existing variable.
            found_var = self._vars[name]
            if not shape.is_compatible_with(found_var.get_shape()):
                raise ValueError(
                    "Trying to share variable %s, but specified shape %s"
                    " and found shape %s." %
                    (name, shape, found_var.get_shape()))
            if not dtype.is_compatible_with(found_var.dtype):
                dtype_str = dtype.name
                found_type_str = found_var.dtype.name
                raise ValueError(
                    "Trying to share variable %s, but specified dtype %s"
                    " and found dtype %s." % (name, dtype_str, found_type_str))
            return found_var

        # The code below handles only the case of creating a new variable.
        if reuse is True:  # pylint: disable=g-bool-id-comparison
            raise ValueError(
                "Variable %s does not exist, or was not created with "
                "tf.get_variable(). Did you mean to set "
                "reuse=tf.AUTO_REUSE in VarScope?" % name)

        # Create the tensor to initialize the variable with default value.
        if initializer is None:
            (
                initializer,
                initializing_from_value,
            ) = self._get_default_initializer(name=name,
                                              shape=shape,
                                              dtype=dtype)
        # Enter an init scope when creating the initializer.
        with tf.init_scope():
            if initializing_from_value:
                init_val = initializer
                variable_dtype = None
            else:
                # Instantiate initializer if provided initializer is a type object.
                if tf_inspect.isclass(initializer):
                    initializer = initializer()
                if shape.is_fully_defined():
                    if ("partition_info"
                            in tf_inspect.getargspec(initializer).args):
                        init_val = functools.partial(
                            initializer,
                            shape.as_list(),
                            dtype=dtype,
                            partition_info=partition_info,
                        )
                    else:
                        init_val = functools.partial(initializer,
                                                     shape.as_list(),
                                                     dtype=dtype)
                    variable_dtype = dtype.base_dtype
                else:
                    init_val = initializer
                    variable_dtype = None

        # Create the variable (Always eagerly as a workaround for a strange
        # tpu / funcgraph / keras functional model interaction )
        with tf.init_scope():
            v = tf.Variable(
                initial_value=init_val,
                name=name,
                trainable=trainable,
                caching_device=caching_device,
                dtype=variable_dtype,
                validate_shape=validate_shape,
                constraint=constraint,
                synchronization=synchronization,
                aggregation=aggregation,
            )

        self._vars[name] = v
        logging.vlog(
            1,
            "Created variable %s with shape %s and init %s",
            v.name,
            format(shape),
            initializer,
        )

        # Run the regularizer if requested and save the resulting loss.
        if regularizer:
            self.add_regularizer(v, regularizer)

        return v