Exemple #1
0
def omnistaging_disabler() -> None:
  global axis_index

  psum_p.bind = partial(core.Primitive.bind, psum_p)  # type: ignore
  psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p))  # type: ignore
  pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args)  # type: ignore

  def _axis_index_bind(*, axis_name):
    dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
    frame = dynamic_axis_env[axis_name]
    sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1]
    nreps = dynamic_axis_env.nreps
    trace = frame.pmap_trace

    out_aval = ShapedArray((), np.int32)
    out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
    eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
                            dict(nreps=nreps, sizes=sizes, axis_name=axis_name),
                            source_info_util.current())
    out_tracer.recipe = eqn

    return out_tracer

  def _axis_index_translation_rule(c, nreps, sizes, axis_name):
    div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32))
    mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32))
    unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
    return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))

  axis_index_p.def_custom_bind(_axis_index_bind)
  axis_index_p.def_abstract_eval(
      lambda *args, **params: ShapedArray((), np.int32))
  xla.translations[axis_index_p] = _axis_index_translation_rule
Exemple #2
0
def _all_to_all_translation_rule(c, x, *, split_axis, concat_axis, axis_name,
                                 axis_index_groups, axis_env, platform):
  # Workaround for AllToAll not being implemented on CPU.
  replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
  if len(replica_groups[0]) == 1:
    return x
  elif platform != 'tpu':
    warnings.warn("all_to_all (and pswapaxes) are only implemented properly for TPUs. All other "
                  "backends emulate it using a very slow and memory intensive algorithm, so expect "
                  "significant slowdowns.")
    lowering = xla.lower_fun(_all_to_all_via_all_gather, multiple_results=False, parallel=True)
    return lowering(c, x,
                    split_axis=split_axis, concat_axis=concat_axis, axis_name=axis_name,
                    axis_index_groups=axis_index_groups, axis_env=axis_env, platform=platform)
  else:
    split_count = len(replica_groups[0])
    if not all(split_count == len(g) for g in replica_groups):
      raise ValueError('Replica groups must be equally sized')
    replica_groups_protos = xc.make_replica_groups(replica_groups)
    if concat_axis == split_axis:
      return xops.AllToAll(x, split_axis, concat_axis, split_count,
                           replica_groups_protos)
    else:
      if concat_axis < split_axis:
        split_axis += 1
      elif split_axis < concat_axis:
        concat_axis += 1
      x = xla.lower_fun(partial(lax.expand_dims, dimensions=(concat_axis,)), multiple_results=False)(c, x)
      x = xops.AllToAll(x, split_axis, concat_axis, split_count, replica_groups_protos)
      x = xla.lower_fun(partial(lax.squeeze, dimensions=(split_axis,)), multiple_results=False)(c, x)
      return x
Exemple #3
0
  def testSumPool(self):
    W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32)
    X = jnp.array(np.random.randn(10, 5, 5, 1), dtype=np.float32)

    def f(params, x):
      one = (1, 1)
      dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
      y = lax.conv_general_dilated(
          x, params, one, 'SAME', one, one, dimension_numbers)
      y = lax.reduce_window(
          y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME')
      return y
    grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2))

    # Test forward prop.
    per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
    per_example = jnp.reshape(per_example, (10, 5, 5, 5))
    per_example_direct = f(W, X)
    self.assertAllClose(per_example, per_example_direct)

    # Test gradients.
    per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
    per_example_direct = []
    for i in range(10):
      g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1)))
      per_example_direct += [
          jnp.reshape(g, (1,) + g.shape)]
    per_example_direct = jnp.concatenate(per_example_direct, axis=0)
    self.assertAllClose(per_example, per_example_direct,
                        rtol=3e-2, atol=1e-3)
Exemple #4
0
 def value_and_jacfwd_f(*args, **kwargs):
     f = lu.wrap_init(fun, kwargs)
     f_partial, dyn_args = argnums_partial(
         f, argnums, args, require_static_args_hashable=False)
     tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
     pushfwd = partial(_jvp, f_partial, dyn_args)
     y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
     tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
     example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
     return y, tree_map(partial(_jacfwd_unravel, example_args), y, jac)
Exemple #5
0
  def testSort(self):
    v = np.arange(12)[::-1].reshape(3, 4)

    sv = vmap(partial(lax.sort, dimension=0), (0,))(v)
    self.assertAllClose(sv, v[:, ::-1])

    sv = vmap(partial(lax.sort, dimension=-1), (0,))(v)
    self.assertAllClose(sv, v[:, ::-1])

    sv = vmap(partial(lax.sort, dimension=0), (1,))(v)
    self.assertAllClose(sv, v[::-1, :].T)

    sv = vmap(partial(lax.sort, dimension=0), (1,), 1)(v)
    self.assertAllClose(sv, v[::-1, :])
Exemple #6
0
def triangular_solve_jvp_rule_a(
    g_a, ans, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
  m, n = b.shape[-2:]
  k = 1 if unit_diagonal else 0
  g_a = jnp.tril(g_a, k=-k) if lower else jnp.triu(g_a, k=k)
  g_a = lax.neg(g_a)
  g_a = jnp.swapaxes(g_a, -1, -2) if transpose_a else g_a
  g_a = jnp.conj(g_a) if conjugate_a else g_a
  dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul,
                precision=lax.Precision.HIGHEST)

  def a_inverse(rhs):
    return triangular_solve(a, rhs, left_side, lower, transpose_a, conjugate_a,
                            unit_diagonal)

  # triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs
  # for matrix/vector inputs). Order these operations in whichever order is
  # cheaper.
  if left_side:
    assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == (m, n)
    if m > n:
      return a_inverse(dot(g_a, ans))  # A^{-1} (∂A X)
    else:
      return dot(a_inverse(g_a), ans)  # (A^{-1} ∂A) X
  else:
    assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == (m, n)
    if m < n:
      return a_inverse(dot(ans, g_a))  # (X ∂A) A^{-1}
    else:
      return dot(ans, a_inverse(g_a))  # X (∂A A^{-1})
Exemple #7
0
def ppermute(x, axis_name, perm):
    """Perform a collective permutation according to the permutation ``perm``.

  If ``x`` is a pytree then the result is equivalent to mapping this function to
  each leaf in the tree.

  This function is an analog of the CollectivePermute XLA HLO.

  Args:
    x: array(s) with a mapped axis named ``axis_name``.
    axis_name: hashable Python object used to name a pmapped axis (see the
      :func:`jax.pmap` documentation for more details).
    perm: list of pairs of ints, representing
      ``(source_index, destination_index)``
      pairs that encode how the mapped axis named ``axis_name`` should be
      shuffled. The integer values are treated as indices into the mapped axis
      ``axis_name``. Any two pairs should not have the same source index or the
      same destination index. For each index of the axis ``axis_name`` that does
      not correspond to a destination index in ``perm``, the corresponding
      values in the result are filled with zeros of the appropriate type.

  Returns:
    Array(s) with the same shape as ``x`` with slices along the axis
    ``axis_name`` gathered from ``x`` according to the permutation ``perm``.
  """
    return tree_util.tree_map(
        partial(ppermute_p.bind, axis_name=axis_name, perm=tuple(perm)), x)
Exemple #8
0
 def testRandom(self):
   seeds = vmap(random.PRNGKey)(np.arange(10))
   ans = vmap(partial(random.normal, shape=(3, 2)))(seeds)
   expected = np.stack([random.normal(random.PRNGKey(seed), (3, 2))
                         for seed in np.arange(10)])
   self.assertAllClose(ans, expected, check_dtypes=False)
   assert len(np.unique(ans)) == 10 * 3 * 2
Exemple #9
0
  def testPerExampleGradients(self):
    def predict(params, inputs):
      for W, b in params:
        outputs = jnp.dot(W, inputs) + b
        inputs = jnp.tanh(outputs)
      return outputs

    def loss(params, data):
      inputs, targets = data
      predictions = predict(params, inputs)
      return jnp.sum((predictions - targets)**2)

    batch_size = 5
    layer_sizes = [3, 2, 4]

    R = np.random.RandomState(0).randn
    params = [(R(m, n), R(m))
              for m, n in zip(layer_sizes[1:], layer_sizes[:-1])]

    input_batch = R(5, 3)
    target_batch = R(5, 4)
    batch = (input_batch, target_batch)

    ans = vmap(partial(grad(loss), params))(batch)

    for ans_pair, param_pair in zip(ans, params):
      dW, db = ans_pair
      W, b = param_pair

      self.assertEqual(dW.shape, (batch_size,) + W.shape)
      self.assertEqual(db.shape, (batch_size,) + b.shape)
Exemple #10
0
def eigh_jvp_rule(primals, tangents, lower):
  # Derivative for eigh in the simplest case of distinct eigenvalues.
  # This is classic nondegenerate perurbation theory, but also see
  # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
  # The general solution treating the case of degenerate eigenvalues is
  # considerably more complicated. Ambitious readers may refer to the general
  # methods below or refer to degenerate perturbation theory in physics.
  # https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
  # https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
  a, = primals
  a_dot, = tangents

  v, w_real = eigh_p.bind(symmetrize(a), lower=lower)

  # for complex numbers we need eigenvalues to be full dtype of v, a:
  w = w_real.astype(a.dtype)
  eye_n = jnp.eye(a.shape[-1], dtype=a.dtype)
  # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
  Fmat = jnp.reciprocal(eye_n + w[..., jnp.newaxis, :] - w[..., jnp.newaxis]) - eye_n
  # eigh impl doesn't support batch dims, but future-proof the grad.
  dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
                precision=lax.Precision.HIGHEST)
  vdag_adot_v = dot(dot(_H(v), a_dot), v)
  dv = dot(v, jnp.multiply(Fmat, vdag_adot_v))
  dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1))
  return (v, w_real), (dv, dw)
Exemple #11
0
def _all_gather_via_psum(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size):
  index = axis_index(axis_name)
  if axis_index_groups is not None:
    indices = np.array(axis_index_groups).flatten()
    axis_index_to_group_index = indices.argsort() % len(axis_index_groups[0])
    index = lax_numpy.array(axis_index_to_group_index)[index]
  outs = tree_util.tree_map(partial(_expand, all_gather_dimension, axis_size, index), x)
  return psum(outs, axis_name, axis_index_groups=axis_index_groups)
Exemple #12
0
 def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes):
   rng = jtu.rand_default(self.rng())
   fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
   operand = rng(shape, dtype)
   ans = vmap(fun, (None, axis))(operand, idxs)
   expected = np.stack([fun(operand, idxs[(slice(None),) * axis + (i,)])
                         for i in range(idxs.shape[axis])])
   self.assertAllClose(ans, expected, check_dtypes=False)
Exemple #13
0
 def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes):
   rng = jtu.rand_default(self.rng())
   fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
   gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
   operand = rng(shape, dtype)
   ans = vmap(gfun, (axis, None))(operand, idxs)
   expected = np.stack([gfun(operand[(slice(None),) * axis + (i,)], idxs)
                         for i in range(operand.shape[axis])])
   self.assertAllClose(ans, expected, check_dtypes=False)
Exemple #14
0
 def test_check_jaxpr_scan_correct(self):
   def f(c, x):
     b = jnp.cos(jnp.sum(jnp.sin(x)) + jnp.sum(jnp.cos(c)))
     c = jnp.sin(c * b)
     return c, b
   xs = jnp.ones((5, 3))
   c = jnp.ones(4)
   jaxpr = make_jaxpr(partial(lax.scan, f))(c, xs).jaxpr
   core.check_jaxpr(jaxpr)
Exemple #15
0
 def test_vjp(self, f, args):
     jtu.check_vjp(f,
                   partial(vjp, f),
                   args,
                   rtol={
                       np.float32: 3e-1,
                       np.float64: 1e-5
                   },
                   atol={
                       np.float32: 1e-2,
                       np.float64: 1e-5
                   })
Exemple #16
0
  def testSortKeyVal(self):
    k = np.arange(12)[::-1].reshape(3, 4)
    v = np.random.RandomState(0).permutation(12).reshape(3, 4)

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v)
    self.assertAllClose(sk, k[:, ::-1])
    self.assertAllClose(sv, v[:, ::-1])

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v)
    self.assertAllClose(sk, k[::-1, :])
    self.assertAllClose(sv, v[::-1, :])

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T)
    self.assertAllClose(sk, k[:, ::-1])
    self.assertAllClose(sv, v[:, ::-1])

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v)
    self.assertAllClose(sk, k[:, ::-1])
    self.assertAllClose(sv, v[:, ::-1])

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v)
    self.assertAllClose(sk, np.broadcast_to(k[0, ::-1], (3, 4)))
    self.assertAllClose(sv, v[:, ::-1])

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0])
    self.assertAllClose(sk, k[:, ::-1])
    self.assertAllClose(sv, np.broadcast_to(v[0, ::-1], (3, 4)))
Exemple #17
0
 def value_and_jacrev_f(*args, **kwargs):
     f = lu.wrap_init(fun, kwargs)
     f_partial, dyn_args = argnums_partial(
         f, argnums, args, require_static_args_hashable=False)
     tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int),
              dyn_args)
     if not has_aux:
         y, pullback = _vjp(f_partial, *dyn_args)
     else:
         y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
     tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
     jac = vmap(pullback)(_std_basis(y))
     jac = jac[0] if isinstance(argnums, int) else jac
     example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
     jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac)
     if not has_aux:
         return y, tree_transpose(tree_structure(example_args),
                                  tree_structure(y), jac_tree)
     else:
         return (y, aux), tree_transpose(tree_structure(example_args),
                                         tree_structure(y), jac_tree)
     return
Exemple #18
0
def _contains_query(vals, query):
    if isinstance(query, tuple):
        return map(partial(_contains_query, vals), query)

    if jnp.isnan(query):
        if jnp.any(jnp.isnan(vals)):
            raise FoundValue('NaN')
    elif jnp.isinf(query):
        if jnp.any(jnp.isinf(vals)):
            raise FoundValue('Found Inf')
    elif jnp.isscalar(query):
        if jnp.any(vals == query):
            raise FoundValue(str(query))
    else:
        raise ValueError('Malformed Query: {}'.format(query))
Exemple #19
0
  def test_reference_cycles(self):
    gc.collect()

    def f(x):
      return x.sum()

    fn = partial(linearize, f)
    params = jnp.zeros([])

    debug = gc.get_debug()
    try:
      fn(params)
      gc.set_debug(gc.DEBUG_SAVEALL)
      self.assertEqual(gc.collect(), 0, msg=str(gc.garbage))
    finally:
      gc.set_debug(debug)
Exemple #20
0
 def tree_update(i, grad_tree, opt_state):
   states_flat, tree, subtrees = opt_state
   grad_flat, tree2 = tree_flatten(grad_tree)
   if tree2 != tree:
     msg = ("optimizer update function was passed a gradient tree that did "
            "not match the parameter tree structure with which it was "
            "initialized: parameter tree {} and grad tree {}.")
     raise TypeError(msg.format(tree, tree2))
   states = map(tree_unflatten, subtrees, states_flat)
   new_states = map(partial(update, i), grad_flat, states)
   new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states))
   for subtree, subtree2 in zip(subtrees, subtrees2):
     if subtree2 != subtree:
       msg = ("optimizer update function produced an output structure that "
              "did not match its input structure: input {} and output {}.")
       raise TypeError(msg.format(subtree, subtree2))
   return OptimizerState(new_states_flat, tree, subtrees)
Exemple #21
0
  def testConvGeneralDilatedBatchNotMajor(self):
    W = jnp.array(np.random.randn(3, 3, 1, 4), dtype=np.float32)
    x = jnp.array(np.random.randn(3, 5, 7, 5, 1), dtype=np.float32)

    def f(params, x):
      one = (1, 1)
      dimension_numbers = ('HNWC', 'HWIO', 'HWNC')
      y = lax.conv_general_dilated(
          x, params, one, 'SAME', one, one, dimension_numbers)
      return y

    per_example = vmap(partial(f, W))(x)
    per_example = jnp.reshape(jnp.transpose(per_example, (1, 2, 0, 3, 4)),
                             (5, 5, 21, 4))
    per_example_direct = f(W, jnp.reshape(jnp.transpose(x, (1, 0, 2, 3, 4)),
                                         (5, 21, 5, 1)))
    self.assertAllClose(per_example, per_example_direct)
Exemple #22
0
  def testNpMaximumPerExampleGrad(self):
    R = np.random.RandomState(0).randn
    x = R(10, 5)
    W = R(5, 5)

    fun = lambda W, x: jnp.sum(jnp.maximum(jnp.dot(x, W), 0.0) ** 2)

    ans = vmap(partial(grad(fun), W))(x)

    W_t = jnp.transpose(W)
    for i in range(10):
      x_ex = x[i:i + 1]

      expected_ans = 2.0 * jnp.dot(
          jnp.maximum(jnp.dot(W_t, jnp.transpose(x_ex)), 0.0), x_ex)
      expected_ans = jnp.transpose(expected_ans)

      self.assertAllClose(
          ans[i], expected_ans, check_dtypes=False,
          atol={np.float32:5e-2} if jtu.device_under_test() == "tpu" else None)
Exemple #23
0
def _solve(a, b):
  _check_solve_shapes(a, b)

  # Broadcast leading dimensions of b to the shape of a, as is required by
  # custom_linear_solve.
  out_shape = tuple(d_a if d_b == 1 else d_b
                    for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape))
  b = jnp.broadcast_to(b, out_shape)

  # With custom_linear_solve, we can reuse the same factorization when
  # computing sensitivities. This is considerably faster.
  lu_, _, permutation = lu(lax.stop_gradient(a))
  custom_solve = partial(
      lax.custom_linear_solve,
      lambda x: _matvec_multiply(a, x),
      solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0),
      transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1))
  if a.ndim == b.ndim + 1:
    # b.shape == [..., m]
    return custom_solve(b)
  else:
    # b.shape == [..., m, k]
    return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
Exemple #24
0
        thunk_name = 'fwd_jaxpr_thunk'
        params['bwd'] = callback_subtrace(params['bwd'], main)
    else:
        raise NotImplementedError(primitive)

    thunk = params.pop(thunk_name)

    @pe._memoize
    def new_thunk():
        thunk_jaxpr = core.ClosedJaxpr(*thunk())
        closed_jaxpr = callback_jaxpr(thunk_jaxpr, main.callback,
                                      main.strip_calls)
        return closed_jaxpr.jaxpr, closed_jaxpr.literals

    params[thunk_name] = new_thunk
    new_fun_jaxpr, new_consts = new_closed_jaxpr.jaxpr, new_closed_jaxpr.literals
    closed_fun_jaxpr = core.ClosedJaxpr(
        pe.convert_constvars_jaxpr(new_fun_jaxpr), ())
    new_num_consts = len(new_consts) + num_consts
    out = primitive.bind(*it.chain(new_consts, vals),
                         fun_jaxpr=closed_fun_jaxpr,
                         num_consts=new_num_consts,
                         **params)
    return safe_map(trace.pure, out)


custom_callback_rules[cd.custom_jvp_call_jaxpr_p] = partial(
    _custom_derivative_call_jaxpr_callback_rule, cd.custom_jvp_call_jaxpr_p)
custom_callback_rules[cd.custom_vjp_call_jaxpr_p] = partial(
    _custom_derivative_call_jaxpr_callback_rule, cd.custom_vjp_call_jaxpr_p)
Exemple #25
0
  if jnp.issubdtype(dtype, np.complexfloating):
    nan = xb.constant(c, np.array(np.nan * (1. + 1j), dtype=dtype))
  else:
    nan = xb.constant(c, np.array(np.nan, dtype=dtype))
  return xops.Broadcast(nan, shape.dimensions())

def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand):
  shape = c.get_shape(operand)
  batch_dims = shape.dimensions()[:-2]
  result, info = potrf_impl(c, operand, lower=True)
  ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
  return _broadcasting_select(c,
                              xops.Reshape(ok, batch_dims + (1, 1)), result,
                              _nan_like(c, result))

xla.backend_specific_translations['cpu'][cholesky_p] = partial(
  _cholesky_cpu_gpu_translation_rule, lapack.potrf)

if cusolver is not None:
  xla.backend_specific_translations['gpu'][cholesky_p] = partial(
    _cholesky_cpu_gpu_translation_rule, cusolver.potrf)

if rocsolver is not None:
  xla.backend_specific_translations['gpu'][cholesky_p] = partial(
    _cholesky_cpu_gpu_translation_rule, rocsolver.potrf)

# Asymmetric eigendecomposition

def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors):
  return (
    xla.apply_primitive(eig_p, operand,
                        compute_left_eigenvectors=compute_left_eigenvectors,
Exemple #26
0
 def test_jvp_linearized(self, f, args):
   jtu.check_jvp(f, partial(jvp_unlinearized, f), args,
                 rtol={np.float32: 3e-2})
Exemple #27
0
 def test_jvp(self, f, args):
   jtu.check_jvp(f, partial(jvp, f), args, rtol={np.float32: 3e-2})
Exemple #28
0
    return xops.Tuple(c, outs)


def _psum_transpose_rule(cts, *args, axis_name, axis_index_groups):
    nonzero_out_cts, treedef = tree_util.tree_flatten(cts)
    nonzero_in_cts = psum_p.bind(*nonzero_out_cts,
                                 axis_name=axis_name,
                                 axis_index_groups=axis_index_groups)
    return tree_util.tree_unflatten(treedef, nonzero_in_cts)


psum_p = core.Primitive('psum')
psum_p.multiple_results = True
psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
pxla.soft_pmap_rules[psum_p] = \
    partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule,
                                            lax.add_p)  # type: ignore
ad.deflinear2(psum_p, _psum_transpose_rule)
pxla.multi_host_supported_collectives.add(psum_p)
batching.primitive_batchers[psum_p] = partial(_collective_batcher, psum_p)
batching.collective_rules[psum_p] = \
  partial(_batched_reduction_collective,
          psum_p,
          lambda v, d: v.sum(d),
          lambda v, axis_size: axis_size * v)


# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
# tracing time.
@psum_p.def_custom_bind
Exemple #29
0
    CallSpec(fun_with_two_calls, (R(3, 2),)),
    CallSpec(fun_with_call_closure, (R(3, 2),)),
    CallSpec(fun_call_jitted, (R(1,),)),
    CallSpec(fun_with_nested_calls, (R(),)),
    CallSpec(fun_with_nested_calls, (R(3, 2),)),
    CallSpec(fun_with_nested_calls_2, (R(1, 2),)),
]

def jvp_unlinearized(f, primals, tangents):
  out, jvp = linearize(f, *primals)
  return out, jvp(*tangents)

test_specs = []
for ts in test_specs_base:
  test_specs.append(ts)
  test_specs.append(CallSpec(partial(jvp, ts.fun), (ts.args, ts.args)))
  test_specs.append(CallSpec(jit(ts.fun), ts.args))
  test_specs.append(CallSpec(jit(jit(ts.fun)), ts.args))
  test_specs.append(CallSpec(partial(jvp_unlinearized, ts.fun),
                             (ts.args, ts.args)))


def fwd_deriv(f):
  def df(x):
    return jvp(f, (x,), (1.0,))[1]

  return df


class CoreTest(jtu.JaxTestCase):
Exemple #30
0
def _allgather(x, dim, size, index, axis_name, axis_index_groups=None):
    outs = tree_util.tree_map(partial(_expand, dim, size, index), x)
    return psum(outs, axis_name, axis_index_groups=axis_index_groups)