Пример #1
0
  def test_tensor_list_empty_list(self):
    l = special_functions.tensor_list([],
                                      element_dtype=dtypes.int32,
                                      element_shape=())
    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.cached_session() as sess:
      self.assertAllEqual(self.evaluate(sl), [])

    l = special_functions.tensor_list((),
                                      element_dtype=dtypes.int32,
                                      element_shape=())
    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.cached_session() as sess:
      self.assertAllEqual(self.evaluate(sl), [])
  def test_tensor_list_empty_list(self):
    l = special_functions.tensor_list([],
                                      element_dtype=dtypes.int32,
                                      element_shape=())
    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.cached_session() as sess:
      self.assertAllEqual(sess.run(sl), [])

    l = special_functions.tensor_list((),
                                      element_dtype=dtypes.int32,
                                      element_shape=())
    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.cached_session() as sess:
      self.assertAllEqual(sess.run(sl), [])
Пример #3
0
  def test_tensor_list_array_from_elements(self):
    elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]

    l = special_functions.tensor_list(elements, use_tensor_array=True)
    sl = l.stack()
    with self.cached_session() as sess:
      self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
  def test_tensor_list_array_from_elements(self):
    elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]

    l = special_functions.tensor_list(elements, use_tensor_array=True)
    sl = l.stack()
    with self.test_session() as sess:
      self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
  def test_tensor_list_from_elements(self):
    elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]

    l = special_functions.tensor_list(elements)
    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.test_session() as sess:
      self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
Пример #6
0
  def test_tensor_list_from_elements(self):
    elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]

    l = special_functions.tensor_list(elements)
    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.cached_session() as sess:
      self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
Пример #7
0
 def test_tensor_list_empty_list_no_type(self):
   with self.assertRaisesRegexp(
       ValueError, 'element_dtype and element_shape are required'):
     special_functions.tensor_list([])
Пример #8
0
 def test_tensor_list_unsupported_initializer(self):
   with self.assertRaisesRegexp(ValueError, 'unknown type'):
     special_functions.tensor_list(np.array([1, 2, 3]))
Пример #9
0
 def test_tensor_list_tensor(self):
   l = special_functions.tensor_list(
       constant_op.constant([], dtype=dtypes.int32))
   sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
   with self.cached_session() as sess:
     self.assertAllEqual(self.evaluate(sl), [])
 def test_tensor_list_unsupported_initializer(self):
   with self.assertRaisesRegexp(ValueError, 'unknown type'):
     special_functions.tensor_list(np.array([1, 2, 3]))
Пример #11
0
 def test_fn():
     l = special_functions.tensor_list([1, 2, 3])
     s = l.pop()
     return s, l
Пример #12
0
 def test_fn():
     l = special_functions.tensor_list([1])
     l.append(2)
     l.append(3)
     return l
 def test_tensor_list_tensor(self):
   l = special_functions.tensor_list(
       constant_op.constant([], dtype=dtypes.int32))
   sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
   with self.cached_session() as sess:
     self.assertAllEqual(sess.run(sl), [])
Пример #14
0
 def f():
     l = special_functions.tensor_list([1, 2, 3])
     directives.set_element_type(l, dtype=dtypes.int32, shape=())
     s = l.pop()
     return s, l
Пример #15
0
 def test_fn():
   l = special_functions.tensor_list([1, 2, 3])
   s = l.pop()
   return s, l
Пример #16
0
 def test_fn():
   l = special_functions.tensor_list([1])
   l.append(2)
   l.append(3)
   return l
 def test_tensor_list_empty_list_no_type(self):
   with self.assertRaisesRegexp(
       ValueError, 'element_dtype and element_shape are required'):
     special_functions.tensor_list([])