Пример #1
0
def stack(list_or_tensor, element_dtype=None, strict=True):
  """Stacks the input, if it admits the notion of stacking.

  For example, a list of tensors can be stacked into a larger tensor. This
  function is similar to tf.stack, but it accepts non-lists and lists of
  non-tensors as arguments. In the latter case, the function does nothing.

  Args:
    list_or_tensor: Any
    element_dtype: tf.DType, optional dtypedtype for the elements in the list.
        Required if the input is stackable, and the list is untyped.
    strict: bool, if True an error is raised if the input is not stackable.
        Otherwise the function is a no-op.

  Returns:
    Any, if the input is stackable, the result will be a tf.Tensor. Otherwise,
    if strict=False, the result will be list_or_tensor.

  Raises:
    ValueError: if strict=True and the input is not stackable.
  """
  if strict:
    def raise_error(x):
      raise ValueError('%s must be stackable when strict=True' % x)
    original_call = raise_error
  else:
    original_call = lambda x: x
  return data_structures.list_stack(
      list_or_tensor,
      data_structures.ListStackOpts(
          element_dtype=element_dtype, original_call=original_call))
def stack(list_or_tensor, element_dtype=None, strict=True):
    """Stacks the input, if it admits the notion of stacking.

  For example, a list of tensors can be stacked into a larger tensor. This
  function is similar to tf.stack, but it accepts non-lists and lists of
  non-tensors as arguments. In the latter case, the function does nothing.

  Args:
    list_or_tensor: Any
    element_dtype: tf.DType, optional dtypedtype for the elements in the list.
        Required if the input is stackable, and the list is untyped.
    strict: bool, if True an error is raised if the input is not stackable.
        Otherwise the function is a no-op.

  Returns:
    Any, if the input is stackable, the result will be a tf.Tensor. Otherwise,
    if strict=False, the result will be list_or_tensor.

  Raises:
    ValueError: if strict=True and the input is not stackable.
  """
    if strict:

        def raise_error(x):
            raise ValueError('%s must be stackable when strict=True' % x)

        original_call = raise_error
    else:
        original_call = lambda x: x
    return data_structures.list_stack(
        list_or_tensor,
        data_structures.ListStackOpts(element_dtype=element_dtype,
                                      original_call=original_call))
Пример #3
0
    def test_stack_fallback(self):
        def dummy_function(l):
            # Lazy person's mock: just transform the argument in a way in which we
            # can check that this function was indeed called.
            return [x * 2 for x in l]

        opts = data_structures.ListStackOpts(element_dtype=None,
                                             original_call=dummy_function)

        self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4])
  def test_stack_tensor_list(self):
    initial_list = constant_op.constant([[1, 2], [3, 4]])
    elem_shape = constant_op.constant([2])
    l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)

    opts = data_structures.ListStackOpts(
        element_dtype=initial_list.dtype, original_call=None)

    with self.test_session() as sess:
      t = data_structures.list_stack(l, opts)
      self.assertAllEqual(sess.run(t), sess.run(initial_list))
  def test_stack_fallback(self):

    def dummy_function(l):
      # Lazy person's mock: just transform the argument in a way in which we
      # can check that this function was indeed called.
      return [x * 2 for x in l]

    opts = data_structures.ListStackOpts(
        element_dtype=None, original_call=dummy_function)

    self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4])
Пример #6
0
    def test_stack_tensor_list(self):
        initial_list = constant_op.constant([[1, 2], [3, 4]])
        elem_shape = constant_op.constant([2])
        l = list_ops.tensor_list_from_tensor(initial_list,
                                             element_shape=elem_shape)

        opts = data_structures.ListStackOpts(element_dtype=initial_list.dtype,
                                             original_call=None)

        with self.test_session() as sess:
            t = data_structures.list_stack(l, opts)
            self.assertAllEqual(sess.run(t), sess.run(initial_list))
Пример #7
0
def stack(list_or_tensor, element_dtype=None):
  """Stacks the input, if it admits the notion of stacking. No-op otherwise.

  For example, a list of tensors can be stacked into a larger tensor. This
  function is similar to tf.stack, but it accepts non-lists and lists of
  non-tensors as arguments. In the latter case, the function does nothing.

  Args:
    list_or_tensor: Any entity.
    element_dtype: Optional dtype for the elements in the list. Required if the
        input is stackable, and the list is untyped.

  Returns:
    If the input is stackable, a new object representing the stacked inputs.
  Otherwise it returns list_or_tensor unchanged.
  """
  return data_structures.list_stack(
      list_or_tensor,
      data_structures.ListStackOpts(
          element_dtype=element_dtype, original_call=lambda x: x))