def test_stack_tensor_list_empty(self):
        l = list_ops.empty_tensor_list(element_shape=None,
                                       element_dtype=dtypes.variant)

        opts = data_structures.ListStackOpts(element_dtype=dtypes.int32,
                                             original_call=None)

        # TODO(mdan): Allow stacking empty lists if the dtype and shape are known.
        with self.assertRaises(ValueError):
            data_structures.list_stack(l, opts)
  def test_stack_tensor_list_empty(self):
    l = list_ops.empty_tensor_list(
        element_shape=None, element_dtype=dtypes.variant)

    opts = data_structures.ListStackOpts(
        element_dtype=dtypes.int32, original_call=None)

    # TODO(mdan): Allow stacking empty lists if the dtype and shape are known.
    with self.assertRaises(ValueError):
      data_structures.list_stack(l, opts)
Beispiel #3
0
def stack(list_or_tensor, element_dtype=None, strict=True):
  """Stacks the input, if it admits the notion of stacking.

  For example, a list of tensors can be stacked into a larger tensor. This
  function is similar to tf.stack, but it accepts non-lists and lists of
  non-tensors as arguments. In the latter case, the function does nothing.

  Args:
    list_or_tensor: Any
    element_dtype: tf.DType, optional dtypedtype for the elements in the list.
        Required if the input is stackable, and the list is untyped.
    strict: bool, if True an error is raised if the input is not stackable.
        Otherwise the function is a no-op.

  Returns:
    Any, if the input is stackable, the result will be a tf.Tensor. Otherwise,
    if strict=False, the result will be list_or_tensor.

  Raises:
    ValueError: if strict=True and the input is not stackable.
  """
  if strict:
    def raise_error(x):
      raise ValueError('%s must be stackable when strict=True' % x)
    original_call = raise_error
  else:
    original_call = lambda x: x
  return data_structures.list_stack(
      list_or_tensor,
      data_structures.ListStackOpts(
          element_dtype=element_dtype, original_call=original_call))
    def test_stack_fallback(self):
        def dummy_function(l):
            # Lazy person's mock: just transform the argument in a way in which we
            # can check that this function was indeed called.
            return [x * 2 for x in l]

        opts = data_structures.ListStackOpts(element_dtype=None,
                                             original_call=dummy_function)

        self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4])
  def test_stack_tensor_list(self):
    initial_list = constant_op.constant([[1, 2], [3, 4]])
    elem_shape = constant_op.constant([2])
    l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)

    opts = data_structures.ListStackOpts(
        element_dtype=initial_list.dtype, original_call=None)

    with self.cached_session() as sess:
      t = data_structures.list_stack(l, opts)
      self.assertAllEqual(self.evaluate(t), self.evaluate(initial_list))
  def test_stack_fallback(self):

    def dummy_function(l):
      # Lazy person's mock: just transform the argument in a way in which we
      # can check that this function was indeed called.
      return [x * 2 for x in l]

    opts = data_structures.ListStackOpts(
        element_dtype=None, original_call=dummy_function)

    self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4])
  def test_stack_tensor_list(self):
    initial_list = constant_op.constant([[1, 2], [3, 4]])
    elem_shape = constant_op.constant([2])
    l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)

    opts = data_structures.ListStackOpts(
        element_dtype=initial_list.dtype, original_call=None)

    with self.cached_session() as sess:
      t = data_structures.list_stack(l, opts)
      self.assertAllEqual(sess.run(t), sess.run(initial_list))