コード例 #1
0
def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
                shared_name=""):
  """Create a variable Operation.

  See also variables.Variable.

  Args:
    shape: The shape of the tensor managed by this variable
    dtype: The underlying type of the tensor values.
    name: optional name to use for the variable op.
    set_shape: If True, set the shape property of the returned Tensor to
      the shape argument.
    container: An optional string. Defaults to "".
      If non-empty, this variable is placed in the given container.
      Otherwise, a default container is used.
    shared_name: An optional string. Defaults to "".
      If non-empty, this variable is named in the given bucket
      with this shared_name. Otherwise, the node name is used instead.

  Returns:
    A variable tensor.
  """
  ret = gen_state_ops._variable(shape=shape, dtype=dtype, name=name,
                                container=container, shared_name=shared_name)
  # TODO(mrry): Move this to where it is used, so we can get rid of this op
  #   wrapper?
  if set_shape:
    ret.set_shape(shape)
  return ret
コード例 #2
0
 def testUnknown(self):
     tf_val = gen_state_ops._variable(shape=[3, 4, 7],
                                      dtype=tf.float32,
                                      name="tf_val",
                                      container="",
                                      shared_name="")
     self.assertIs(None, tf.contrib.util.constant_value(tf_val))
コード例 #3
0
ファイル: state_ops.py プロジェクト: brchiu/tensorflow
def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
                shared_name=""):
  """Create a variable Operation.

  See also variables.Variable.

  Args:
    shape: The shape of the tensor managed by this variable
    dtype: The underlying type of the tensor values.
    name: optional name to use for the variable op.
    set_shape: If True, set the shape property of the returned Tensor to
      the shape argument.
    container: An optional string. Defaults to "".
      If non-empty, this variable is placed in the given container.
      Otherwise, a default container is used.
    shared_name: An optional string. Defaults to "".
      If non-empty, this variable is named in the given bucket
      with this shared_name. Otherwise, the node name is used instead.

  Returns:
    A variable tensor.
  """
  if not set_shape:
    shape = tensor_shape.unknown_shape()
  ret = gen_state_ops._variable(shape=shape, dtype=dtype, name=name,
                                container=container, shared_name=shared_name)
  # TODO(mrry): Move this to where it is used, so we can get rid of this op
  #   wrapper?
  if set_shape:
    ret.set_shape(shape)
  return ret
コード例 #4
0
 def testUnknown(self):
   tf_val = gen_state_ops._variable(
       shape=[3, 4, 7],
       dtype=dtypes.float32,
       name="tf_val",
       container="",
       shared_name="")
   self.assertIs(None, tensor_util.constant_value(tf_val))
コード例 #5
0
ファイル: graph_util_test.py プロジェクト: zqsunny/tensorflow
  def testTwoDeviceFunctions(self):
    with ops.Graph().as_default() as g:
      var_0 = gen_state_ops._variable(shape=[1], dtype=dtypes.float32, 
          name="var_0", container="", shared_name="")
      with g.device(test_device_func_pin_variable_to_cpu):
        var_1 = gen_state_ops._variable(shape=[1], dtype=dtypes.float32, 
            name="var_1", container="", shared_name="")
      var_2 = gen_state_ops._variable(shape=[1], dtype=dtypes.float32, 
          name="var_2", container="", shared_name="")
      var_3 = gen_state_ops._variable(shape=[1], dtype=dtypes.float32, 
          name="var_3", container="", shared_name="")
      with g.device(test_device_func_pin_variable_to_cpu):
        var_4 = gen_state_ops._variable(shape=[1], dtype=dtypes.float32, 
            name="var_4", container="", shared_name="")
        with g.device("/device:GPU:0"):
          var_5 = gen_state_ops._variable(shape=[1], dtype=dtypes.float32, 
              name="var_5", container="", shared_name="")
        var_6 = gen_state_ops._variable(shape=[1], dtype=dtypes.float32, 
            name="var_6", container="", shared_name="")

    self.assertDeviceEqual(var_0.device, None)
    self.assertDeviceEqual(var_1.device, "/device:CPU:0")
    self.assertDeviceEqual(var_2.device, None)
    self.assertDeviceEqual(var_3.device, None)
    self.assertDeviceEqual(var_4.device, "/device:CPU:0")
    self.assertDeviceEqual(var_5.device, "/device:GPU:0")
    self.assertDeviceEqual(var_6.device, "/device:CPU:0")
コード例 #6
0
  def testTwoDeviceFunctions(self):
    with ops.Graph().as_default() as g:
      var_0 = gen_state_ops._variable(shape=[1], dtype=dtypes.float32, 
          name="var_0", container="", shared_name="")
      with g.device(test_device_func_pin_variable_to_cpu):
        var_1 = gen_state_ops._variable(shape=[1], dtype=dtypes.float32, 
            name="var_1", container="", shared_name="")
      var_2 = gen_state_ops._variable(shape=[1], dtype=dtypes.float32, 
          name="var_2", container="", shared_name="")
      var_3 = gen_state_ops._variable(shape=[1], dtype=dtypes.float32, 
          name="var_3", container="", shared_name="")
      with g.device(test_device_func_pin_variable_to_cpu):
        var_4 = gen_state_ops._variable(shape=[1], dtype=dtypes.float32, 
            name="var_4", container="", shared_name="")
        with g.device("/device:GPU:0"):
          var_5 = gen_state_ops._variable(shape=[1], dtype=dtypes.float32, 
              name="var_5", container="", shared_name="")
        var_6 = gen_state_ops._variable(shape=[1], dtype=dtypes.float32, 
            name="var_6", container="", shared_name="")

    self.assertDeviceEqual(var_0.device, None)
    self.assertDeviceEqual(var_1.device, "/device:CPU:0")
    self.assertDeviceEqual(var_2.device, None)
    self.assertDeviceEqual(var_3.device, None)
    self.assertDeviceEqual(var_4.device, "/device:CPU:0")
    self.assertDeviceEqual(var_5.device, "/device:GPU:0")
    self.assertDeviceEqual(var_6.device, "/device:CPU:0")
コード例 #7
0
ファイル: state_ops.py プロジェクト: chdinh/tensorflow
def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
                shared_name=""):
  """Deprecated. Used variable_op_v2 instead."""
  if not set_shape:
    shape = tensor_shape.unknown_shape()
  ret = gen_state_ops._variable(shape=shape, dtype=dtype, name=name,
                                container=container, shared_name=shared_name)
  # TODO(mrry): Move this to where it is used, so we can get rid of this op
  #   wrapper?
  if set_shape:
    ret.set_shape(shape)
  return ret
コード例 #8
0
ファイル: state_ops.py プロジェクト: sanketg10/tensorflow
def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
                shared_name=""):
  """Deprecated. Used variable_op_v2 instead."""
  if not set_shape:
    shape = tensor_shape.unknown_shape()
  ret = gen_state_ops._variable(shape=shape, dtype=dtype, name=name,
                                container=container, shared_name=shared_name)
  # TODO(mrry): Move this to where it is used, so we can get rid of this op
  #   wrapper?
  if set_shape:
    ret.set_shape(shape)
  return ret
コード例 #9
0
 def testDecay(self):
   initial_lr = 0.1
   k = 10
   decay_rate = 0.96
   step = gen_state_ops._variable(shape=[], dtype=dtypes.int32,
       name="step", container="", shared_name="")
   assign_step = state_ops.assign(step, 0)
   increment_step = state_ops.assign_add(step, 1)
   decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, step,
                                                      k, decay_rate)
   with self.test_session():
     assign_step.op.run()
     for i in range(k+1):
       expected = initial_lr * math.exp(-i / k * decay_rate)
       self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
       increment_step.op.run()
コード例 #10
0
 def testDecay(self):
   initial_lr = 0.1
   k = 10
   decay_rate = 0.96
   step = gen_state_ops._variable(shape=[], dtype=dtypes.int32,
       name="step", container="", shared_name="")
   assign_step = state_ops.assign(step, 0)
   increment_step = state_ops.assign_add(step, 1)
   decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, step,
                                                      k, decay_rate)
   with self.test_session():
     assign_step.op.run()
     for i in range(k+1):
       expected = initial_lr * math.exp(-i / k * decay_rate)
       self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
       increment_step.op.run()
コード例 #11
0
 def testAverageVariablesDeviceAssignment(self):
   with tf.device("/job:dev_v0"):
     v0 = tf.Variable(10.0, name="v0")
   with tf.device("/job:dev_v1"):
     v1 = gen_state_ops._variable(shape=[1], dtype=tf.float32, 
         name="v1", container="", shared_name="")
     v1.set_shape([1])
   tensor2 = v0 + v1
   ema = tf.train.ExponentialMovingAverage(0.25, name="foo_avg")
   with tf.device("/job:default"):
     ema.apply([v0, v1, tensor2])
   self.assertDeviceEqual("/job:dev_v0", ema.average(v0).device)
   self.assertDeviceEqual("/job:dev_v1", ema.average(v1).device)
   # However, the colocation property is maintained.
   self.assertEqual([b"loc:@v1"],
                    ema.average(v1).op.colocation_groups())
   self.assertDeviceEqual("/job:default", ema.average(tensor2).device)
コード例 #12
0
 def testAverageVariablesDeviceAssignment(self):
   with tf.device("/job:dev_v0"):
     v0 = tf.Variable(10.0, name="v0")
   with tf.device("/job:dev_v1"):
     v1 = gen_state_ops._variable(shape=[1], dtype=tf.float32, 
         name="v1", container="", shared_name="")
     v1.set_shape([1])
   tensor2 = v0 + v1
   ema = tf.train.ExponentialMovingAverage(0.25, name="foo_avg")
   with tf.device("/job:default"):
     ema.apply([v0, v1, tensor2])
   self.assertDeviceEqual("/job:dev_v0", ema.average(v0).device)
   self.assertDeviceEqual("/job:dev_v1", ema.average(v1).device)
   # However, the colocation property is maintained.
   self.assertEqual([b"loc:@v1"],
                    ema.average(v1).op.colocation_groups())
   self.assertDeviceEqual("/job:default", ema.average(tensor2).device)
コード例 #13
0
 def testContainer(self):
   with tf.Graph().as_default():
     v0 = tf.Variable([0])
     with tf.container("l1"):
       v1 = tf.Variable([1])
       with tf.container("l2"):
         v2 = tf.Variable([2])
         special_v = gen_state_ops._variable(shape=[1], dtype=tf.float32, 
             name="VariableInL3", container="l3", shared_name="")
       v3 = tf.Variable([3])
     v4 = tf.Variable([4])
   self.assertEqual(tf.compat.as_bytes(""), v0.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l1"), v1.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l2"), v2.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l3"),
                    special_v.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l1"), v3.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes(""), v4.op.get_attr("container"))
コード例 #14
0
 def testStaircase(self):
   with self.test_session():
     step = gen_state_ops._variable(shape=[], dtype=dtypes.int32,
         name="step", container="", shared_name="")
     assign_100 = state_ops.assign(step, 100)
     assign_1 = state_ops.assign(step, 1)
     assign_2 = state_ops.assign(step, 2)
     decayed_lr = learning_rate_decay.exponential_decay(.1, step, 3, 0.96,
                                                        staircase=True)
     # No change to learning rate
     assign_1.op.run()
     self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
     assign_2.op.run()
     self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
     # Decayed learning rate
     assign_100.op.run()
     expected = .1 * 0.96 ** (100 // 3)
     self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
コード例 #15
0
 def testContainer(self):
   with tf.Graph().as_default():
     v0 = tf.Variable([0])
     with tf.container("l1"):
       v1 = tf.Variable([1])
       with tf.container("l2"):
         v2 = tf.Variable([2])
         special_v = gen_state_ops._variable(shape=[1], dtype=tf.float32, 
             name="VariableInL3", container="l3", shared_name="")
       v3 = tf.Variable([3])
     v4 = tf.Variable([4])
   self.assertEqual(tf.compat.as_bytes(""), v0.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l1"), v1.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l2"), v2.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l3"),
                    special_v.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l1"), v3.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes(""), v4.op.get_attr("container"))
コード例 #16
0
 def testStaircase(self):
   with self.test_session():
     step = gen_state_ops._variable(shape=[], dtype=dtypes.int32,
         name="step", container="", shared_name="")
     assign_100 = state_ops.assign(step, 100)
     assign_1 = state_ops.assign(step, 1)
     assign_2 = state_ops.assign(step, 2)
     decayed_lr = learning_rate_decay.exponential_decay(.1, step, 3, 0.96,
                                                        staircase=True)
     # No change to learning rate
     assign_1.op.run()
     self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
     assign_2.op.run()
     self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
     # Decayed learning rate
     assign_100.op.run()
     expected = .1 * 0.96 ** (100 // 3)
     self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
コード例 #17
0
 def testStaircase(self):
   initial_lr = 0.1
   k = 10
   decay_rate = 0.96
   step = gen_state_ops._variable(shape=[], dtype=dtypes.int32,
       name="step", container="", shared_name="")
   assign_step = state_ops.assign(step, 0)
   increment_step = state_ops.assign_add(step, 1)
   decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr,
                                                       step,
                                                       k,
                                                       decay_rate,
                                                       staircase=True)
   with self.test_session():
     assign_step.op.run()
     for i in range(k+1):
       expected = initial_lr / (1 + decay_rate * (i // k))
       self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
       increment_step.op.run()
コード例 #18
0
 def testStaircase(self):
   initial_lr = 0.1
   k = 10
   decay_rate = 0.96
   step = gen_state_ops._variable(shape=[], dtype=dtypes.int32,
       name="step", container="", shared_name="")
   assign_step = state_ops.assign(step, 0)
   increment_step = state_ops.assign_add(step, 1)
   decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr,
                                                       step,
                                                       k,
                                                       decay_rate,
                                                       staircase=True)
   with self.test_session():
     assign_step.op.run()
     for i in range(k+1):
       expected = initial_lr / (1 + decay_rate * (i // k))
       self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
       increment_step.op.run()
コード例 #19
0
    def testSGDR(self):
        k = 100
        initial_lr = 0.5
        t_0 = 2
        mult_factor = 12
        step = gen_state_ops._variable(shape=[],
                                       dtype=dtypes.int32,
                                       name="step",
                                       container="",
                                       shared_name="")
        assign_step = state_ops.assign(step, 0)
        increment_step = state_ops.assign_add(step, 1)
        sgdr_lr = sgdr_decay(initial_lr, step, t_0, mult_factor)

        with self.test_session():
            assign_step.op.run()
            for i in range(k + 1):
                lr = sgdr_lr.eval()
                print(lr)
                increment_step.op.run()
コード例 #20
0
  def _init_from_args(self,
                      initial_value=None,
                      trainable=True,
                      collections=None,
                      validate_shape=True,
                      caching_device=None,
                      name=None,
                      dtype=None,
                      expected_shape=None):
    """Creates a new variable from arguments.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called. In
        that case, `dtype` must be specified. (Note that initializer functions
        from init_ops.py must first be bound to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
      validate_shape: If `False`, allows the variable to be initialized with a
        value of unknown shape. If `True`, the default, the shape of
        `initial_value` must be known.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      expected_shape: A TensorShape. If set, initial_value is expected
        to have this shape.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
    """
    if initial_value is None:
      raise ValueError("initial_value must be specified.")
    init_from_fn = callable(initial_value)
    if init_from_fn and dtype is None:
      raise ValueError(
          "dtype must also be specified when initial_value is callable.")

    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    if not isinstance(collections, (list, tuple, set)):
      raise ValueError(
          "collections argument to Variable constructor must be a list, tuple, "
          "or set. Got %s of type %s" % (collections, type(collections)))
    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
    expected_shape = tensor_shape.as_shape(expected_shape)
    with ops.control_dependencies(None):
      with ops.name_scope(name, "Variable", [] if init_from_fn else
                          [initial_value]) as name:

        # Get the initial value from a callable function. The real shape of the
        # variable will be set later, since under the init_from_fn case, the
        # shape won't be known until after the function is invoked.
        #
        # NOTE: The current Variable OpKernel does not support
        # partially defined shapes, so we only set the shape if it is
        # fully defined. For historical reasons, we use the scalar
        # shape (`[]`) to represent an unknown or partially known
        # shape. A future version of the Variable ops will remove this
        # limitation.
        def full_shape_to_list(shape):
          """Returns shape as a list if shape is fully defined."""
          if shape and shape.is_fully_defined():
            return shape.as_list()
          else:
            return []

        def assert_expected_shape():
          """Asserts that the initial value has the expected shape."""
          if expected_shape:
            expected_shape.assert_is_compatible_with(
                self._initial_value.get_shape())

        if init_from_fn:
          expected_shape_list = full_shape_to_list(expected_shape)
          set_shape = validate_shape and expected_shape.is_fully_defined()
          self._variable = gen_state_ops._variable(
              shape=expected_shape_list, 
              dtype=dtype.base_dtype, 
              name=name, 
              container="", 
              shared_name="")
          if set_shape:
            self._variable.set_shape(expected_shape_list)
          with ops.colocate_with(self._variable.op):
            with ops.name_scope("Initializer"):
              # Colocate the tensors created by the initial_value() function
              # with the variable itself.
              self._initial_value = ops.convert_to_tensor(
                  initial_value(), name="initial_value", dtype=dtype)
              assert_expected_shape()

        # Or get the initial value from a Tensor or Python object.
        else:
          self._initial_value = ops.convert_to_tensor(
              initial_value, name="initial_value", dtype=dtype)
          assert_expected_shape()
          set_shape = (validate_shape
                       and self._initial_value.get_shape().is_fully_defined())
          # In this case, the variable op can't be created until after the
          # initial_value has been converted to a Tensor with a known type.
          self._variable = gen_state_ops._variable(
              shape=full_shape_to_list(self._initial_value.get_shape()),
              dtype=self._initial_value.dtype.base_dtype,
              name=name, 
              container="", 
              shared_name="")
          if set_shape:
            self._variable.set_shape(
                full_shape_to_list(self._initial_value.get_shape()))
        # Manually overrides the variable's shape with the initial value's.
        if validate_shape:
          initial_value_shape = self._initial_value.get_shape()
          if not initial_value_shape.is_fully_defined():
            raise ValueError("initial_value must have a shape specified: %s" %
                             self._initial_value)
          self._variable.set_shape(initial_value_shape)
          # TODO(b/28152992): Remove the below hack modifying the node_def shape
          # directly once set_shape() handles it.
          self._variable.op.node_def.attr["shape"].shape.CopyFrom(
              initial_value_shape.as_proto())

        # Assigns initial value.
        self._initializer_op = state_ops.assign(
            self._variable, self._initial_value,
            validate_shape=validate_shape).op

        # TODO(vrv): Change this class to not take caching_device, but
        # to take the op to colocate the snapshot with, so we can use
        # colocation rather than devices.
        if caching_device is not None:
          with ops.device(caching_device):
            self._snapshot = array_ops.identity(self._variable, name="read")
        else:
          with ops.colocate_with(self._variable.op):
            self._snapshot = array_ops.identity(self._variable, name="read")

    ops.add_to_collections(collections, self)
    self._caching_device = caching_device
    self._save_slice_info = None
コード例 #21
0
ファイル: variables.py プロジェクト: DavidNemeskey/tensorflow
  def _init_from_args(self,
                      initial_value=None,
                      trainable=True,
                      collections=None,
                      validate_shape=True,
                      caching_device=None,
                      name=None,
                      dtype=None,
                      expected_shape=None):
    """Creates a new variable from arguments.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called. In
        that case, `dtype` must be specified. (Note that initializer functions
        from init_ops.py must first be bound to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.VARIABLES]`.
      validate_shape: If `False`, allows the variable to be initialized with a
        value of unknown shape. If `True`, the default, the shape of
        `initial_value` must be known.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      expected_shape: A TensorShape. If set, initial_value is expected
        to have this shape.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
    """
    if initial_value is None:
      raise ValueError("initial_value must be specified.")
    init_from_fn = callable(initial_value)
    if init_from_fn and dtype is None:
      raise ValueError(
          "dtype must also be specified when initial_value is callable.")

    if collections is None:
      collections = [ops.GraphKeys.VARIABLES]
    if not isinstance(collections, (list, tuple, set)):
      raise ValueError(
          "collections argument to Variable constructor must be a list, tuple, "
          "or set. Got %s of type %s" % (collections, type(collections)))
    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
    expected_shape = tensor_shape.as_shape(expected_shape)
    with ops.control_dependencies(None):
      with ops.name_scope(name, "Variable", [] if init_from_fn else
                          [initial_value]) as name:

        # Get the initial value from a callable function. The real shape of the
        # variable will be set later, since under the init_from_fn case, the
        # shape won't be known until after the function is invoked.
        #
        # NOTE: The current Variable OpKernel does not support
        # partially defined shapes, so we only set the shape if it is
        # fully defined. For historical reasons, we use the scalar
        # shape (`[]`) to represent an unknown or partially known
        # shape. A future version of the Variable ops will remove this
        # limitation.
        def full_shape_to_list(shape):
          """Returns shape as a list if shape is fully defined."""
          if shape and shape.is_fully_defined():
            return shape.as_list()
          else:
            return []

        def assert_expected_shape():
          """Asserts that the initial value has the expected shape."""
          if expected_shape:
            expected_shape.assert_is_compatible_with(
                self._initial_value.get_shape())

        if init_from_fn:
          expected_shape_list = full_shape_to_list(expected_shape)
          set_shape = validate_shape and expected_shape.is_fully_defined()
          self._variable = gen_state_ops._variable(
              shape=expected_shape_list, 
              dtype=dtype.base_dtype, 
              name=name, 
              container="", 
              shared_name="")
          if set_shape:
            self._variable.set_shape(expected_shape_list)
          with ops.colocate_with(self._variable.op):
            with ops.name_scope("Initializer"):
              # Colocate the tensors created by the initial_value() function
              # with the variable itself.
              self._initial_value = ops.convert_to_tensor(
                  initial_value(), name="initial_value", dtype=dtype)
              assert_expected_shape()

        # Or get the initial value from a Tensor or Python object.
        else:
          self._initial_value = ops.convert_to_tensor(
              initial_value, name="initial_value", dtype=dtype)
          assert_expected_shape()
          set_shape = (validate_shape
                       and self._initial_value.get_shape().is_fully_defined())
          # In this case, the variable op can't be created until after the
          # initial_value has been converted to a Tensor with a known type.
          self._variable = gen_state_ops._variable(
              shape=full_shape_to_list(self._initial_value.get_shape()),
              dtype=self._initial_value.dtype.base_dtype,
              name=name, 
              container="", 
              shared_name="")
          if set_shape:
            self._variable.set_shape(
                full_shape_to_list(self._initial_value.get_shape()))
        # Manually overrides the variable's shape with the initial value's.
        if validate_shape:
          initial_value_shape = self._initial_value.get_shape()
          if not initial_value_shape.is_fully_defined():
            raise ValueError("initial_value must have a shape specified: %s" %
                             self._initial_value)
          self._variable.set_shape(initial_value_shape)
          # TODO(b/28152992): Remove the below hack modifying the node_def shape
          # directly once set_shape() handles it.
          self._variable.op.node_def.attr["shape"].shape.CopyFrom(
              initial_value_shape.as_proto())

        # Assigns initial value.
        self._initializer_op = state_ops.assign(
            self._variable, self._initial_value,
            validate_shape=validate_shape).op

        # TODO(vrv): Change this class to not take caching_device, but
        # to take the op to colocate the snapshot with, so we can use
        # colocation rather than devices.
        if caching_device is not None:
          with ops.device(caching_device):
            self._snapshot = array_ops.identity(self._variable, name="read")
        else:
          with ops.colocate_with(self._variable.op):
            self._snapshot = array_ops.identity(self._variable, name="read")

    ops.add_to_collections(collections, self)
    self._caching_device = caching_device
    self._save_slice_info = None