def pfor(fn, n): if JAX_MODE: import jax # pylint: disable=g-import-not-at-top return jax.vmap(fn)(np.arange(n)) outs = [fn(i) for i in range(n)] flat_outs = [nest.flatten(o) for o in outs] return nest.pack_sequence_as( outs[0], [np.array(o) for o in zip(*flat_outs)])
def _while_loop_jax( cond, body, loop_vars, # pylint: disable=redefined-outer-name shape_invariants=None, parallel_iterations=10, # pylint: disable=unused-argument back_prop=True, swap_memory=False, # pylint: disable=unused-argument maximum_iterations=None, name=None): # pylint: disable=unused-argument """Jax implementation of `tf.while_loop`.""" from jax import lax # pylint: disable=g-import-not-at-top pack_body = lambda x: nest.pack_sequence_as(loop_vars, nest.flatten(x)) if maximum_iterations is None: def override_body_fn(args): return pack_body(body(*args)) def override_cond_fn(args): return cond(*args) return lax.while_loop(override_cond_fn, override_body_fn, loop_vars) elif back_prop: def override_body_fn(args, _): c = cond(*args) sc = ops.get_static_value(c) if sc is None: args = lax.cond(c, args, lambda args: pack_body(body(*args)), args, lambda args: args) elif sc: args = pack_body(body(*args)) return args, () loop_vars, _ = lax.scan(override_body_fn, loop_vars, xs=None, length=maximum_iterations) return loop_vars else: def override_body_fn(args): i, args = args return i + 1, pack_body(body(*args)) def override_cond_fn(args): i, args = args return cond(*args) & (i < maximum_iterations) return lax.while_loop(override_cond_fn, override_body_fn, (np.array(0), loop_vars))[1]
def _scan( # pylint: disable=unused-argument fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, infer_shape=True, reverse=False, name=None): """Scan implementation.""" if reverse: elems = nest.map_structure(lambda x: x[::-1], elems) if initializer is None: if nest.is_nested(elems): raise NotImplementedError initializer = elems[0] elems = elems[1:] prepend = [[initializer]] else: prepend = None def func(arg, x): return nest.flatten( fn(nest.pack_sequence_as(initializer, arg), nest.pack_sequence_as(elems, x))) arg = nest.flatten(initializer) if JAX_MODE: from jax import lax # pylint: disable=g-import-not-at-top def scan_body(arg, x): arg = func(arg, x) return arg, arg _, out = lax.scan(scan_body, arg, nest.flatten(elems)) else: out = [[] for _ in range(len(arg))] for x in zip(*nest.flatten(elems)): arg = func(arg, x) for i, z in enumerate(arg): out[i].append(z) if prepend is not None: out = [pre + list(o) for (pre, o) in zip(prepend, out)] ordering = (lambda x: x[::-1]) if reverse else (lambda x: x) return nest.pack_sequence_as(initializer, [ordering(np.array(o)) for o in out])
def _foldl(fn, elems, initializer=None, parallel_iterations=10, # pylint: disable=unused-argument back_prop=True, swap_memory=False, name=None): # pylint: disable=unused-argument """tf.foldl, in numpy.""" elems_flat = nest.flatten(elems) if initializer is None: initializer = nest.map_structure(lambda el: el[0], elems) elems_flat = [el[1:] for el in elems_flat] if len({len(el) for el in elems_flat}) != 1: raise ValueError( 'Mismatched element sizes: {}'.format(nest.map_structure(len, elems))) carry = initializer for el in zip(*elems_flat): carry = fn(carry, nest.pack_sequence_as(elems, el)) return carry
def unflatten_f(*args_flat): unflat_args, unflat_kwargs = nest.pack_sequence_as( (args, kwargs), args_flat) return f(*unflat_args, **unflat_kwargs)
def func(arg, x): return nest.flatten( fn(nest.pack_sequence_as(initializer, arg), nest.pack_sequence_as(elems, x)))