def stack(list_or_tensor, element_dtype=None, strict=True): """Stacks the input, if it admits the notion of stacking. For example, a list of tensors can be stacked into a larger tensor. This function is similar to tf.stack, but it accepts non-lists and lists of non-tensors as arguments. In the latter case, the function does nothing. Args: list_or_tensor: Any element_dtype: tf.DType, optional dtypedtype for the elements in the list. Required if the input is stackable, and the list is untyped. strict: bool, if True an error is raised if the input is not stackable. Otherwise the function is a no-op. Returns: Any, if the input is stackable, the result will be a tf.Tensor. Otherwise, if strict=False, the result will be list_or_tensor. Raises: ValueError: if strict=True and the input is not stackable. """ if strict: def raise_error(x): raise ValueError('%s must be stackable when strict=True' % x) original_call = raise_error else: original_call = lambda x: x return data_structures.list_stack( list_or_tensor, data_structures.ListStackOpts(element_dtype=element_dtype, original_call=original_call))
def test_stack_fallback(self): def dummy_function(l): # Lazy person's mock: just transform the argument in a way in which we # can check that this function was indeed called. return [x * 2 for x in l] opts = data_structures.ListStackOpts(element_dtype=None, original_call=dummy_function) self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4])
def test_stack_tensor_list(self): initial_list = constant_op.constant([[1, 2], [3, 4]]) elem_shape = constant_op.constant([2]) l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) opts = data_structures.ListStackOpts(element_dtype=initial_list.dtype, original_call=None) with self.test_session() as sess: t = data_structures.list_stack(l, opts) self.assertAllEqual(sess.run(t), sess.run(initial_list))
def stack(list_or_tensor, element_dtype=None): """Stacks the input, if it admits the notion of stacking. No-op otherwise. For example, a list of tensors can be stacked into a larger tensor. This function is similar to tf.stack, but it accepts non-lists and lists of non-tensors as arguments. In the latter case, the function does nothing. Args: list_or_tensor: Any entity. element_dtype: Optional dtype for the elements in the list. Required if the input is stackable, and the list is untyped. Returns: If the input is stackable, a new object representing the stacked inputs. Otherwise it returns list_or_tensor unchanged. """ return data_structures.list_stack( list_or_tensor, data_structures.ListStackOpts( element_dtype=element_dtype, original_call=lambda x: x))