示例#1
0
  def testStackEmptyList(self, max_num_elements):
    # Should be able to stack empty lists with fully defined element_shape.
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=[1, 2],
        max_num_elements=max_num_elements)
    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t).shape, (0, 1, 2))

    # Should not be able to stack empty lists with partially defined
    # element_shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "non-fully-defined"):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=[None, 2],
          max_num_elements=max_num_elements)
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)

    # Should not be able to stack empty lists with undefined element_shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "non-fully-defined"):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=None,
          max_num_elements=max_num_elements)
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)
示例#2
0
    def testStackEmptyList(self, max_num_elements):
        # Should be able to stack empty lists with fully defined element_shape.
        l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                       element_shape=[1, 2],
                                       max_num_elements=max_num_elements)
        t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
        self.assertAllEqual(self.evaluate(t).shape, (0, 1, 2))

        # Should not be able to stack empty lists with partially defined
        # element_shape.
        with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                     "non-fully-defined"):
            l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                           element_shape=[-1, 2],
                                           max_num_elements=max_num_elements)
            t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
            self.evaluate(t)

        # Should not be able to stack empty lists with undefined element_shape.
        with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                     "non-fully-defined"):
            l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                           element_shape=-1,
                                           max_num_elements=max_num_elements)
            t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
            self.evaluate(t)
示例#3
0
  def testConcat(self):
    c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
    l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
    l1 = list_ops.tensor_list_from_tensor([-1.0], element_shape=scalar_shape())
    l_batch_0 = array_ops.stack([l0, l1])
    l_batch_1 = array_ops.stack([l1, l0])

    l_concat_01 = list_ops.tensor_list_concat_lists(
        l_batch_0, l_batch_1, element_dtype=dtypes.float32)
    l_concat_10 = list_ops.tensor_list_concat_lists(
        l_batch_1, l_batch_0, element_dtype=dtypes.float32)
    l_concat_00 = list_ops.tensor_list_concat_lists(
        l_batch_0, l_batch_0, element_dtype=dtypes.float32)
    l_concat_11 = list_ops.tensor_list_concat_lists(
        l_batch_1, l_batch_1, element_dtype=dtypes.float32)

    expected_00 = [[1.0, 2.0, 1.0, 2.0], [-1.0, -1.0]]
    expected_01 = [[1.0, 2.0, -1.0], [-1.0, 1.0, 2.0]]
    expected_10 = [[-1.0, 1.0, 2.0], [1.0, 2.0, -1.0]]
    expected_11 = [[-1.0, -1.0], [1.0, 2.0, 1.0, 2.0]]

    for i, (concat, expected) in enumerate(zip(
        [l_concat_00, l_concat_01, l_concat_10, l_concat_11],
        [expected_00, expected_01, expected_10, expected_11])):
      splitted = array_ops.unstack(concat)
      splitted_stacked_ret = self.evaluate(
          (list_ops.tensor_list_stack(splitted[0], dtypes.float32),
           list_ops.tensor_list_stack(splitted[1], dtypes.float32)))
      print("Test concat %d: %s, %s, %s, %s"
            % (i, expected[0], splitted_stacked_ret[0],
               expected[1], splitted_stacked_ret[1]))
      self.assertAllClose(expected[0], splitted_stacked_ret[0])
      self.assertAllClose(expected[1], splitted_stacked_ret[1])

    # Concatenating mismatched shapes fails.
    with self.assertRaises((errors.InvalidArgumentError, ValueError)):
      self.evaluate(
          list_ops.tensor_list_concat_lists(
              l_batch_0,
              list_ops.empty_tensor_list(scalar_shape(), dtypes.float32),
              element_dtype=dtypes.float32))

    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "element shapes are not identical at index 0"):
      l_batch_of_vec_tls = array_ops.stack(
          [list_ops.tensor_list_from_tensor([[1.0]], element_shape=[1])] * 2)
      self.evaluate(
          list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_vec_tls,
                                            element_dtype=dtypes.float32))

    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 r"input_b\[0\].dtype != element_dtype."):
      l_batch_of_int_tls = array_ops.stack(
          [list_ops.tensor_list_from_tensor([1], element_shape=scalar_shape())]
          * 2)
      self.evaluate(
          list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_int_tls,
                                            element_dtype=dtypes.float32))
示例#4
0
 def testAddTensorListsFailsIfLeadingDimsMismatch(self):
   with self.cached_session(), self.test_scope():
     l1 = list_ops.tensor_list_reserve(
         element_shape=[], element_dtype=dtypes.float32, num_elements=2)
     l2 = list_ops.tensor_list_reserve(
         element_shape=[], element_dtype=dtypes.float32, num_elements=3)
     l = math_ops.add_n([l1, l2])
     with self.assertRaisesRegexp(
         errors.InvalidArgumentError,
         "TensorList arguments to AddN must all have the same shape"):
       list_ops.tensor_list_stack(l, element_dtype=dtypes.float32).eval()
示例#5
0
 def testAddTensorListsFailsIfLeadingDimsMismatch(self):
   with self.session(), self.test_scope():
     l1 = list_ops.tensor_list_reserve(
         element_shape=[], element_dtype=dtypes.float32, num_elements=2)
     l2 = list_ops.tensor_list_reserve(
         element_shape=[], element_dtype=dtypes.float32, num_elements=3)
     l = math_ops.add_n([l1, l2])
     with self.assertRaisesRegex(
         errors.InvalidArgumentError,
         "TensorList arguments to AddN must all have the same shape"):
       list_ops.tensor_list_stack(l, element_dtype=dtypes.float32).eval()
示例#6
0
  def test_tensor_list_empty_list(self):
    l = special_functions.tensor_list([],
                                      element_dtype=dtypes.int32,
                                      element_shape=())
    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.cached_session() as sess:
      self.assertAllEqual(self.evaluate(sl), [])

    l = special_functions.tensor_list((),
                                      element_dtype=dtypes.int32,
                                      element_shape=())
    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.cached_session() as sess:
      self.assertAllEqual(self.evaluate(sl), [])
  def test_tensor_list_empty_list(self):
    l = special_functions.tensor_list([],
                                      element_dtype=dtypes.int32,
                                      element_shape=())
    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.cached_session() as sess:
      self.assertAllEqual(sess.run(sl), [])

    l = special_functions.tensor_list((),
                                      element_dtype=dtypes.int32,
                                      element_shape=())
    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.cached_session() as sess:
      self.assertAllEqual(sess.run(sl), [])
  def testStackWithPartiallyDefinedElementShape(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=[-1])
    l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0]))
    l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0]))

    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[1.0], [2.0]])

    # Should raise an error when the element tensors do not all have the same
    # shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"):
      l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0, 3.0]))
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)
 def test_tf_tensor_list_new_empty(self):
   l = data_structures.tf_tensor_list_new([],
                                          element_dtype=dtypes.int32,
                                          element_shape=())
   t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
   with self.cached_session() as sess:
     self.assertAllEqual(sess.run(t), [])
def _tf_tensor_list_stack(list_, opts):
    """Overload of list_stack that stages a Tensor list write."""
    if opts.element_dtype is None:
        raise ValueError(
            'cannot stack a list without knowing its element type;'
            ' use set_element_type to annotate it')
    return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype)
示例#11
0
 def testStackWithUninitializedTensors(self):
     with self.cached_session(), self.test_scope():
         l = list_ops.tensor_list_reserve(element_dtype=dtypes.float32,
                                          element_shape=[],
                                          num_elements=3)
         t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
         self.assertAllEqual(t, [0., 0., 0.])
 def test_tf_tensor_list_new_empty(self):
     l = data_structures.tf_tensor_list_new([],
                                            element_dtype=dtypes.int32,
                                            element_shape=())
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
     with self.cached_session() as sess:
         self.assertAllEqual(sess.run(t), [])
示例#13
0
 def testSetStackReservedUnknownElementShape(self):
   with self.cached_session(), self.test_scope():
     l = list_ops.tensor_list_reserve(
         element_dtype=dtypes.float32, element_shape=None, num_elements=2)
     l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0])
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
     self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]])
示例#14
0
  def test_tensor_list_from_elements(self):
    elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]

    l = special_functions.tensor_list(elements)
    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.cached_session() as sess:
      self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
示例#15
0
 def testAddN(self):
   l1 = list_ops.tensor_list_from_tensor([1.0, 2.0], element_shape=[])
   l2 = list_ops.tensor_list_from_tensor([3.0, 4.0], element_shape=[])
   l3 = list_ops.tensor_list_from_tensor([5.0, 6.0], element_shape=[])
   result = math_ops.add_n((l1, l2, l3))
   result_t = list_ops.tensor_list_stack(result, element_dtype=dtypes.float32)
   self.assertAllEqual(self.evaluate(result_t), [9., 12.])
 def testSetStackReservedUnknownElementShape(self):
   with self.cached_session(), self.test_scope():
     l = list_ops.tensor_list_reserve(
         element_dtype=dtypes.float32, element_shape=None, num_elements=2)
     l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0])
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
     self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]])
示例#17
0
 def testStack(self):
   l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                  element_shape=scalar_shape())
   l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
   l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
   t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
   self.assertAllEqual(t, [1.0, 2.0])
示例#18
0
  def testPruning(self):
    x = constant_op.constant(1)

    tensor_list = list_ops.empty_tensor_list(
        element_dtype=x.dtype, element_shape=x.shape)

    def Cond(x, tl):
      del tl  # Unused for Cond.
      return x < 5

    def Body(x, tl):
      return x + 1, list_ops.tensor_list_push_back(tl, x)

    outputs = while_loop_v1(Cond, Body, [x, tensor_list])

    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    train_op.append(outputs[0])

    def GetOptimizedGraph():
      mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
      rewriter_config = rewriter_config_pb2.RewriterConfig(
          constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
          memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
      return tf_optimizer.OptimizeGraph(rewriter_config, mg)

    g = GetOptimizedGraph()
    self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)

    stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
    train_op.append(stack)
    g = GetOptimizedGraph()
    self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
示例#19
0
  def test_list_pop(self):

    def test_fn():
      l = [1, 2, 3]
      utils.set_element_type(l, dtypes.int32, ())
      s = l.pop()
      return s, l

    node = self.parse_and_analyze(
        test_fn,
        {
            'utils': utils,
            'dtypes': dtypes
        },
        include_type_analysis=True,
    )
    node = lists.transform(node, self.ctx)

    with self.compiled(node) as result:
      result.utils = utils
      result.dtypes = dtypes
      with self.test_session() as sess:
        ts, tl = result.test_fn()
        r = list_ops.tensor_list_stack(tl, dtypes.int32)
        self.assertAllEqual(sess.run(r), [1, 2])
        self.assertAllEqual(sess.run(ts), 3)
示例#20
0
    def testPruning(self):
        x = constant_op.constant(1)

        tensor_list = list_ops.empty_tensor_list(element_dtype=x.dtype,
                                                 element_shape=x.shape)

        def Cond(x, tl):
            del tl  # Unused for Cond.
            return x < 5

        def Body(x, tl):
            return x + 1, list_ops.tensor_list_push_back(tl, x)

        outputs = while_loop_v1(Cond, Body, [x, tensor_list])

        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(outputs[0])

        def GetOptimizedGraph():
            mg = meta_graph.create_meta_graph_def(
                graph=ops.get_default_graph())
            rewriter_config = rewriter_config_pb2.RewriterConfig(
                constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
                memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
            return tf_optimizer.OptimizeGraph(rewriter_config, mg)

        g = GetOptimizedGraph()
        self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)

        stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
        train_op.append(stack)
        g = GetOptimizedGraph()
        self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
示例#21
0
 def testAddN(self):
   l1 = list_ops.tensor_list_from_tensor([1.0, 2.0], element_shape=[])
   l2 = list_ops.tensor_list_from_tensor([3.0, 4.0], element_shape=[])
   l3 = list_ops.tensor_list_from_tensor([5.0, 6.0], element_shape=[])
   result = math_ops.add_n((l1, l2, l3))
   result_t = list_ops.tensor_list_stack(result, element_dtype=dtypes.float32)
   self.assertAllEqual(self.evaluate(result_t), [9., 12.])
    def build_graph(parameters):
        """Build the TensorListSetItem op testing graph."""
        item = tf.placeholder(dtype=parameters["element_dtype"],
                              shape=parameters["element_shape"])
        tensor_list = list_ops.tensor_list_reserve(
            element_shape=None,
            num_elements=parameters["num_elements"],
            element_dtype=parameters["element_dtype"])

        init_state = (0, tensor_list)
        condition = lambda i, _: i < parameters["num_elements"]

        def loop_body(i, tensor_list):
            new_item = tf.add(
                tf.add(item, item),
                tf.constant(value=1, dtype=parameters["element_dtype"]))
            new_list = list_ops.tensor_list_set_item(tensor_list, i, new_item)
            return i + 1, new_list

        _, tensor_list = tf.while_loop(condition, loop_body, init_state)
        out = list_ops.tensor_list_stack(
            tensor_list,
            num_elements=parameters["num_elements"],
            element_dtype=parameters["element_dtype"])
        return [item], [out]
示例#23
0
    def test_list_pop(self):
        def test_fn():
            l = [1, 2, 3]
            utils.set_element_type(l, dtypes.int32, ())
            s = l.pop()
            return s, l

        node = self.parse_and_analyze(
            test_fn,
            {
                'utils': utils,
                'dtypes': dtypes
            },
            include_type_analysis=True,
        )
        node = lists.transform(node, self.ctx)

        with self.compiled(node) as result:
            result.utils = utils
            result.dtypes = dtypes
            with self.test_session() as sess:
                ts, tl = result.test_fn()
                r = list_ops.tensor_list_stack(tl, dtypes.int32)
                self.assertAllEqual(sess.run(r), [1, 2])
                self.assertAllEqual(sess.run(ts), 3)
 def testStack(self):
     l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                    element_shape=scalar_shape())
     l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
     l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
     self.assertAllEqual(self.evaluate(t), [1.0, 2.0])
  def test_tensor_list_from_elements(self):
    elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]

    l = special_functions.tensor_list(elements)
    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.test_session() as sess:
      self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
示例#26
0
    def testZerosLike(self):
        l_empty = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                             element_shape=scalar_shape())
        l_empty_zeros = array_ops.zeros_like(l_empty)
        t_empty_zeros = list_ops.tensor_list_stack(
            l_empty_zeros, element_dtype=dtypes.float32)

        l_full = list_ops.tensor_list_push_back(l_empty,
                                                constant_op.constant(1.0))
        l_full = list_ops.tensor_list_push_back(l_full,
                                                constant_op.constant(2.0))
        l_full_zeros = array_ops.zeros_like(l_full)
        t_full_zeros = list_ops.tensor_list_stack(l_full_zeros,
                                                  element_dtype=dtypes.float32)

        self.assertAllEqual(self.evaluate(t_empty_zeros), [])
        self.assertAllEqual(self.evaluate(t_full_zeros), [0.0, 0.0])
 def stack(self, name=None):
   """See TensorArray."""
   with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]):
     value = list_ops.tensor_list_stack(
         input_handle=self._flow, element_dtype=self._dtype)
     if self._element_shape and self._element_shape[0].dims is not None:
       value.set_shape([None] + self._element_shape[0].dims)
     return value
  def test_append_tensor_list(self):
    l = data_structures.new_list()
    x = constant_op.constant([1, 2, 3])
    l = data_structures.list_append(l, x)

    t = list_ops.tensor_list_stack(l, element_dtype=x.dtype)
    with self.test_session() as sess:
      self.assertAllEqual(sess.run(t), [[1, 2, 3]])
示例#29
0
 def stack(self, name=None):
   """See TensorArray."""
   with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]):
     value = list_ops.tensor_list_stack(
         input_handle=self._flow, element_dtype=self._dtype)
     if self._element_shape and self._element_shape[0].dims is not None:
       value.set_shape([None] + self._element_shape[0].dims)
     return value
示例#30
0
  def testStackWithUnknownElementShape(self, max_num_elements):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=None,
        max_num_elements=max_num_elements)
    l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
    l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))

    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [1.0, 2.0])

    # Should raise an error when the element tensors do not all have the same
    # shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"):
      l = list_ops.tensor_list_push_back(l, constant_op.constant([3.0, 4.0]))
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)
 def testGetSetItem(self):
     t = constant_op.constant([1.0, 2.0])
     l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape())
     e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
     self.assertAllEqual(self.evaluate(e0), 1.0)
     l = list_ops.tensor_list_set_item(l, 0, 3.0)
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
     self.assertAllEqual(self.evaluate(t), [3.0, 2.0])
示例#32
0
  def testStackWithUnknownElementShape(self, max_num_elements):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=None,
        max_num_elements=max_num_elements)
    l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
    l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))

    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [1.0, 2.0])

    # Should raise an error when the element tensors do not all have the same
    # shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"):
      l = list_ops.tensor_list_push_back(l, constant_op.constant([3.0, 4.0]))
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)
示例#33
0
 def testGetSetItem(self):
   t = constant_op.constant([1.0, 2.0])
   l = list_ops.tensor_list_from_tensor(t, element_shape=[])
   e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
   self.assertAllEqual(self.evaluate(e0), 1.0)
   l = list_ops.tensor_list_set_item(l, 0, 3.0)
   t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
   self.assertAllEqual(self.evaluate(t), [3.0, 2.0])
示例#34
0
    def test_append_tensor_list(self):
        l = data_structures.new_list()
        x = constant_op.constant([1, 2, 3])
        l = data_structures.list_append(l, x)

        t = list_ops.tensor_list_stack(l, element_dtype=x.dtype)
        with self.test_session() as sess:
            self.assertAllEqual(sess.run(t), [[1, 2, 3]])
示例#35
0
    def test_initialized_list(self):
        def test_fn():
            return [1, 2, 3]

        with self.converted(test_fn, lists, {}) as result:
            with self.test_session() as sess:
                tl = result.test_fn()
                r = list_ops.tensor_list_stack(tl, dtypes.int32)
                self.assertAllEqual(sess.run(r), [1, 2, 3])
示例#36
0
  def test_set_item_tensor_list(self):
    initial_list = constant_op.constant([[1, 2], [3, 4]])
    elem_shape = constant_op.constant([2])
    l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
    l = slices.set_item(l, 0, [5, 6])

    with self.cached_session() as sess:
      t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
      self.assertAllEqual(self.evaluate(t), [[5, 6], [3, 4]])
示例#37
0
 def testStackFromTensorGradients(self):
   with backprop.GradientTape() as tape:
     c = constant_op.constant([1.0, 2.0])
     tape.watch(c)
     l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
     c2 = list_ops.tensor_list_stack(
         l, element_dtype=dtypes.float32)
     result = c2 * 2.0
   self.assertAllEqual(tape.gradient(result, [c])[0], [2.0, 2.0])
示例#38
0
 def testGetSet(self):
   with self.cached_session(), self.test_scope():
     t = constant_op.constant([1.0, 2.0])
     l = list_ops.tensor_list_from_tensor(t, element_shape=[])
     e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
     self.assertAllEqual(e0, 1.0)
     l = list_ops.tensor_list_set_item(l, 0, 3.0)
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
     self.assertAllEqual(t, [3.0, 2.0])
 def _testStack(self, max_num_elements):
   l = list_ops.empty_tensor_list(
       element_dtype=dtypes.float32,
       element_shape=scalar_shape(),
       max_num_elements=max_num_elements)
   l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
   l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
   t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
   self.assertAllEqual(self.evaluate(t), [1.0, 2.0])
示例#40
0
 def testGetSetReserved(self):
   with self.cached_session(), self.test_scope():
     l = list_ops.tensor_list_reserve(
         element_dtype=dtypes.float32, element_shape=[], num_elements=2)
     e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
     self.assertAllEqual(e0, 0.0)
     l = list_ops.tensor_list_set_item(l, 0, 3.0)
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
     self.assertAllEqual(t, [3.0, 0.0])
 def testGetSetReserved(self):
   with self.cached_session(), self.test_scope():
     l = list_ops.tensor_list_reserve(
         element_dtype=dtypes.float32, element_shape=[], num_elements=2)
     e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
     self.assertAllEqual(e0, 0.0)
     l = list_ops.tensor_list_set_item(l, 0, 3.0)
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
     self.assertAllEqual(t, [3.0, 0.0])
 def testStackFromTensorGradients(self):
     with backprop.GradientTape() as tape:
         c = constant_op.constant([1.0, 2.0])
         tape.watch(c)
         l = list_ops.tensor_list_from_tensor(c,
                                              element_shape=scalar_shape())
         c2 = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
         result = c2 * 2.0
     self.assertAllEqual(tape.gradient(result, [c])[0], [2.0, 2.0])
 def testGetSet(self):
   with self.cached_session(), self.test_scope():
     t = constant_op.constant([1.0, 2.0])
     l = list_ops.tensor_list_from_tensor(t, element_shape=[])
     e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
     self.assertAllEqual(e0, 1.0)
     l = list_ops.tensor_list_set_item(l, 0, 3.0)
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
     self.assertAllEqual(t, [3.0, 2.0])
示例#44
0
 def testGraphStack(self):
   with context.graph_mode(), self.test_session():
     tl = list_ops.empty_tensor_list(
         element_shape=constant_op.constant([1], dtype=dtypes.int32),
         element_dtype=dtypes.int32)
     tl = list_ops.tensor_list_push_back(tl, [1])
     self.assertAllEqual(
         list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32).eval(),
         [[1]])
示例#45
0
 def testGraphStack(self):
   with context.graph_mode(), self.test_session():
     tl = list_ops.empty_tensor_list(
         element_shape=constant_op.constant([1], dtype=dtypes.int32),
         element_dtype=dtypes.int32)
     tl = list_ops.tensor_list_push_back(tl, [1])
     self.assertAllEqual(
         list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32).eval(),
         [[1]])
 def testSetDoesNotUpdatePushIndex(self):
   with self.cached_session(), self.test_scope():
     l = list_ops.empty_tensor_list(
         element_shape=[], element_dtype=dtypes.float32, max_num_elements=2)
     # SetItem should not change the push index.
     l = list_ops.tensor_list_set_item(l, 1, 3.)
     l = list_ops.tensor_list_push_back(l, 5.)
     l = list_ops.tensor_list_push_back(l, 7.)
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
     self.assertAllEqual(t, [5., 7.])
示例#47
0
 def testSetDoesNotUpdatePushIndex(self):
   with self.cached_session(), self.test_scope():
     l = list_ops.empty_tensor_list(
         element_shape=[], element_dtype=dtypes.float32, max_num_elements=2)
     # SetItem should not change the push index.
     l = list_ops.tensor_list_set_item(l, 1, 3.)
     l = list_ops.tensor_list_push_back(l, 5.)
     l = list_ops.tensor_list_push_back(l, 7.)
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
     self.assertAllEqual(t, [5., 7.])
示例#48
0
 def testStackFromTensorGradients(self):
   with backprop.GradientTape() as tape:
     c = constant_op.constant([1.0, 2.0])
     tape.watch(c)
     l = list_ops.tensor_list_from_tensor(c, element_shape=[])
     c2 = list_ops.tensor_list_stack(
         l, element_dtype=dtypes.float32, num_elements=2)
     result = c2 * 2.0
   grad = tape.gradient(result, [c])[0]
   self.assertAllEqual(self.evaluate(grad), [2.0, 2.0])
示例#49
0
 def testGraphStack(self):
   with self.cached_session():
     tl = list_ops.empty_tensor_list(
         element_shape=constant_op.constant([1], dtype=dtypes.int32),
         element_dtype=dtypes.int32)
     tl = list_ops.tensor_list_push_back(tl, [1])
     self.assertAllEqual(
         self.evaluate(
             list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)),
         [[1]])
示例#50
0
    def test_set_item_tensor_list(self):
        initial_list = constant_op.constant([[1, 2], [3, 4]])
        elem_shape = constant_op.constant([2])
        l = list_ops.tensor_list_from_tensor(initial_list,
                                             element_shape=elem_shape)
        l = slices.set_item(l, 0, [5, 6])

        with self.cached_session() as sess:
            t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
            self.assertAllEqual(self.evaluate(t), [[5, 6], [3, 4]])
示例#51
0
  def test_initialized_list(self):

    def test_fn():
      return [1, 2, 3]

    with self.converted(test_fn, lists, {}) as result:
      with self.test_session() as sess:
        tl = result.test_fn()
        r = list_ops.tensor_list_stack(tl, dtypes.int32)
        self.assertAllEqual(sess.run(r), [1, 2, 3])
 def testZerosLikeForTensorList(self):
     with self.session(), self.test_scope():
         l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                        element_shape=[],
                                        max_num_elements=2)
         l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
         z = array_ops.zeros_like(l)
         z = list_ops.tensor_list_stack(z, element_dtype=dtypes.float32)
         self.assertAllEqual(z.shape.as_list(), [None])
         self.assertAllEqual(z, [0.0, 0.0])
示例#53
0
 def testPushInEmptyListWithUnknownElementShape(self):
   with self.cached_session(), self.test_scope():
     l = list_ops.empty_tensor_list(
         element_dtype=dtypes.float32, element_shape=None, max_num_elements=2)
     l = list_ops.tensor_list_push_back(l, [3.0, 4.0])
     # Pushing an element with a different shape should raise an error.
     with self.assertRaisesRegexp(errors.InvalidArgumentError, "Shape"):
       l = list_ops.tensor_list_push_back(l, 5.)
       self.evaluate(
           list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))
 def testPushInEmptyListWithUnknownElementShape(self):
   with self.cached_session(), self.test_scope():
     l = list_ops.empty_tensor_list(
         element_dtype=dtypes.float32, element_shape=None, max_num_elements=2)
     l = list_ops.tensor_list_push_back(l, [3.0, 4.0])
     # Pushing an element with a different shape should raise an error.
     with self.assertRaisesRegexp(errors.InternalError, "shape"):
       l = list_ops.tensor_list_push_back(l, 5.)
       self.evaluate(
           list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))
    def testPushBackBatch(self):
        c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
        l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
        l1 = list_ops.tensor_list_from_tensor([-1.0],
                                              element_shape=scalar_shape())
        l_batch = array_ops.stack([l0, l1])
        l_push = list_ops.tensor_list_push_back_batch(l_batch, [3.0, 4.0])
        l_unstack = array_ops.unstack(l_push)
        l0_ret = list_ops.tensor_list_stack(l_unstack[0], dtypes.float32)
        l1_ret = list_ops.tensor_list_stack(l_unstack[1], dtypes.float32)
        self.assertAllClose([1.0, 2.0, 3.0], self.evaluate(l0_ret))
        self.assertAllClose([-1.0, 4.0], self.evaluate(l1_ret))

        with ops.control_dependencies([l_push]):
            l_unstack_orig = array_ops.unstack(l_batch)
            l0_orig_ret = list_ops.tensor_list_stack(l_unstack_orig[0],
                                                     dtypes.float32)
            l1_orig_ret = list_ops.tensor_list_stack(l_unstack_orig[1],
                                                     dtypes.float32)

        # Check that without aliasing, push_back_batch still works; and
        # that it doesn't modify the input.
        l0_r_v, l1_r_v, l0_orig_v, l1_orig_v = self.evaluate(
            (l0_ret, l1_ret, l0_orig_ret, l1_orig_ret))
        self.assertAllClose([1.0, 2.0, 3.0], l0_r_v)
        self.assertAllClose([-1.0, 4.0], l1_r_v)
        self.assertAllClose([1.0, 2.0], l0_orig_v)
        self.assertAllClose([-1.0], l1_orig_v)

        # Pushing back mismatched shapes fails.
        with self.assertRaises((errors.InvalidArgumentError, ValueError)):
            self.evaluate(list_ops.tensor_list_push_back_batch(l_batch, []))

        with self.assertRaisesRegexp(
                errors.InvalidArgumentError,
                "incompatible shape to a list at index 0"):
            self.evaluate(
                list_ops.tensor_list_push_back_batch(l_batch, [[3.0], [4.0]]))

        with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                     "Invalid data type at index 0"):
            self.evaluate(list_ops.tensor_list_push_back_batch(
                l_batch, [3, 4]))
 def testZerosLikeForTensorList(self):
   with self.cached_session(), self.test_scope():
     l = list_ops.empty_tensor_list(
         element_dtype=dtypes.float32,
         element_shape=[],
         max_num_elements=2)
     l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
     z = array_ops.zeros_like(l)
     z = list_ops.tensor_list_stack(z, element_dtype=dtypes.float32)
     self.assertAllEqual(z.shape.as_list(), [None])
     self.assertAllEqual(z, [0.0, 0.0])
示例#57
0
 def testResourceVariableScatterGather(self):
   c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
   l = list_ops.tensor_list_from_tensor(c, element_shape=[])
   v = vs.get_variable("var", initializer=[l] * 10, use_resource=True)
   v_r_0_stacked = list_ops.tensor_list_stack(v[0], dtypes.float32)
   self.evaluate(v.initializer)
   self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_0_stacked))
   v_r_sparse_stacked = list_ops.tensor_list_stack(
       v.sparse_read(0), dtypes.float32)
   self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_sparse_stacked))
   l_new_0 = list_ops.tensor_list_from_tensor([3.0, 4.0], element_shape=[])
   l_new_1 = list_ops.tensor_list_from_tensor([5.0, 6.0], element_shape=[])
   updated_v = state_ops.scatter_update(v, [3, 5], [l_new_0, l_new_1])
   updated_v_elems = array_ops.unstack(updated_v)
   updated_v_stacked = [
       list_ops.tensor_list_stack(el, dtypes.float32) for el in updated_v_elems
   ]
   expected = ([[1.0, 2.0]] * 3 + [[3.0, 4.0], [1.0, 2.0], [5.0, 6.0]] +
               [[1.0, 2.0]] * 4)
   self.assertAllEqual(self.evaluate(updated_v_stacked), expected)
示例#58
0
 def _testStack(self, max_num_elements):
   l = list_ops.empty_tensor_list(
       element_dtype=dtypes.float32,
       element_shape=[],
       max_num_elements=max_num_elements)
   l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
   l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
   t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
   if not context.executing_eagerly():
     self.assertAllEqual(t.shape.as_list(), [None])
   self.assertAllEqual(self.evaluate(t), [1.0, 2.0])