コード例 #1
0
  def testStackEmptyList(self, max_num_elements):
    # Should be able to stack empty lists with fully defined element_shape.
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=[1, 2],
        max_num_elements=max_num_elements)
    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t).shape, (0, 1, 2))

    # Should not be able to stack empty lists with partially defined
    # element_shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "non-fully-defined"):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=[-1, 2],
          max_num_elements=max_num_elements)
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)

    # Should not be able to stack empty lists with undefined element_shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "non-fully-defined"):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=-1,
          max_num_elements=max_num_elements)
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)
コード例 #2
0
ファイル: list_ops_test.py プロジェクト: aeverall/tensorflow
  def testGatherEmptyList(self, max_num_elements):
    # Should be able to gather from empty lists with fully defined
    # element_shape.
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=[1, 2],
        max_num_elements=max_num_elements)
    t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32)
    self.assertAllEqual((0, 1, 2), self.evaluate(t).shape)

    # Should not be able to gather from empty lists with partially defined
    # element_shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "non-fully-defined"):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=[None, 2],
          max_num_elements=max_num_elements)
      t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32)
      self.evaluate(t)

    # Should not be able to gather from empty lists with undefined
    # element_shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "non-fully-defined"):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=None,
          max_num_elements=max_num_elements)
      t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32)
      self.evaluate(t)
コード例 #3
0
ファイル: list_ops_test.py プロジェクト: aeverall/tensorflow
 def testConcatEmptyListWithFullyDefinedElementShape(self):
   l = list_ops.empty_tensor_list(
       element_dtype=dtypes.float32, element_shape=[5, 2])
   t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
   self.assertAllEqual(self.evaluate(t).shape, (0, 2))
   l = list_ops.empty_tensor_list(
       element_dtype=dtypes.float32, element_shape=[None, 2])
   t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
   self.assertAllEqual(self.evaluate(t).shape, (0, 2))
コード例 #4
0
ファイル: while_v2.py プロジェクト: becster/tensorflow
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of a While op produced by while_loop."""
  body_graph = _get_body_graph(op)

  # Replace None gradients with zeros. This is needed because `grads` could have
  # None incoming gradients for the TensorLists. If we pass None's through, the
  # custom gradient of TensorListPopBack will create an EmptyTensorList inside
  # the FuncGraph which is undesirable.
  # TODO(b/80444525): There might be an issue with treating no gradient as zero
  # gradient in certain cases. Consider replacing None gradients with Zeros
  # for accumulators only.
  grads = [
      g if g is not None else array_ops.zeros_like(output)
      for g, output in zip(grads, op.outputs)
  ]

  body_grad_graph, args = _create_grad_func(
      body_graph, grads,
      util.unique_grad_fn_name(body_graph.name), op)

  intermediate_tensors = _get_intermediates(body_grad_graph)

  for intermediate_tensor in intermediate_tensors:
    tensor_list = list_ops.empty_tensor_list(
        element_dtype=intermediate_tensor.dtype,
        element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape))
    with body_grad_graph.as_default():
      tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True)
      # Push the intermediate tensor to the tensor list.
      appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph,
                                                            intermediate_tensor)
      # Add this modified tensor list to the list of outputs.
      body_grad_graph.outputs.append(appended_tensor_list)

  def grad_cond(counter, max_iters, *unused_args):
    return counter < max_iters

  loop_vars = args + body_grad_graph.external_captures
  grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
  cond_grad_graph = func_graph_module.func_graph_from_py_func(
      grad_cond_name, grad_cond, loop_vars, {},
      func_graph=util.WhileCondFuncGraph(grad_cond_name))

  assert len(loop_vars) == len(body_grad_graph.inputs)
  assert len(loop_vars) == len(body_grad_graph.outputs)
  assert len(loop_vars) == len(cond_grad_graph.inputs)

  outputs = gen_functional_ops._while(
      loop_vars,
      util.create_new_tf_function(cond_grad_graph),
      util.create_new_tf_function(body_grad_graph),
      output_shapes=[t.shape for t in body_grad_graph.outputs],
      name="%s_grad" % op.name)

  _copy_handle_data(body_grad_graph.outputs, outputs)
  _maybe_set_lowering_attr(outputs[0].op)

  # outputs[0] is the loop counter.
  # outputs[1] is the total number of loop iterations.
  return outputs[2:2 + len(op.inputs)]
コード例 #5
0
  def testDefaultGradYs(self):
    with ops.Graph().as_default():
      tl = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
      a = constant(1.0)
      tl = list_ops.tensor_list_push_back(tl, a)

      grad_tl = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
      grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0))

      grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0]
      with self.cached_session() as sess:
        self.assertEquals(self.evaluate(grad), 5.)
コード例 #6
0
 def testStack(self):
   l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                  element_shape=scalar_shape())
   l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
   l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
   t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
   self.assertAllEqual(t, [1.0, 2.0])
コード例 #7
0
ファイル: while_v2_test.py プロジェクト: ThunderQi/tensorflow
  def testPruning(self):
    x = constant_op.constant(1)

    tensor_list = list_ops.empty_tensor_list(
        element_dtype=x.dtype, element_shape=x.shape)

    def Cond(x, tl):
      del tl  # Unused for Cond.
      return x < 5

    def Body(x, tl):
      return x + 1, list_ops.tensor_list_push_back(tl, x)

    outputs = while_loop_v1(Cond, Body, [x, tensor_list])

    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    train_op.append(outputs[0])

    def GetOptimizedGraph():
      mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
      rewriter_config = rewriter_config_pb2.RewriterConfig(
          constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
          memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
      return tf_optimizer.OptimizeGraph(rewriter_config, mg)

    g = GetOptimizedGraph()
    self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)

    stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
    train_op.append(stack)
    g = GetOptimizedGraph()
    self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
コード例 #8
0
ファイル: list_ops_test.py プロジェクト: aeverall/tensorflow
 def testConcatWithNonFullyDefinedElementShape(self):
   l = list_ops.empty_tensor_list(
       element_dtype=dtypes.float32, element_shape=[None, 2])
   l = list_ops.tensor_list_push_back(l, [[0., 1.]])
   l = list_ops.tensor_list_push_back(l, [[2., 3.], [4., 5.]])
   t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
   self.assertAllEqual(self.evaluate(t), [[0., 1.], [2., 3.], [4., 5.]])
コード例 #9
0
ファイル: while_v2_test.py プロジェクト: ThunderQi/tensorflow
  def testDuplicateAccumulator(self):
    x = constant_op.constant(2.)

    tensor_list = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=ScalarShape())

    def Cond(x, tl):
      del tl  # Unused for Cond.
      return x < 5.

    def Body(x, tl):
      # There is an accumulator in the loop already so we should not add
      # another.
      tl = list_ops.tensor_list_push_back(tl, x)
      return x**2., tl

    ret = while_loop_v2(Cond, Body, [x, tensor_list])

    for op in ops.get_default_graph().get_operations():
      if op.type == "While":
        while_op = op

    body_graph = while_v2._get_body_graph(while_op)
    # body_graph.inputs: [counter_arg, x_arg, tl_arg, *accumulators]
    x_input_t = body_graph.inputs[1]
    accumulator_count = len(
        [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"])
    self.assertEqual(accumulator_count, 1)

    grad = gradients_impl.gradients(ret[0], x)
    with self.cached_session() as sess:
      self.assertEqual(sess.run(ret[0]), 16.)
      self.assertSequenceEqual(sess.run(grad), [32.])
コード例 #10
0
 def testUnknownShape(self):
   l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                  element_shape=-1)
   l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
   l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0, 2.0]))
   _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
   self.assertAllEqual(e, [1.0, 2.0])
コード例 #11
0
ファイル: list_ops_test.py プロジェクト: aeverall/tensorflow
 def testSetOnEmptyListWithMaxNumElementsFails(self):
   l = list_ops.empty_tensor_list(
       element_dtype=dtypes.float32, element_shape=[], max_num_elements=3)
   with self.assertRaisesRegexp(
       errors.InvalidArgumentError,
       "Trying to modify element 0 in a list with 0 elements."):
     l = list_ops.tensor_list_set_item(l, 0, 1.)
     self.evaluate(l)
コード例 #12
0
ファイル: list_ops_test.py プロジェクト: aeverall/tensorflow
 def testPushInFullListFails(self):
   l = list_ops.empty_tensor_list(
       element_dtype=dtypes.float32, element_shape=[], max_num_elements=1)
   l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
   with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                "Tried to push item into a full list"):
     l = list_ops.tensor_list_push_back(l, 2.)
     self.evaluate(l)
コード例 #13
0
ファイル: list_ops_test.py プロジェクト: aeverall/tensorflow
      def body(i, m, t1):
        t1 = control_flow_ops.cond(
            math_ops.equal(list_ops.tensor_list_length(t1), 0),
            lambda: list_ops.empty_tensor_list(m.shape, m.dtype), lambda: t1)

        t1 = list_ops.tensor_list_push_back(t1, m * i)
        i += 1.0
        return i, m, t1
コード例 #14
0
ファイル: list_ops_test.py プロジェクト: aeverall/tensorflow
 def _testPushPop(self, max_num_elements):
   l = list_ops.empty_tensor_list(
       element_dtype=dtypes.float32,
       element_shape=[],
       max_num_elements=max_num_elements)
   l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
   l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
   self.assertAllEqual(self.evaluate(e), 1.0)
コード例 #15
0
  def testConcat(self):
    c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
    l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
    l1 = list_ops.tensor_list_from_tensor([-1.0], element_shape=scalar_shape())
    l_batch_0 = array_ops.stack([l0, l1])
    l_batch_1 = array_ops.stack([l1, l0])

    l_concat_01 = list_ops.tensor_list_concat_lists(
        l_batch_0, l_batch_1, element_dtype=dtypes.float32)
    l_concat_10 = list_ops.tensor_list_concat_lists(
        l_batch_1, l_batch_0, element_dtype=dtypes.float32)
    l_concat_00 = list_ops.tensor_list_concat_lists(
        l_batch_0, l_batch_0, element_dtype=dtypes.float32)
    l_concat_11 = list_ops.tensor_list_concat_lists(
        l_batch_1, l_batch_1, element_dtype=dtypes.float32)

    expected_00 = [[1.0, 2.0, 1.0, 2.0], [-1.0, -1.0]]
    expected_01 = [[1.0, 2.0, -1.0], [-1.0, 1.0, 2.0]]
    expected_10 = [[-1.0, 1.0, 2.0], [1.0, 2.0, -1.0]]
    expected_11 = [[-1.0, -1.0], [1.0, 2.0, 1.0, 2.0]]

    for i, (concat, expected) in enumerate(zip(
        [l_concat_00, l_concat_01, l_concat_10, l_concat_11],
        [expected_00, expected_01, expected_10, expected_11])):
      splitted = array_ops.unstack(concat)
      splitted_stacked_ret = self.evaluate(
          (list_ops.tensor_list_stack(splitted[0], dtypes.float32),
           list_ops.tensor_list_stack(splitted[1], dtypes.float32)))
      print("Test concat %d: %s, %s, %s, %s"
            % (i, expected[0], splitted_stacked_ret[0],
               expected[1], splitted_stacked_ret[1]))
      self.assertAllClose(expected[0], splitted_stacked_ret[0])
      self.assertAllClose(expected[1], splitted_stacked_ret[1])

    # Concatenating mismatched shapes fails.
    with self.assertRaises((errors.InvalidArgumentError, ValueError)):
      self.evaluate(
          list_ops.tensor_list_concat_lists(
              l_batch_0,
              list_ops.empty_tensor_list(scalar_shape(), dtypes.float32),
              element_dtype=dtypes.float32))

    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "element shapes are not identical at index 0"):
      l_batch_of_vec_tls = array_ops.stack(
          [list_ops.tensor_list_from_tensor([[1.0]], element_shape=[1])] * 2)
      self.evaluate(
          list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_vec_tls,
                                            element_dtype=dtypes.float32))

    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 r"input_b\[0\].dtype != element_dtype."):
      l_batch_of_int_tls = array_ops.stack(
          [list_ops.tensor_list_from_tensor([1], element_shape=scalar_shape())]
          * 2)
      self.evaluate(
          list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_int_tls,
                                            element_dtype=dtypes.float32))
コード例 #16
0
ファイル: list_ops_test.py プロジェクト: aeverall/tensorflow
 def testPopFromEmptyTensorListFails(self, max_num_elements):
   l = list_ops.empty_tensor_list(
       element_dtype=dtypes.float32,
       element_shape=[],
       max_num_elements=max_num_elements)
   with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                "Trying to pop from an empty list"):
     l = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
     self.evaluate(l)
コード例 #17
0
ファイル: list_ops_test.py プロジェクト: aeverall/tensorflow
 def testConcatListWithScalarElementShapeFails(self):
   l = list_ops.empty_tensor_list(
       element_dtype=dtypes.float32, element_shape=tensor_shape.scalar())
   with self.assertRaisesRegexp(
       errors.InvalidArgumentError,
       "Concat requires elements to be at least vectors, "
       "found scalars instead"):
     t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
     self.evaluate(t)
コード例 #18
0
def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
  """Overload of new_list that stages a Tensor list creation."""
  if tensor_util.is_tensor(elements):
    if element_shape is not None:
      raise ValueError(
          'element shape may not be specified when creating list from tensor')
    element_shape = array_ops.shape(elements)[1:]
    l = list_ops.tensor_list_from_tensor(elements, element_shape=element_shape)
    return l

  elements = tuple(ops.convert_to_tensor(el) for el in elements)

  all_dtypes = set(el.dtype for el in elements)
  if len(all_dtypes) == 1:
    inferred_dtype = tuple(all_dtypes)[0]
    if element_dtype is not None and element_dtype != inferred_dtype:
      raise ValueError(
          'incompatible dtype; specified: {}, inferred from {}: {}'.format(
              element_dtype, elements, inferred_dtype))
  elif all_dtypes:
    # Heterogeneous lists are ok.
    if element_dtype is not None:
      raise ValueError(
          'specified dtype {} is inconsistent with that of elements {}'.format(
              element_dtype, elements))
    inferred_dtype = dtypes.variant
  else:
    inferred_dtype = dtypes.variant

  all_shapes = set(tuple(el.shape.as_list()) for el in elements)
  if len(all_shapes) == 1:
    inferred_shape = array_ops.shape(elements[0])
    if element_shape is not None and element_shape != inferred_shape:
      raise ValueError(
          'incompatible shape; specified: {}, inferred from {}: {}'.format(
              element_shape, elements, inferred_shape))
  elif all_shapes:
    # Heterogeneous lists are ok.
    if element_shape is not None:
      raise ValueError(
          'specified shape {} is inconsistent with that of elements {}'.format(
              element_shape, elements))
    inferred_shape = constant_op.constant(-1)  # unknown shape, by convention
  else:
    inferred_shape = constant_op.constant(-1)  # unknown shape, by convention

  if element_dtype is None:
    element_dtype = inferred_dtype
  if element_shape is None:
    element_shape = inferred_shape

  element_shape = ops.convert_to_tensor(element_shape, dtype=dtypes.int32)
  l = list_ops.empty_tensor_list(
      element_shape=element_shape, element_dtype=element_dtype)
  for el in elements:
    l = list_ops.tensor_list_push_back(l, el)
  return l
コード例 #19
0
ファイル: list_ops_test.py プロジェクト: DILASSS/tensorflow
 def testGraphStack(self):
   with context.graph_mode(), self.test_session():
     tl = list_ops.empty_tensor_list(
         element_shape=constant_op.constant([1], dtype=dtypes.int32),
         element_dtype=dtypes.int32)
     tl = list_ops.tensor_list_push_back(tl, [1])
     self.assertAllEqual(
         list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32).eval(),
         [[1]])
コード例 #20
0
 def _testStack(self, max_num_elements):
   l = list_ops.empty_tensor_list(
       element_dtype=dtypes.float32,
       element_shape=scalar_shape(),
       max_num_elements=max_num_elements)
   l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
   l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
   t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
   self.assertAllEqual(self.evaluate(t), [1.0, 2.0])
コード例 #21
0
ファイル: list_ops_test.py プロジェクト: aeverall/tensorflow
 def testConcatEmptyListWithPartiallyDefinedElementShapeFails(self):
   l = list_ops.empty_tensor_list(
       element_dtype=dtypes.float32, element_shape=[2, None])
   with self.assertRaisesRegexp(
       errors.InvalidArgumentError,
       "All except the first dimension must be fully"
       " defined when concating an empty tensor list"):
     t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
     self.evaluate(t)
コード例 #22
0
 def testEmptyTensorListMax(self):
   with self.cached_session() as sess, self.test_scope():
     l = list_ops.empty_tensor_list(
         element_shape=(10, 15), element_dtype=dtypes.float32,
         max_num_elements=2)
     l = list_ops.tensor_list_push_back(
         l, array_ops.fill(value=3.0, dims=(10, 15)))
     _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
     self.assertAllEqual(sess.run(e), 3.0 * np.ones((10, 15)))
コード例 #23
0
 def testEmptyTensorListNoMax(self):
   with self.cached_session() as sess, self.test_scope():
     l = list_ops.empty_tensor_list(
         element_shape=(7, 15), element_dtype=dtypes.float32)
     l = list_ops.tensor_list_push_back(
         l, constant_op.constant(1.0, shape=(7, 15)))
     _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  "Set the max number of elements"):
       self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15)))
コード例 #24
0
 def testSetDoesNotUpdatePushIndex(self):
   with self.cached_session(), self.test_scope():
     l = list_ops.empty_tensor_list(
         element_shape=[], element_dtype=dtypes.float32, max_num_elements=2)
     # SetItem should not change the push index.
     l = list_ops.tensor_list_set_item(l, 1, 3.)
     l = list_ops.tensor_list_push_back(l, 5.)
     l = list_ops.tensor_list_push_back(l, 7.)
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
     self.assertAllEqual(t, [5., 7.])
コード例 #25
0
ファイル: list_ops_test.py プロジェクト: aeverall/tensorflow
 def testGraphStack(self):
   with self.cached_session():
     tl = list_ops.empty_tensor_list(
         element_shape=constant_op.constant([1], dtype=dtypes.int32),
         element_dtype=dtypes.int32)
     tl = list_ops.tensor_list_push_back(tl, [1])
     self.assertAllEqual(
         self.evaluate(
             list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)),
         [[1]])
コード例 #26
0
  def test_stack_tensor_list_empty(self):
    l = list_ops.empty_tensor_list(
        element_shape=None, element_dtype=dtypes.variant)

    opts = data_structures.ListStackOpts(
        element_dtype=dtypes.int32, original_call=None)

    # TODO(mdan): Allow stacking empty lists if the dtype and shape are known.
    with self.assertRaises(ValueError):
      data_structures.list_stack(l, opts)
コード例 #27
0
ファイル: list_ops_test.py プロジェクト: aeverall/tensorflow
 def testPushPopGradients(self):
   with backprop.GradientTape() as tape:
     l = list_ops.empty_tensor_list(
         element_dtype=dtypes.float32, element_shape=[])
     c = constant_op.constant(1.0)
     tape.watch(c)
     l = list_ops.tensor_list_push_back(l, c)
     l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
     e = 2 * e
   self.assertAllEqual(self.evaluate(tape.gradient(e, [c])[0]), 2.0)
コード例 #28
0
 def testPushInEmptyListWithUnknownElementShape(self):
   with self.cached_session(), self.test_scope():
     l = list_ops.empty_tensor_list(
         element_dtype=dtypes.float32, element_shape=None, max_num_elements=2)
     l = list_ops.tensor_list_push_back(l, [3.0, 4.0])
     # Pushing an element with a different shape should raise an error.
     with self.assertRaisesRegexp(errors.InvalidArgumentError, "Shape"):
       l = list_ops.tensor_list_push_back(l, 5.)
       self.evaluate(
           list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))
コード例 #29
0
ファイル: while_v2.py プロジェクト: ziky90/tensorflow
  def _capture_helper(self, tensor, name):
    if tensor.graph is not self._forward_graph:
      return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name)

    while tensor.op.type == "Identity":
      # We do not accumulate the output of identity nodes so we try to capture
      # the input of the Identity node instead.
      tensor = tensor.op.inputs[0]

    captured_tensor = self._indirect_captures.get(tensor)
    if captured_tensor is not None:
      return captured_tensor

    # Resource tensors are not accumulated and handled specially.
    if tensor.dtype == dtypes.resource:
      return self._resource_capture_helper(tensor)

    # Create or find an existing accumulator output for `tensor` in the forward
    # graph, and fetch from this accumulator in the gradient graph to get the
    # raw intermediate value.
    accumulator = _get_accumulator(tensor)
    if accumulator is None:
      # Create the initial empty tensor list.
      with self._forward_graph.outer_graph.as_default():
        tensor_list = list_ops.empty_tensor_list(
            element_dtype=tensor.dtype, element_shape=tensor.shape,
            max_num_elements=self._maximum_iterations)
      self.empty_tensor_lists.append(tensor_list)

      # Push the intermediate tensor to the tensor list. This captures
      # `tensor_list`.
      with self._forward_graph.as_default():
        accumulator = list_ops.tensor_list_push_back(tensor_list, tensor)
      # Add the modified tensor list to the list of outputs. This output will be
      # all the accumulated values.
      self._forward_graph.outputs.append(accumulator)

      # Capture in the cond graph as well so the forward cond and body inputs
      # match.
      with self._forward_cond_graph.as_default():
        self._forward_cond_graph.capture(tensor_list)

    # Capture the accumulator tensor list in the gradient graph directly from
    # the forward graph -- we'll later modify this to capture the final list
    # output by the forward While op instead.
    captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(
        accumulator, name)

    # Pop the intermediate value from the tensor list in the gradient graph.
    new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
        captured_accumulator, element_dtype=tensor.dtype)

    self._indirect_captures[tensor] = captured_tensor
    self.popped_tensor_lists[captured_accumulator] = new_tensor_list
    return captured_tensor
コード例 #30
0
ファイル: list_ops_test.py プロジェクト: aeverall/tensorflow
 def testConcatWithMismatchingTensorShapesFails(self):
   l = list_ops.empty_tensor_list(
       element_dtype=dtypes.float32, element_shape=None)
   l = list_ops.tensor_list_push_back(l, [[0., 1.]])
   l = list_ops.tensor_list_push_back(l, [[2.], [4.]])
   with self.assertRaisesRegexp(
       errors.InvalidArgumentError,
       r"Tried to concat tensors with unequal shapes: "
       r"\[2\] vs \[1\]"):
     t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
     self.evaluate(t)
コード例 #31
0
 def testPushPop(self):
     with self.session() as sess, self.test_scope():
         l = list_ops.empty_tensor_list(element_shape=(7, 15),
                                        element_dtype=dtypes.float32,
                                        max_num_elements=10)
         l = list_ops.tensor_list_push_back(
             l, constant_op.constant(1.0, shape=(7, 15)))
         l = list_ops.tensor_list_push_back(
             l, constant_op.constant(2.0, shape=(7, 15)))
         l, e2 = list_ops.tensor_list_pop_back(l,
                                               element_dtype=dtypes.float32)
         _, e1 = list_ops.tensor_list_pop_back(l,
                                               element_dtype=dtypes.float32)
         self.assertAllEqual(sess.run(e2), 2.0 * np.ones((7, 15)))
         self.assertAllEqual(sess.run(e1), 1.0 * np.ones((7, 15)))
コード例 #32
0
 def testGatherGrad(self, max_num_elements):
     with backprop.GradientTape() as tape:
         l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                        element_shape=[],
                                        max_num_elements=max_num_elements)
         c0 = constant_op.constant(1.0)
         tape.watch(c0)
         l = list_ops.tensor_list_push_back(l, c0)
         l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
         t = list_ops.tensor_list_gather(l, [1, 0],
                                         element_dtype=dtypes.float32)
         self.assertAllEqual(self.evaluate(t), [2.0, 1.0])
         s = (t[0] + t[1]) * (t[0] + t[1])
     dt = tape.gradient(s, c0)
     self.assertAllEqual(self.evaluate(dt), 6.0)
コード例 #33
0
  def testStackWithPartiallyDefinedElementShape(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=[-1])
    l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0]))
    l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0]))

    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[1.0], [2.0]])

    # Should raise an error when the element tensors do not all have the same
    # shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"):
      l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0, 3.0]))
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)
コード例 #34
0
    def testGraphStackInLoop(self):
        with context.graph_mode(), self.test_session():
            t1 = list_ops.empty_tensor_list(element_shape=constant_op.constant(
                [], dtype=dtypes.int32),
                                            element_dtype=dtypes.int32)
            i = constant_op.constant(0, dtype=dtypes.int32)

            def body(i, t1):
                t1 = list_ops.tensor_list_push_back(t1, i)
                i += 1
                return i, t1

            i, t1 = control_flow_ops.while_loop(
                lambda i, t1: math_ops.less(i, 4), body, [i, t1])
            s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32)
            self.assertAllEqual(self.evaluate(s1), [0, 1, 2, 3])
コード例 #35
0
 def testPushPopSeparateLists(self):
   with self.cached_session() as sess, self.test_scope():
     l = list_ops.empty_tensor_list(
         element_shape=[],
         element_dtype=dtypes.float32,
         max_num_elements=20)
     l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
     l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
     l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0))
     _, e11 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
     l2, e21 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32)
     l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32)
     l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32)
     l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32)
     result = sess.run([e11, [e21, e22], [e31, e32]])
     self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]])
コード例 #36
0
  def testStackWithUnknownElementShape(self, max_num_elements):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=None,
        max_num_elements=max_num_elements)
    l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
    l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))

    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [1.0, 2.0])

    # Should raise an error when the element tensors do not all have the same
    # shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"):
      l = list_ops.tensor_list_push_back(l, constant_op.constant([3.0, 4.0]))
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)
コード例 #37
0
ファイル: list_ops_test.py プロジェクト: psfoley/tensorflow
    def testZerosLike(self):
        l_empty = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                             element_shape=scalar_shape())
        l_empty_zeros = array_ops.zeros_like(l_empty)
        t_empty_zeros = list_ops.tensor_list_stack(
            l_empty_zeros, element_dtype=dtypes.float32)

        l_full = list_ops.tensor_list_push_back(l_empty,
                                                constant_op.constant(1.0))
        l_full = list_ops.tensor_list_push_back(l_full,
                                                constant_op.constant(2.0))
        l_full_zeros = array_ops.zeros_like(l_full)
        t_full_zeros = list_ops.tensor_list_stack(l_full_zeros,
                                                  element_dtype=dtypes.float32)

        self.assertAllEqual(self.evaluate(t_empty_zeros), [])
        self.assertAllEqual(self.evaluate(t_full_zeros), [0.0, 0.0])
コード例 #38
0
def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
    """Overload of new_list that stages a Tensor list creation."""
    elements = tuple(ops.convert_to_tensor(el) for el in elements)

    all_dtypes = set(el.dtype for el in elements)
    if len(all_dtypes) == 1:
        inferred_dtype = tuple(all_dtypes)[0]
        if element_dtype is not None and element_dtype != inferred_dtype:
            raise ValueError(
                'incompatible dtype; specified: {}, inferred from {}: {}'.
                format(element_dtype, elements, inferred_dtype))
    else:
        # Heterogeneous lists are ok.
        if element_dtype is not None:
            raise ValueError(
                'specified dtype {} is inconsistent with that of elements {}'.
                format(element_dtype, elements))
        inferred_dtype = dtypes.variant

    all_shapes = set(tuple(el.shape.as_list()) for el in elements)
    if len(all_shapes) == 1:
        inferred_shape = array_ops.shape(elements[0])
        if element_shape is not None and element_shape != inferred_shape:
            raise ValueError(
                'incompatible shape; specified: {}, inferred from {}: {}'.
                format(element_shape, elements, inferred_shape))
    else:
        # Heterogeneous lists are ok.
        if element_shape is not None:
            raise ValueError(
                'specified shape {} is inconsistent with that of elements {}'.
                format(element_shape, elements))
        inferred_shape = constant_op.constant(
            -1)  # unknown shape, by convention

    if element_dtype is None:
        element_dtype = inferred_dtype
    if element_shape is None:
        element_shape = inferred_shape

    l = list_ops.empty_tensor_list(element_shape=element_shape,
                                   element_dtype=element_dtype)
    for el in elements:
        l = list_ops.tensor_list_push_back(l, el)
    return l
コード例 #39
0
  def testGatherWithUnknownElementShape(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=-1)
    l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
    l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
    l = list_ops.tensor_list_push_back(l, constant_op.constant([3.0, 4.0]))

    t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [2.0, 1.0])

    t = list_ops.tensor_list_gather(l, [2], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[3.0, 4.0]])

    # Should raise an error when the requested tensors do not all have the same
    # shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"):
      t = list_ops.tensor_list_gather(l, [0, 2], element_dtype=dtypes.float32)
      self.evaluate(t)
コード例 #40
0
 def testDoNotConstantFoldVariants(self):
   with self.cached_session() as sess, self.test_scope():
     val = array_ops.placeholder(dtype=dtypes.float32)
     l = list_ops.empty_tensor_list(
         element_shape=(7, 15),
         element_dtype=dtypes.float32,
         max_num_elements=10)
     # Note: Pushing a Placeholder will force the constant folding code
     # to build a Const node with a DT_VARIANT output. This tests that XLA
     # passes a cf_consider_fn which prevent folding such nodes.
     l = list_ops.tensor_list_push_back(
         l, array_ops.fill(value=val, dims=(7, 15)))
     l = list_ops.tensor_list_push_back(
         l, constant_op.constant(2.0, shape=(7, 15)))
     l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
     _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
     self.assertAllEqual(sess.run(e2, {val: 1.0}), 2.0 * np.ones((7, 15)))
     self.assertAllEqual(sess.run(e1, {val: 1.0}), 1.0 * np.ones((7, 15)))
コード例 #41
0
    def _testPruning(self):
        x = constant_op.constant(1)

        tensor_list = list_ops.empty_tensor_list(element_dtype=x.dtype,
                                                 element_shape=x.shape)

        def Cond(x, tl):
            del tl  # Unused for Cond.
            return x < 5

        def Body(x, tl):
            return x + 1, list_ops.tensor_list_push_back(tl, x)

        outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])

        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(outputs[0])

        g = GetOptimizedGraph()
        # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
        # away, causing an extra Enter node.
        enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1
        self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
        # Test that the TensorList is pruned out.
        self.assertEmpty([
            n for n in g.node if n.op == "Enter"
            and n.attr["T"].type == dtypes.variant.as_datatype_enum
        ])
        self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"])

        stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
        train_op.append(stack)
        g = GetOptimizedGraph()
        # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
        # away, causing an extra Enter node.
        enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
        self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
        # Test that the TensorList is not pruned out.
        self.assertNotEmpty([
            n for n in g.node if n.op == "Enter"
            and n.attr["T"].type == dtypes.variant.as_datatype_enum
        ])
        self.assertNotEmpty(
            [n for n in g.node if n.op == "TensorListPushBack"])
コード例 #42
0
    def test_dynamic_list_append(self):
        l = []
        l = tl.dynamic_list_append(l, 1)
        self.assertListEqual(l, [1])

        l = list_ops.empty_tensor_list(self._shape(()), dtypes.int32)
        l = tl.dynamic_list_append(l, 1)
        s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
        self.assertAllEqual(s, [1])

        l = tensor_array_ops.TensorArray(dtypes.int32,
                                         size=0,
                                         dynamic_size=True)
        l = tl.dynamic_list_append(l, 1)
        s = l.stack()
        self.assertAllEqual(s, [1])

        l = tl.TensorList(self._shape(()), dtypes.int32)
        l = tl.dynamic_list_append(l, 1)
        self.assertAllEqual(l[0], 1)
コード例 #43
0
  def testGatherWithPartiallyDefinedElementShape(self, max_num_elements):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=[None],
        max_num_elements=max_num_elements)
    l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0]))
    l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0, 3.0]))
    l = list_ops.tensor_list_push_back(l, constant_op.constant([4.0, 5.0]))

    t = list_ops.tensor_list_gather(l, [0], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[1.0]])

    t = list_ops.tensor_list_gather(l, [1, 2], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[2.0, 3.0], [4.0, 5.0]])

    # Should raise an error when the requested tensors do not all have the same
    # shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"):
      t = list_ops.tensor_list_gather(l, [0, 2], element_dtype=dtypes.float32)
      self.evaluate(t)
コード例 #44
0
    def testLoopWithTensorListPushBack(self):
        x = constant_op.constant(2.)

        tensor_list = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                                 element_shape=ScalarShape())

        def Cond(x, tl):
            del tl  # Unused for Cond.
            return x < 5.

        def Body(x, tl):
            tl = list_ops.tensor_list_push_back(tl, x)
            tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.))
            return x**2., tl

        ret = while_loop_v2(Cond, Body, [x, tensor_list])
        grad = gradients_impl.gradients(ret[0], x)
        with self.cached_session() as sess:
            self.assertEqual(sess.run(ret[0]), 16.)
            self.assertSequenceEqual(self.evaluate(grad), [32.])
コード例 #45
0
  def testGraphStackSwitchDtype(self):
    with context.graph_mode(), self.test_session():
      list_ = list_ops.empty_tensor_list(
          element_shape=constant_op.constant([], dtype=dtypes.int32),
          element_dtype=dtypes.int32)
      m = constant_op.constant([1, 2, 3], dtype=dtypes.float32)

      def body(list_, m):
        list_ = control_flow_ops.cond(
            math_ops.equal(list_ops.tensor_list_length(list_), 0),
            lambda: list_ops.empty_tensor_list(m.shape, m.dtype), lambda: list_)
        list_ = list_ops.tensor_list_push_back(list_, m)
        return list_, m

      for _ in range(2):
        list_, m = body(list_, m)

      s1 = list_ops.tensor_list_stack(list_, element_dtype=dtypes.float32)
      np_s1 = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)
      self.assertAllEqual(self.evaluate(s1), np_s1)
コード例 #46
0
 def testSerializeListWithMaxNumElements(self):
     if context.num_gpus():
         # TODO(b/119151861): Enable on GPU.
         return
     worker = test_util.create_local_cluster(num_workers=1, num_ps=1)[0][0]
     with ops.Graph().as_default(), session.Session(target=worker.target):
         with ops.device("/job:worker"):
             l = list_ops.empty_tensor_list(element_shape=-1,
                                            element_dtype=dtypes.float32,
                                            max_num_elements=2)
             l = list_ops.tensor_list_push_back(l, 1.)
         with ops.device("/job:ps"):
             l_ps = array_ops.identity(l)
             l_ps = list_ops.tensor_list_push_back(l_ps, 2.)
         with self.assertRaisesRegexp(
                 errors.InvalidArgumentError,
                 "Tried to push item into a full list"):
             with ops.device("/job:worker"):
                 l_worker = array_ops.identity(l_ps)
                 l_worker = list_ops.tensor_list_push_back(l_worker, 3.0)
                 self.evaluate(l_worker)
コード例 #47
0
  def test_dynamic_list_append(self):
    l = []
    l = tl.dynamic_list_append(l, 1)
    self.assertListEqual(l, [1])

    l = list_ops.empty_tensor_list(self._shape(()), dtypes.int32)
    l = tl.dynamic_list_append(l, 1)
    s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.cached_session() as sess:
      self.assertAllEqual(self.evaluate(s), [1])

    l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
    l = tl.dynamic_list_append(l, 1)
    s = l.stack()
    with self.cached_session() as sess:
      self.assertAllEqual(self.evaluate(s), [1])

    l = tl.TensorList(self._shape(()), dtypes.int32)
    l = tl.dynamic_list_append(l, 1)
    with self.cached_session() as sess:
      self.assertAllEqual(sess.run(l[0]), 1)
コード例 #48
0
    def testZerosLike(self):
        for dtype in (dtypes.uint8, dtypes.uint16, dtypes.int8, dtypes.int16,
                      dtypes.int32, dtypes.int64, dtypes.float16,
                      dtypes.float32, dtypes.float64, dtypes.complex64,
                      dtypes.complex128, dtypes.bool):
            l_empty = list_ops.empty_tensor_list(element_dtype=dtype,
                                                 element_shape=scalar_shape())
            l_empty_zeros = array_ops.zeros_like(l_empty)
            t_empty_zeros = list_ops.tensor_list_stack(l_empty_zeros,
                                                       element_dtype=dtype)

            l_full = list_ops.tensor_list_push_back(
                l_empty, math_ops.cast(0, dtype=dtype))
            l_full = list_ops.tensor_list_push_back(
                l_full, math_ops.cast(1, dtype=dtype))
            l_full_zeros = array_ops.zeros_like(l_full)
            t_full_zeros = list_ops.tensor_list_stack(l_full_zeros,
                                                      element_dtype=dtype)

            self.assertAllEqual(self.evaluate(t_empty_zeros), [])
            self.assertAllEqual(self.evaluate(t_full_zeros),
                                np.zeros((2, ), dtype=dtype.as_numpy_dtype))
コード例 #49
0
  def testGraphStackInLoopSwitchDtype(self):
    with context.graph_mode(), self.test_session():
      t1 = list_ops.empty_tensor_list(
          element_shape=constant_op.constant([], dtype=dtypes.int32),
          element_dtype=dtypes.int32)
      i = constant_op.constant(0, dtype=dtypes.float32)
      m = constant_op.constant([1, 2, 3], dtype=dtypes.float32)

      def body(i, m, t1):
        t1 = control_flow_ops.cond(
            math_ops.equal(list_ops.tensor_list_length(t1), 0),
            lambda: list_ops.empty_tensor_list(m.shape, m.dtype), lambda: t1)

        t1 = list_ops.tensor_list_push_back(t1, m * i)
        i += 1.0
        return i, m, t1

      i, m, t1 = control_flow_ops.while_loop(
          lambda i, m, t1: math_ops.less(i, 4), body, [i, m, t1])
      s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.float32)
      np_s1 = np.vstack([np.arange(1, 4) * i for i in range(4)])
      self.assertAllEqual(self.evaluate(s1), np_s1)
コード例 #50
0
    def testDuplicateAccumulator(self):
        x = constant_op.constant(2.)

        tensor_list = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                                 element_shape=ScalarShape())

        def Cond(x, tl):
            del tl  # Unused for Cond.
            return x < 5.

        def Body(x, tl):
            # There is an accumulator in the loop already so we should not add
            # another.
            tl = list_ops.tensor_list_push_back(tl, x)
            return x**2., tl

        ret = while_loop_v2(Cond,
                            Body, [x, tensor_list],
                            return_same_structure=False)

        for op in ops.get_default_graph().get_operations():
            if op.type == "While" or op.type == "StatelessWhile":
                while_op = op

        body_graph = while_v2._get_graph(while_op, "body")
        x_input_index = [
            i for i, inp in enumerate(while_op.inputs) if inp == x
        ][0]
        x_input_t = body_graph.inputs[x_input_index]
        accumulator_count = len([
            c for c in x_input_t.consumers() if c.type == "TensorListPushBack"
        ])
        self.assertEqual(accumulator_count, 1)

        grad = gradients_impl.gradients(ret[0], x)
        with self.cached_session() as sess:
            self.assertEqual(sess.run(ret[0]), 16.)
            self.assertSequenceEqual(self.evaluate(grad), [32.])
コード例 #51
0
ファイル: data_structures.py プロジェクト: PavelToropynya/my
def _tf_tensor_list_new(elements):
  """Overload of new_list that stages a Tensor list creation."""
  elements = tuple(ops.convert_to_tensor(el) for el in elements)
  all_dtypes = set(el.dtype for el in elements)
  if len(all_dtypes) == 1:
    element_dtype = tuple(all_dtypes)[0]
  else:
    # Heterogeneous lists are ok.
    element_dtype = dtypes.variant

  # TODO(mdan): This may fail for elements of variable shapes.
  all_shapes = set(tuple(el.shape.as_list()) for el in elements)
  if len(all_shapes) == 1:
    element_shape = array_ops.shape(elements[0])
  else:
    # Heterogeneous lists are ok.
    element_shape = constant_op.constant(-1)  # unknown shape, by convention

  l = list_ops.empty_tensor_list(
      element_shape=element_shape, element_dtype=element_dtype)
  for el in elements:
    l = list_ops.tensor_list_push_back(l, el)
  return l
コード例 #52
0
    def testPruning(self):
        x = constant_op.constant(1)

        tensor_list = list_ops.empty_tensor_list(element_dtype=x.dtype,
                                                 element_shape=x.shape)

        def Cond(x, tl):
            del tl  # Unused for Cond.
            return x < 5

        def Body(x, tl):
            return x + 1, list_ops.tensor_list_push_back(tl, x)

        outputs = while_loop_v1(Cond, Body, [x, tensor_list])

        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(outputs[0])

        def GetOptimizedGraph():
            mg = meta_graph.create_meta_graph_def(
                graph=ops.get_default_graph())
            config = config_pb2.ConfigProto()
            config.graph_options.rewrite_options.CopyFrom(
                rewriter_config_pb2.RewriterConfig(
                    constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
                    memory_optimization=rewriter_config_pb2.RewriterConfig.
                    MANUAL))
            return tf_optimizer.OptimizeGraph(config, mg)

        g = GetOptimizedGraph()
        self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)

        stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
        train_op.append(stack)
        g = GetOptimizedGraph()
        self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
コード例 #53
0
 def testCPUGPUCopyNested(self):
   if not context.num_gpus():
     return
   t = constant_op.constant([1.0, 2.0])
   child_l = list_ops.tensor_list_from_tensor(t, element_shape=[])
   l = list_ops.empty_tensor_list(
       element_shape=constant_op.constant([], dtype=dtypes.int32),
       element_dtype=dtypes.variant)
   l = list_ops.tensor_list_push_back(l, child_l)
   with context.device("gpu:0"):
     l_gpu = array_ops.identity(l)
     _, child_l_gpu = list_ops.tensor_list_pop_back(
         l_gpu, element_dtype=dtypes.variant)
     self.assertAllEqual(
         self.evaluate(
             list_ops.tensor_list_pop_back(
                 child_l_gpu, element_dtype=dtypes.float32)[1]), 2.0)
   l_cpu = array_ops.identity(l_gpu)
   _, child_l_cpu = list_ops.tensor_list_pop_back(
       l_cpu, element_dtype=dtypes.variant)
   self.assertAllEqual(
       self.evaluate(
           list_ops.tensor_list_pop_back(
               child_l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
コード例 #54
0
 def testPushPop(self):
     l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                    element_shape=scalar_shape())
     l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
     l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
     self.assertAllEqual(self.evaluate(e), 1.0)
コード例 #55
0
    def testConcat(self):
        c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
        l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
        l1 = list_ops.tensor_list_from_tensor([-1.0],
                                              element_shape=scalar_shape())
        l_batch_0 = array_ops.stack([l0, l1])
        l_batch_1 = array_ops.stack([l1, l0])

        l_concat_01 = list_ops.tensor_list_concat_lists(
            l_batch_0, l_batch_1, element_dtype=dtypes.float32)
        l_concat_10 = list_ops.tensor_list_concat_lists(
            l_batch_1, l_batch_0, element_dtype=dtypes.float32)
        l_concat_00 = list_ops.tensor_list_concat_lists(
            l_batch_0, l_batch_0, element_dtype=dtypes.float32)
        l_concat_11 = list_ops.tensor_list_concat_lists(
            l_batch_1, l_batch_1, element_dtype=dtypes.float32)

        expected_00 = [[1.0, 2.0, 1.0, 2.0], [-1.0, -1.0]]
        expected_01 = [[1.0, 2.0, -1.0], [-1.0, 1.0, 2.0]]
        expected_10 = [[-1.0, 1.0, 2.0], [1.0, 2.0, -1.0]]
        expected_11 = [[-1.0, -1.0], [1.0, 2.0, 1.0, 2.0]]

        for i, (concat, expected) in enumerate(
                zip([l_concat_00, l_concat_01, l_concat_10, l_concat_11],
                    [expected_00, expected_01, expected_10, expected_11])):
            splitted = array_ops.unstack(concat)
            splitted_stacked_ret = self.evaluate(
                (list_ops.tensor_list_stack(splitted[0], dtypes.float32),
                 list_ops.tensor_list_stack(splitted[1], dtypes.float32)))
            print("Test concat %d: %s, %s, %s, %s" %
                  (i, expected[0], splitted_stacked_ret[0], expected[1],
                   splitted_stacked_ret[1]))
            self.assertAllClose(expected[0], splitted_stacked_ret[0])
            self.assertAllClose(expected[1], splitted_stacked_ret[1])

        # Concatenating mismatched shapes fails.
        with self.assertRaises((errors.InvalidArgumentError, ValueError)):
            self.evaluate(
                list_ops.tensor_list_concat_lists(
                    l_batch_0,
                    list_ops.empty_tensor_list(scalar_shape(), dtypes.float32),
                    element_dtype=dtypes.float32))

        with self.assertRaisesRegexp(
                errors.InvalidArgumentError,
                "element shapes are not identical at index 0"):
            l_batch_of_vec_tls = array_ops.stack(
                [list_ops.tensor_list_from_tensor([[1.0]], element_shape=[1])
                 ] * 2)
            self.evaluate(
                list_ops.tensor_list_concat_lists(
                    l_batch_0,
                    l_batch_of_vec_tls,
                    element_dtype=dtypes.float32))

        with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                     r"input_b\[0\].dtype != element_dtype."):
            l_batch_of_int_tls = array_ops.stack([
                list_ops.tensor_list_from_tensor([1],
                                                 element_shape=scalar_shape())
            ] * 2)
            self.evaluate(
                list_ops.tensor_list_concat_lists(
                    l_batch_0,
                    l_batch_of_int_tls,
                    element_dtype=dtypes.float32))
コード例 #56
0
ファイル: while_v2.py プロジェクト: zwang269/tensorflow
  def _capture_helper(self, tensor, name):
    if tensor.graph is not self._forward_graph:
      return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name)

    while tensor.op.type == "Identity":
      # We do not accumulate the output of identity nodes so we try to capture
      # the input of the Identity node instead.
      tensor = tensor.op.inputs[0]

    captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
    if captured_tensor is not None:
      return captured_tensor

    # Do not accumulate loop invariants.
    if (any(tensor is t for t in self._forward_graph.inputs) and
        any(tensor is t for t in self._forward_graph.outputs)):
      captured_tensor = super(_WhileBodyGradFuncGraph,
                              self)._capture_helper(tensor, name)
      # Add to `popped_tensor_lists` so that this gets added to the list of
      # outputs.
      # TODO(srbs): Rename popped_tensor_lists.
      self.popped_tensor_lists[ops.tensor_id(captured_tensor)] = captured_tensor
      self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
      return captured_tensor

    # Do not accumulate Const nodes. Instead copy them directly in the backward
    # graph.
    # TODO(srbs): This just checks for `Const` nodes. Consider checking for
    # graph compile time consts in general.
    # TODO(srbs): Consider making this a loop input.
    if constant_op.is_constant(tensor):
      real_value = constant_op.constant(
          tensor_util.constant_value(tensor), dtype=tensor.dtype)
      self._indirect_captures[ops.tensor_id(tensor)] = real_value
      return real_value

    # Resource tensors are not accumulated and handled specially.
    if tensor.dtype == dtypes.resource:
      return self._resource_capture_helper(tensor)

    # No need to accumulate loop invariants. Capture them directly.
    # The captured tensor gets resolved to the corresponding while output in
    # `_resolve_grad_captures`.
    if _is_loop_invariant(tensor, self._forward_graph_inputs,
                          self._forward_graph_outputs):
      captured_tensor = super(_WhileBodyGradFuncGraph,
                              self)._capture_helper(tensor, name)
      return captured_tensor

    # Create or find an existing accumulator output for `tensor` in the forward
    # graph, and fetch from this accumulator in the gradient graph to get the
    # raw intermediate value.
    accumulator = _get_accumulator(tensor)
    if accumulator is None:
      # Create the initial empty tensor list.
      #
      # Note: We clear the control dependencies to avoid a cycle in case a
      # control tensor has an input path to an output of the  forward While.
      #
      # E.g.:
      # x = tf.while_loop(...)
      # y = f(x)
      # with tf.control_dependencies([y]):
      #   tf.gradients(y, x)
      #
      # Since the EmptyTensorList is fed back into the forward While, not
      # removing the control edge would cause a cycle.
      with self._forward_graph.outer_graph.as_default():
        with util.clear_control_inputs():
          tensor_list = list_ops.empty_tensor_list(
              element_dtype=tensor.dtype,
              element_shape=tensor.shape,
              max_num_elements=self._maximum_iterations,
              name=_build_accumulator_name(tensor))
      self.empty_tensor_lists.append(tensor_list)

      # Push the intermediate tensor to the tensor list. This captures
      # `tensor_list`.
      with self._forward_graph.as_default():
        accumulator = list_ops.tensor_list_push_back(tensor_list, tensor)
      # Add the modified tensor list to the list of outputs. This output will be
      # all the accumulated values.
      self._forward_graph.outputs.append(accumulator)

      # Capture in the cond graph as well so the forward cond and body inputs
      # match.
      with self._forward_cond_graph.as_default():
        self._forward_cond_graph.capture(tensor_list)

    # Capture the accumulator tensor list in the gradient graph directly from
    # the forward graph -- we'll later modify this to capture the final list
    # output by the forward While op instead.
    captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(
        accumulator, name)

    # Pop the intermediate value from the tensor list in the gradient graph.
    new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
        captured_accumulator, element_dtype=tensor.dtype)

    self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
    self.popped_tensor_lists[ops.tensor_id(
        captured_accumulator)] = new_tensor_list
    return captured_tensor
コード例 #57
0
ファイル: while_v2.py プロジェクト: zwang269/tensorflow
def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               parallel_iterations=10,
               maximum_iterations=None,
               name=None,
               return_same_structure=True,
               back_prop=True):
  """Like tf.while_loop, except emits a single While op."""
  # Keep the original loop_vars around to know which args were TensorArrays.
  orig_loop_vars = loop_vars
  # Cache its length since we use it at multiple places below.
  len_orig_loop_vars = len(orig_loop_vars)

  # Convert TensorArrays to their flow variables. These get converted back to
  # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
  # `wrapped_body` below.
  loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
  loop_vars = nest.map_structure(
      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars,
      expand_composites=True)
  if shape_invariants is not None:
    nest.assert_same_structure(orig_loop_vars, shape_invariants,
                               expand_composites=False)
    signature = nest.map_structure(
        control_flow_ops._shape_invariant_to_type_spec, loop_vars,
        list(shape_invariants), expand_composites=False)
    shape_invariants = nest.map_structure(
        control_flow_ops._get_shape_invariant, loop_vars,
        list(shape_invariants), expand_composites=False)

  else:
    signature = nest.map_structure(
        type_spec.type_spec_from_value, loop_vars, expand_composites=False)
    shape_invariants = nest.map_structure(
        control_flow_ops._get_shape_invariant, loop_vars,
        expand_composites=False)
  if not name:
    name = "while"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      cond_name = util.unique_fn_name(scope, "cond")
      body_name = util.unique_fn_name(scope, "body")
    maximum_iterations_loop_var = _build_maximum_iterations_loop_var(
        maximum_iterations)
    loop_counter = constant_op.constant(
        0,
        dtype=maximum_iterations_loop_var.dtype
        if maximum_iterations is not None else None,
        name="loop_counter")
    # Add loop counter needed for computing gradients.
    loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars

    shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants
    signature = (
        [tensor_spec.TensorSpec.from_tensor(loop_counter),
         tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] +
        signature)

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = ops.get_default_graph()._add_control_dependencies

    def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
      """Extra `cond` wrapper that can handle the extra counter loop_var."""
      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      pred = cond(*_pack_sequence_as(orig_loop_vars, args))
      if (tensor_util.is_tensor(pred) and
          (pred.shape.dims is None or pred.shape.dims)):
        pred = array_ops.squeeze_v2(pred)

      if maximum_iterations is None:
        return pred
      else:
        return math_ops.logical_and(
            loop_counter < maximum_iterations_arg, pred)

    # NOTE(skyewm): we set collections to the outer graph's collections for
    # compatibility with TPUEstimator.
    cond_graph = func_graph_module.func_graph_from_py_func(
        cond_name,
        wrapped_cond,
        [],  # We provide signature instead of args.
        {},
        signature=signature,
        func_graph=util.WhileCondFuncGraph(
            cond_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)

    def wrapped_body(loop_counter, maximum_iterations_arg, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        maximum_iterations_arg: Maximum iterations of the loop.
        *args: List of args

      Returns:
        A list of tensors the same length as args.
      """
      # Capture the tensors already captured in cond_graph so that they appear
      # in the same order in body_graph.external_captures.
      for t in cond_graph.external_captures:
        ops.get_default_graph().capture(t)

      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      outputs = body(*_pack_sequence_as(orig_loop_vars, args))
      if not nest.is_sequence_or_composite(outputs):
        outputs = [outputs]
      # Compare the structure of input and output of body converting the
      # top-level tuples to list to be compatible with legacy while_loop.
      nest.assert_same_structure(list(outputs), list(orig_loop_vars),
                                 expand_composites=True)

      outputs = _tensor_array_to_flow(outputs)

      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1, maximum_iterations_arg] + list(outputs)

    body_graph = func_graph_module.func_graph_from_py_func(
        body_name,
        wrapped_body,
        [],  # We provide signature instead of args.
        {},
        signature=signature,
        func_graph=util.WhileBodyFuncGraph(
            body_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)
    # Add external captures of body to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    loop_vars = loop_vars + body_graph.external_captures
    # TODO(srbs): Update lowering code to create _Enter nodes with
    # is_constant=True for inputs that are directly passed to outputs.
    body_graph.outputs.extend(body_graph.internal_captures)

    # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
    # that it expects to receive those as arguments.
    with cond_graph.as_default():
      num_cond_captures = len(cond_graph.external_captures)
      assert (cond_graph.external_captures ==
              body_graph.external_captures[:num_cond_captures])
      cond_graph_captures = object_identity.ObjectIdentitySet(
          cond_graph.external_captures)
      for body_capture in body_graph.external_captures[num_cond_captures:]:
        assert body_capture not in cond_graph_captures
        cond_graph.capture(body_capture)

    # Make sure that the shapes of the loop outputs are compatible with the
    # shape invariants, or the shapes of the loop vars if the invariants are not
    # specified.
    num_flattened_outputs = len(nest.flatten(orig_loop_vars,
                                             expand_composites=True))
    # First var is loop counter and second var is maximum_iterations.
    first_loop_var_index = 2
    _check_shapes_compat(
        body_graph.outputs[first_loop_var_index:first_loop_var_index +
                           num_flattened_outputs],
        nest.flatten(
            shape_invariants[first_loop_var_index:first_loop_var_index +
                             len_orig_loop_vars], expand_composites=True),
        nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
                               len_orig_loop_vars], expand_composites=True))

    num_original_outputs = len(body_graph.outputs)
    if back_prop and util.output_all_intermediates():
      # Export all tensors in the loop body that may be needed for gradient
      # computation. We do this by accumulating the intermediate values in
      # TensorLists.
      intermediate_tensors = _get_intermediates(body_graph)

      for intermediate_tensor in intermediate_tensors:
        tensor_list = list_ops.empty_tensor_list(
            element_dtype=intermediate_tensor.dtype,
            element_shape=intermediate_tensor.shape,
            max_num_elements=maximum_iterations)
        loop_vars.append(tensor_list)
        with cond_graph.as_default():
          # Add a placeholder to cond_graph's inputs corresponding to the
          # tensor_list.
          cond_graph.capture(tensor_list)
        with body_graph.as_default():
          # Push the intermediate tensor to the tensor list. This captures the
          # `tensor_list` as well.
          appended_tensor_list = list_ops.tensor_list_push_back(
              tensor_list, intermediate_tensor)
          # Add this modified tensor list to the list of outputs.
          body_graph.outputs.append(appended_tensor_list)

    flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True)
    _check_num_inputs_outputs(cond_graph, body_graph,
                              len(flattened_loop_vars))
    _check_inputs_outputs_types_match(body_graph, flattened_loop_vars)

    with ops.control_dependencies(
        list(cond_graph.control_captures) + list(body_graph.control_captures)):
      output_shapes = [t.shape for t in body_graph.outputs]
      orig_loop_vars_range = slice(first_loop_var_index,
                                   first_loop_var_index + num_flattened_outputs)
      output_shapes[orig_loop_vars_range] = nest.flatten(
          shape_invariants, expand_composites=True)[orig_loop_vars_range]

      cond_stateful_ops = [
          op for op in cond_graph.get_operations() if op._is_stateful
      ]
      body_stateful_ops = [
          op for op in body_graph.get_operations() if op._is_stateful
      ]
      if (cond_stateful_ops or body_stateful_ops):
        op_fn = gen_functional_ops._while
      else:
        op_fn = gen_functional_ops.stateless_while

      outputs = op_fn(
          flattened_loop_vars,
          util.create_new_tf_function(cond_graph),
          util.create_new_tf_function(body_graph),
          output_shapes=output_shapes,
          parallel_iterations=parallel_iterations,
          name=scope)
      # This is needed so we do not compute derivative wrt these extra outputs.
      outputs[0].op._set_attr("_num_original_outputs",
                              attr_value_pb2.AttrValue(i=num_original_outputs))

    _copy_handle_data(body_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op)

    # Return identities for each output of the While op, rather than the output
    # of the While op directly. This makes pruning work if the output of
    # while_loop() is fetched: the lowering pass converts the While outputs into
    # IdentityN outputs, which if fetched will cause all ops in the body to be
    # run (since it takes all exit ops as input). After lowering, each output
    # identity op will end up with only the appropriate exit op as input.
    outputs = tuple(array_ops.identity(t) for t in outputs)

  outputs = _pack_sequence_as(
      orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index +
                              num_flattened_outputs])

  if return_same_structure:
    return outputs

  flattened_outputs = nest.flatten(outputs, expand_composites=True)
  if len(flattened_outputs) == 1:
    return flattened_outputs[0]
  else:
    return outputs
コード例 #58
0
ファイル: while_v2.py プロジェクト: zhao181/tensorflow
    def _capture_helper(self, tensor, name):
        if tensor.graph is not self._forward_graph:
            return super(_WhileBodyGradFuncGraph,
                         self)._capture_helper(tensor, name)

        while tensor.op.type == "Identity":
            # We do not accumulate the output of identity nodes so we try to capture
            # the input of the Identity node instead.
            tensor = tensor.op.inputs[0]

        captured_tensor = self._indirect_captures.get(tensor)
        if captured_tensor is not None:
            return captured_tensor

        if tensor.dtype == dtypes.resource:
            # Resource-type tensors are not accumulated.
            # If a resource tensor exists in the loop body it must either be a loop
            # input or an output of a nested While op inside the loop body which
            # had captured the external resource.
            if tensor in self._forward_graph.inputs:
                index = self._forward_graph.inputs.index(tensor)
            elif tensor.op.type == "While":
                # Captured resources occur at the same index in the lists of inputs and
                # outputs of a while op. So we lookup the input of `tensor.op` at the
                # same index as the index of `tensor` in the `tensor.op.outputs`.
                index = self._forward_graph.inputs.index(
                    tensor.op.inputs[tensor.value_index])
            else:
                raise ValueError(
                    "Taking gradient of a while loop which creates"
                    " a resource in its body is not supported: %s" %
                    str(tensor))
            # This must be a loop invariant.
            assert self._forward_graph.inputs[
                index] == self._forward_graph.outputs[
                    index], "Resource tensors must be loop invariants %s." % str(
                        self._forward_graph._while.inputs[index])
            tensor_in_outer_graph = self._forward_graph._while.inputs[index]
            self._indirect_captures[tensor] = self.capture(
                tensor_in_outer_graph, whitelisted=True)
            return self._indirect_captures[tensor]

        # Create or find an existing accumulator output for `tensor` in the forward
        # graph, and fetch from this accumulator in the gradient graph to get the
        # raw intermediate value.
        accumulator = _get_accumulator(tensor)
        if accumulator is None:
            # Create the initial empty tensor list.
            with self._forward_graph.outer_graph.as_default():
                tensor_list = list_ops.empty_tensor_list(
                    element_dtype=tensor.dtype,
                    element_shape=tensor.shape,
                    max_num_elements=self._maximum_iterations)
            self.empty_tensor_lists.append(tensor_list)

            # Push the intermediate tensor to the tensor list. This captures
            # `tensor_list`.
            with self._forward_graph.as_default():
                accumulator = list_ops.tensor_list_push_back(
                    tensor_list, tensor)
            # Add the modified tensor list to the list of outputs. This output will be
            # all the accumulated values.
            self._forward_graph.outputs.append(accumulator)

            # Capture in the cond graph as well so the forward cond and body inputs
            # match.
            with self._forward_cond_graph.as_default():
                self._forward_cond_graph.capture(tensor_list)

        # Capture the accumulator tensor list in the gradient graph directly from
        # the forward graph -- we'll later modify this to capture the final list
        # output by the forward While op instead.
        captured_accumulator = super(_WhileBodyGradFuncGraph,
                                     self)._capture_helper(accumulator, name)

        # Pop the intermediate value from the tensor list in the gradient graph.
        new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
            captured_accumulator, element_dtype=tensor.dtype)

        self._indirect_captures[tensor] = captured_tensor
        self.popped_tensor_lists[captured_accumulator] = new_tensor_list
        return captured_tensor
コード例 #59
0
 def _simple_tensor_list(self):
     return list_ops.empty_tensor_list(element_shape=constant_op.constant(
         [1]),
                                       element_dtype=dtypes.int32)
コード例 #60
0
 def clear(self):
   self.list_ = list_ops.empty_tensor_list(self.shape, self.dtype)