def _map_fn( # pylint: disable=unused-argument fn, elems, dtype=None, parallel_iterations=None, back_prop=True, swap_memory=False, infer_shape=True, name=None, fn_output_signature=None): """Numpy implementation of tf.map_fn.""" if fn_output_signature is not None and nest.is_nested(fn_output_signature): # If fn returns a tuple, then map_fn returns a tuple as well; and similarly # for lists and more complex nestings. We do not support this behavior at # this time, so we raise an error explicitly instead of silently doing the # wrong thing. raise NotImplementedError if JAX_MODE: from jax import tree_util # pylint: disable=g-import-not-at-top elems_flat, in_tree = tree_util.tree_flatten(elems) elems_zipped = zip(*elems_flat) def func(flat_args): unflat_args = tree_util.tree_unflatten(in_tree, flat_args) return fn(unflat_args) return np.stack([func(x) for x in elems_zipped]) if isinstance(elems, np.ndarray): return np.array([fn(x) for x in elems]) # In the NumPy backend, we do not yet support map_fn over lists, tuples, or # other structures. raise NotImplementedError
def _scan( # pylint: disable=unused-argument fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, infer_shape=True, reverse=False, name=None): """Scan implementation.""" if reverse: elems = nest.map_structure(lambda x: x[::-1], elems) if initializer is None: if nest.is_nested(elems): raise NotImplementedError initializer = elems[0] elems = elems[1:] prepend = [[initializer]] else: prepend = None def func(arg, x): return nest.flatten( fn(nest.pack_sequence_as(initializer, arg), nest.pack_sequence_as(elems, x))) arg = nest.flatten(initializer) if JAX_MODE: from jax import lax # pylint: disable=g-import-not-at-top def scan_body(arg, x): arg = func(arg, x) return arg, arg _, out = lax.scan(scan_body, arg, nest.flatten(elems)) else: out = [[] for _ in range(len(arg))] for x in zip(*nest.flatten(elems)): arg = func(arg, x) for i, z in enumerate(arg): out[i].append(z) if prepend is not None: out = [pre + list(o) for (pre, o) in zip(prepend, out)] ordering = (lambda x: x[::-1]) if reverse else (lambda x: x) return nest.pack_sequence_as(initializer, [ordering(np.array(o)) for o in out])
def arg_is_blockwise(block_dimensions, arg, arg_split_dim): """Detect if input should be interpreted as a list of blocks.""" # Tuples and lists of length equal to the number of operators may be # blockwise. if (isinstance(arg, (tuple, list)) and len(arg) == len(block_dimensions)): # If the elements of the iterable are not nested, interpret the input as # blockwise. if not any(nest.is_nested(x) for x in arg): return True else: arg_dims = [ ops.convert_to_tensor(x).shape[arg_split_dim] for x in arg ] self_dims = [dim.value for dim in block_dimensions] # If none of the operator dimensions are known, interpret the input as # blockwise if its matching dimensions are unequal. if all(self_d is None for self_d in self_dims): # A nested tuple/list with a single outermost element is not blockwise if len(arg_dims) == 1: return False elif any(dim != arg_dims[0] for dim in arg_dims): return True else: raise ValueError( "Parsing of the input structure is ambiguous. Please input " "a blockwise iterable of `Tensor`s or a single `Tensor`." ) # If input dimensions equal the respective (known) blockwise operator # dimensions, then the input is blockwise. if all(self_d == arg_d or self_d is None for self_d, arg_d in zip(self_dims, arg_dims)): return True # If input dimensions equals are all equal, and are greater than or equal # to the sum of the known operator dimensions, interpret the input as # blockwise. # input is not blockwise. self_dim = sum(self_d for self_d in self_dims if self_d is not None) if all(s == arg_dims[0] for s in arg_dims) and arg_dims[0] >= self_dim: return False # If none of these conditions is met, the input shape is mismatched. raise ValueError( "Input dimension does not match operator dimension.") else: return False