コード例 #1
0
    def testInitializeFromValue(self):
        with self.test_session() as sess:
            init = constant_op.constant(0.1)
            w = variable_scope.get_variable("v", initializer=init)
            sess.run(variables_lib.initialize_variables([w]))
            self.assertAllClose(w.eval(), 0.1)

            with self.assertRaisesRegexp(ValueError, "shape"):
                # We disallow explicit shape specification when initializer is constant.
                variable_scope.get_variable("u", [1], initializer=init)

            with variable_scope.variable_scope("foo", initializer=init):
                # Constant initializer can be passed through scopes if needed.
                v = variable_scope.get_variable("v")
                sess.run(variables_lib.initialize_variables([v]))
                self.assertAllClose(v.eval(), 0.1)

            # Check that non-float32 initializer creates a non-float32 variable.
            init = constant_op.constant(1, dtype=dtypes.int32)
            t = variable_scope.get_variable("t", initializer=init)
            self.assertEqual(t.dtype.base_dtype, dtypes.int32)

            # Raise error if `initializer` dtype and `dtype` are not identical.
            with self.assertRaisesRegexp(ValueError, "don't match"):
                variable_scope.get_variable("s",
                                            initializer=init,
                                            dtype=dtypes.float64)
コード例 #2
0
  def testInitializeFromValue(self):
    with self.test_session() as sess:
      init = constant_op.constant(0.1)
      w = variable_scope.get_variable("v", initializer=init)
      sess.run(variables_lib.initialize_variables([w]))
      self.assertAllClose(w.eval(), 0.1)

      with self.assertRaisesRegexp(ValueError, "shape"):
        # We disallow explicit shape specification when initializer is constant.
        variable_scope.get_variable("u", [1], initializer=init)

      with variable_scope.variable_scope("foo", initializer=init):
        # Constant initializer can be passed through scopes if needed.
        v = variable_scope.get_variable("v")
        sess.run(variables_lib.initialize_variables([v]))
        self.assertAllClose(v.eval(), 0.1)

      # Check that non-float32 initializer creates a non-float32 variable.
      init = constant_op.constant(1, dtype=dtypes.int32)
      t = variable_scope.get_variable("t", initializer=init)
      self.assertEqual(t.dtype.base_dtype, dtypes.int32)

      # Raise error if `initializer` dtype and `dtype` are not identical.
      with self.assertRaisesRegexp(ValueError, "don't match"):
        variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64)
コード例 #3
0
 def testVarScopeInitializer(self):
   with self.test_session() as sess:
     init = init_ops.constant_initializer(0.3)
     with variable_scope.variable_scope("tower") as tower:
       with variable_scope.variable_scope("foo", initializer=init):
         v = variable_scope.get_variable("v", [])
         sess.run(variables_lib.initialize_variables([v]))
         self.assertAllClose(v.eval(), 0.3)
       with variable_scope.variable_scope(tower, initializer=init):
         w = variable_scope.get_variable("w", [])
         sess.run(variables_lib.initialize_variables([w]))
         self.assertAllClose(w.eval(), 0.3)
コード例 #4
0
 def testVarScopeInitializer(self):
     with self.test_session() as sess:
         init = init_ops.constant_initializer(0.3)
         with variable_scope.variable_scope("tower") as tower:
             with variable_scope.variable_scope("foo", initializer=init):
                 v = variable_scope.get_variable("v", [])
                 sess.run(variables_lib.initialize_variables([v]))
                 self.assertAllClose(v.eval(), 0.3)
             with variable_scope.variable_scope(tower, initializer=init):
                 w = variable_scope.get_variable("w", [])
                 sess.run(variables_lib.initialize_variables([w]))
                 self.assertAllClose(w.eval(), 0.3)
コード例 #5
0
 def test_local_variable(self):
   with self.test_session() as sess:
     self.assertEquals([], variables_lib.local_variables())
     value0 = 42
     variables_lib2.local_variable(value0)
     value1 = 43
     variables_lib2.local_variable(value1)
     variables = variables_lib.local_variables()
     self.assertEquals(2, len(variables))
     self.assertRaises(errors_impl.OpError, sess.run, variables)
     variables_lib.initialize_variables(variables).run()
     self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
コード例 #6
0
  def testInitFromNonTensorValue(self):
    with self.test_session() as sess:
      v = variable_scope.get_variable("v", initializer=4, dtype=dtypes.int32)
      sess.run(variables_lib.initialize_variables([v]))
      self.assertAllClose(v.eval(), 4)

      w = variable_scope.get_variable(
          "w", initializer=numpy.array([1, 2, 3]), dtype=dtypes.int64)
      sess.run(variables_lib.initialize_variables([w]))
      self.assertAllClose(w.eval(), [1, 2, 3])

      with self.assertRaises(TypeError):
        variable_scope.get_variable("x", initializer={})
コード例 #7
0
  def testInitFromNonTensorValue(self):
    with self.test_session() as sess:
      v = variable_scope.get_variable("v", initializer=4, dtype=dtypes.int32)
      sess.run(variables_lib.initialize_variables([v]))
      self.assertAllClose(v.eval(), 4)

      w = variable_scope.get_variable(
          "w", initializer=numpy.array([1, 2, 3]), dtype=dtypes.int64)
      sess.run(variables_lib.initialize_variables([w]))
      self.assertAllClose(w.eval(), [1, 2, 3])

      with self.assertRaises(TypeError):
        variable_scope.get_variable("x", initializer={})
コード例 #8
0
  def testConvertVariablesToConstsWithFunctions(self):
    @function.Defun(dtypes.float32)
    def plus_one(x):
      return x + 1.0

    with ops.Graph().as_default():
      variable_node = variables.Variable(1.0, name="variable_node")
      _ = variables.Variable(1.0, name="unused_variable_node")
      defun_node = plus_one(variable_node)
      output_node = math_ops_lib.multiply(
          defun_node, 2.0, name="output_node")

      with session.Session() as sess:
        init = variables.initialize_variables([variable_node])
        sess.run(init)
        output = sess.run(output_node)
        self.assertNear(4.0, output, 0.00001)
        variable_graph_def = sess.graph.as_graph_def()

        # First get the constant_graph_def when variable_names_whitelist is set,
        # note that if variable_names_whitelist is not set an error will be
        # thrown because unused_variable_node is not initialized.
        constant_graph_def = graph_util.convert_variables_to_constants(
            sess,
            variable_graph_def, ["output_node"],
            variable_names_whitelist=set(["variable_node"]))

        self.assertEqual(variable_graph_def.library,
                         constant_graph_def.library)
コード例 #9
0
  def testConvertVariablesToConsts(self):
    with ops.Graph().as_default():
      variable_node = variables.Variable(1.0, name="variable_node")
      _ = variables.Variable(1.0, name="unused_variable_node")
      output_node = math_ops_lib.multiply(
          variable_node, 2.0, name="output_node")
      with session.Session() as sess:
        init = variables.initialize_variables([variable_node])
        sess.run(init)
        output = sess.run(output_node)
        self.assertNear(2.0, output, 0.00001)
        variable_graph_def = sess.graph.as_graph_def()
        # First get the constant_graph_def when variable_names_whitelist is set,
        # note that if variable_names_whitelist is not set an error will be
        # thrown because unused_variable_node is not initialized.
        constant_graph_def = graph_util.convert_variables_to_constants(
            sess,
            variable_graph_def, ["output_node"],
            variable_names_whitelist=set(["variable_node"]))

        # Then initialize the unused variable, and get another
        # constant_graph_def when variable_names_whitelist is not set.
        sess.run(variables.global_variables_initializer())
        constant_graph_def_without_variable_whitelist = (
            graph_util.convert_variables_to_constants(sess, variable_graph_def,
                                                      ["output_node"]))

        # The unused variable should be cleared so the two graphs should be
        # equivalent.
        self.assertEqual(
            str(constant_graph_def),
            str(constant_graph_def_without_variable_whitelist))

        # Test variable name black list. This should result in the variable not
        # being a const.
        sess.run(variables.global_variables_initializer())
        constant_graph_def_with_blacklist = (
            graph_util.convert_variables_to_constants(
                sess,
                variable_graph_def, ["output_node"],
                variable_names_blacklist=set(["variable_node"])))
        variable_node = None
        for node in constant_graph_def_with_blacklist.node:
          if node.name == "variable_node":
            variable_node = node
        self.assertIsNotNone(variable_node)
        self.assertEqual(variable_node.op, "VariableV2")

    # Now we make sure the variable is now a constant, and that the graph still
    # produces the expected result.
    with ops.Graph().as_default():
      _ = importer.import_graph_def(constant_graph_def, name="")
      self.assertEqual(4, len(constant_graph_def.node))
      for node in constant_graph_def.node:
        self.assertNotEqual("Variable", node.op)
        self.assertNotEqual("VariableV2", node.op)
      with session.Session() as sess:
        output_node = sess.graph.get_tensor_by_name("output_node:0")
        output = sess.run(output_node)
        self.assertNear(2.0, output, 0.00001)
コード例 #10
0
    def testVarScopeRegularizer(self):
        with self.test_session() as sess:
            init = init_ops.constant_initializer(0.3)

            def regularizer1(v):
                return math_ops.reduce_mean(v) + 0.1

            def regularizer2(v):
                return math_ops.reduce_mean(v) + 0.2

            with variable_scope.variable_scope(
                    "tower", regularizer=regularizer1) as tower:
                with variable_scope.variable_scope("foo", initializer=init):
                    v = variable_scope.get_variable("v", [])
                    sess.run(variables_lib.initialize_variables([v]))
                    losses = ops.get_collection(
                        ops.GraphKeys.REGULARIZATION_LOSSES)
                    self.assertEqual(1, len(losses))
                    self.assertAllClose(losses[0].eval(), 0.4)
                with variable_scope.variable_scope(tower,
                                                   initializer=init) as vs:
                    u = variable_scope.get_variable("u", [])
                    vs.set_regularizer(regularizer2)
                    w = variable_scope.get_variable("w", [])
                    # Next 3 variable not regularized to test disabling regularization.
                    x = variable_scope.get_variable(
                        "x", [], regularizer=variable_scope.no_regularizer)
                    with variable_scope.variable_scope(
                            "baz", regularizer=variable_scope.no_regularizer):
                        y = variable_scope.get_variable("y", [])
                    vs.set_regularizer(variable_scope.no_regularizer)
                    z = variable_scope.get_variable("z", [])
                    # Check results.
                    losses = ops.get_collection(
                        ops.GraphKeys.REGULARIZATION_LOSSES)
                    self.assertEqual(3, len(losses))
                    sess.run(
                        variables_lib.initialize_variables([u, w, x, y, z]))
                    self.assertAllClose(losses[0].eval(), 0.4)
                    self.assertAllClose(losses[1].eval(), 0.4)
                    self.assertAllClose(losses[2].eval(), 0.5)
                with variable_scope.variable_scope("foo", reuse=True):
                    v = variable_scope.get_variable(
                        "v", [])  # "v" is alredy there, reused
                    losses = ops.get_collection(
                        ops.GraphKeys.REGULARIZATION_LOSSES)
                    self.assertEqual(3, len(losses))  # No new loss added.
コード例 #11
0
 def __setitem__(self, index, value):
     for use_gpu in [False, True]:
         with self.test.test_session(use_gpu=use_gpu) as sess:
             var = variables.Variable(self.x)
             sess.run(variables.initialize_variables([var]))
             val = sess.run(var[index].assign(
                 constant_op.constant(value, dtype=self.tensor_type)))
             valnp = np.copy(self.x_np)
             valnp[index] = np.array(value)
             self.test.assertAllEqual(val, valnp)
コード例 #12
0
  def testVarScopeRegularizer(self):
    with self.test_session() as sess:
      init = init_ops.constant_initializer(0.3)

      def regularizer1(v):
        return math_ops.reduce_mean(v) + 0.1

      def regularizer2(v):
        return math_ops.reduce_mean(v) + 0.2

      with variable_scope.variable_scope(
          "tower", regularizer=regularizer1) as tower:
        with variable_scope.variable_scope("foo", initializer=init):
          v = variable_scope.get_variable("v", [])
          sess.run(variables_lib.initialize_variables([v]))
          losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
          self.assertEqual(1, len(losses))
          self.assertAllClose(losses[0].eval(), 0.4)
        with variable_scope.variable_scope(tower, initializer=init) as vs:
          u = variable_scope.get_variable("u", [])
          vs.set_regularizer(regularizer2)
          w = variable_scope.get_variable("w", [])
          # Next 3 variable not regularized to test disabling regularization.
          x = variable_scope.get_variable(
              "x", [], regularizer=variable_scope.no_regularizer)
          with variable_scope.variable_scope(
              "baz", regularizer=variable_scope.no_regularizer):
            y = variable_scope.get_variable("y", [])
          vs.set_regularizer(variable_scope.no_regularizer)
          z = variable_scope.get_variable("z", [])
          # Check results.
          losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
          self.assertEqual(3, len(losses))
          sess.run(variables_lib.initialize_variables([u, w, x, y, z]))
          self.assertAllClose(losses[0].eval(), 0.4)
          self.assertAllClose(losses[1].eval(), 0.4)
          self.assertAllClose(losses[2].eval(), 0.5)
        with variable_scope.variable_scope("foo", reuse=True):
          v = variable_scope.get_variable("v",
                                          [])  # "v" is alredy there, reused
          losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
          self.assertEqual(3, len(losses))  # No new loss added.
コード例 #13
0
ファイル: array_ops_test.py プロジェクト: Immexxx/tensorflow
 def __setitem__(self, index, value):
   for use_gpu in [False, True]:
     with self.test.test_session(use_gpu=use_gpu) as sess:
       var = variables.Variable(self.x)
       sess.run(variables.initialize_variables([var]))
       val = sess.run(var[index].assign(
           constant_op.constant(
               value, dtype=self.tensor_type)))
       valnp = np.copy(self.x_np)
       valnp[index] = np.array(value)
       self.test.assertAllEqual(val, valnp)
コード例 #14
0
ファイル: array_ops_test.py プロジェクト: zmmqq00/tensorflow
    def __setitem__(self, index, value):
        value = np.array(value).astype(self.tensor_type.as_numpy_dtype)
        # Give the value a non-zero imaginary component for complex types.
        if self.tensor_type.is_complex:
            value -= 1j * value

        with self.test.test_session(use_gpu=True) as sess:
            var = variables.Variable(self.x)
            sess.run(variables.initialize_variables([var]))
            val = sess.run(var[index].assign(
                constant_op.constant(value, dtype=self.tensor_type)))
            valnp = np.copy(self.x_np)
            valnp[index] = np.array(value)
            self.test.assertAllEqual(val, valnp)
コード例 #15
0
ファイル: array_ops_test.py プロジェクト: hailingc/tensorflow
  def __setitem__(self, index, value):
    value = np.array(value).astype(self.tensor_type.as_numpy_dtype)
    # Give the value a non-zero imaginary component for complex types.
    if self.tensor_type.is_complex:
      value -= 1j * value

    with self.test.test_session(use_gpu=True) as sess:
      var = variables.Variable(self.x)
      sess.run(variables.initialize_variables([var]))
      val = sess.run(var[index].assign(
          constant_op.constant(
              value, dtype=self.tensor_type)))
      valnp = np.copy(self.x_np)
      valnp[index] = np.array(value)
      self.test.assertAllEqual(val, valnp)
コード例 #16
0
  def __setitem__(self, index, value):
    value = np.array(value).astype(self.tensor_type.as_numpy_dtype)
    # Give the value a non-zero imaginary component for complex types.
    if self.tensor_type.is_complex:
      value -= 1j * value

    with self.test.test_session(use_gpu=True) as sess:
      var = variables.Variable(self.x)
      sess.run(variables.initialize_variables([var]))
      val = sess.run(var[index].assign(value))
      # val_copy is used to check that tf.assign works equivalently to the
      # assign method above.
      val_copy = sess.run(state_ops.assign(var[index], value))
      valnp = np.copy(self.x_np)
      valnp[index] = np.array(value)
      self.test.assertAllEqual(val, valnp)
      self.test.assertAllEqual(val_copy, valnp)
コード例 #17
0
 def testGetVariableScope(self):
   # Test the get_variable_scope() function and setting properties of result.
   with self.test_session() as sess:
     init = init_ops.constant_initializer(0.3)
     with variable_scope.variable_scope("foo"):
       new_init1 = variable_scope.get_variable_scope().initializer
       self.assertEqual(new_init1, None)
       # Check that we can set initializer like this.
       variable_scope.get_variable_scope().set_initializer(init)
       v = variable_scope.get_variable("v", [])
       sess.run(variables_lib.initialize_variables([v]))
       self.assertAllClose(v.eval(), 0.3)
       # Check that we can set reuse.
       variable_scope.get_variable_scope().reuse_variables()
       with self.assertRaises(ValueError):  # Fail, w does not exist yet.
         variable_scope.get_variable("w", [1])
     # Check that the set initializer goes away.
     new_init = variable_scope.get_variable_scope().initializer
     self.assertEqual(new_init, None)
コード例 #18
0
 def testGetVariableScope(self):
   # Test the get_variable_scope() function and setting properties of result.
   with self.test_session() as sess:
     init = init_ops.constant_initializer(0.3)
     with variable_scope.variable_scope("foo"):
       new_init1 = variable_scope.get_variable_scope().initializer
       self.assertEqual(new_init1, None)
       # Check that we can set initializer like this.
       variable_scope.get_variable_scope().set_initializer(init)
       v = variable_scope.get_variable("v", [])
       sess.run(variables_lib.initialize_variables([v]))
       self.assertAllClose(v.eval(), 0.3)
       # Check that we can set reuse.
       variable_scope.get_variable_scope().reuse_variables()
       with self.assertRaises(ValueError):  # Fail, w does not exist yet.
         variable_scope.get_variable("w", [1])
     # Check that the set initializer goes away.
     new_init = variable_scope.get_variable_scope().initializer
     self.assertEqual(new_init, None)
コード例 #19
0
ファイル: thin_stack_test.py プロジェクト: hans/tensorflow
  def _testIntegrated(self, batch_size, model_dim, num_timesteps, ff_fun, sim_fun):
    """
    Test the simplest possible transition sequence on a batch of random inputs.
    """

    tf.reset_default_graph()

    embedding_dim = model_dim
    num_tokens = (num_timesteps + 1) / 2

    with self.test_session(use_gpu=self.use_gpu) as s:
      stack = Variable(np.zeros((batch_size * num_timesteps, model_dim), dtype=np.float32), name="stack")
      buffer = Variable(np.random.random((batch_size * num_tokens, embedding_dim)).astype(np.float32), name="buffer")
      queue = Variable(np.zeros((batch_size * num_timesteps,), dtype=np.float32), name="queue")
      cursors = Variable(np.zeros((batch_size,), dtype=np.float32) - 1., name="cursors")
      buffer_cursors = Variable(np.zeros((batch_size,), dtype=np.float32), name="buffer_cursors")

      ######## Fprop test.
      top = ff_fun(batch_size, stack, buffer, queue, cursors, buffer_cursors)
      top_sim = sim_fun(buffer)

      s.run(initialize_variables(tf.all_variables()))

      ######## Bprop test.
      # Get some scalar error signal for grad calculation
      top, top_sim = tf.reduce_sum(top), tf.reduce_sum(top_sim)
      with tf.control_dependencies([top]):
        grad = tf.gradients(top, buffer)[0]
      grad_sim = tf.gradients(top_sim, buffer)[0]

      ######## Run fetches.
      ret = s.run([top, top_sim, grad, grad_sim])
      top_, top_sim_, grad_, grad_sim_ = ret[:4]

    self.assertAllClose(top_, top_sim_)
    self.assertAllClose(grad_, grad_sim_)
コード例 #20
0
  def testIntermediateLookupGrad(self):
    """
    Test the gradient of a standard lookup somewhere in the middle of a stack
    recurrence.
    """

    batch_size = 2
    model_dim = 5
    embedding_dim = 5
    num_timesteps = 5

    num_tokens = (num_timesteps + 1) / 2

    with self.test_session(use_gpu=self.use_gpu) as s:
      # Example 1: S S R S
      # Example 2: S S S R
      #                  ^
      # we are running lookup at the above timestep
      stack = Variable([[-1., -1., -1., -1., -1.],
                        [ 1.,  1.,  1.,  1.,  1.],
                        [-2., -2., -2., -2., -2.],
                        [ 2.,  2.,  2.,  2.,  2.],
                        [-3., -3., -3., -3., -3.],
                        [ 3.,  3.,  3.,  3.,  3.],
                        [ 0.,  0.,  0.,  0.,  0.],
                        [ 0.,  0.,  0.,  0.,  0.],
                        [ 0.,  0.,  0.,  0.,  0.],
                        [ 0.,  0.,  0.,  0.,  0.]])
      buffer = Variable([[-1., -1., -1., -1., -1.],
                         [ 1.,  1.,  1.,  1.,  1.],
                         [-2., -2., -2., -2., -2.],
                         [ 2.,  2.,  2.,  2.,  2.],
                         [-3., -3., -3., -3., -3.],
                         [ 3.,  3.,  3.,  3.,  3.]])
      queue = Variable([2., 0.,
                        0., 1.,
                        0., 2.,
                        0., 0.,
                        0., 0.])
      cursors = Variable([0., 2.])
      buffer_cursors = Variable([2., 3.])

      s.run(initialize_variables([stack, buffer, queue, cursors, buffer_cursors]))

      stack_val = stack.eval()
      buffer_val = buffer.eval()

      lookup = ts.thin_stack_lookup(stack, buffer, queue, cursors, buffer_cursors, timestep=3)

      #### GRADIENT

      stack1_grad = tf.random_uniform((batch_size, model_dim))
      stack2_grad = tf.random_uniform((batch_size, model_dim))
      buf_top_grad = tf.random_uniform((batch_size, model_dim))
      in_grads = (stack1_grad, stack2_grad, buf_top_grad, None)

      # HACK: Zero out stack and buffer before invoking this op.
      # In a real / full bprop, things would have been zeroed out
      # at the start of the bprop algorithm.
      zero_stack = tf.assign(stack, stack * 0.)
      zero_buffer = tf.assign(buffer, buffer * 0.)

      # Enforce computation order: lookup, then zero out, then grad
      with tf.control_dependencies(lookup + (zero_stack, zero_buffer)):
        out_grads = ts._thin_stack_lookup_gradient(lookup[0].op, in_grads)
      out_grads = out_grads[:2]

      fetch = out_grads + (stack1_grad, stack2_grad, buf_top_grad)

      ret = s.run(fetch)

    grad_stack, grad_buffer, stack1_grad, stack2_grad, buf_top_grad = ret

    grad_stack_expected = np.zeros_like(stack_val)
コード例 #21
0
  def testIntermediateUpdate(self):
    """Test a standard update somewhere in the middle of a stack recurrence."""
    batch_size = 2
    model_dim = 5
    embedding_dim = 5
    num_timesteps = 5

    num_tokens = (num_timesteps + 1) / 2

    with self.test_session(use_gpu=self.use_gpu) as s:
      # Example 1: S S R S
      # Example 2: S S S R
      #                  ^
      # we are running lookup at the above timestep

      stack = Variable([[-1., -1., -1., -1., -1.],
                        [ 1.,  1.,  1.,  1.,  1.],
                        [-2., -2., -2., -2., -2.],
                        [ 2.,  2.,  2.,  2.,  2.],
                        [-3., -3., -3., -3., -3.],
                        [ 3.,  3.,  3.,  3.,  3.],
                        [ 0.,  0.,  0.,  0.,  0.],
                        [ 0.,  0.,  0.,  0.,  0.],
                        [ 0.,  0.,  0.,  0.,  0.],
                        [ 0.,  0.,  0.,  0.,  0.]])
      buffer = Variable([[-1., -1., -1., -1., -1.],
                         [ 1.,  1.,  1.,  1.,  1.],
                         [-2., -2., -2., -2., -2.],
                         [ 2.,  2.,  2.,  2.,  2.],
                         [-3., -3., -3., -3., -3.],
                         [ 3.,  3.,  3.,  3.,  3.]])
      queue = Variable([2., 0.,
                        0., 1.,
                        0., 2.,
                        0., 0.,
                        0., 0.])
      cursors = Variable([0., 2.])
      buffer_cursors = constant_op.constant([2., 3.])
      t = 3

      s.run(initialize_variables([stack, buffer, queue, cursors]))

      stack_val = stack.eval()
      buffer_val = buffer.eval()

      shift_in = constant_op.constant(np.array([buffer_val[4], buffer_val[5]]))
      reduce_in = constant_op.constant(np.array([stack_val[4] + stack_val[0],
                                                 stack_val[5] + stack_val[3]]))
      transitions = tf.expand_dims(constant_op.constant([0., 1.]), 1)
      input_val = transitions * reduce_in + (1. - transitions) * shift_in

      ret = ts.thin_stack_update(input_val, transitions,
                                 stack, queue, cursors, buffer_cursors, t)
      stack_next, queue_next, cursors_next, buffer_cursors_next = s.run(ret)

    stack_expected = np.copy(stack_val)
    stack_expected[6] = buffer_val[4]
    stack_expected[7] = stack_val[5] + stack_val[3]

    queue_expected = np.array([2., 0.,
                               3., 3.,
                               0., 2., # NB: we didn't erase this, but it's okay
                               0., 0.,
                               0., 0.])
    cursors_expected = np.array([1., 1.])
    buffer_cursors_expected = np.array([3., 3.])

    self.assertAllEqual(stack_next, stack_expected)
    self.assertAllEqual(queue_next, queue_expected)
    self.assertAllEqual(cursors_next, cursors_expected)
    self.assertAllEqual(buffer_cursors_next, buffer_cursors_expected)