예제 #1
0
 def testNamelessStore(self):
   vs = variable_scope._get_default_variable_store()
   vs.get_variable("v1", [2])
   vs.get_variable("v2", [2])
   expected_names = ["%s:0" % name for name in ["v1", "v2"]]
   self.assertEqual(set(expected_names),
                    set([v.name for v in vs._vars.values()]))
예제 #2
0
def _default_getter(name, shape, dtype, initializer=None,
                    partition_info=None, **kwargs):
  """A pared-down version of get_variable which does not reuse variables."""
  dtype = dtypes.as_dtype(dtype)
  shape_object = tensor_shape.as_shape(shape)
  with ops.init_scope():
    if initializer is None:
      initializer, initializing_from_value = (
          variable_scope._get_default_variable_store()._get_default_initializer(  # pylint: disable=protected-access
              name=name, shape=shape_object, dtype=dtype))
    else:
      initializing_from_value = not callable(initializer)
    # Same logic as get_variable
    variable_dtype = dtype.base_dtype
    if initializing_from_value:
      if shape is not None:
        raise ValueError("If initializer is a constant, do not specify shape.")
      initial_value = initializer
    else:
      # Instantiate initializer if provided initializer is a type object.
      if isinstance(initializer, type(init_ops.Initializer)):
        initializer = initializer(dtype=dtype)
      def initial_value():
        return initializer(
            shape_object.as_list(), dtype=dtype, partition_info=partition_info)
    return resource_variable_ops.ResourceVariable(
        initial_value=initial_value,
        name=name,
        dtype=variable_dtype,
        **kwargs
    )
예제 #3
0
 def getvar(self,
            getter,
            name,
            shape=None,
            dtype=None,
            initializer=None,
            trainable=True,
            collections=None,
            **kwargs):
   """A custom variable getter."""
   # Here, we switch the default graph to the outer graph and ask the
   # variable scope in which the function is defined to give us the
   # variable. The variable is stashed in extra_vars and returned to
   # the caller.
   #
   # We capture these variables so that the variable definition is
   # hoisted upward to the outer most graph.
   with self._outer_graph.as_default():
     # pylint: disable=protected-access
     var = self._vscope.get_variable(
         vs._get_default_variable_store(),
         name,
         shape=shape,
         dtype=dtype,
         initializer=initializer,
         trainable=trainable,
         collections=collections)
     self.extra_vars.append(var)
     return var
예제 #4
0
def _default_initializer(name, shape, dtype):
  """The default initializer for variables."""
  # pylint: disable=protected-access
  store = variable_scope._get_default_variable_store()
  initializer = store._get_default_initializer(name, shape=shape, dtype=dtype)
  # pylint: enable=protected-access
  return initializer[0]
예제 #5
0
 def __init__(self, variable_scope_name):
   self._variable_scope_name = variable_scope_name
   default = variable_scope._get_default_variable_store()  # pylint: disable=protected-access
   if default._store_eager_variables:  # pylint: disable=protected-access
     self._eager_variable_store = variable_scope.EagerVariableStore(default)
   else:
     self._eager_variable_store = variable_scope.EagerVariableStore()
예제 #6
0
 def testNameExists(self):
   vs = variable_scope._get_default_variable_store()
   # No check by default, so we can both create and get existing names.
   v = vs.get_variable("v", [1])
   v1 = vs.get_variable("v", [1])
   assert v == v1
   # When reuse is False, we fail when variables are already there.
   vs.get_variable("w", [1], reuse=False)  # That's ok.
   with self.assertRaises(ValueError):
     vs.get_variable("v", [1], reuse=False)  # That fails.
   # When reuse is True, we fail when variables are new.
   vs.get_variable("v", [1], reuse=True)  # That's ok.
   with self.assertRaises(ValueError):
     vs.get_variable("u", [1], reuse=True)  # That fails.
예제 #7
0
 def getvar(self, name, shape, dtype, initializer, **kwargs):
   """A custom variable getter."""
   # TODO(zhifengc): We probably need to support other 10-ish options
   # vs.get_variable supports.
   #
   # Here, we switch the default graph to the outer graph and ask the
   # variable scope in which the function is defined to give us the
   # variable. The variable is stashed in extra_vars and returned to
   # the caller.
   #
   # We capture these variables so that the variable definition is
   # hoisted upward to the outer most graph.
   with self._outer_graph.as_default():
     # pylint: disable=protected-access
     var = self._vscope.get_variable(vs._get_default_variable_store(), name,
                                     shape, dtype, initializer)
     self.extra_vars.append(var)
     return var
예제 #8
0
  def __init__(self, name, func, create_scope_now=False, unique_name=None,
               custom_getter=None):
    """Creates a template for the given function.

    Args:
      name: A name for the scope created by this template. The
        name will be made unique by appending `_N` to the it (see how
        `tf.variable_scope` treats the `default_name` for details).
      func: The function to apply each time.
      create_scope_now: Whether to create the scope at Template construction
        time, rather than first call. Defaults to false. Creating the scope at
        construction time may be more convenient if the template is passed
        through much lower level code, and you want to be sure of the scope
        name without knowing exactly where it will be first called. If set to
        True, the scope will be created in the constructor, and all subsequent
        times in __call__, leading to a trailing numeral being added to the
        names of all created Tensors. If set to False, the scope will be created
        at the first call location.
      unique_name: When used, it overrides name_ and is not made unique. If a
        template of the same scope/unique_name already exists and reuse is
        false, an error is raised. Defaults to None.
      custom_getter: optional custom getter to pass to variable_scope()

    Raises:
      RuntimeError: if eager mode is not enabled.
      ValueError: if the name is None or unique_name is provided.
    """
    if not context.in_eager_mode():
      raise RuntimeError(
          "{} objects can only be used when eager execution is enabled, use "
          "tf.Template for graph construction".
          format(type(self)))
    if unique_name:
      raise ValueError("unique_name cannot be used in eager mode.")
    super(EagerTemplate, self).__init__(name, func, create_scope_now,
                                        unique_name, custom_getter)
    # Create an eager variable store only if the current variable store cannot
    # store eager variables. This should allow for correct nesting.
    default_vstore = variable_scope._get_default_variable_store()  # pylint: disable=protected-access
    if default_vstore._store_eager_variables:  # pylint: disable=protected-access
      raise ValueError("Nested EagerTemaplates are not currently supported.")
    else:
      self._eager_variable_store = variable_scope.EagerVariableStore()
예제 #9
0
 def getvar(
     self,
     getter,
     name,
     shape=None,
     dtype=None,
     initializer=None,
     reuse=None,
     trainable=True,
     collections=None,  # pylint: disable=redefined-outer-name
     use_resource=None,
     **kwargs):
   """A custom variable getter."""
   # Here, we switch the default graph to the outer graph and ask the
   # variable scope in which the function is defined to give us the
   # variable. The variable is stashed in extra_vars and returned to
   # the caller.
   #
   # We capture these variables so that the variable definition is
   # hoisted upward to the outer most graph.
   with self._outer_graph.as_default():
     # pylint: disable=protected-access
     var = self._vscope.get_variable(
         vs._get_default_variable_store(),
         name,
         shape=shape,
         dtype=dtype,
         initializer=initializer,
         reuse=reuse,
         trainable=trainable,
         collections=collections,
         use_resource=use_resource)
     self.extra_vars.append(var)
     if (isinstance(var, resource_variable_ops.ResourceVariable) and
         self._capture_resource_var_by_value):
       # For resource-based variables read the variable outside the function
       # and pass in the value. This ensures that the function is pure and
       # differentiable. TODO(apassos) this may have performance problems if
       # the function will only do embedding lookups on the variable.
       return var.value()
     return var
예제 #10
0
 def __init__(self, name=None):
   if isinstance(name, variable_scope.VariableScope):
     raise ValueError("VariableScopes are not valid Network names.")
   if name is not None and "/" in name:
     raise ValueError(
         "Forward slashes ('/') are not allowed in Network names.")
   super(Network, self).__init__(name=name)
   self._layers = []
   self._sub_layer_name_uids = collections.defaultdict(int)
   # Initially None, but set to False for networks which are first built as
   # top-level.
   self._first_parent = None  # A weak reference to our first parent.
   self._non_network_sublayers = []
   self._owned_layers = {}
   # The scope to use if we end up without a parent.
   self._default_parent_variable_scope = variable_scope.get_variable_scope()
   # Hold on to the variable scope counts from init to check whether a scope
   # with the name we want was ever created in our parent scope. Without this
   # check we might have name collisions if the parent scope on init gets
   # closed before build is called.
   self._variable_scope_counts_on_init = (
       variable_scope._get_default_variable_store().variable_scopes_count)
예제 #11
0
  def __init__(self, name=None):
    """Configure the `Network`.

    Args:
      name: The name to use for this `Network`. If specified, it must be unique
        in the context where this `Network` is first
         (1) added to another `Network` (in which case it must not share a name
           with other `Layers` added to that `Network`), or
         (2) built/called (in which case no other 'top-level' `Network`s may
          share this name).
        If unspecified or None, the `Network` will be named using its class
        name, with a number appended if necessary for uniqueness (e.g. MyNetwork
        -> 'my_network_1').

    Raises:
      ValueError: If `name` is not valid. Note that some naming errors will
        instead be raised when the `Network` is called.
    """
    if isinstance(name, variable_scope.VariableScope):
      raise ValueError("VariableScopes are not valid Network names.")
    if name is not None and "/" in name:
      raise ValueError(
          "Forward slashes ('/') are not allowed in Network names.")
    super(Network, self).__init__(name=name)
    self._layers = []
    self._sub_layer_name_uids = collections.defaultdict(int)
    # Initially None, but set to False for networks which are first built as
    # top-level.
    self._first_parent = None  # A weak reference to our first parent.
    self._non_network_sublayers = []
    self._owned_layers = {}
    # The scope to use if we end up without a parent.
    self._default_parent_variable_scope = variable_scope.get_variable_scope()
    # Hold on to the variable scope counts from init to check whether a scope
    # with the name we want was ever created in our parent scope. Without this
    # check we might have name collisions if the parent scope on init gets
    # closed before build is called.
    self._variable_scope_counts_on_init = (
        variable_scope._get_default_variable_store().variable_scopes_count)
예제 #12
0
파일: tputil.py 프로젝트: shawwn/ml-notes
def tf_varstore():
    #tf.get_collection(('__variable_store',))[0]
    return vs._get_default_variable_store()
예제 #13
0
 def testGetVar(self):
   vs = variable_scope._get_default_variable_store()
   v = vs.get_variable("v", [1])
   v1 = vs.get_variable("v", [1])
   assert v == v1
예제 #14
0
def _get_variable_dict_from_varstore():
    var_dict = variable_scope._get_default_variable_store()._vars  # pylint: disable=protected-access
    sorted_var_dict = collections.OrderedDict(
        sorted(var_dict.items(), key=lambda t: t[0]))
    return sorted_var_dict
예제 #15
0
def get_variable(
    name,  # unique
    key_dtype=dtypes.int64,
    value_dtype=dtypes.float32,
    dim=1,
    devices=None,
    partitioner=default_partition_fn,
    shared_name="get_variable",
    initializer=None,
    trainable=True,
    checkpoint=True,
    init_size=0,
    restrict_policy=None,
):
    """Gets an `Variable` object with this name if it exists,
         or create a new one.

    Args:
      name: A unique name for the `Variable`.
      key_dtype: the type of the key tensors.
      value_dtype: the type of the value tensors.
      dim: the length of the value array for each key.
      devices: the list of devices holding the tables.
        One table will be created on each device.
      partitioner: partition function of keys,
        return the partition index for each key.

      Example partition func:
      ```python
      def default_partition_fn(keys, shard_num):
        return tf.cast(keys % shard_num, dtype=tf.int32)
      ```
      shared_name: No used.
      initializer: The value to use if a key is missing in the hash table.
        which can a python number, numpy array or `tf.initializer` instances.
        If initializer is `None` (the default), `0` will be used.
      trainable: True, will be treated as a trainable Variable, and add to
        to the list of variables collected in the graph under the key
        `GraphKeys.TRAINABLE_VARIABLES`.
      checkpoint: if True, the contents of the SparseVariable are
        saved to and restored from checkpoints.
        If `shared_name` is empty for a checkpointed table,
        it is shared using the table node name.
      init_size: initial size for the Variable and initial size of each hash 
        tables will be int(init_size / N), N is the number of the devices.
      restrict_policy: a restrict policy to specify the rule to restrict the
        size of variable. If in training program, the variable is updated by
        optimizer, then the sparse slot variables in optimizer are also be
        restricted.

    Returns:
      A `Variable` object.
    """
    var_ = None
    scope = variable_scope.get_variable_scope()
    scope_store = variable_scope._get_default_variable_store()
    full_name = scope.name + "/" + name if scope.name else name
    if full_name in scope_store._vars:
        if scope.reuse is False:
            err_msg = ("Variable %s already exists, disallowed."
                       " Did you mean to set reuse=True or "
                       "reuse=tf.AUTO_REUSE in VarScope?" % full_name)

            raise ValueError(err_msg)
    else:
        var_ = Variable(
            key_dtype=key_dtype,
            value_dtype=value_dtype,
            dim=dim,
            devices=devices,
            partitioner=partitioner,
            shared_name=shared_name,
            name=full_name,
            initializer=initializer,
            trainable=trainable,
            checkpoint=checkpoint,
            init_size=init_size,
            restrict_policy=restrict_policy,
        )
        scope_store._vars[full_name] = var_
    return scope_store._vars[full_name]
예제 #16
0
def get_or_create_layer(name, create_layer_method):
    """Use this method to track nested keras models in a shim-decorated method.

    This method can be used within a `tf.keras.Layer`'s methods decorated by
    the`track_tf1_style_variables` shim, to additionally track inner keras Model
    objects created within the same method. The inner model's variables and
    losses will be accessible via the outer model's `variables` and `losses`
    attributes.

    This enables tracking of inner keras models using TF2 behaviors, with
    minimal changes to existing TF1-style code.

    Example:

    ```python
    class NestedLayer(tf.keras.layers.Layer):

      def __init__(self, units, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.units = units

      def build_model(self):
        inp = tf.keras.Input(shape=(5, 5))
        dense_layer = tf.keras.layers.Dense(
            10, name="dense", kernel_regularizer="l2",
            kernel_initializer=tf.compat.v1.ones_initializer())
        model = tf.keras.Model(inputs=inp, outputs=dense_layer(inp))
        return model

      @tf.compat.v1.keras.utils.track_tf1_style_variables
      def call(self, inputs):
        model = tf.compat.v1.keras.utils.get_or_create_layer(
            "dense_model", self.build_model)
        return model(inputs)
    ```
    The inner model creation should be confined to its own zero-arg function,
    which should be passed into this method. In TF1, this method will
    immediately create and return the desired model, without any tracking.

    Args:
      name: A name to give the nested layer to track.
      create_layer_method: a Callable that takes no args and returns the nested
      layer.

    Returns:
      The created layer.
    """
    store = vs._get_default_variable_store()
    if not isinstance(store, _EagerVariableStore):
        if not tf.compat.v1.executing_eagerly_outside_functions():
            # tf1 case; just create and return layer
            return create_layer_method()
        else:
            raise ValueError(
                "Tried to call get_or_create_layer in eager mode from a method "
                "notdecorated with "
                "@tf.compat.v1.keras.utils.track_tf1_style_variables."
            )
    vs_name = tf.compat.v1.get_variable_scope().name
    name = f"{vs_name}/{name}"
    return store.get_or_create_layer(name, create_layer_method)
예제 #17
0
 def testGetVar(self):
     vs = variable_scope._get_default_variable_store()
     v = vs.get_variable("v", [1])
     v1 = vs.get_variable("v", [1])
     self.assertEqual(v, v1)
예제 #18
0
 def testResource(self):
     vs = variable_scope._get_default_variable_store()
     v1 = vs.get_variable("v", [1], use_resource=True)
     self.assertTrue(isinstance(v1, resource_variable_ops.ResourceVariable))
예제 #19
0
 def testResource(self):
   vs = variable_scope._get_default_variable_store()
   v1 = vs.get_variable("v", [1], use_resource=True)
   self.assertTrue(isinstance(v1, resource_variable_ops.ResourceVariable))
예제 #20
0
def init_from_checkpoint(checkpoint_dir, assignment_map):
  """Using assingment map initializes current variables with loaded tensors.

  Note: This overrides default initialization ops of specified variables and
  redefines dtype.

  Assignment map supports next syntax:
    `'scope_name/': 'checkpoint_scope_name/'` - will load all variables in
      current `scope_name` from `checkpoint_scope_name` with matching variable
      names.
    `'scope_name/variable_name': 'checkpoint_scope_name/some_other_variable'` -
      will initalize `scope_name/variable_name` variable
      from `checkpoint_scope_name/some_other_variable`.
    `variable: 'scope_varaible_name'` - will initialize given variable with
      variable from the checkpoint.
    `'scope_name/': '/'` - will load all variables in current `scope_name` from
      checkpoint's root (e.g. no scope).

  Supports loading into partitioned variables, which are represented as
  '<variable>/part_<part #>'.

  Example:
  ```python
    # Create variables.
    with tf.variable_scope('test'):
      m = tf.get_variable('my_var')
    with tf.variable_scope('test2'):
      var2 = tf.get_variable('my_var')
    ...
    # Specify which variables to intialize from checkpoint.
    init_from_checkpoint(checkpoint_dir, {
      'test/my_var': 'some_var',
      'test2/', 'some_scope/'})
    ...
    # Or use `Variable` objects to identify what to initialize.
    init_from_checkpoint(checkpoint_dir, {
      var2: 'some_scope/var2',
    })
    ...
    # Initialize variables as usual.
    session.run(tf.get_all_variables())
  ```

  Args:
    checkpoint_dir: Directory with checkpoints file or path to checkpoint.
    assignment_map: Dict, where keys are names of current variables
                    (in default graph) and values are names of the variables
                    in the checkpoint.

  Raises:
    tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
    ValueError: If missing variables in current graph.
  """
  filepattern = _get_checkpoint_filename(checkpoint_dir)
  reader = load_checkpoint(checkpoint_dir)
  variable_map = reader.get_variable_to_shape_map()
  for current_name, tensor_name in six.iteritems(assignment_map):
    scopes = ""
    var = None
    # Check if this is Variable object.
    if isinstance(current_name, variables.Variable):
      var = current_name
    else:
      var_scope = vs._get_default_variable_store()
      # Check if this is variable in var_store.
      var = var_scope._vars.get(current_name, None)
      # Also check if variable is partitioned as list.
      if var is None:
        if current_name + "/part_0" in var_scope._vars:
          var = []
          i = 0
          while current_name + "/part_%d" % i in var_scope._vars:
            var.append(var_scope._vars[current_name + "/part_%d" % i])
            i += 1
    if var is not None:
      # If 1 to 1 mapping was provided, find variable in the scope.
      if tensor_name not in variable_map:
        raise ValueError("Tensor %s is not found in %s checkpoint" % (
            tensor_name, checkpoint_dir
        ))
      if isinstance(var, variables.Variable):
        # Additional at-call-time checks.
        if not var.get_shape().is_compatible_with(variable_map[tensor_name]):
          raise ValueError(
              "Shape of variable %s (%s) doesn't match with shape of "
              "tensor %s (%s) from checkpoint reader." % (
                  var.name, str(var.get_shape()),
                  tensor_name, str(variable_map[tensor_name])
              ))
      _set_variable_or_list_initializer(var, filepattern, tensor_name)
      logging.info("Initialize variable %s from checkpoint %s with %s" % (
          current_name, checkpoint_dir, tensor_name
      ))
    else:
      if "/" in current_name:
        scopes = current_name[:current_name.rindex("/")]
        current_name = current_name[current_name.rindex("/") + 1:]
      if not tensor_name.endswith("/"):
        raise ValueError(
            "Assignment map with scope only name (%s) "
            "should map to scope only (%s). "
            "Should be 'scope/': 'other_scope/'." % (
                scopes, tensor_name
            ))
      # If scope to scope mapping was provided, find all variables in the scope.
      for var_name in var_scope._vars:
        if var_name.startswith(scopes):
          # Lookup name with specified prefix and suffix from current variable.
          # If tensor_name given is '/' (root), don't use it for full name.
          if tensor_name != "/":
            full_tensor_name = tensor_name + var_name[len(scopes) + 1:]
          else:
            full_tensor_name = var_name[len(scopes) + 1:]
          if full_tensor_name not in variable_map:
            raise ValueError(
                "Tensor %s (%s in %s) is not found in %s checkpoint" % (
                    full_tensor_name, var_name[len(scopes) + 1:], tensor_name,
                    checkpoint_dir
                ))
          var = var_scope._vars[var_name]
          _set_variable_or_list_initializer(var, filepattern, full_tensor_name)
          logging.info("Initialize variable %s from checkpoint %s with %s" % (
              var_name, checkpoint_dir, tensor_name
          ))
예제 #21
0
def init_from_checkpoint(checkpoint_dir, assignment_map):
  """Using assingment map initializes current variables with loaded tensors.

  Note: This overrides default initialization ops of specified variables and
  redefines dtype.

  Assignment map supports next syntax:
    `'scope_name/': 'checkpoint_scope_name/'` - will load all variables in
      current `scope_name` from `checkpoint_scope_name` with matching variable
      names.
    `'scope_name/variable_name': 'checkpoint_scope_name/some_other_variable'` -
    will initalize `scope_name/variable_name` variable
    from `checkpoint_scope_name/some_other_variable`.

  Example:
  ```python
    # Create variables.
    with tf.variable_scope('test'):
      m = tf.get_variable('my_var')
    with tf.variable_scope('test2'):
      m = tf.get_variable('my_var')
    ...
    # Specify which variables to intialize from checkpoint.
    init_from_checkpoint(checkpoint_dir, {
      'test/my_var': 'some_var',
      'test2/', 'some_scope/'})
    ...
    # Initialize variables as usual.
    session.run(tf.get_all_variables())
  ```

  Args:
    checkpoint_dir: Directory with checkpoints file or path to checkpoint.
    assignment_map: Dict, where keys are names of current variables
                    (in default graph) and values are names of the variables
                    in the checkpoint.

  Raises:
    tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
    ValueError: If missing variables in current graph.
  """
  reader = load_checkpoint(checkpoint_dir)
  variable_map = reader.get_variable_to_shape_map()
  for current_name, tensor_name in six.iteritems(assignment_map):
    scopes = ""
    if "/" in current_name:
      scopes = current_name[:current_name.rindex("/")]
      current_name = current_name[current_name.rindex("/") + 1:]
    if current_name:
      # If 1 to 1 mapping was provided, find variable in the scope.
      if tensor_name not in variable_map:
        raise ValueError("Tensor %s is not found in %s checkpoint" % (
            tensor_name, checkpoint_dir
        ))
      with vs.variable_scope(scopes, reuse=True):
        var = vs.get_variable(current_name)
        var._initializer_op = _checkpoint_initializer(var, reader, tensor_name)  # pylint: disable=protected-access
        logging.info("Initialize variable %s from checkpoint %s with %s" % (
            var.name, checkpoint_dir, tensor_name
        ))
    else:
      if not tensor_name.endswith("/"):
        raise ValueError(
            "Assignment map with scope only name (%s) "
            "should map to scope only (%s). "
            "Should be 'scope/': 'other_scope/'." % (
                scopes, tensor_name
            ))
      # If scope to scope mapping was provided, find all variables in the scope.
      # TODO(ipolosukhin): Refactor variable_scope module to provide nicer APIs.
      var_scope = vs._get_default_variable_store()  # pylint: disable=protected-access
      for var_name in var_scope._vars:  # pylint: disable=protected-access
        if var_name.startswith(scopes):
          # Lookup name with specified prefix and suffix from current variable.
          full_tensor_name = tensor_name + var_name[len(scopes) + 1:]
          if full_tensor_name not in variable_map:
            raise ValueError(
                "Tensor %s (%s in %s) is not found in %s checkpoint" % (
                    full_tensor_name, var_name[len(scopes) + 1:], tensor_name,
                    checkpoint_dir
                ))
          var = var_scope._vars[var_name]  # pylint: disable=protected-access
          var._initializer_op = _checkpoint_initializer(  # pylint: disable=protected-access
              var, reader, full_tensor_name)
          logging.info("Initialize variable %s from checkpoint %s with %s" % (
              var_name, checkpoint_dir, tensor_name
          ))
예제 #22
0
    def add_weight(self,
                   name,
                   shape,
                   dtype=None,
                   initializer=None,
                   regularizer=None,
                   trainable=None,
                   constraint=None,
                   use_resource=None,
                   synchronization=vs.VariableSynchronization.AUTO,
                   aggregation=vs.VariableAggregation.NONE,
                   partitioner=None,
                   **kwargs):
        """Adds a new variable to the layer, or gets an existing one; returns it.

    Args:
      name: variable name.
      shape: variable shape.
      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
      initializer: initializer instance (callable).
      regularizer: regularizer instance (callable).
      trainable: whether the variable should be part of the layer's
        "trainable_variables" (e.g. variables, biases)
        or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
        Note, if the current variable scope is marked as non-trainable
        then this parameter is ignored and any added variables are also
        marked as non-trainable. `trainable` defaults to `True` unless
        `synchronization` is set to `ON_READ`.
      constraint: constraint instance (callable).
      use_resource: Whether to use `ResourceVariable`.
      synchronization: Indicates when a distributed a variable will be
        aggregated. Accepted values are constants defined in the class
        `tf.VariableSynchronization`. By default the synchronization is set to
        `AUTO` and the current `DistributionStrategy` chooses
        when to synchronize. If `synchronization` is set to `ON_READ`,
        `trainable` must not be set to `True`.
      aggregation: Indicates how a distributed variable will be aggregated.
        Accepted values are constants defined in the class
        `tf.VariableAggregation`.
      partitioner: (optional) partitioner instance (callable).  If
        provided, when the requested variable is created it will be split
        into multiple partitions according to `partitioner`.  In this case,
        an instance of `PartitionedVariable` is returned.  Available
        partitioners include `tf.compat.v1.fixed_size_partitioner` and
        `tf.compat.v1.variable_axis_size_partitioner`.  For more details, see
        the documentation of `tf.compat.v1.get_variable` and the  "Variable
        Partitioners and Sharding" section of the API guide.
      **kwargs: Additional keyword arguments.

    Returns:
      The created variable.  Usually either a `Variable` or `ResourceVariable`
      instance.  If `partitioner` is not `None`, a `PartitionedVariable`
      instance is returned.

    Raises:
      RuntimeError: If called with partitioned variable regularization and
        eager execution is enabled.
      ValueError: When trainable has been set to True with synchronization
        set as `ON_READ`.
    """
        for kwarg in kwargs:
            if kwarg != 'experimental_autocast':
                raise TypeError('Unknown keyword argument:', kwarg)
        if self._keras_style:
            return super(Layer, self).add_weight(
                name=name,
                shape=shape,
                dtype=dtype,
                initializer=initializer,
                regularizer=regularizer,
                trainable=trainable and self.trainable,
                constraint=constraint,
                use_resource=use_resource,
                synchronization=vs.VariableSynchronization.AUTO,
                aggregation=vs.VariableAggregation.NONE,
                partitioner=partitioner,
                **kwargs)

        if synchronization == vs.VariableSynchronization.ON_READ:
            if trainable:
                raise ValueError(
                    'Synchronization value can be set to '
                    'VariableSynchronization.ON_READ only for non-trainable variables. '
                    'You have specified trainable=True and '
                    'synchronization=VariableSynchronization.ON_READ.')
            else:
                # Set trainable to be false when variable is to be synced on read.
                trainable = False
        elif trainable is None:
            trainable = True

        def _should_add_regularizer(variable, existing_variable_set):
            if base_layer_utils.is_split_variable(variable):
                for var in variable:
                    if var in existing_variable_set:
                        return False
                return True
            else:
                return variable not in existing_variable_set

        init_graph = None
        if not context.executing_eagerly():
            default_graph = ops.get_default_graph()
            if default_graph.building_function:
                with ops.init_scope():
                    # Retrieve the variables from the graph into which variables
                    # will be lifted; if initialization ops will be lifted into
                    # the eager context, then there is nothing to retrieve, since variable
                    # collections are not supported when eager execution is enabled.
                    if not context.executing_eagerly():
                        init_graph = ops.get_default_graph()
                        existing_variables = set(
                            tf_variables.global_variables())
            else:
                # Initialization ops will not be lifted out of the default graph.
                init_graph = default_graph
                existing_variables = set(tf_variables.global_variables())

        if dtype is None:
            dtype = self.dtype or dtypes.float32

        self._set_scope(None)
        reuse = self.built or self._reuse
        prev_len_trainable = len(self._trainable_weights)
        with vs.variable_scope(self._scope,
                               reuse=reuse,
                               auxiliary_name_scope=False) as scope:
            self._current_scope = scope
            with backend.name_scope(self._name_scope()):  # pylint: disable=not-callable
                use_resource = (use_resource or self._use_resource_variables
                                or scope.use_resource)
                if initializer is None:
                    initializer = scope.initializer
                variable = super(Layer, self).add_weight(
                    name,
                    shape,
                    dtype=dtypes.as_dtype(dtype),
                    initializer=initializer,
                    trainable=trainable and self.trainable,
                    constraint=constraint,
                    partitioner=partitioner,
                    use_resource=use_resource,
                    synchronization=synchronization,
                    aggregation=aggregation,
                    getter=vs.get_variable,
                    **kwargs)

                if regularizer:
                    if (ops.executing_eagerly_outside_functions()
                            or _should_add_regularizer(variable,
                                                       existing_variables)):
                        self._handle_weight_regularization(
                            name, variable, regularizer)
                        var_store = vs._get_default_variable_store()  # pylint: disable=protected-access
                        # When the shim to get variable scope working in TF2 is used,
                        # We need to explicitly make the shim track the regularization
                        # losses as the collections will not be accessible.
                        if hasattr(var_store, 'add_regularizer'):
                            var_store.add_regularizer(variable, regularizer)

                if init_graph is not None:
                    # Handle edge case where a custom getter has overridden `trainable`.
                    # There is one known occurrence of this, in unit test
                    # testBasicRNNCellNotTrainable in
                    # contrib.rnn.python.kernel_tests.core_rnn_cell_test
                    with init_graph.as_default():
                        trainable_variables = tf_variables.trainable_variables(
                        )
                    if (trainable and self.trainable
                            and variable not in trainable_variables):
                        # A custom getter / variable scope overrode the trainable flag.
                        extra_trainable_vars = self._trainable_weights[
                            prev_len_trainable:]
                        self._trainable_weights = self._trainable_weights[:
                                                                          prev_len_trainable]
                        self._non_trainable_weights += extra_trainable_vars
        return variable
예제 #23
0
 def testGetVar(self):
   vs = variable_scope._get_default_variable_store()
   v = vs.get_variable("v", [1])
   v1 = vs.get_variable("v", [1])
   self.assertEqual(v, v1)
예제 #24
0
def init_from_checkpoint(checkpoint_dir, assignment_map):
    """Using assingment map initializes current variables with loaded tensors.

  Note: This overrides default initialization ops of specified variables and
  redefines dtype.

  Assignment map supports next syntax:
    `'scope_name/': 'checkpoint_scope_name/'` - will load all variables in
      current `scope_name` from `checkpoint_scope_name` with matching variable
      names.
    `'scope_name/variable_name': 'checkpoint_scope_name/some_other_variable'` -
      will initalize `scope_name/variable_name` variable
      from `checkpoint_scope_name/some_other_variable`.
    `variable: 'scope_varaible_name'` - will initialize given variable with
      variable from the checkpoint.
    `'scope_name/': '/'` - will load all variables in current `scope_name` from
      checkpoint's root (e.g. no scope).

  Supports loading into partitioned variables, which are represented as
  '<variable>/part_<part #>'.

  Example:
  ```python
    # Create variables.
    with tf.variable_scope('test'):
      m = tf.get_variable('my_var')
    with tf.variable_scope('test2'):
      var2 = tf.get_variable('my_var')
    ...
    # Specify which variables to intialize from checkpoint.
    init_from_checkpoint(checkpoint_dir, {
      'test/my_var': 'some_var',
      'test2/', 'some_scope/'})
    ...
    # Or use `Variable` objects to identify what to initialize.
    init_from_checkpoint(checkpoint_dir, {
      var2: 'some_scope/var2',
    })
    ...
    # Initialize variables as usual.
    session.run(tf.get_all_variables())
  ```

  Args:
    checkpoint_dir: Directory with checkpoints file or path to checkpoint.
    assignment_map: Dict, where keys are names of current variables
                    (in default graph) and values are names of the variables
                    in the checkpoint.

  Raises:
    tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
    ValueError: If missing variables in current graph.
  """
    filepattern = _get_checkpoint_filename(checkpoint_dir)
    reader = load_checkpoint(checkpoint_dir)
    variable_map = reader.get_variable_to_shape_map()
    for current_name, tensor_name in six.iteritems(assignment_map):
        scopes = ""
        var = None
        # Check if this is Variable object.
        if isinstance(current_name, variables.Variable):
            var = current_name
        else:
            var_scope = vs._get_default_variable_store()
            # Check if this is variable in var_store.
            var = var_scope._vars.get(current_name, None)
            # Also check if variable is partitioned as list.
            if var is None:
                if current_name + "/part_0" in var_scope._vars:
                    var = []
                    i = 0
                    while current_name + "/part_%d" % i in var_scope._vars:
                        var.append(var_scope._vars[current_name +
                                                   "/part_%d" % i])
                        i += 1
        if var is not None:
            # If 1 to 1 mapping was provided, find variable in the scope.
            if tensor_name not in variable_map:
                raise ValueError("Tensor %s is not found in %s checkpoint" %
                                 (tensor_name, checkpoint_dir))
            if isinstance(var, variables.Variable):
                # Additional at-call-time checks.
                if not var.get_shape().is_compatible_with(
                        variable_map[tensor_name]):
                    raise ValueError(
                        "Shape of variable %s (%s) doesn't match with shape of "
                        "tensor %s (%s) from checkpoint reader." %
                        (var.name, str(var.get_shape()), tensor_name,
                         str(variable_map[tensor_name])))
            _set_variable_or_list_initializer(var, filepattern, tensor_name)
            logging.info("Initialize variable %s from checkpoint %s with %s" %
                         (current_name, checkpoint_dir, tensor_name))
        else:
            if "/" in current_name:
                scopes = current_name[:current_name.rindex("/")]
                current_name = current_name[current_name.rindex("/") + 1:]
            if not tensor_name.endswith("/"):
                raise ValueError("Assignment map with scope only name (%s) "
                                 "should map to scope only (%s). "
                                 "Should be 'scope/': 'other_scope/'." %
                                 (scopes, tensor_name))
            # If scope to scope mapping was provided, find all variables in the scope.
            for var_name in var_scope._vars:
                if var_name.startswith(scopes):
                    # Lookup name with specified prefix and suffix from current variable.
                    # If tensor_name given is '/' (root), don't use it for full name.
                    if tensor_name != "/":
                        full_tensor_name = tensor_name + var_name[len(scopes) +
                                                                  1:]
                    else:
                        full_tensor_name = var_name[len(scopes) + 1:]
                    if full_tensor_name not in variable_map:
                        raise ValueError(
                            "Tensor %s (%s in %s) is not found in %s checkpoint"
                            % (full_tensor_name, var_name[len(scopes) + 1:],
                               tensor_name, checkpoint_dir))
                    var = var_scope._vars[var_name]
                    _set_variable_or_list_initializer(var, filepattern,
                                                      full_tensor_name)
                    logging.info(
                        "Initialize variable %s from checkpoint %s with %s" %
                        (var_name, checkpoint_dir, tensor_name))
예제 #25
0
def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
  """Initializes current variables with tensors loaded from given checkpoint.

  Note: This overrides default initialization ops of specified variables and
  redefines dtype.

  Assignment map supports following syntax:

  * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
    current `scope_name` from `checkpoint_scope_name` with matching tensor
    names.
  * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
    will initialize `scope_name/variable_name` variable
    from `checkpoint_scope_name/some_other_variable`.
  * `'scope_variable_name': variable` - will initialize given `tf.Variable`
    object with tensor 'scope_variable_name' from the checkpoint.
  * `'scope_variable_name': list(variable)` - will initialize list of
    partitioned variables with tensor 'scope_variable_name' from the checkpoint.
  * `'/': 'scope_name/'` - will load all variables in current `scope_name` from
    checkpoint's root (e.g. no scope).

  Supports loading into partitioned variables, which are represented as
  `'<variable>/part_<part #>'`.

  Example:

  ```python

  # Say, '/tmp/model.ckpt' has the following tensors:
  #  -- name='old_scope_1/var1', shape=[20, 2]
  #  -- name='old_scope_1/var2', shape=[50, 4]
  #  -- name='old_scope_2/var3', shape=[100, 100]

  # Create new model's variables
  with tf.variable_scope('new_scope_1'):
    var1 = tf.get_variable('var1', shape=[20, 2],
                           initializer=tf.zeros_initializer())
  with tf.variable_scope('new_scope_2'):
    var2 = tf.get_variable('var2', shape=[50, 4],
                           initializer=tf.zeros_initializer())
    # Partition into 5 variables along the first axis.
    var3 = tf.get_variable(name='var3', shape=[100, 100],
                           initializer=tf.zeros_initializer(),
                           partitioner=lambda shape, dtype: [5, 1])

  # Initialize all variables in `new_scope_1` from `old_scope_1`.
  init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/', 'new_scope_1'})

  # Use names to specify which variables to initialize from checkpoint.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_1/var1': 'new_scope_1/var1',
                        'old_scope_1/var2': 'new_scope_2/var2'})

  # Or use tf.Variable objects to identify what to initialize.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_1/var1': var1,
                        'old_scope_1/var2': var2})

  # Initialize partitioned variables using variable's name
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_2/var3': 'new_scope_2/var3'})

  # Or specify the list of tf.Variable objects.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_2/var3': var3._get_variable_list()})

  ```

  Args:
    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
    assignment_map: Dict, where keys are names of the variables in the
      checkpoint and values are current variables or names of current variables
      (in default graph).

  Raises:
    tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
    ValueError: If missing variables in current graph.
  """
  ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file)
  reader = load_checkpoint(ckpt_dir_or_file)
  variable_map = reader.get_variable_to_shape_map()
  for tensor_name_in_ckpt, current_var_or_name in sorted(
      six.iteritems(assignment_map)):
    var = None
    # Check if this is Variable object or list of Variable objects (in case of
    # partitioned variables).
    is_var = lambda x: isinstance(x, variables.Variable)
    if is_var(current_var_or_name) or (
        isinstance(current_var_or_name, list)
        and all(is_var(v) for v in current_var_or_name)):
      var = current_var_or_name
    else:
      store_vars = vs._get_default_variable_store()._vars  # pylint:disable=protected-access
      # Check if this variable is in var_store.
      var = store_vars.get(current_var_or_name, None)
      # Also check if variable is partitioned as list.
      if var is None:
        var = _collect_partitioned_variable(current_var_or_name, store_vars)
    if var is not None:
      # If 1 to 1 mapping was provided, find variable in the checkpoint.
      if tensor_name_in_ckpt not in variable_map:
        raise ValueError("Tensor %s is not found in %s checkpoint %s" % (
            tensor_name_in_ckpt, ckpt_dir_or_file, variable_map
        ))
      if is_var(var):
        # Additional at-call-time checks.
        if not var.get_shape().is_compatible_with(
            variable_map[tensor_name_in_ckpt]):
          raise ValueError(
              "Shape of variable %s (%s) doesn't match with shape of "
              "tensor %s (%s) from checkpoint reader." % (
                  var.name, str(var.get_shape()),
                  tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt])
              ))
        var_name = var.name
      else:
        var_name = ",".join([v.name for v in var])
      _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
      logging.debug("Initialize variable %s from checkpoint %s with %s",
                    var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
    else:
      scopes = ""
      # TODO(vihanjain): Support list of 'current_var_or_name' here.
      if "/" in current_var_or_name:
        scopes = current_var_or_name[:current_var_or_name.rindex("/")]
      if not tensor_name_in_ckpt.endswith("/"):
        raise ValueError(
            "Assignment map with scope only name {} should map to scope only "
            "{}. Should be 'scope/': 'other_scope/'.".format(
                scopes, tensor_name_in_ckpt))
      # If scope to scope mapping was provided, find all variables in the scope
      # and create variable to variable mapping.
      scope_variables = set()
      for var_name in store_vars:
        if not scopes or var_name.startswith(scopes + "/"):
          # Consume /part_ if partitioned variable.
          if "/part_" in var_name:
            var_name = var_name[:var_name.index("/part_")]
          scope_variables.add(var_name)
      for var_name in sorted(scope_variables):
        # Lookup name with specified prefix and suffix from current variable.
        # If tensor_name given is '/' (root), don't use it for full name.
        full_tensor_name = var_name[len(scopes):]
        if current_var_or_name != "/":
          full_tensor_name = full_tensor_name[1:]
        if tensor_name_in_ckpt != "/":
          full_tensor_name = tensor_name_in_ckpt + full_tensor_name
        # Remove trailing '/', if any, in the full_tensor_name
        if full_tensor_name.endswith("/"):
          full_tensor_name = full_tensor_name[:-1]
        if full_tensor_name not in variable_map:
          raise ValueError(
              "Tensor %s (%s in %s) is not found in %s checkpoint" % (
                  full_tensor_name, var_name[len(scopes) + 1:],
                  tensor_name_in_ckpt, ckpt_dir_or_file
              ))
        var = store_vars.get(var_name, None)
        if var is None:
          var = _collect_partitioned_variable(var_name, store_vars)
        _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name)
        logging.debug("Initialize variable %s from checkpoint %s with %s",
                      var_name, ckpt_dir_or_file, full_tensor_name)
예제 #26
0
def _init_from_checkpoint(_, ckpt_dir_or_file, assignment_map):
  """See `init_from_checkpoint` for documentation."""

  ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file)
  reader = load_checkpoint(ckpt_dir_or_file)
  variable_map = reader.get_variable_to_shape_map()
  for tensor_name_in_ckpt, current_var_or_name in sorted(
      six.iteritems(assignment_map)):
    var = None
    # Check if this is Variable object or list of Variable objects (in case of
    # partitioned variables).
    if _is_variable(current_var_or_name) or (
        isinstance(current_var_or_name, list)
        and all(_is_variable(v) for v in current_var_or_name)):
      var = current_var_or_name
    else:
      store_vars = vs._get_default_variable_store()._vars  # pylint:disable=protected-access
      # Check if this variable is in var_store.
      var = store_vars.get(current_var_or_name, None)
      # Also check if variable is partitioned as list.
      if var is None:
        var = _collect_partitioned_variable(current_var_or_name, store_vars)
    if var is not None:
      # If 1 to 1 mapping was provided, find variable in the checkpoint.
      if tensor_name_in_ckpt not in variable_map:
        raise ValueError("Tensor %s is not found in %s checkpoint %s" % (
            tensor_name_in_ckpt, ckpt_dir_or_file, variable_map
        ))
      if _is_variable(var):
        # Additional at-call-time checks.
        if not var.get_shape().is_compatible_with(
            variable_map[tensor_name_in_ckpt]):
          raise ValueError(
              "Shape of variable %s (%s) doesn't match with shape of "
              "tensor %s (%s) from checkpoint reader." % (
                  var.name, str(var.get_shape()),
                  tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt])
              ))
        var_name = var.name
      else:
        var_name = ",".join([v.name for v in var])
      _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
      logging.debug("Initialize variable %s from checkpoint %s with %s",
                    var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
    else:
      scopes = ""
      # TODO(vihanjain): Support list of 'current_var_or_name' here.
      if "/" in current_var_or_name:
        scopes = current_var_or_name[:current_var_or_name.rindex("/")]
      if not tensor_name_in_ckpt.endswith("/"):
        raise ValueError(
            "Assignment map with scope only name {} should map to scope only "
            "{}. Should be 'scope/': 'other_scope/'.".format(
                scopes, tensor_name_in_ckpt))
      # If scope to scope mapping was provided, find all variables in the scope
      # and create variable to variable mapping.
      scope_variables = set()
      for var_name in store_vars:
        if not scopes or var_name.startswith(scopes + "/"):
          # Consume /part_ if partitioned variable.
          if "/part_" in var_name:
            var_name = var_name[:var_name.index("/part_")]
          scope_variables.add(var_name)
      for var_name in sorted(scope_variables):
        # Lookup name with specified prefix and suffix from current variable.
        # If tensor_name given is '/' (root), don't use it for full name.
        full_tensor_name = var_name[len(scopes):]
        if current_var_or_name != "/":
          full_tensor_name = full_tensor_name[1:]
        if tensor_name_in_ckpt != "/":
          full_tensor_name = tensor_name_in_ckpt + full_tensor_name
        # Remove trailing '/', if any, in the full_tensor_name
        if full_tensor_name.endswith("/"):
          full_tensor_name = full_tensor_name[:-1]
        if full_tensor_name not in variable_map:
          raise ValueError(
              "Tensor %s (%s in %s) is not found in %s checkpoint" % (
                  full_tensor_name, var_name[len(scopes) + 1:],
                  tensor_name_in_ckpt, ckpt_dir_or_file
              ))
        var = store_vars.get(var_name, None)
        if var is None:
          var = _collect_partitioned_variable(var_name, store_vars)
        _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name)
        logging.debug("Initialize variable %s from checkpoint %s with %s",
                      var_name, ckpt_dir_or_file, full_tensor_name)
예제 #27
0
 def testGetVar(self):
   vs = variable_scope._get_default_variable_store()
   v = vs.get_variable("v", [1])
   v1 = vs.get_variable("v", [1])
   assert v == v1
예제 #28
0
def _get_variable_dict_from_varstore():
  var_dict = variable_scope._get_default_variable_store()._vars  # pylint: disable=protected-access
  sorted_var_dict = collections.OrderedDict(
      sorted(var_dict.items(), key=lambda t: t[0]))
  return sorted_var_dict
예제 #29
0
def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
    """Initializes current variables with tensors loaded from given checkpoint.

  Note: This overrides default initialization ops of specified variables and
  redefines dtype.

  Assignment map supports following syntax:

  * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
    current `scope_name` from `checkpoint_scope_name` with matching tensor
    names.
  * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
    will initialize `scope_name/variable_name` variable
    from `checkpoint_scope_name/some_other_variable`.
  * `'scope_variable_name': variable` - will initialize given `tf.Variable`
    object with tensor 'scope_variable_name' from the checkpoint.
  * `'scope_variable_name': list(variable)` - will initialize list of
    partitioned variables with tensor 'scope_variable_name' from the checkpoint.
  * `'/': 'scope_name/'` - will load all variables in current `scope_name` from
    checkpoint's root (e.g. no scope).

  Supports loading into partitioned variables, which are represented as
  `'<variable>/part_<part #>'`.

  Example:

  ```python

  # Say, '/tmp/model.ckpt' has the following tensors:
  #  -- name='old_scope_1/var1', shape=[20, 2]
  #  -- name='old_scope_1/var2', shape=[50, 4]
  #  -- name='old_scope_2/var3', shape=[100, 100]

  # Create new model's variables
  with tf.variable_scope('new_scope_1'):
    var1 = tf.get_variable('var1', shape=[20, 2],
                           initializer=tf.zeros_initializer())
  with tf.variable_scope('new_scope_2'):
    var2 = tf.get_variable('var2', shape=[50, 4],
                           initializer=tf.zeros_initializer())
    # Partition into 5 variables along the first axis.
    var3 = tf.get_variable(name='var3', shape=[100, 100],
                           initializer=tf.zeros_initializer(),
                           partitioner=lambda shape, dtype: [5, 1])

  # Initialize all variables in `new_scope_1` from `old_scope_1`.
  init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/', 'new_scope_1'})

  # Use names to specify which variables to initialize from checkpoint.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_1/var1': 'new_scope_1/var1',
                        'old_scope_1/var2': 'new_scope_2/var2'})

  # Or use tf.Variable objects to identify what to initialize.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_1/var1': var1,
                        'old_scope_1/var2': var2})

  # Initialize partitioned variables using variable's name
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_2/var3': 'new_scope_2/var3'})

  # Or specify the list of tf.Variable objects.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_2/var3': var3._get_variable_list()})

  ```

  Args:
    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
    assignment_map: Dict, where keys are names of the variables in the
      checkpoint and values are current variables or names of current variables
      (in default graph).

  Raises:
    tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
    ValueError: If missing variables in current graph.
  """
    ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file)
    reader = load_checkpoint(ckpt_dir_or_file)
    variable_map = reader.get_variable_to_shape_map()
    for tensor_name_in_ckpt, current_var_or_name in six.iteritems(
            assignment_map):
        var = None
        # Check if this is Variable object or list of Variable objects (in case of
        # partitioned variables).
        is_var = lambda x: isinstance(x, variables.Variable)
        if is_var(current_var_or_name) or (
                isinstance(current_var_or_name, list)
                and all(is_var(v) for v in current_var_or_name)):
            var = current_var_or_name
        else:
            store_vars = vs._get_default_variable_store()._vars  # pylint:disable=protected-access
            # Check if this variable is in var_store.
            var = store_vars.get(current_var_or_name, None)
            # Also check if variable is partitioned as list.
            if var is None:
                var = _collect_partitioned_variable(current_var_or_name,
                                                    store_vars)
        if var is not None:
            # If 1 to 1 mapping was provided, find variable in the checkpoint.
            if tensor_name_in_ckpt not in variable_map:
                raise ValueError(
                    "Tensor %s is not found in %s checkpoint %s" %
                    (tensor_name_in_ckpt, ckpt_dir_or_file, variable_map))
            if is_var(var):
                # Additional at-call-time checks.
                if not var.get_shape().is_compatible_with(
                        variable_map[tensor_name_in_ckpt]):
                    raise ValueError(
                        "Shape of variable %s (%s) doesn't match with shape of "
                        "tensor %s (%s) from checkpoint reader." %
                        (var.name, str(var.get_shape()), tensor_name_in_ckpt,
                         str(variable_map[tensor_name_in_ckpt])))
                var_name = var.name
            else:
                var_name = ",".join([v.name for v in var])
            _set_variable_or_list_initializer(var, ckpt_file,
                                              tensor_name_in_ckpt)
            logging.info("Initialize variable %s from checkpoint %s with %s",
                         var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
        else:
            scopes = ""
            # TODO(vihanjain): Support list of 'current_var_or_name' here.
            if "/" in current_var_or_name:
                scopes = current_var_or_name[:current_var_or_name.rindex("/")]
            if not tensor_name_in_ckpt.endswith("/"):
                raise ValueError(
                    "Assignment map with scope only name {} should map to scope only "
                    "{}. Should be 'scope/': 'other_scope/'.".format(
                        scopes, tensor_name_in_ckpt))
            # If scope to scope mapping was provided, find all variables in the scope
            # and create variable to variable mapping.
            scope_variables = set()
            for var_name in store_vars:
                if not scopes or var_name.startswith(scopes + "/"):
                    # Consume /part_ if partitioned variable.
                    if "/part_" in var_name:
                        var_name = var_name[:var_name.index("/part_")]
                    scope_variables.add(var_name)
            for var_name in scope_variables:
                # Lookup name with specified prefix and suffix from current variable.
                # If tensor_name given is '/' (root), don't use it for full name.
                full_tensor_name = var_name[len(scopes):]
                if current_var_or_name != "/":
                    full_tensor_name = full_tensor_name[1:]
                if tensor_name_in_ckpt != "/":
                    full_tensor_name = tensor_name_in_ckpt + full_tensor_name
                if full_tensor_name not in variable_map:
                    raise ValueError(
                        "Tensor %s (%s in %s) is not found in %s checkpoint" %
                        (full_tensor_name, var_name[len(scopes) + 1:],
                         tensor_name_in_ckpt, ckpt_dir_or_file))
                var = store_vars.get(var_name, None)
                if var is None:
                    var = _collect_partitioned_variable(var_name, store_vars)
                _set_variable_or_list_initializer(var, ckpt_file,
                                                  full_tensor_name)
                logging.info(
                    "Initialize variable %s from checkpoint %s with %s",
                    var_name, ckpt_dir_or_file, full_tensor_name)
def get_variable(
    name,  # unique
    key_dtype=dtypes.int64,
    value_dtype=dtypes.float32,
    dim=1,
    devices=None,
    partitioner=default_partition_fn,
    shared_name="get_variable",
    initializer=None,
    trainable=True,
    checkpoint=True,
    init_size=0,
    kv_creator=None,
    restrict_policy=None,
    bp_v2=False,
):
    """Gets an `Variable` object with this name if it exists,
         or create a new one.

    Args:
      name: A unique name for the `Variable`.
      key_dtype: the type of the key tensors.
      value_dtype: the type of the value tensors.
      dim: the length of the value array for each key.
      devices: the list of devices holding the tables.
        One table will be created on each device.
      partitioner: partition function of keys,
        return the partition index for each key.

      Example partition func:
      ```python
      def default_partition_fn(keys, shard_num):
        return tf.cast(keys % shard_num, dtype=tf.int32)
      ```
      shared_name: No used.
      initializer: The value to use if a key is missing in the hash table.
        which can a python number, numpy array or `tf.initializer` instances.
        If initializer is `None` (the default), `0` will be used.
      trainable: Bool. If true, the variable will be treated as a trainable.
        Default is true.
      checkpoint: if True, the contents of the SparseVariable are
        saved to and restored from checkpoints.
        If `shared_name` is empty for a checkpointed table,
        it is shared using the table node name.
      init_size: initial size for the Variable and initial size of each hash
        tables will be int(init_size / N), N is the number of the devices.
      restrict_policy: a restrict policy to specify the rule to restrict the
        size of variable. If in training program, the variable is updated by
        optimizer, then the sparse slot variables in optimizer are also be
        restricted.
      bp_v2: By default with `bp_v2=False`, the optimizer will update
        dynamic embedding values by *setting* (key, value) after
        `optimizer.apply_gradient`. If one key is used by multiple workers
        at the same time, only one of them will be seen, while the others are
        overwritten. By setting `bp_v2=True`, the optimizer will update
        parameters by *adding delta* instead of *setting*, which solves the
        race condition problem among workers during backpropagation in
        large-scale distributed asynchronous training.

    Returns:
      A `Variable` object.
    """
    var_ = None
    scope = variable_scope.get_variable_scope()
    scope_store = variable_scope._get_default_variable_store()
    full_name = scope.name + "/" + name if scope.name else name
    if full_name in scope_store._vars:
        if scope.reuse is False:
            err_msg = ("Variable %s already exists, disallowed."
                       " Did you mean to set reuse=True or "
                       "reuse=tf.AUTO_REUSE in VarScope?" % full_name)

            raise ValueError(err_msg)
    else:
        var_ = Variable(
            key_dtype=key_dtype,
            value_dtype=value_dtype,
            dim=dim,
            devices=devices,
            partitioner=partitioner,
            shared_name=shared_name,
            name=full_name,
            initializer=initializer,
            trainable=trainable,
            checkpoint=checkpoint,
            init_size=init_size,
            kv_creator=kv_creator,
            restrict_policy=restrict_policy,
            bp_v2=bp_v2,
        )
        scope_store._vars[full_name] = var_
    return scope_store._vars[full_name]
예제 #31
0
def init_from_checkpoint(checkpoint_dir, assignment_map):
    """Using assignment map initializes current variables with loaded tensors.

  Note: This overrides default initialization ops of specified variables and
  redefines dtype.

  Assignment map supports following syntax:
    `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
      current `scope_name` from `checkpoint_scope_name` with matching variable
      names.
    `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
      will initialize `scope_name/variable_name` variable
      from `checkpoint_scope_name/some_other_variable`.
    `'scope_variable_name': variable` - will initialize given `tf.Variable`
      object with variable from the checkpoint.
    `'scope_variable_name': list(variable)` - will initialize list of
      partitioned variables with variable from the checkpoint.
    `'/': 'scope_name/'` - will load all variables in current `scope_name` from
      checkpoint's root (e.g. no scope).

  Supports loading into partitioned variables, which are represented as
  '<variable>/part_<part #>'.

  Example:
  ```python
    # Create variables.
    with tf.variable_scope('test'):
      m = tf.get_variable('my_var')
    with tf.variable_scope('test2'):
      var2 = tf.get_variable('my_var')
    var3 = tf.get_variable(name="my1", shape=[100, 100],
                           partitioner=lambda shape, dtype: [5, 1])
    ...
    # Specify which variables to intialize from checkpoint.
    init_from_checkpoint(checkpoint_dir, {
      'some_var': 'test/my_var',
      'some_scope/': 'test2/'})
    ...
    # Or use `Variable` objects to identify what to initialize.
    init_from_checkpoint(checkpoint_dir, {
      'some_scope/var2': var2,
    })
    # Initialize partitioned variables
    init_from_checkpoint(checkpoint_dir, {
      'some_var_from_ckpt': 'part_var',
    })
    # Or specifying the list of `Variable` objects.
    init_from_checkpoint(checkpoint_dir, {
      'some_var_from_ckpt': var3._get_variable_list(),
    })
    ...
    # Initialize variables as usual.
    session.run(tf.get_all_variables())
  ```

  Args:
    checkpoint_dir: Directory with checkpoints file or path to checkpoint.
    assignment_map: Dict, where keys are names of the variables in the
      checkpoint and values are current variables or names of current variables
      (in default graph).

  Raises:
    tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
    ValueError: If missing variables in current graph.
  """
    filepattern = _get_checkpoint_filename(checkpoint_dir)
    reader = load_checkpoint(checkpoint_dir)
    variable_map = reader.get_variable_to_shape_map()
    for tensor_name_in_ckpt, current_var_or_name in six.iteritems(
            assignment_map):
        var = None
        # Check if this is Variable object or list of Variable objects (in case of
        # partitioned variables).
        is_var = lambda x: isinstance(x, variables.Variable)
        if is_var(current_var_or_name) or (
                isinstance(current_var_or_name, list)
                and all(is_var(v) for v in current_var_or_name)):
            var = current_var_or_name
        else:
            var_scope = vs._get_default_variable_store()
            # Check if this variable is in var_store.
            var = var_scope._vars.get(current_var_or_name, None)
            # Also check if variable is partitioned as list.
            if var is None:
                var = _collect_partitioned_variable(current_var_or_name,
                                                    var_scope)
        if var is not None:
            # If 1 to 1 mapping was provided, find variable in the checkpoint.
            if tensor_name_in_ckpt not in variable_map:
                raise ValueError(
                    "Tensor %s is not found in %s checkpoint %s" %
                    (tensor_name_in_ckpt, checkpoint_dir, variable_map))
            if is_var(var):
                # Additional at-call-time checks.
                if not var.get_shape().is_compatible_with(
                        variable_map[tensor_name_in_ckpt]):
                    raise ValueError(
                        "Shape of variable %s (%s) doesn't match with shape of "
                        "tensor %s (%s) from checkpoint reader." %
                        (var.name, str(var.get_shape()), tensor_name_in_ckpt,
                         str(variable_map[tensor_name_in_ckpt])))
                var_name = var.name
            else:
                var_name = ",".join([v.name for v in var])
            _set_variable_or_list_initializer(var, filepattern,
                                              tensor_name_in_ckpt)
            logging.info("Initialize variable %s from checkpoint %s with %s" %
                         (var_name, checkpoint_dir, tensor_name_in_ckpt))
        else:
            scopes = ""
            # TODO(vihanjain): Support list of 'current_var_or_name' here.
            if "/" in current_var_or_name:
                scopes = current_var_or_name[:current_var_or_name.rindex("/")]
            if not tensor_name_in_ckpt.endswith("/"):
                raise ValueError(
                    "Assignment map with scope only name {} should map to scope only "
                    "{}. Should be 'scope/': 'other_scope/'.".format(
                        scopes, tensor_name_in_ckpt))
            # If scope to scope mapping was provided, find all variables in the scope
            # and create variable to variable mapping.
            scope_variables = set()
            for var_name in var_scope._vars:
                if var_name.startswith(scopes):
                    # Consume /part_ if partitioned variable.
                    if "/part_" in var_name:
                        var_name = var_name[:var_name.index("/part_")]
                    scope_variables.add(var_name)
            for var_name in scope_variables:
                # Lookup name with specified prefix and suffix from current variable.
                # If tensor_name given is '/' (root), don't use it for full name.
                full_tensor_name = var_name[len(scopes):]
                if current_var_or_name != "/":
                    full_tensor_name = full_tensor_name[1:]
                if tensor_name_in_ckpt != "/":
                    full_tensor_name = tensor_name_in_ckpt + full_tensor_name
                if full_tensor_name not in variable_map:
                    raise ValueError(
                        "Tensor %s (%s in %s) is not found in %s checkpoint" %
                        (full_tensor_name, var_name[len(scopes) + 1:],
                         tensor_name_in_ckpt, checkpoint_dir))
                var = var_scope._vars.get(var_name, None)
                if var is None:
                    var = _collect_partitioned_variable(var_name, var_scope)
                _set_variable_or_list_initializer(var, filepattern,
                                                  full_tensor_name)
                logging.info(
                    "Initialize variable %s from checkpoint %s with %s" %
                    (var_name, checkpoint_dir, full_tensor_name))
예제 #32
0
def _init_from_checkpoint(_, ckpt_dir_or_file, assignment_map):
  """See `init_from_checkpoint` for documentation."""

  ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file)
  reader = load_checkpoint(ckpt_dir_or_file)
  variable_map = reader.get_variable_to_shape_map()
  for tensor_name_in_ckpt, current_var_or_name in sorted(
      six.iteritems(assignment_map)):
    var = None
    # Check if this is Variable object or list of Variable objects (in case of
    # partitioned variables).
    if _is_variable(current_var_or_name) or (
        isinstance(current_var_or_name, list)
        and all(_is_variable(v) for v in current_var_or_name)):
      var = current_var_or_name
    else:
      store_vars = vs._get_default_variable_store()._vars  # pylint:disable=protected-access
      # Check if this variable is in var_store.
      var = store_vars.get(current_var_or_name, None)
      # Also check if variable is partitioned as list.
      if var is None:
        var = _collect_partitioned_variable(current_var_or_name, store_vars)
    if var is not None:
      # If 1 to 1 mapping was provided, find variable in the checkpoint.
      if tensor_name_in_ckpt not in variable_map:
        raise ValueError("Tensor %s is not found in %s checkpoint %s" % (
            tensor_name_in_ckpt, ckpt_dir_or_file, variable_map
        ))
      if _is_variable(var):
        # Additional at-call-time checks.
        if not var.get_shape().is_compatible_with(
            variable_map[tensor_name_in_ckpt]):
          raise ValueError(
              "Shape of variable %s (%s) doesn't match with shape of "
              "tensor %s (%s) from checkpoint reader." % (
                  var.name, str(var.get_shape()),
                  tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt])
              ))
        var_name = var.name
      else:
        var_name = ",".join([v.name for v in var])
      _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
      logging.debug("Initialize variable %s from checkpoint %s with %s",
                    var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
    else:
      scopes = ""
      # TODO(vihanjain): Support list of 'current_var_or_name' here.
      if "/" in current_var_or_name:
        scopes = current_var_or_name[:current_var_or_name.rindex("/")]
      if not tensor_name_in_ckpt.endswith("/"):
        raise ValueError(
            "Assignment map with scope only name {} should map to scope only "
            "{}. Should be 'scope/': 'other_scope/'.".format(
                scopes, tensor_name_in_ckpt))
      # If scope to scope mapping was provided, find all variables in the scope
      # and create variable to variable mapping.
      scope_variables = set()
      for var_name in store_vars:
        if not scopes or var_name.startswith(scopes + "/"):
          # Consume /part_ if partitioned variable.
          if "/part_" in var_name:
            var_name = var_name[:var_name.index("/part_")]
          scope_variables.add(var_name)
      for var_name in sorted(scope_variables):
        # Lookup name with specified prefix and suffix from current variable.
        # If tensor_name given is '/' (root), don't use it for full name.
        full_tensor_name = var_name[len(scopes):]
        if current_var_or_name != "/":
          full_tensor_name = full_tensor_name[1:]
        if tensor_name_in_ckpt != "/":
          full_tensor_name = tensor_name_in_ckpt + full_tensor_name
        # Remove trailing '/', if any, in the full_tensor_name
        if full_tensor_name.endswith("/"):
          full_tensor_name = full_tensor_name[:-1]
        if full_tensor_name not in variable_map:
          raise ValueError(
              "Tensor %s (%s in %s) is not found in %s checkpoint" % (
                  full_tensor_name, var_name[len(scopes) + 1:],
                  tensor_name_in_ckpt, ckpt_dir_or_file
              ))
        var = store_vars.get(var_name, None)
        if var is None:
          var = _collect_partitioned_variable(var_name, store_vars)
        _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name)
        logging.debug("Initialize variable %s from checkpoint %s with %s",
                      var_name, ckpt_dir_or_file, full_tensor_name)
예제 #33
0
def init_from_checkpoint(checkpoint_dir, assignment_map):
  """Using assingment map initializes current variables with loaded tensors.

  Note: This overrides default initialization ops of specified variables and
  redefines dtype.

  Assignment map supports following syntax:
    `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
      current `scope_name` from `checkpoint_scope_name` with matching variable
      names.
    `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
      will initalize `scope_name/variable_name` variable
      from `checkpoint_scope_name/some_other_variable`.
    `'scope_variable_name': variable` - will initialize given `tf.Variable`
      object with variable from the checkpoint.
    `'scope_variable_name': list(variable)` - will initialize list of
      partitioned variables with variable from the checkpoint.
    `'/': 'scope_name/'` - will load all variables in current `scope_name` from
      checkpoint's root (e.g. no scope).

  Supports loading into partitioned variables, which are represented as
  '<variable>/part_<part #>'.

  Example:
  ```python
    # Create variables.
    with tf.variable_scope('test'):
      m = tf.get_variable('my_var')
    with tf.variable_scope('test2'):
      var2 = tf.get_variable('my_var')
    var3 = tf.get_variable(name="my1", shape=[100, 100],
                           partitioner=lambda shape, dtype: [5, 1])
    ...
    # Specify which variables to intialize from checkpoint.
    init_from_checkpoint(checkpoint_dir, {
      'some_var': 'test/my_var',
      'some_scope/': 'test2/'})
    ...
    # Or use `Variable` objects to identify what to initialize.
    init_from_checkpoint(checkpoint_dir, {
      'some_scope/var2': var2,
    })
    # Initialize partitioned variables
    init_from_checkpoint(checkpoint_dir, {
      'some_var_from_ckpt': 'part_var',
    })
    # Or specifying the list of `Variable` objects.
    init_from_checkpoint(checkpoint_dir, {
      'some_var_from_ckpt': var3._get_variable_list(),
    })
    ...
    # Initialize variables as usual.
    session.run(tf.get_all_variables())
  ```

  Args:
    checkpoint_dir: Directory with checkpoints file or path to checkpoint.
    assignment_map: Dict, where keys are names of the variables in the
      checkpoint and values are current variables or names of current variables
      (in default graph).

  Raises:
    tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
    ValueError: If missing variables in current graph.
  """
  filepattern = _get_checkpoint_filename(checkpoint_dir)
  reader = load_checkpoint(checkpoint_dir)
  variable_map = reader.get_variable_to_shape_map()
  for tensor_name_in_ckpt, current_var_or_name in six.iteritems(assignment_map):
    var = None
    # Check if this is Variable object or list of Variable objects (in case of
    # partitioned variables).
    is_var = lambda x: isinstance(x, variables.Variable)
    if is_var(current_var_or_name) or (
        isinstance(current_var_or_name, list)
        and all(is_var(v) for v in current_var_or_name)):
      var = current_var_or_name
    else:
      var_scope = vs._get_default_variable_store()
      # Check if this variable is in var_store.
      var = var_scope._vars.get(current_var_or_name, None)
      # Also check if variable is partitioned as list.
      if var is None:
        var = _collect_partitioned_variable(current_var_or_name, var_scope)
    if var is not None:
      # If 1 to 1 mapping was provided, find variable in the checkpoint.
      if tensor_name_in_ckpt not in variable_map:
        raise ValueError("Tensor %s is not found in %s checkpoint %s" % (
            tensor_name_in_ckpt, checkpoint_dir, variable_map
        ))
      if is_var(var):
        # Additional at-call-time checks.
        if not var.get_shape().is_compatible_with(
            variable_map[tensor_name_in_ckpt]):
          raise ValueError(
              "Shape of variable %s (%s) doesn't match with shape of "
              "tensor %s (%s) from checkpoint reader." % (
                  var.name, str(var.get_shape()),
                  tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt])
              ))
        var_name = var.name
      else:
        var_name = ",".join([v.name for v in var])
      _set_variable_or_list_initializer(var, filepattern, tensor_name_in_ckpt)
      logging.info("Initialize variable %s from checkpoint %s with %s" % (
          var_name, checkpoint_dir, tensor_name_in_ckpt
      ))
    else:
      scopes = ""
      # TODO(vihanjain): Support list of 'current_var_or_name' here.
      if "/" in current_var_or_name:
        scopes = current_var_or_name[:current_var_or_name.rindex("/")]
      if not tensor_name_in_ckpt.endswith("/"):
        raise ValueError(
            "Assignment map with scope only name {} should map to scope only "
            "{}. Should be 'scope/': 'other_scope/'.".format(
                scopes, tensor_name_in_ckpt))
      # If scope to scope mapping was provided, find all variables in the scope
      # and create variable to variable mapping.
      scope_variables = set()
      for var_name in var_scope._vars:
        if var_name.startswith(scopes):
          # Consume /part_ if partitioned variable.
          if "/part_" in var_name:
            var_name = var_name[:var_name.index("/part_")]
          scope_variables.add(var_name)
      for var_name in scope_variables:
        # Lookup name with specified prefix and suffix from current variable.
        # If tensor_name given is '/' (root), don't use it for full name.
        if tensor_name_in_ckpt != "/":
          full_tensor_name = tensor_name_in_ckpt + var_name[len(scopes) + 1:]
        else:
          full_tensor_name = var_name[len(scopes) + 1:]
        if full_tensor_name not in variable_map:
          raise ValueError(
              "Tensor %s (%s in %s) is not found in %s checkpoint" % (
                  full_tensor_name, var_name[len(scopes) + 1:],
                  tensor_name_in_ckpt, checkpoint_dir
              ))
        var = var_scope._vars.get(var_name, None)
        if var is None:
          var = _collect_partitioned_variable(var_name, var_scope)
        _set_variable_or_list_initializer(var, filepattern, full_tensor_name)
        logging.info("Initialize variable %s from checkpoint %s with %s" % (
            var_name, checkpoint_dir, full_tensor_name
        ))
예제 #34
0
def get_variable(
        name,  # unique,
        embedding_dim,
        key_dtype=dtypes.int64,
        value_dtype=dtypes.float32,
        initializer=None,
        regularizer=None,
        reuse=None,
        trainable=True,
        collections=None,
        caching_device=None,
        partitioner=None,
        validate_shape=True,
        constraint=None):
    if key_dtype == dtypes.int64 or key_dtype == dtypes.int32:
        invalid_key = -1
    else:
        raise ValueError(
            "Not support key_dtype: %s, only support int64/int32" % key_dtype)
    if initializer is None:
        initializer = init_ops.truncated_normal_initializer()

    scope = variable_scope.get_variable_scope()
    scope_store = variable_scope._get_default_variable_store()

    if regularizer is None:
        regularizer = scope._regularizer
    if caching_device is None:
        caching_device = scope._caching_device
    if partitioner is None:
        partitioner = scope._partitioner
    if not context.executing_eagerly():
        if reuse is None:
            reuse = scope._reuse
    else:
        reuse = AUTO_REUSE

    full_name = scope.name + "/" + name if scope.name else name
    # Variable names only depend on variable_scope (full_name here),
    # not name_scope, so we reset it below for the time of variable creation.
    with ops.name_scope(None):
        dtype = value_dtype
        # Check that `initializer` dtype and `dtype` are consistent before
        # replacing them with defaults.
        if (dtype is not None and initializer is not None
                and not callable(initializer)):
            init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype
            if init_dtype != dtype:
                raise ValueError(
                    "Initializer type '%s' and explicit dtype '%s' "
                    "don't match." % (init_dtype, dtype))
        if initializer is None:
            initializer = scope._initializer
        if constraint is None:
            constraint = scope._constraint
        if dtype is None:
            dtype = scope._dtype
        if invalid_key is None:
            invalid_key = -1
        ev_store = _EmbeddingVariableStore(scope_store)
        return ev_store.get_variable(full_name,
                                     shape=embedding_dim,
                                     dtype=value_dtype,
                                     ktype=key_dtype,
                                     initializer=initializer,
                                     regularizer=regularizer,
                                     reuse=reuse,
                                     trainable=trainable,
                                     collections=collections,
                                     caching_device=caching_device,
                                     partitioner=partitioner,
                                     validate_shape=True,
                                     constraint=None)