def _scan( # pylint: disable=unused-argument fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, infer_shape=True, reverse=False, name=None): """Scan implementation.""" if reverse: elems = nest.map_structure(lambda x: x[::-1], elems) if initializer is None: if nest.is_nested(elems): raise NotImplementedError initializer = elems[0] elems = elems[1:] prepend = [[initializer]] else: prepend = None def func(arg, x): return nest.flatten( fn(nest.pack_sequence_as(initializer, arg), nest.pack_sequence_as(elems, x))) arg = nest.flatten(initializer) if JAX_MODE: from jax import lax # pylint: disable=g-import-not-at-top def scan_body(arg, x): arg = func(arg, x) return arg, arg _, out = lax.scan(scan_body, arg, nest.flatten(elems)) else: out = [[] for _ in range(len(arg))] for x in zip(*nest.flatten(elems)): arg = func(arg, x) for i, z in enumerate(arg): out[i].append(z) if prepend is not None: out = [pre + list(o) for (pre, o) in zip(prepend, out)] ordering = (lambda x: x[::-1]) if reverse else (lambda x: x) return nest.pack_sequence_as(initializer, [ordering(np.array(o)) for o in out])
def _infer_dtype(value, default_dtype): """Guesses an object's dtype.""" # Need to check for onp type first because onp types are subclasses of Python # types. if hasattr(value, 'dtype'): # Duck-typing onp types return value.dtype elif isinstance(value, bool): return np.bool elif isinstance(value, six.integer_types): return np.int32 elif isinstance(value, float): return np.float32 elif isinstance(value, complex): return np.complex128 elif isinstance(value, (tuple, list)): # Try inferring the type from items in the object if possible. for v in nest.flatten(value): if hasattr(v, 'dtype'): return v.dtype try: # Finally fall back to raw types (int, bool). return _infer_dtype(value[0], default_dtype) except (IndexError, TypeError): return default_dtype raise ValueError(('Attempt to convert a value ({})' ' with an unsupported type ({}) to a Tensor.').format( value, type(value)))
def pfor(fn, n): if JAX_MODE: import jax # pylint: disable=g-import-not-at-top return jax.vmap(fn)(np.arange(n)) outs = [fn(i) for i in range(n)] flat_outs = [nest.flatten(o) for o in outs] return nest.pack_sequence_as( outs[0], [np.array(o) for o in zip(*flat_outs)])
def _foldl_jax(fn, elems, initializer=None, parallel_iterations=10, # pylint: disable=unused-argument back_prop=True, swap_memory=False, name=None): # pylint: disable=unused-argument """tf.foldl, in JAX.""" if initializer is None: initializer = nest.map_structure(lambda el: el[0], elems) elems = nest.map_structure(lambda el: el[1:], elems) if len(set(nest.flatten(nest.map_structure(len, elems)))) != 1: raise ValueError( 'Mismatched element sizes: {}'.format(nest.map_structure(len, elems))) from jax import lax # pylint: disable=g-import-not-at-top return lax.scan( lambda carry, el: (fn(carry, el), None), initializer, elems)[0]
def _while_loop_jax( cond, body, loop_vars, # pylint: disable=redefined-outer-name shape_invariants=None, parallel_iterations=10, # pylint: disable=unused-argument back_prop=True, swap_memory=False, # pylint: disable=unused-argument maximum_iterations=None, name=None): # pylint: disable=unused-argument """Jax implementation of `tf.while_loop`.""" from jax import lax # pylint: disable=g-import-not-at-top pack_body = lambda x: nest.pack_sequence_as(loop_vars, nest.flatten(x)) if maximum_iterations is None: def override_body_fn(args): return pack_body(body(*args)) def override_cond_fn(args): return cond(*args) return lax.while_loop(override_cond_fn, override_body_fn, loop_vars) elif back_prop: def override_body_fn(args, _): c = cond(*args) sc = ops.get_static_value(c) if sc is None: args = lax.cond(c, args, lambda args: pack_body(body(*args)), args, lambda args: args) elif sc: args = pack_body(body(*args)) return args, () loop_vars, _ = lax.scan(override_body_fn, loop_vars, xs=None, length=maximum_iterations) return loop_vars else: def override_body_fn(args): i, args = args return i + 1, pack_body(body(*args)) def override_cond_fn(args): i, args = args return cond(*args) & (i < maximum_iterations) return lax.while_loop(override_cond_fn, override_body_fn, (np.array(0), loop_vars))[1]
def _foldl(fn, elems, initializer=None, parallel_iterations=10, # pylint: disable=unused-argument back_prop=True, swap_memory=False, name=None): # pylint: disable=unused-argument """tf.foldl, in numpy.""" elems_flat = nest.flatten(elems) if initializer is None: initializer = nest.map_structure(lambda el: el[0], elems) elems_flat = [el[1:] for el in elems_flat] if len({len(el) for el in elems_flat}) != 1: raise ValueError( 'Mismatched element sizes: {}'.format(nest.map_structure(len, elems))) carry = initializer for el in zip(*elems_flat): carry = fn(carry, nest.pack_sequence_as(elems, el)) return carry
def jit_wrapper(*args, **kwargs): @functools.wraps(f) def unflatten_f(*args_flat): unflat_args, unflat_kwargs = nest.pack_sequence_as( (args, kwargs), args_flat) return f(*unflat_args, **unflat_kwargs) args_flat = nest.flatten((args, kwargs)) static_argnums = tuple( i for (i, arg) in enumerate(args_flat) if non_jittable(arg)) cache_key = (static_argnums, len(args), tuple(kwargs.keys())) if cache.get(cache_key, None) is None: cache[cache_key] = jit(unflatten_f, static_argnums=static_argnums) return cache[cache_key](*args_flat)
def common_dtype(args_list, dtype_hint=None): """Returns explict dtype from `args_list` if exists, else dtype_hint.""" dtype = None for a in nest.flatten(args_list): if hasattr(a, 'dtype'): dt = a.dtype else: continue if dtype is None: dtype = dt elif dtype != dt: raise TypeError('Found incompatible dtypes, {} and {}'.format( dtype, dt)) if dtype is None and dtype_hint is None: return None return dtype_hint if dtype is None else dtype
def _scan( # pylint: disable=unused-argument fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, infer_shape=True, reverse=False, name=None): """Scan implementation.""" if reverse: elems = nest.map_structure(lambda x: x[::-1], elems) if initializer is None: initializer = nest.map_structure( lambda x: x[0], elems, expand_composites=True) elems = nest.map_structure(lambda x: x[1:], elems, expand_composites=True) prepend = initializer else: prepend = None if JAX_MODE: from jax import lax # pylint: disable=g-import-not-at-top def scan_body(arg, x): arg = fn(arg, x) return arg, arg _, out = lax.scan(scan_body, initializer, elems) else: length = len(nest.flatten(elems)[0]) arg = initializer out = [] for i in range(length): arg = fn(arg, nest.map_structure(lambda x: x[i], elems)) # pylint: disable=cell-var-from-loop out.append(arg) out = nest.map_structure(lambda *x: np.stack(x, axis=0), *out) if prepend is not None: out = nest.map_structure( lambda p, o: np.concatenate([p[np.newaxis], o], axis=0), prepend, out) ordering = (lambda x: x[::-1]) if reverse else (lambda x: x) return nest.map_structure(ordering, out, expand_composites=True)
def func(arg, x): return nest.flatten( fn(nest.pack_sequence_as(initializer, arg), nest.pack_sequence_as(elems, x)))