Exemple #1
0
    def define_local(self, name, value):
        if name in self._locals:
            raise Exception("Local already defined: %s" % name)

        should_wrap_in_var = False
        if self._wrap_locals_in_vars:
            if isinstance(value, tf.Tensor):
                should_wrap_in_var = True

            # HACK(adamb) Unwrapping in here really isn't great, since auto-unwrapping can create unexpected behavior.
            if isinstance(value, RetvalBag) and value.len() == 1:
                if isinstance(value.get(None), tf.Tensor):
                    should_wrap_in_var = True
                    value = value.get(None)

        if should_wrap_in_var:
            variable = state_ops.variable_op_v2(value.get_shape(),
                                                value.dtype.base_dtype)

            with tf.control_dependencies(None):
                value = tf.identity(
                    tf.cond(tf.is_variable_initialized(variable),
                            lambda: variable,
                            lambda: tf.assign(variable, value)))
            print("value", value)

        self._locals[name] = value
        return value
Exemple #2
0
  def _init_from_args(self,
                      initial_value=None,
                      trainable=True,
                      collections=None,
                      validate_shape=True,
                      caching_device=None,
                      name=None,
                      dtype=None,
                      expected_shape=None):
    """Creates a new variable from arguments.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions  from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
      validate_shape: If `False`, allows the variable to be initialized with a
        value of unknown shape. If `True`, the default, the shape of
        `initial_value` must be known.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      expected_shape: Deprecated. Ignored.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
    """
    _ = expected_shape
    if initial_value is None:
      raise ValueError("initial_value must be specified.")
    init_from_fn = callable(initial_value)

    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    if not isinstance(collections, (list, tuple, set)):
      raise ValueError(
          "collections argument to Variable constructor must be a list, tuple, "
          "or set. Got %s of type %s" % (collections, type(collections)))
    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
    with ops.control_dependencies(None):
      with ops.name_scope(name, "Variable", [] if init_from_fn else
                          [initial_value]) as name:

        if init_from_fn:
          # Use attr_scope and device(None) to simulate the behavior of
          # colocate_with when the variable we want to colocate with doesn't
          # yet exist.
          true_name = ops._name_from_scope_name(name)
          attr = attr_value_pb2.AttrValue(
              list=attr_value_pb2.AttrValue.ListValue(
                  s=[compat.as_bytes("loc:@%s" % true_name)]))
          # pylint: disable=protected-access
          with ops.get_default_graph()._attr_scope({"_class": attr}):
            with ops.name_scope("Initializer"),  ops.device(None):
              self._initial_value = ops.convert_to_tensor(
                  initial_value(), name="initial_value", dtype=dtype)
              shape = (self._initial_value.get_shape()
                       if validate_shape else tensor_shape.unknown_shape())
            self._variable = state_ops.variable_op_v2(
                shape,
                self._initial_value.dtype.base_dtype,
                name=name)

        # Or get the initial value from a Tensor or Python object.
        else:
          self._initial_value = ops.convert_to_tensor(
              initial_value, name="initial_value", dtype=dtype)
          shape = (self._initial_value.get_shape()
                   if validate_shape else tensor_shape.unknown_shape())
          # In this case, the variable op can't be created until after the
          # initial_value has been converted to a Tensor with a known type.
          self._variable = state_ops.variable_op_v2(
              shape,
              self._initial_value.dtype.base_dtype,
              name=name)

        # Manually overrides the variable's shape with the initial value's.
        if validate_shape:
          initial_value_shape = self._initial_value.get_shape()
          if not initial_value_shape.is_fully_defined():
            raise ValueError("initial_value must have a shape specified: %s" %
                             self._initial_value)

        # Assigns initial value.
        self._initializer_op = state_ops.assign(
            self._variable, self._initial_value,
            validate_shape=validate_shape).op

        # TODO(vrv): Change this class to not take caching_device, but
        # to take the op to colocate the snapshot with, so we can use
        # colocation rather than devices.
        if caching_device is not None:
          with ops.device(caching_device):
            self._snapshot = array_ops.identity(self._variable, name="read")
        else:
          with ops.colocate_with(self._variable.op):
            self._snapshot = array_ops.identity(self._variable, name="read")

    ops.add_to_collections(collections, self)
    self._caching_device = caching_device
    self._save_slice_info = None
    def create_op(self, *args, **kwargs):
        """Creates an `Operation`.

    For operations of the following form

      orig_value = op(*args, **kwargs)

    this function constructs the following subgraph :

      v = Variable()
      if v is not initialized:
        orig_value = op(*args, **kwargs)
        v.assign(orig_value) # Initializes v
        return orig_value
      else:
        return v

    The above transformation is not performed and the original op is returned
    as is if any of the following is true:
    * `_return_as_is` flag is set to true.
    * op_type is listed in _PASS_THROUGH_OPS
    * op has no outputs.
    * One of the op's return value has a ref type.

    Args:
      *args: Arguments for create_op()
      **kwargs: Keyword arguments for create_op(). Refer to
        tensorflow.python.framework.ops.Graph.create_op() for the mandatory
        and optional arguments.

    Returns:
      An Operation.

    Raises:
      UnimplementedError: if output type is a reference and the op's type
        is not one of the supported types in `_REF_OPS_WHITELIST`.
    """
        op_type = kwargs['op_type'] if 'op_type' in kwargs else args[0]
        output_dtypes = kwargs['dtypes'] if 'dtypes' in kwargs else args[2]
        output_dtypes = [dtypes.as_dtype(d) for d in output_dtypes]

        if self._return_as_is or op_type in _PASS_THROUGH_OPS:
            return self._wrap(
                super(ImperativeGraph, self).create_op(*args, **kwargs))

        if not output_dtypes:
            return self._wrap(
                super(ImperativeGraph, self).create_op(*args, **kwargs))

        output_has_ref = any([dtype._is_ref_dtype for dtype in output_dtypes])  # pylint: disable=protected-access

        if output_has_ref:
            if op_type not in _REF_OPS_WHITELIST:
                raise errors.UnimplementedError(
                    None, None, op_type + ' op not supported in '
                    'imperative graph')

            ret = super(ImperativeGraph, self).create_op(*args, **kwargs)

            if self._in_variable_creation:
                if op_type == 'Assign':
                    self.add_pending_init(ret)

            return self._wrap(ret)

        with self.return_as_is():
            # Declares the variables to hold the output values of this op.
            op_output_var = [
                state_ops.variable_op_v2(tensor_shape.TensorShape(None),
                                         dtype,
                                         container=self._name)
                for dtype in output_dtypes
            ]
            # Ops to free the resources used by the temporary cache variables.
            # The following two ops are created for each cache variable,
            # having no control dependencies on any other ops :
            # var_handle_op ----> destroy_resource_op
            for dtype, v in zip(output_dtypes, op_output_var):
                with ops.control_dependencies(None):
                    self._variable_cleanup_ops += [
                        gen_resource_variable_ops.destroy_resource_op(
                            gen_resource_variable_ops.var_handle_op(
                                dtype,
                                tensor_shape.TensorShape(None),
                                container=self._name,
                                shared_name=v.op.name),
                            ignore_lookup_error=True)
                    ]

            # Create the conditional to run the original op only when the variable
            # corresponding to the first output is not initialized.
            inited = state_ops.is_variable_initialized(op_output_var[0])
            v_f, v_t = control_flow_ops.ref_switch(op_output_var[0], inited)
            # pylint: disable=protected-access
            v_f_op = gen_array_ops._ref_identity(v_f)
            v_t_op = gen_array_ops._ref_identity(v_t)
            # pylint: enable=protected-access

            with ops.control_dependencies([v_f_op.op]):
                # Create the original op
                orig_op = self._wrap(
                    super(ImperativeGraph, self).create_op(*args, **kwargs))
            shapes = [val.get_shape() for val in orig_op.outputs]

            controls = []
            for var, val in zip(op_output_var, orig_op.outputs):
                if (not val.get_shape().is_fully_defined()
                        or val.get_shape().num_elements() > 0):
                    assign_op = state_ops.assign(var,
                                                 val,
                                                 validate_shape=False)
                    assign_op.set_shape(val.get_shape())
                    controls.append(assign_op)

            values = []
            if len(controls) > 1:
                if control_flow_ops.IsSwitch(orig_op):
                    # pylint: disable=protected-access
                    controls = gen_control_flow_ops._ref_merge(controls)
                    # pylint: enable=protected-access
                else:
                    controls = control_flow_ops.tuple(controls)

            for var, val in zip(op_output_var, orig_op.outputs):
                with ops.control_dependencies(controls):
                    with self.colocate_with(v_f_op):
                        real_val = array_ops.identity(val)
                with ops.control_dependencies([v_t_op.op]):
                    with self.colocate_with(v_t_op):
                        stored_val = array_ops.identity(var)
                    stored_val.set_shape(val.get_shape())
                    real_val, _ = control_flow_ops.merge(
                        [real_val, stored_val])
                real_val.op.node_def.attr['_gradient_op_type'].CopyFrom(
                    attr_value_pb2.AttrValue(
                        s=compat.as_bytes(self._merge_op_type)))
                values.append(real_val)

            for i, _ in enumerate(shapes):
                values[i].set_shape(shapes[i])
            self._outputs_map[orig_op.name] = values
            try:
                self._gradient_function_map[
                    orig_op.name] = ops.get_gradient_function(orig_op)
            except (KeyError, LookupError):
                pass
            else:
                orig_op.node_def.attr['_gradient_op_type'].CopyFrom(
                    attr_value_pb2.AttrValue(
                        s=compat.as_bytes(self._imperative_op_type)))

            return MultiOutputOperation(values)
  def create_op(self, *args, **kwargs):
    """Creates an `Operation`.

    For operations of the following form

      orig_value = op(*args, **kwargs)

    this function constructs the following subgraph :

      v = Variable()
      if v is not initialized:
        orig_value = op(*args, **kwargs)
        v.assign(orig_value) # Initializes v
        return orig_value
      else:
        return v

    The above transformation is not performed and the original op is returned
    as is if any of the following is true:
    * `_return_as_is` flag is set to true.
    * op_type is listed in _PASS_THROUGH_OPS
    * op has no outputs.
    * One of the op's return value has a ref type.

    Args:
      *args: Arguments for create_op()
      **kwargs: Keyword arguments for create_op(). Refer to
        tensorflow.python.framework.ops.Graph.create_op() for the mandatory
        and optional arguments.

    Returns:
      An Operation.

    Raises:
      UnimplementedError: if output type is a reference and the op's type
        is not one of the supported types in `_REF_OPS_WHITELIST`.
    """
    op_type = kwargs['op_type'] if 'op_type' in kwargs else args[0]
    output_dtypes = kwargs['dtypes'] if 'dtypes' in kwargs else args[2]
    output_dtypes = [dtypes.as_dtype(d) for d in output_dtypes]

    if self._return_as_is or op_type in _PASS_THROUGH_OPS:
      return self._wrap(super(ImperativeGraph, self).create_op(*args, **kwargs))

    if not output_dtypes:
      return self._wrap(
          super(ImperativeGraph, self).create_op(*args, **kwargs))

    output_has_ref = any([dtype._is_ref_dtype for dtype in output_dtypes])  # pylint: disable=protected-access

    if output_has_ref:
      if op_type not in _REF_OPS_WHITELIST:
        raise errors.UnimplementedError(None, None,
                                        op_type + ' op not supported in '
                                        'imperative graph')

      ret = super(ImperativeGraph, self).create_op(*args, **kwargs)

      if self._in_variable_creation:
        if op_type == 'Assign':
          self.add_pending_init(ret)

      return self._wrap(ret)

    with self.return_as_is():
      # Declares the variables to hold the output values of this op.
      op_output_var = [state_ops.variable_op_v2(
          tensor_shape.TensorShape(None), dtype, container=self._name)
                       for dtype in output_dtypes]
      # Ops to free the resources used by the temporary cache variables.
      # The following two ops are created for each cache variable,
      # having no control dependencies on any other ops :
      # var_handle_op ----> destroy_resource_op
      for dtype, v in zip(output_dtypes, op_output_var):
        with ops.control_dependencies(None):
          self._variable_cleanup_ops += [
              gen_resource_variable_ops.destroy_resource_op(
                  gen_resource_variable_ops.var_handle_op(
                      dtype, tensor_shape.TensorShape(None),
                      container=self._name, shared_name=v.op.name),
                  ignore_lookup_error=True)]

      # Create the conditional to run the original op only when the variable
      # corresponding to the first output is not initialized.
      inited = state_ops.is_variable_initialized(op_output_var[0])
      v_f, v_t = control_flow_ops.ref_switch(op_output_var[0], inited)
      # pylint: disable=protected-access
      v_f_op = gen_array_ops._ref_identity(v_f)
      v_t_op = gen_array_ops._ref_identity(v_t)
      # pylint: enable=protected-access

      with ops.control_dependencies([v_f_op.op]):
        # Create the original op
        orig_op = self._wrap(
            super(ImperativeGraph, self).create_op(*args, **kwargs))
      shapes = [val.get_shape() for val in orig_op.outputs]

      controls = []
      for var, val in zip(op_output_var, orig_op.outputs):
        if (not val.get_shape().is_fully_defined() or
            val.get_shape().num_elements() > 0):
          assign_op = state_ops.assign(var, val, validate_shape=False)
          assign_op.set_shape(val.get_shape())
          controls.append(assign_op)

      values = []
      if len(controls) > 1:
        if control_flow_ops.IsSwitch(orig_op):
          # pylint: disable=protected-access
          controls = gen_control_flow_ops._ref_merge(controls)
          # pylint: enable=protected-access
        else:
          controls = control_flow_ops.tuple(controls)

      for var, val in zip(op_output_var, orig_op.outputs):
        with ops.control_dependencies(controls):
          with self.colocate_with(v_f_op):
            real_val = array_ops.identity(val)
        with ops.control_dependencies([v_t_op.op]):
          with self.colocate_with(v_t_op):
            stored_val = array_ops.identity(var)
          stored_val.set_shape(val.get_shape())
          real_val, _ = control_flow_ops.merge([real_val, stored_val])
        real_val.op.node_def.attr['_gradient_op_type'].CopyFrom(
            attr_value_pb2.AttrValue(s=compat.as_bytes(self._merge_op_type)))
        values.append(real_val)

      for i, _ in enumerate(shapes):
        values[i].set_shape(shapes[i])
      self._outputs_map[orig_op.name] = values
      try:
        self._gradient_function_map[orig_op.name] = ops.get_gradient_function(
            orig_op)
      except (KeyError, LookupError):
        pass
      else:
        orig_op.node_def.attr['_gradient_op_type'].CopyFrom(
            attr_value_pb2.AttrValue(
                s=compat.as_bytes(self._imperative_op_type)))

      return MultiOutputOperation(values, orig_op)
  def _init_from_args(self,
                      initial_value=None,
                      trainable=True,
                      collections=None,
                      validate_shape=True,
                      caching_device=None,
                      name=None,
                      dtype=None,
                      expected_shape=None):
    """Creates a new variable from arguments.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions  from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
      validate_shape: If `False`, allows the variable to be initialized with a
        value of unknown shape. If `True`, the default, the shape of
        `initial_value` must be known.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      expected_shape: Deprecated. Ignored.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
    """
    _ = expected_shape
    if initial_value is None:
      raise ValueError("initial_value must be specified.")
    init_from_fn = callable(initial_value)

    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    if not isinstance(collections, (list, tuple, set)):
      raise ValueError(
          "collections argument to Variable constructor must be a list, tuple, "
          "or set. Got %s of type %s" % (collections, type(collections)))
    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
    with ops.control_dependencies(None):
      with ops.name_scope(name, "Variable", [] if init_from_fn else
                          [initial_value]) as name:

        if init_from_fn:
          # Use attr_scope and device(None) to simulate the behavior of
          # colocate_with when the variable we want to colocate with doesn't
          # yet exist.
          true_name = ops._name_from_scope_name(name)
          attr = attr_value_pb2.AttrValue(
              list=attr_value_pb2.AttrValue.ListValue(
                  s=[compat.as_bytes("loc:@%s" % true_name)]))
          # pylint: disable=protected-access
          with ops.get_default_graph()._attr_scope({"_class": attr}):
            with ops.name_scope("Initializer"),  ops.device(None):
              self._initial_value = ops.convert_to_tensor(
                  initial_value(), name="initial_value", dtype=dtype)
              shape = (self._initial_value.get_shape()
                       if validate_shape else tensor_shape.unknown_shape())
            self._variable = state_ops.variable_op_v2(
                shape,
                self._initial_value.dtype.base_dtype,
                name=name)

        # Or get the initial value from a Tensor or Python object.
        else:
          self._initial_value = ops.convert_to_tensor(
              initial_value, name="initial_value", dtype=dtype)
          shape = (self._initial_value.get_shape()
                   if validate_shape else tensor_shape.unknown_shape())
          # In this case, the variable op can't be created until after the
          # initial_value has been converted to a Tensor with a known type.
          self._variable = state_ops.variable_op_v2(
              shape,
              self._initial_value.dtype.base_dtype,
              name=name)

        # Manually overrides the variable's shape with the initial value's.
        if validate_shape:
          initial_value_shape = self._initial_value.get_shape()
          if not initial_value_shape.is_fully_defined():
            raise ValueError("initial_value must have a shape specified: %s" %
                             self._initial_value)

        # Assigns initial value.
        self._initializer_op = state_ops.assign(
            self._variable, self._initial_value,
            validate_shape=validate_shape).op

        # TODO(vrv): Change this class to not take caching_device, but
        # to take the op to colocate the snapshot with, so we can use
        # colocation rather than devices.
        if caching_device is not None:
          with ops.device(caching_device):
            self._snapshot = array_ops.identity(self._variable, name="read")
        else:
          with ops.colocate_with(self._variable.op):
            self._snapshot = array_ops.identity(self._variable, name="read")

    ops.add_to_collections(collections, self)
    self._caching_device = caching_device
    self._save_slice_info = None
Exemple #6
0
  def _init_from_args(self,
                      initial_value=None,
                      trainable=True,
                      collections=None,
                      validate_shape=True,
                      caching_device=None,
                      name=None,
                      dtype=None,
                      expected_shape=None,
                      constraint=None):
    _ = expected_shape
    if initial_value is None:
      raise ValueError("initial_value must be specified.")
    init_from_fn = callable(initial_value)

    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    if not isinstance(collections, (list, tuple, set)):
      raise ValueError(
          "collections argument to Variable constructor must be a list, tuple, "
          "or set. Got %s of type %s" % (collections, type(collections)))
    if constraint is not None and not callable(constraint):
      raise ValueError("The `constraint` argument must be a callable.")

    # Store the graph key so optimizers know how to only retrieve variables from
    # this graph.
    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
    if isinstance(initial_value, checkpointable.CheckpointInitialValue):
      self._maybe_initialize_checkpointable()
      self._update_uid = initial_value.checkpoint_position.restore_uid
      initial_value = initial_value.wrapped_value

    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
    with ops.init_scope():
      # Ensure that we weren't lifted into the eager context.
      if context.executing_eagerly():
        raise RuntimeError(
            "tf.Variable not supported when eager execution is enabled. "
            "Please use tf.contrib.eager.Variable instead")
      with ops.name_scope(name, "Variable", [] if init_from_fn else
                          [initial_value]) as name:

        if init_from_fn:
          # Use attr_scope and device(None) to simulate the behavior of
          # colocate_with when the variable we want to colocate with doesn't
          # yet exist.
          true_name = ops._name_from_scope_name(name)  # pylint: disable=protected-access
          attr = attr_value_pb2.AttrValue(
              list=attr_value_pb2.AttrValue.ListValue(
                  s=[compat.as_bytes("loc:@%s" % true_name)]))
          # pylint: disable=protected-access
          with ops.get_default_graph()._attr_scope({"_class": attr}):
            with ops.name_scope("Initializer"), ops.device(None):
              self._initial_value = ops.convert_to_tensor(
                  initial_value(), name="initial_value", dtype=dtype)
              shape = (self._initial_value.get_shape()
                       if validate_shape else tensor_shape.unknown_shape())
            self._variable = state_ops.variable_op_v2(
                shape,
                self._initial_value.dtype.base_dtype,
                name=name)
          # pylint: enable=protected-access

        # Or get the initial value from a Tensor or Python object.
        else:
          self._initial_value = ops.convert_to_tensor(
              initial_value, name="initial_value", dtype=dtype)
          # pylint: disable=protected-access
          if self._initial_value.op._get_control_flow_context() is not None:
            raise ValueError(
                "Initializer for variable %s is from inside a control-flow "
                "construct, such as a loop or conditional. When creating a "
                "variable inside a loop or conditional, use a lambda as the "
                "initializer." % name)
          # pylint: enable=protected-access
          shape = (self._initial_value.get_shape()
                   if validate_shape else tensor_shape.unknown_shape())
          # In this case, the variable op can't be created until after the
          # initial_value has been converted to a Tensor with a known type.
          self._variable = state_ops.variable_op_v2(
              shape,
              self._initial_value.dtype.base_dtype,
              name=name)

        # Manually overrides the variable's shape with the initial value's.
        if validate_shape:
          initial_value_shape = self._initial_value.get_shape()
          if not initial_value_shape.is_fully_defined():
            raise ValueError("initial_value must have a shape specified: %s" %
                             self._initial_value)

        # If 'initial_value' makes use of other variables, make sure we don't
        # have an issue if these other variables aren't initialized first by
        # using their initialized_value() method.
        self._initializer_op = state_ops.assign(
            self._variable,
            self._try_guard_against_uninitialized_dependencies(
                self._initial_value),
            validate_shape=validate_shape).op

        # TODO(vrv): Change this class to not take caching_device, but
        # to take the op to colocate the snapshot with, so we can use
        # colocation rather than devices.
        if caching_device is not None:
          with ops.device(caching_device):
            self._snapshot = array_ops.identity(self._variable, name="read")
        else:
          with ops.colocate_with(self._variable.op):
            self._snapshot = array_ops.identity(self._variable, name="read")
      ops.add_to_collections(collections, self)

    self._caching_device = caching_device
    self._save_slice_info = None
    self._constraint = constraint