Ejemplo n.º 1
0
def _foldl_jax(fn, elems, initializer=None, parallel_iterations=10,  # pylint: disable=unused-argument
               back_prop=True, swap_memory=False, name=None):  # pylint: disable=unused-argument
  """tf.foldl, in JAX."""
  if initializer is None:
    initializer = nest.map_structure(lambda el: el[0], elems)
    elems = nest.map_structure(lambda el: el[1:], elems)
  if len(set(nest.flatten(nest.map_structure(len, elems)))) != 1:
    raise ValueError(
        'Mismatched element sizes: {}'.format(nest.map_structure(len, elems)))
  from jax import lax  # pylint: disable=g-import-not-at-top
  return lax.scan(
      lambda carry, el: (fn(carry, el), None), initializer, elems)[0]
Ejemplo n.º 2
0
def _foldl(fn, elems, initializer=None, parallel_iterations=10,  # pylint: disable=unused-argument
           back_prop=True, swap_memory=False, name=None):  # pylint: disable=unused-argument
  """tf.foldl, in numpy."""
  elems_flat = nest.flatten(elems)
  if initializer is None:
    initializer = nest.map_structure(lambda el: el[0], elems)
    elems_flat = [el[1:] for el in elems_flat]
  if len({len(el) for el in elems_flat}) != 1:
    raise ValueError(
        'Mismatched element sizes: {}'.format(nest.map_structure(len, elems)))
  carry = initializer
  for el in zip(*elems_flat):
    carry = fn(carry, nest.pack_sequence_as(elems, el))
  return carry
Ejemplo n.º 3
0
def _scan(  # pylint: disable=unused-argument
        fn,
        elems,
        initializer=None,
        parallel_iterations=10,
        back_prop=True,
        swap_memory=False,
        infer_shape=True,
        reverse=False,
        name=None):
    """Scan implementation."""

    if reverse:
        elems = nest.map_structure(lambda x: x[::-1], elems)

    if initializer is None:
        if nest.is_nested(elems):
            raise NotImplementedError
        initializer = elems[0]
        elems = elems[1:]
        prepend = [[initializer]]
    else:
        prepend = None

    def func(arg, x):
        return nest.flatten(
            fn(nest.pack_sequence_as(initializer, arg),
               nest.pack_sequence_as(elems, x)))

    arg = nest.flatten(initializer)
    if JAX_MODE:
        from jax import lax  # pylint: disable=g-import-not-at-top

        def scan_body(arg, x):
            arg = func(arg, x)
            return arg, arg

        _, out = lax.scan(scan_body, arg, nest.flatten(elems))
    else:
        out = [[] for _ in range(len(arg))]
        for x in zip(*nest.flatten(elems)):
            arg = func(arg, x)
            for i, z in enumerate(arg):
                out[i].append(z)

    if prepend is not None:
        out = [pre + list(o) for (pre, o) in zip(prepend, out)]

    ordering = (lambda x: x[::-1]) if reverse else (lambda x: x)
    return nest.pack_sequence_as(initializer,
                                 [ordering(np.array(o)) for o in out])
Ejemplo n.º 4
0
def _scan(  # pylint: disable=unused-argument
    fn,
    elems,
    initializer=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    infer_shape=True,
    reverse=False,
    name=None):
  """Scan implementation."""

  if reverse:
    elems = nest.map_structure(lambda x: x[::-1], elems)

  if initializer is None:
    initializer = nest.map_structure(
        lambda x: x[0], elems, expand_composites=True)
    elems = nest.map_structure(lambda x: x[1:], elems, expand_composites=True)
    prepend = initializer
  else:
    prepend = None

  if JAX_MODE:
    from jax import lax  # pylint: disable=g-import-not-at-top
    def scan_body(arg, x):
      arg = fn(arg, x)
      return arg, arg

    _, out = lax.scan(scan_body, initializer, elems)
  else:
    length = len(nest.flatten(elems)[0])
    arg = initializer
    out = []
    for i in range(length):
      arg = fn(arg, nest.map_structure(lambda x: x[i], elems))  # pylint: disable=cell-var-from-loop
      out.append(arg)
    out = nest.map_structure(lambda *x: np.stack(x, axis=0), *out)

  if prepend is not None:
    out = nest.map_structure(
        lambda p, o: np.concatenate([p[np.newaxis], o], axis=0), prepend, out)

  ordering = (lambda x: x[::-1]) if reverse else (lambda x: x)
  return nest.map_structure(ordering, out, expand_composites=True)
Ejemplo n.º 5
0
 def evaluate(self, x):
   def _evaluate(x):
     if x is None:
       return x
     return onp.array(x)
   return nest.map_structure(_evaluate, x, expand_composites=True)
Ejemplo n.º 6
0
 def _experimental_parameter_ndims_to_matrix_ndims(self):
     # None of the operators contribute to the matrix shape.
     return {"operators": nest.map_structure(lambda _: 0, self.operators)}