Esempio n. 1
0
def populate_deserializable_objects():
    """Populates dict ALL_OBJECTS with every built-in layer.
  """
    global LOCAL
    if not hasattr(LOCAL, 'ALL_OBJECTS'):
        LOCAL.ALL_OBJECTS = {}
        LOCAL.GENERATED_WITH_V2 = None

    if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
        # Objects dict is already generated for the proper TF version:
        # do nothing.
        return

    LOCAL.ALL_OBJECTS = {}
    LOCAL.GENERATED_WITH_V2 = tf2.enabled()

    base_cls = base_layer.Layer
    generic_utils.populate_dict_with_module_objects(
        LOCAL.ALL_OBJECTS,
        ALL_MODULES,
        obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))

    # Overwrite certain V1 objects with V2 versions
    if tf2.enabled():
        generic_utils.populate_dict_with_module_objects(
            LOCAL.ALL_OBJECTS,
            ALL_V2_MODULES,
            obj_filter=lambda x: inspect.isclass(x) and issubclass(
                x, base_cls))

    # These deserialization aliases are added for backward compatibility,
    # as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2"
    # were used as class name for v1 and v2 version of BatchNormalization,
    # respectively. Here we explicitly convert them to their canonical names.
    LOCAL.ALL_OBJECTS[
        'BatchNormalizationV1'] = batch_normalization_v1.BatchNormalization
    LOCAL.ALL_OBJECTS[
        'BatchNormalizationV2'] = batch_normalization.BatchNormalization

    # Prevent circular dependencies.
    from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top

    LOCAL.ALL_OBJECTS['Input'] = input_layer.Input
    LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec
    LOCAL.ALL_OBJECTS['Functional'] = models.Functional
    LOCAL.ALL_OBJECTS['Model'] = models.Model
    LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential

    # Merge layers, function versions.
    LOCAL.ALL_OBJECTS['add'] = merge.add
    LOCAL.ALL_OBJECTS['subtract'] = merge.subtract
    LOCAL.ALL_OBJECTS['multiply'] = merge.multiply
    LOCAL.ALL_OBJECTS['average'] = merge.average
    LOCAL.ALL_OBJECTS['maximum'] = merge.maximum
    LOCAL.ALL_OBJECTS['minimum'] = merge.minimum
    LOCAL.ALL_OBJECTS['concatenate'] = merge.concatenate
    LOCAL.ALL_OBJECTS['dot'] = merge.dot
def populate_deserializable_objects():
    """Populates dict ALL_OBJECTS with every built-in layer.
  """
    global LOCAL
    if not hasattr(LOCAL, 'ALL_OBJECTS'):
        LOCAL.ALL_OBJECTS = {}
        LOCAL.GENERATED_WITH_V2 = None

    if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
        # Objects dict is already generated for the proper TF version:
        # do nothing.
        return

    LOCAL.ALL_OBJECTS = {}
    LOCAL.GENERATED_WITH_V2 = tf2.enabled()

    base_cls = base_layer.Layer
    generic_utils.populate_dict_with_module_objects(
        LOCAL.ALL_OBJECTS,
        ALL_MODULES,
        obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))

    # Overwrite certain V1 objects with V2 versions
    if tf2.enabled():
        generic_utils.populate_dict_with_module_objects(
            LOCAL.ALL_OBJECTS,
            ALL_V2_MODULES,
            obj_filter=lambda x: inspect.isclass(x) and issubclass(
                x, base_cls))

    # Prevent circular dependencies.
    from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top

    LOCAL.ALL_OBJECTS['Input'] = input_layer.Input
    LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec
    LOCAL.ALL_OBJECTS['Functional'] = models.Functional
    LOCAL.ALL_OBJECTS['Model'] = models.Model
    LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential

    # Merge layers, function versions.
    LOCAL.ALL_OBJECTS['add'] = merge.add
    LOCAL.ALL_OBJECTS['subtract'] = merge.subtract
    LOCAL.ALL_OBJECTS['multiply'] = merge.multiply
    LOCAL.ALL_OBJECTS['average'] = merge.average
    LOCAL.ALL_OBJECTS['maximum'] = merge.maximum
    LOCAL.ALL_OBJECTS['minimum'] = merge.minimum
    LOCAL.ALL_OBJECTS['concatenate'] = merge.concatenate
    LOCAL.ALL_OBJECTS['dot'] = merge.dot
Esempio n. 3
0
def deserialize_keras_object(identifier,
                             module_objects=None,
                             custom_objects=None,
                             printable_module_name='object'):
    """Turns the serialized form of a Keras object back into an actual object."""
    if identifier is None:
        return None

    if isinstance(identifier, dict):
        # In this case we are dealing with a Keras config dictionary.
        config = identifier
        (cls, cls_config) = class_and_config_for_serialized_keras_object(
            config, module_objects, custom_objects, printable_module_name)

        if hasattr(cls, 'from_config'):
            arg_spec = tf_inspect.getfullargspec(cls.from_config)
            custom_objects = custom_objects or {}

            if 'custom_objects' in arg_spec.args:
                return cls.from_config(
                    cls_config,
                    custom_objects=dict(
                        list(_GLOBAL_CUSTOM_OBJECTS.items()) +
                        list(custom_objects.items())))
            with CustomObjectScope(custom_objects):
                return cls.from_config(cls_config)
        else:
            # Then `cls` may be a function returning a class.
            # in this case by convention `config` holds
            # the kwargs of the function.
            custom_objects = custom_objects or {}
            with CustomObjectScope(custom_objects):
                return cls(**cls_config)
    elif isinstance(identifier, six.string_types):
        object_name = identifier
        if custom_objects and object_name in custom_objects:
            obj = custom_objects.get(object_name)
        elif object_name in _GLOBAL_CUSTOM_OBJECTS:
            obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
        else:
            obj = module_objects.get(object_name)
            if obj is None:
                raise ValueError(
                    'Unknown {}: {}. Please ensure this object is '
                    'passed to the `custom_objects` argument. See '
                    'https://www.tensorflow.org/guide/keras/save_and_serialize'
                    '#registering_the_custom_object for details.'.format(
                        printable_module_name, object_name))

        # Classes passed by name are instantiated with no args, functions are
        # returned as-is.
        if tf_inspect.isclass(obj):
            return obj()
        return obj
    elif tf_inspect.isfunction(identifier):
        # If a function has already been deserialized, return as is.
        return identifier
    else:
        raise ValueError('Could not interpret serialized %s: %s' %
                         (printable_module_name, identifier))
Esempio n. 4
0
  def decorator(f):
    if tf_inspect.isclass(f):
      raise ValueError('`run_v2_only` only supports test methods.')

    def decorated(self, *args, **kwargs):
      if not tf2.enabled():
        self.skipTest('Test is only compatible with v2')

      return f(self, *args, **kwargs)

    return decorated
Esempio n. 5
0
def get(identifier):
    if identifier is None:
        return None
    if isinstance(identifier, dict):
        return deserialize(identifier)
    elif isinstance(identifier, six.string_types):
        identifier = str(identifier)
        return deserialize(identifier)
    elif callable(identifier):
        if inspect.isclass(identifier):
            identifier = identifier()
        return identifier
    else:
        raise ValueError('Could not interpret initializer identifier: ' +
                         str(identifier))
Esempio n. 6
0
def get(identifier):
    """Retrieve a Keras initializer by the identifier.

  The `identifier` may be the string name of a initializers function or class (
  case-sensitively).

  >>> identifier = 'Ones'
  >>> tf.keras.initializers.deserialize(identifier)
  <...tensorflow.python.keras.initializers.initializers_v2.Ones...>

  You can also specify `config` of the initializer to this function by passing
  dict containing `class_name` and `config` as an identifier. Also note that the
  `class_name` must map to a `Initializer` class.

  >>> cfg = {'class_name': 'Ones', 'config': {}}
  >>> tf.keras.initializers.deserialize(cfg)
  <...tensorflow.python.keras.initializers.initializers_v2.Ones...>

  In the case that the `identifier` is a class, this method will return a new
  instance of the class by its constructor.

  Args:
    identifier: String or dict that contains the initializer name or
      configurations.

  Returns:
    Initializer instance base on the input identifier.

  Raises:
    ValueError: If the input identifier is not a supported type or in a bad
      format.
  """

    if identifier is None:
        return None
    if isinstance(identifier, dict):
        return deserialize(identifier)
    elif isinstance(identifier, str):
        identifier = str(identifier)
        return deserialize(identifier)
    elif callable(identifier):
        if inspect.isclass(identifier):
            identifier = identifier()
        return identifier
    else:
        raise ValueError('Could not interpret initializer identifier: ' +
                         str(identifier))
Esempio n. 7
0
  def decorator(arg):
    """Registers a class with the Keras serialization framework."""
    class_name = name if name is not None else arg.__name__
    registered_name = package + '>' + class_name

    if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'):
      raise ValueError(
          'Cannot register a class that does not have a get_config() method.')

    if registered_name in _GLOBAL_CUSTOM_OBJECTS:
      raise ValueError(
          '%s has already been registered to %s' %
          (registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name]))

    if arg in _GLOBAL_CUSTOM_NAMES:
      raise ValueError('%s has already been registered to %s' %
                       (arg, _GLOBAL_CUSTOM_NAMES[arg]))
    _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg
    _GLOBAL_CUSTOM_NAMES[arg] = registered_name

    return arg
Esempio n. 8
0
def make_variable(name,
                  shape=None,
                  dtype=dtypes.float32,
                  initializer=None,
                  trainable=None,
                  caching_device=None,
                  validate_shape=True,
                  constraint=None,
                  use_resource=None,
                  collections=None,
                  synchronization=tf_variables.VariableSynchronization.AUTO,
                  aggregation=tf_variables.VariableAggregation.NONE,
                  partitioner=None):  # pylint: disable=unused-argument
    """Temporary util to create a variable (relies on `variable_scope.variable`).

  Some reuse-related technicalities prevent us from using
  `variable_scope.get_variable()` directly, so we use a subcomponent
  that has fewer constraints (`variable_scope.variable()`).

  In the longer term, it seems like a similar "default variable creator" method
  should exist in `Trackable` instead. When this happens, we can get
  rid of this temporary solution.

  TODO(fchollet): remove this method when no longer needed.

  Arguments:
    name: Variable name.
    shape: Variable shape.
    dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
    initializer: Initializer instance (callable).
    trainable: Whether the variable should be part of the layer's
      "trainable_variables" (e.g. variables, biases)
      or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
      Note, if the current variable scope is marked as non-trainable
      then this parameter is ignored and any added variables are also
      marked as non-trainable. `trainable` defaults to `True` unless
      `synchronization` is set to `ON_READ`.
    caching_device: Passed to `tf.Variable`.
    validate_shape: Passed to `tf.Variable`.
    constraint: Constraint instance (callable).
    use_resource: Whether to use a `ResourceVariable`.
    collections: List of graph collections keys. The new variable is added to
      these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
    synchronization: Indicates when a distributed a variable will be
      aggregated. Accepted values are constants defined in the class
      `tf.VariableSynchronization`. By default the synchronization is set to
      `AUTO` and the current `DistributionStrategy` chooses
      when to synchronize. If `synchronization` is set to `ON_READ`,
      `trainable` must not be set to `True`.
    aggregation: Indicates how a distributed variable will be aggregated.
      Accepted values are constants defined in the class
      `tf.VariableAggregation`.
    partitioner: Not handled at this time.

  Returns:
    Variable instance.
  """
    initializing_from_value = False
    if initializer is not None and not callable(initializer):
        initializing_from_value = True

    if initializing_from_value:
        init_val = initializer
        variable_dtype = None
    else:
        # Instantiate initializer if provided initializer is a type object.
        if tf_inspect.isclass(initializer):
            initializer = initializer()
        init_val = functools.partial(initializer, shape, dtype=dtype)
        variable_dtype = dtype.base_dtype
    if use_resource is None:
        use_resource = True

    # TODO(apassos,rohanj) figure out how to remove collections from here so we
    # can remove the V1.
    variable_shape = tensor_shape.TensorShape(shape)
    return tf_variables.VariableV1(
        initial_value=init_val,
        name=name,
        trainable=trainable,
        caching_device=caching_device,
        dtype=variable_dtype,
        validate_shape=validate_shape,
        constraint=constraint,
        use_resource=use_resource,
        collections=collections,
        synchronization=synchronization,
        aggregation=aggregation,
        shape=variable_shape if variable_shape else None)
Esempio n. 9
0
def populate_deserializable_objects():
    """Populates dict ALL_OBJECTS with every built-in initializer.
  """
    global LOCAL
    if not hasattr(LOCAL, 'ALL_OBJECTS'):
        LOCAL.ALL_OBJECTS = {}
        LOCAL.GENERATED_WITH_V2 = None

    if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
        # Objects dict is already generated for the proper TF version:
        # do nothing.
        return

    LOCAL.ALL_OBJECTS = {}
    LOCAL.GENERATED_WITH_V2 = tf2.enabled()

    # Compatibility aliases (need to exist in both V1 and V2).
    LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant
    LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal
    LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform
    LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal
    LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform
    LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity
    LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal
    LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform
    LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones
    LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal
    LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal
    LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform
    LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal
    LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling
    LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros

    # Out of an abundance of caution we also include these aliases that have
    # a non-zero probability of having been included in saved configs in the past.
    LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal
    LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform
    LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal
    LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform
    LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal
    LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform

    if tf2.enabled():
        # For V2, entries are generated automatically based on the content of
        # initializers_v2.py.
        v2_objs = {}
        base_cls = initializers_v2.Initializer
        generic_utils.populate_dict_with_module_objects(
            v2_objs, [initializers_v2],
            obj_filter=lambda x: inspect.isclass(x) and issubclass(
                x, base_cls))
        for key, value in v2_objs.items():
            LOCAL.ALL_OBJECTS[key] = value
            # Functional aliases.
            LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
    else:
        # V1 initializers.
        v1_objs = {
            'Constant': init_ops.Constant,
            'GlorotNormal': init_ops.GlorotNormal,
            'GlorotUniform': init_ops.GlorotUniform,
            'Identity': init_ops.Identity,
            'Ones': init_ops.Ones,
            'Orthogonal': init_ops.Orthogonal,
            'VarianceScaling': init_ops.VarianceScaling,
            'Zeros': init_ops.Zeros,
            'HeNormal': initializers_v1.HeNormal,
            'HeUniform': initializers_v1.HeUniform,
            'LecunNormal': initializers_v1.LecunNormal,
            'LecunUniform': initializers_v1.LecunUniform,
            'RandomNormal': initializers_v1.RandomNormal,
            'RandomUniform': initializers_v1.RandomUniform,
            'TruncatedNormal': initializers_v1.TruncatedNormal,
        }
        for key, value in v1_objs.items():
            LOCAL.ALL_OBJECTS[key] = value
            # Functional aliases.
            LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value

    # More compatibility aliases.
    LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal']
    LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform']
    LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones']
    LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros']
Esempio n. 10
0
def deserialize_keras_object(identifier,
                             module_objects=None,
                             custom_objects=None,
                             printable_module_name='object'):
    """Turns the serialized form of a Keras object back into an actual object.

  This function is for mid-level library implementers rather than end users.

  Importantly, this utility requires you to provide the dict of `module_objects`
  to use for looking up the object config; this is not populated by default.
  If you need a deserialization utility that has preexisting knowledge of
  built-in Keras objects, use e.g. `keras.layers.deserialize(config)`,
  `keras.metrics.deserialize(config)`, etc.

  Calling `deserialize_keras_object` while underneath the
  `SharedObjectLoadingScope` context manager will cause any already-seen shared
  objects to be returned as-is rather than creating a new object.

  Args:
    identifier: the serialized form of the object.
    module_objects: A dictionary of built-in objects to look the name up in.
      Generally, `module_objects` is provided by midlevel library implementers.
    custom_objects: A dictionary of custom objects to look the name up in.
      Generally, `custom_objects` is provided by the end user.
    printable_module_name: A human-readable string representing the type of the
      object. Printed in case of exception.

  Returns:
    The deserialized object.

  Example:

  A mid-level library implementer might want to implement a utility for
  retrieving an object from its config, as such:

  ```python
  def deserialize(config, custom_objects=None):
     return deserialize_keras_object(
       identifier,
       module_objects=globals(),
       custom_objects=custom_objects,
       name="MyObjectType",
     )
  ```

  This is how e.g. `keras.layers.deserialize()` is implemented.
  """
    if identifier is None:
        return None

    if isinstance(identifier, dict):
        # In this case we are dealing with a Keras config dictionary.
        config = identifier
        (cls, cls_config) = class_and_config_for_serialized_keras_object(
            config, module_objects, custom_objects, printable_module_name)

        # If this object has already been loaded (i.e. it's shared between multiple
        # objects), return the already-loaded object.
        shared_object_id = config.get(SHARED_OBJECT_KEY)
        shared_object = _shared_object_loading_scope().get(shared_object_id)  # pylint: disable=assignment-from-none
        if shared_object is not None:
            return shared_object

        if hasattr(cls, 'from_config'):
            arg_spec = tf_inspect.getfullargspec(cls.from_config)
            custom_objects = custom_objects or {}

            if 'custom_objects' in arg_spec.args:
                deserialized_obj = cls.from_config(
                    cls_config,
                    custom_objects=dict(
                        list(_GLOBAL_CUSTOM_OBJECTS.items()) +
                        list(custom_objects.items())))
            else:
                with CustomObjectScope(custom_objects):
                    deserialized_obj = cls.from_config(cls_config)
        else:
            # Then `cls` may be a function returning a class.
            # in this case by convention `config` holds
            # the kwargs of the function.
            custom_objects = custom_objects or {}
            with CustomObjectScope(custom_objects):
                deserialized_obj = cls(**cls_config)

        # Add object to shared objects, in case we find it referenced again.
        _shared_object_loading_scope().set(shared_object_id, deserialized_obj)

        return deserialized_obj

    elif isinstance(identifier, str):
        object_name = identifier
        if custom_objects and object_name in custom_objects:
            obj = custom_objects.get(object_name)
        elif object_name in _GLOBAL_CUSTOM_OBJECTS:
            obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
        else:
            obj = module_objects.get(object_name)
            if obj is None:
                raise ValueError(
                    'Unknown {}: {}. Please ensure this object is '
                    'passed to the `custom_objects` argument. See '
                    'https://www.tensorflow.org/guide/keras/save_and_serialize'
                    '#registering_the_custom_object for details.'.format(
                        printable_module_name, object_name))

        # Classes passed by name are instantiated with no args, functions are
        # returned as-is.
        if tf_inspect.isclass(obj):
            return obj()
        return obj
    elif tf_inspect.isfunction(identifier):
        # If a function has already been deserialized, return as is.
        return identifier
    else:
        raise ValueError('Could not interpret serialized %s: %s' %
                         (printable_module_name, identifier))
Esempio n. 11
0
    def _get_single_variable(self,
                             name,
                             shape=None,
                             dtype=dtypes.float32,
                             initializer=None,
                             regularizer=None,
                             partition_info=None,
                             reuse=None,
                             trainable=None,
                             caching_device=None,
                             validate_shape=True,
                             constraint=None,
                             synchronization=vs.VariableSynchronization.AUTO,
                             aggregation=vs.VariableAggregation.NONE):
        """Get or create a single Variable (e.g.

    a shard or entire variable).

    See the documentation of get_variable above (ignore partitioning components)
    for details.

    Args:
      name: see get_variable.
      shape: see get_variable.
      dtype: see get_variable.
      initializer: see get_variable.
      regularizer: see get_variable.
      partition_info: _PartitionInfo object.
      reuse: see get_variable.
      trainable: see get_variable.
      caching_device: see get_variable.
      validate_shape: see get_variable.
      constraint: see get_variable.
      synchronization: see get_variable.
      aggregation: see get_variable.

    Returns:
      A Variable.  See documentation of get_variable above.

    Raises:
      ValueError: See documentation of get_variable above.
    """
        # Set to true if initializer is a constant.
        initializing_from_value = False
        if initializer is not None and not callable(initializer):
            initializing_from_value = True
        if shape is not None and initializing_from_value:
            raise ValueError(
                "If initializer is a constant, do not specify shape.")

        dtype = dtypes.as_dtype(dtype)
        shape = as_shape(shape)

        if name in self._vars:
            # Here we handle the case when returning an existing variable.
            if reuse is False:  # pylint: disable=g-bool-id-comparison
                err_msg = ("Variable %s already exists, disallowed."
                           " Did you mean to set reuse=True or "
                           "reuse=tf.AUTO_REUSE in VarScope?" % name)
                # ResourceVariables don't have an op associated with so no traceback
                raise ValueError(err_msg)
            found_var = self._vars[name]
            if not shape.is_compatible_with(found_var.get_shape()):
                raise ValueError(
                    "Trying to share variable %s, but specified shape %s"
                    " and found shape %s." %
                    (name, shape, found_var.get_shape()))
            if not dtype.is_compatible_with(found_var.dtype):
                dtype_str = dtype.name
                found_type_str = found_var.dtype.name
                raise ValueError(
                    "Trying to share variable %s, but specified dtype %s"
                    " and found dtype %s." % (name, dtype_str, found_type_str))
            return found_var

        # The code below handles only the case of creating a new variable.
        if reuse is True:  # pylint: disable=g-bool-id-comparison
            raise ValueError(
                "Variable %s does not exist, or was not created with "
                "tf.get_variable(). Did you mean to set "
                "reuse=tf.AUTO_REUSE in VarScope?" % name)

        # Create the tensor to initialize the variable with default value.
        if initializer is None:
            initializer, initializing_from_value = self._get_default_initializer(
                name=name, shape=shape, dtype=dtype)
        # Enter an init scope when creating the initializer.
        with ops.init_scope():
            if initializing_from_value:
                init_val = initializer
                variable_dtype = None
            else:
                # Instantiate initializer if provided initializer is a type object.
                if tf_inspect.isclass(initializer):
                    initializer = initializer()
                if shape.is_fully_defined():
                    if "partition_info" in tf_inspect.getargspec(
                            initializer).args:
                        init_val = functools.partial(
                            initializer,
                            shape.as_list(),
                            dtype=dtype,
                            partition_info=partition_info)
                    else:
                        init_val = functools.partial(initializer,
                                                     shape.as_list(),
                                                     dtype=dtype)
                    variable_dtype = dtype.base_dtype
                else:
                    init_val = initializer
                    variable_dtype = None

        # Create the variable (Always eagerly as a workaround for a strange
        # tpu / funcgraph / keras functional model interaction )
        with ops.init_scope():
            v = variables.Variable(initial_value=init_val,
                                   name=name,
                                   trainable=trainable,
                                   caching_device=caching_device,
                                   dtype=variable_dtype,
                                   validate_shape=validate_shape,
                                   constraint=constraint,
                                   synchronization=synchronization,
                                   aggregation=aggregation)

        self._vars[name] = v
        logging.vlog(1, "Created variable %s with shape %s and init %s",
                     v.name, format(shape), initializer)

        # Run the regularizer if requested and save the resulting loss.
        if regularizer:
            self.add_regularizer(v, regularizer)

        return v