def stack(list_or_tensor, element_dtype=None, strict=True):
    """Stacks the input, if it admits the notion of stacking.

  For example, a list of tensors can be stacked into a larger tensor. This
  function is similar to tf.stack, but it accepts non-lists and lists of
  non-tensors as arguments. In the latter case, the function does nothing.

  Args:
    list_or_tensor: Any
    element_dtype: tf.DType, optional dtypedtype for the elements in the list.
        Required if the input is stackable, and the list is untyped.
    strict: bool, if True an error is raised if the input is not stackable.
        Otherwise the function is a no-op.

  Returns:
    Any, if the input is stackable, the result will be a tf.Tensor. Otherwise,
    if strict=False, the result will be list_or_tensor.

  Raises:
    ValueError: if strict=True and the input is not stackable.
  """
    if strict:

        def raise_error(x):
            raise ValueError('%s must be stackable when strict=True' % x)

        original_call = raise_error
    else:
        original_call = lambda x: x
    return data_structures.list_stack(
        list_or_tensor,
        data_structures.ListStackOpts(element_dtype=element_dtype,
                                      original_call=original_call))
Beispiel #2
0
    def test_stack_fallback(self):
        def dummy_function(l):
            # Lazy person's mock: just transform the argument in a way in which we
            # can check that this function was indeed called.
            return [x * 2 for x in l]

        opts = data_structures.ListStackOpts(element_dtype=None,
                                             original_call=dummy_function)

        self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4])
Beispiel #3
0
    def test_stack_tensor_list(self):
        initial_list = constant_op.constant([[1, 2], [3, 4]])
        elem_shape = constant_op.constant([2])
        l = list_ops.tensor_list_from_tensor(initial_list,
                                             element_shape=elem_shape)

        opts = data_structures.ListStackOpts(element_dtype=initial_list.dtype,
                                             original_call=None)

        with self.test_session() as sess:
            t = data_structures.list_stack(l, opts)
            self.assertAllEqual(sess.run(t), sess.run(initial_list))
Beispiel #4
0
def stack(list_or_tensor, element_dtype=None):
  """Stacks the input, if it admits the notion of stacking. No-op otherwise.

  For example, a list of tensors can be stacked into a larger tensor. This
  function is similar to tf.stack, but it accepts non-lists and lists of
  non-tensors as arguments. In the latter case, the function does nothing.

  Args:
    list_or_tensor: Any entity.
    element_dtype: Optional dtype for the elements in the list. Required if the
        input is stackable, and the list is untyped.

  Returns:
    If the input is stackable, a new object representing the stacked inputs.
  Otherwise it returns list_or_tensor unchanged.
  """
  return data_structures.list_stack(
      list_or_tensor,
      data_structures.ListStackOpts(
          element_dtype=element_dtype, original_call=lambda x: x))