Example #1
0
def pfor(fn, n):
  if JAX_MODE:
    import jax  # pylint: disable=g-import-not-at-top
    return jax.vmap(fn)(np.arange(n))
  outs = [fn(i) for i in range(n)]
  flat_outs = [nest.flatten(o) for o in outs]
  return nest.pack_sequence_as(
      outs[0], [np.array(o) for o in zip(*flat_outs)])
Example #2
0
def _while_loop_jax(
        cond,
        body,
        loop_vars,  # pylint: disable=redefined-outer-name
        shape_invariants=None,
        parallel_iterations=10,  # pylint: disable=unused-argument
        back_prop=True,
        swap_memory=False,  # pylint: disable=unused-argument
        maximum_iterations=None,
        name=None):  # pylint: disable=unused-argument
    """Jax implementation of `tf.while_loop`."""
    from jax import lax  # pylint: disable=g-import-not-at-top

    pack_body = lambda x: nest.pack_sequence_as(loop_vars, nest.flatten(x))

    if maximum_iterations is None:

        def override_body_fn(args):
            return pack_body(body(*args))

        def override_cond_fn(args):
            return cond(*args)

        return lax.while_loop(override_cond_fn, override_body_fn, loop_vars)
    elif back_prop:

        def override_body_fn(args, _):
            c = cond(*args)
            sc = ops.get_static_value(c)
            if sc is None:
                args = lax.cond(c, args, lambda args: pack_body(body(*args)),
                                args, lambda args: args)
            elif sc:
                args = pack_body(body(*args))
            return args, ()

        loop_vars, _ = lax.scan(override_body_fn,
                                loop_vars,
                                xs=None,
                                length=maximum_iterations)
        return loop_vars
    else:

        def override_body_fn(args):
            i, args = args
            return i + 1, pack_body(body(*args))

        def override_cond_fn(args):
            i, args = args
            return cond(*args) & (i < maximum_iterations)

        return lax.while_loop(override_cond_fn, override_body_fn,
                              (np.array(0), loop_vars))[1]
Example #3
0
def _scan(  # pylint: disable=unused-argument
        fn,
        elems,
        initializer=None,
        parallel_iterations=10,
        back_prop=True,
        swap_memory=False,
        infer_shape=True,
        reverse=False,
        name=None):
    """Scan implementation."""

    if reverse:
        elems = nest.map_structure(lambda x: x[::-1], elems)

    if initializer is None:
        if nest.is_nested(elems):
            raise NotImplementedError
        initializer = elems[0]
        elems = elems[1:]
        prepend = [[initializer]]
    else:
        prepend = None

    def func(arg, x):
        return nest.flatten(
            fn(nest.pack_sequence_as(initializer, arg),
               nest.pack_sequence_as(elems, x)))

    arg = nest.flatten(initializer)
    if JAX_MODE:
        from jax import lax  # pylint: disable=g-import-not-at-top

        def scan_body(arg, x):
            arg = func(arg, x)
            return arg, arg

        _, out = lax.scan(scan_body, arg, nest.flatten(elems))
    else:
        out = [[] for _ in range(len(arg))]
        for x in zip(*nest.flatten(elems)):
            arg = func(arg, x)
            for i, z in enumerate(arg):
                out[i].append(z)

    if prepend is not None:
        out = [pre + list(o) for (pre, o) in zip(prepend, out)]

    ordering = (lambda x: x[::-1]) if reverse else (lambda x: x)
    return nest.pack_sequence_as(initializer,
                                 [ordering(np.array(o)) for o in out])
Example #4
0
def _foldl(fn, elems, initializer=None, parallel_iterations=10,  # pylint: disable=unused-argument
           back_prop=True, swap_memory=False, name=None):  # pylint: disable=unused-argument
  """tf.foldl, in numpy."""
  elems_flat = nest.flatten(elems)
  if initializer is None:
    initializer = nest.map_structure(lambda el: el[0], elems)
    elems_flat = [el[1:] for el in elems_flat]
  if len({len(el) for el in elems_flat}) != 1:
    raise ValueError(
        'Mismatched element sizes: {}'.format(nest.map_structure(len, elems)))
  carry = initializer
  for el in zip(*elems_flat):
    carry = fn(carry, nest.pack_sequence_as(elems, el))
  return carry
Example #5
0
 def unflatten_f(*args_flat):
     unflat_args, unflat_kwargs = nest.pack_sequence_as(
         (args, kwargs), args_flat)
     return f(*unflat_args, **unflat_kwargs)
Example #6
0
 def func(arg, x):
     return nest.flatten(
         fn(nest.pack_sequence_as(initializer, arg),
            nest.pack_sequence_as(elems, x)))