示例#1
0
def _scan(  # pylint: disable=unused-argument
        fn,
        elems,
        initializer=None,
        parallel_iterations=10,
        back_prop=True,
        swap_memory=False,
        infer_shape=True,
        reverse=False,
        name=None):
    """Scan implementation."""

    if reverse:
        elems = nest.map_structure(lambda x: x[::-1], elems)

    if initializer is None:
        if nest.is_nested(elems):
            raise NotImplementedError
        initializer = elems[0]
        elems = elems[1:]
        prepend = [[initializer]]
    else:
        prepend = None

    def func(arg, x):
        return nest.flatten(
            fn(nest.pack_sequence_as(initializer, arg),
               nest.pack_sequence_as(elems, x)))

    arg = nest.flatten(initializer)
    if JAX_MODE:
        from jax import lax  # pylint: disable=g-import-not-at-top

        def scan_body(arg, x):
            arg = func(arg, x)
            return arg, arg

        _, out = lax.scan(scan_body, arg, nest.flatten(elems))
    else:
        out = [[] for _ in range(len(arg))]
        for x in zip(*nest.flatten(elems)):
            arg = func(arg, x)
            for i, z in enumerate(arg):
                out[i].append(z)

    if prepend is not None:
        out = [pre + list(o) for (pre, o) in zip(prepend, out)]

    ordering = (lambda x: x[::-1]) if reverse else (lambda x: x)
    return nest.pack_sequence_as(initializer,
                                 [ordering(np.array(o)) for o in out])
示例#2
0
def _infer_dtype(value, default_dtype):
  """Guesses an object's dtype."""
  # Need to check for onp type first because onp types are subclasses of Python
  # types.
  if hasattr(value, 'dtype'):
    # Duck-typing onp types
    return value.dtype
  elif isinstance(value, bool):
    return np.bool
  elif isinstance(value, six.integer_types):
    return np.int32
  elif isinstance(value, float):
    return np.float32
  elif isinstance(value, complex):
    return np.complex128
  elif isinstance(value, (tuple, list)):
    # Try inferring the type from items in the object if possible.
    for v in nest.flatten(value):
      if hasattr(v, 'dtype'):
        return v.dtype
    try:  # Finally fall back to raw types (int, bool).
      return _infer_dtype(value[0], default_dtype)
    except (IndexError, TypeError):
      return default_dtype
  raise ValueError(('Attempt to convert a value ({})'
                    ' with an unsupported type ({}) to a Tensor.').format(
                        value, type(value)))
示例#3
0
def pfor(fn, n):
  if JAX_MODE:
    import jax  # pylint: disable=g-import-not-at-top
    return jax.vmap(fn)(np.arange(n))
  outs = [fn(i) for i in range(n)]
  flat_outs = [nest.flatten(o) for o in outs]
  return nest.pack_sequence_as(
      outs[0], [np.array(o) for o in zip(*flat_outs)])
示例#4
0
def _foldl_jax(fn, elems, initializer=None, parallel_iterations=10,  # pylint: disable=unused-argument
               back_prop=True, swap_memory=False, name=None):  # pylint: disable=unused-argument
  """tf.foldl, in JAX."""
  if initializer is None:
    initializer = nest.map_structure(lambda el: el[0], elems)
    elems = nest.map_structure(lambda el: el[1:], elems)
  if len(set(nest.flatten(nest.map_structure(len, elems)))) != 1:
    raise ValueError(
        'Mismatched element sizes: {}'.format(nest.map_structure(len, elems)))
  from jax import lax  # pylint: disable=g-import-not-at-top
  return lax.scan(
      lambda carry, el: (fn(carry, el), None), initializer, elems)[0]
示例#5
0
def _while_loop_jax(
        cond,
        body,
        loop_vars,  # pylint: disable=redefined-outer-name
        shape_invariants=None,
        parallel_iterations=10,  # pylint: disable=unused-argument
        back_prop=True,
        swap_memory=False,  # pylint: disable=unused-argument
        maximum_iterations=None,
        name=None):  # pylint: disable=unused-argument
    """Jax implementation of `tf.while_loop`."""
    from jax import lax  # pylint: disable=g-import-not-at-top

    pack_body = lambda x: nest.pack_sequence_as(loop_vars, nest.flatten(x))

    if maximum_iterations is None:

        def override_body_fn(args):
            return pack_body(body(*args))

        def override_cond_fn(args):
            return cond(*args)

        return lax.while_loop(override_cond_fn, override_body_fn, loop_vars)
    elif back_prop:

        def override_body_fn(args, _):
            c = cond(*args)
            sc = ops.get_static_value(c)
            if sc is None:
                args = lax.cond(c, args, lambda args: pack_body(body(*args)),
                                args, lambda args: args)
            elif sc:
                args = pack_body(body(*args))
            return args, ()

        loop_vars, _ = lax.scan(override_body_fn,
                                loop_vars,
                                xs=None,
                                length=maximum_iterations)
        return loop_vars
    else:

        def override_body_fn(args):
            i, args = args
            return i + 1, pack_body(body(*args))

        def override_cond_fn(args):
            i, args = args
            return cond(*args) & (i < maximum_iterations)

        return lax.while_loop(override_cond_fn, override_body_fn,
                              (np.array(0), loop_vars))[1]
示例#6
0
def _foldl(fn, elems, initializer=None, parallel_iterations=10,  # pylint: disable=unused-argument
           back_prop=True, swap_memory=False, name=None):  # pylint: disable=unused-argument
  """tf.foldl, in numpy."""
  elems_flat = nest.flatten(elems)
  if initializer is None:
    initializer = nest.map_structure(lambda el: el[0], elems)
    elems_flat = [el[1:] for el in elems_flat]
  if len({len(el) for el in elems_flat}) != 1:
    raise ValueError(
        'Mismatched element sizes: {}'.format(nest.map_structure(len, elems)))
  carry = initializer
  for el in zip(*elems_flat):
    carry = fn(carry, nest.pack_sequence_as(elems, el))
  return carry
示例#7
0
        def jit_wrapper(*args, **kwargs):

          @functools.wraps(f)
          def unflatten_f(*args_flat):
            unflat_args, unflat_kwargs = nest.pack_sequence_as(
                (args, kwargs), args_flat)
            return f(*unflat_args, **unflat_kwargs)

          args_flat = nest.flatten((args, kwargs))
          static_argnums = tuple(
              i for (i, arg) in enumerate(args_flat) if non_jittable(arg))
          cache_key = (static_argnums, len(args), tuple(kwargs.keys()))
          if cache.get(cache_key, None) is None:
            cache[cache_key] = jit(unflatten_f, static_argnums=static_argnums)
          return cache[cache_key](*args_flat)
示例#8
0
def common_dtype(args_list, dtype_hint=None):
    """Returns explict dtype from `args_list` if exists, else dtype_hint."""
    dtype = None
    for a in nest.flatten(args_list):
        if hasattr(a, 'dtype'):
            dt = a.dtype
        else:
            continue
        if dtype is None:
            dtype = dt
        elif dtype != dt:
            raise TypeError('Found incompatible dtypes, {} and {}'.format(
                dtype, dt))
    if dtype is None and dtype_hint is None:
        return None
    return dtype_hint if dtype is None else dtype
示例#9
0
def _scan(  # pylint: disable=unused-argument
    fn,
    elems,
    initializer=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    infer_shape=True,
    reverse=False,
    name=None):
  """Scan implementation."""

  if reverse:
    elems = nest.map_structure(lambda x: x[::-1], elems)

  if initializer is None:
    initializer = nest.map_structure(
        lambda x: x[0], elems, expand_composites=True)
    elems = nest.map_structure(lambda x: x[1:], elems, expand_composites=True)
    prepend = initializer
  else:
    prepend = None

  if JAX_MODE:
    from jax import lax  # pylint: disable=g-import-not-at-top
    def scan_body(arg, x):
      arg = fn(arg, x)
      return arg, arg

    _, out = lax.scan(scan_body, initializer, elems)
  else:
    length = len(nest.flatten(elems)[0])
    arg = initializer
    out = []
    for i in range(length):
      arg = fn(arg, nest.map_structure(lambda x: x[i], elems))  # pylint: disable=cell-var-from-loop
      out.append(arg)
    out = nest.map_structure(lambda *x: np.stack(x, axis=0), *out)

  if prepend is not None:
    out = nest.map_structure(
        lambda p, o: np.concatenate([p[np.newaxis], o], axis=0), prepend, out)

  ordering = (lambda x: x[::-1]) if reverse else (lambda x: x)
  return nest.map_structure(ordering, out, expand_composites=True)
示例#10
0
 def func(arg, x):
     return nest.flatten(
         fn(nest.pack_sequence_as(initializer, arg),
            nest.pack_sequence_as(elems, x)))