Example #1
0
    def test_unimplemented_interpreter_rules(self):
        foo_p = Primitive('foo')

        def foo(x):
            return foo_p.bind(x)

        jtu.check_raises(lambda: foo(1.0), NotImplementedError,
                         "Evaluation rule for 'foo' not implemented")

        jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
                         "Abstract evaluation for 'foo' not implemented")

        jtu.check_raises(
            lambda: grad(foo)(1.0), NotImplementedError,
            "Forward-mode differentiation rule for 'foo' not implemented")

        foo_p.def_abstract_eval(lambda x: x)

        jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
                         "XLA translation rule for 'foo' not implemented")

        foo_p.def_impl(lambda x: x)
        defjvp(foo_p, lambda g, x: foo(g))

        jtu.check_raises(
            lambda: grad(foo)(1.0), NotImplementedError,
            "Reverse-mode differentiation rule for 'foo' not implemented")
Example #2
0
    return coo_todense(data_dot, row, col, shape=shape)


def _coo_todense_transpose(ct, data, row, col, *, shape):
    # Note: we assume that transpose has the same sparsity pattern.
    # Can we check this?
    assert ad.is_undefined_primal(data)
    if ad.is_undefined_primal(row) or ad.is_undefined_primal(col):
        raise ValueError("Cannot transpose with respect to sparse indices")
    assert ct.shape == shape
    assert row.aval.dtype == col.aval.dtype
    assert ct.dtype == data.aval.dtype
    return _coo_extract(row, col, ct), row, col


ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
xla.translations[coo_todense_p] = xla.lower_fun(_coo_todense_impl,
                                                multiple_results=False)
if cusparse and cusparse.is_supported:
    xla.backend_specific_translations['gpu'][
        coo_todense_p] = _coo_todense_gpu_translation_rule

#--------------------------------------------------------------------
# coo_fromdense

coo_fromdense_p = core.Primitive('coo_fromdense')
coo_fromdense_p.multiple_results = True


def coo_fromdense(mat, *, nnz, index_dtype=jnp.int32):
Example #3
0
    return lax.lgamma(x) + lax.lgamma(y) - lax.lgamma(x + y)


@_wraps(osp_special.betainc)
def betainc(a, b, x):
    a, b, x = _promote_args_inexact("betainc", a, b, x)
    return lax.betainc(a, b, x)


@_wraps(osp_special.digamma, update_doc=False)
def digamma(x):
    x, = _promote_args_inexact("digamma", x)
    return lax.digamma(x)


ad.defjvp(lax.digamma_p, lambda g, x: lax.mul(g, polygamma(1, x)))


@_wraps(osp_special.gammainc, update_doc=False)
def gammainc(a, x):
    a, x = _promote_args_inexact("gammainc", a, x)
    return lax.igamma(a, x)


@_wraps(osp_special.gammaincc, update_doc=False)
def gammaincc(a, x):
    a, x = _promote_args_inexact("gammaincc", a, x)
    return lax.igammac(a, x)


@_wraps(osp_special.erf)
Example #4
0
    return coo_todense(data_dot, row, col, shape=shape)


def _coo_todense_transpose(ct, data, row, col, *, shape):
    # Note: we assume that transpose has the same sparsity pattern.
    # Can we check this?
    assert ad.is_undefined_primal(data)
    if ad.is_undefined_primal(row) or ad.is_undefined_primal(col):
        raise ValueError("Cannot transpose with respect to sparse indices")
    assert ct.shape == shape
    assert row.aval.dtype == col.aval.dtype
    assert ct.dtype == data.aval.dtype
    return _coo_extract(row, col, ct), row, col


ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
xla.register_translation(coo_todense_p, _coo_todense_translation_rule)
if cusparse and cusparse.is_supported:
    xla.register_translation(coo_todense_p,
                             _coo_todense_gpu_translation_rule,
                             platform='gpu')

#--------------------------------------------------------------------
# coo_fromdense

coo_fromdense_p = core.Primitive('coo_fromdense')
coo_fromdense_p.multiple_results = True


def coo_fromdense(mat, *, nse, index_dtype=jnp.int32):
Example #5
0
        raise TypeError(msg.format(a.shape, b.shape))
    return b.shape


def triangular_solve_transpose_rule(cotangent, a, b, left_side, lower,
                                    transpose_a, conjugate_a):
    assert a is not None and b is None
    cotangent_b = triangular_solve(a, cotangent, left_side, lower,
                                   not transpose_a, conjugate_a)
    return [None, cotangent_b]


triangular_solve_p = standard_primitive(triangular_solve_shape_rule,
                                        triangular_solve_dtype_rule,
                                        'triangular_solve')
ad.defjvp(triangular_solve_p, None,
          lambda g_b, a, b, **kwargs: triangular_solve(a, g_b, **kwargs))
ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule


def qr_impl(operand, full_matrices):
    q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
    return core.pack((q, r))


def qr_translation_rule(c, operand, full_matrices):
    return c.QR(operand, full_matrices=full_matrices)


def qr_abstract_eval(operand, full_matrices):
    if isinstance(operand, ShapedArray):
        if operand.ndim < 2:
Example #6
0
  Transform is written such that it acts as the identity during gradient
  backpropagation.

  Args:
    T: Transformation; ndarray(shape=[spatial_dim, spatial_dim]).
    v: Collection of vectors; ndarray(shape=[..., spatial_dim]).

  Returns:
    Transformed vectors; ndarray(shape=[..., spatial_dim]).
  """
    _check_transform_shapes(T, v)
    return np.dot(v, T)


ad.defjvp(transform.primitive, lambda g, T, v: ad_util.zero, lambda g, T, v: g)


def pairwise_displacement(Ra, Rb):
    """Compute a matrix of pairwise displacements given two sets of positions.

  Args:
    Ra: Vector of positions; ndarray(shape=[n, spatial_dim]).
    Rb: Vector of positions; ndarray(shape=[m, spatial_dim]).

  Returns:
    Matrix of displacements; ndarray(shape=[n, m, spatial_dim]).
  """
    return Ra[:, np.newaxis, :] - Rb[np.newaxis, :, :]

Example #7
0
    g, y = _promote_args_like(osp_special.xlogy, g, y)
    return lax._safe_mul(lax._brcast(g, y), lax._brcast(lax.log(y), g))


def xlogy_jvp_rhs(g, x, y, jaxpr, aval, consts):
    x, y = _promote_args_like(osp_special.xlogy, x, y)
    g, x = _promote_args_like(osp_special.xlogy, g, x)
    jac = lax._safe_mul(lax._brcast(x, y), lax._brcast(lax.reciprocal(y), x))
    return lax.mul(lax._brcast(g, jac), jac)


xlogy_p = Primitive('xlogy')
xlogy_p.def_impl(partial(xla.apply_primitive, xlogy_p))
xlogy_p.def_abstract_eval(xlogy_abstract_eval)
xla.translations[xlogy_p] = xlogy_translate
ad.defjvp(xlogy_p, xlogy_jvp_lhs, xlogy_jvp_rhs)


def xlog1py(x, y):
    jaxpr, out, consts = partial_eval.trace_unwrapped_to_jaxpr(
        xlog1py_impl, tuple(lax._abstractify(o) for o in (x, y)))
    aval, _ = out
    return xlog1py_p.bind(x, y, jaxpr=jaxpr, aval=aval, consts=consts)


def xlog1py_impl(x, y):
    return x * np.where(x == 0., 0., np.log1p(y))


def xlog1py_jvp_lhs(g, x, y, jaxpr, aval, consts):
    x, y = _promote_args_like(osp_special.xlog1py, x, y)
Example #8
0
    if base_dilation is not None:
        operand_shape = lax._dilate_shape(operand_shape, base_dilation)
    if window_dilation is not None:
        window_dimensions = lax._dilate_shape(window_dimensions,
                                              window_dilation)
    pads_lo, pads_hi = zip(*padding)
    operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi)
    return core.stride_shape(operand_padded, window_dimensions, window_strides)


_reduce_window_max_translation_rule = partial(
    _reduce_window_chooser_translation_rule, lax.max_p, lax._get_max_identity)
reduce_window_max_p = lax.standard_primitive(
    _common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_max',
    _reduce_window_max_translation_rule)
ad.defjvp(reduce_window_max_p,
          partial(_reduce_window_chooser_jvp_rule, lax.max_p))
batching.primitive_batchers[reduce_window_max_p] = partial(
    _reduce_window_batch_rule, _reduce_window_max)

_reduce_window_min_translation_rule = partial(
    _reduce_window_chooser_translation_rule, lax.min_p, lax._get_min_identity)
reduce_window_min_p = lax.standard_primitive(
    _common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_min',
    _reduce_window_min_translation_rule)
ad.defjvp(reduce_window_min_p,
          partial(_reduce_window_chooser_jvp_rule, lax.min_p))

_reduce_window_min_batch_rule = partial(_reduce_window_batch_rule,
                                        _reduce_window_min)
batching.primitive_batchers[reduce_window_min_p] = partial(
    _reduce_window_batch_rule, _reduce_window_min)
Example #9
0
def _coo_todense_jvp(data_dot, data, row, col, *, shape):
  return coo_todense(data_dot, row, col, shape=shape)

def _coo_todense_transpose(ct, data, row, col, *, shape):
  # Note: we assume that transpose has the same sparsity pattern.
  # Can we check this?
  assert ad.is_undefined_primal(data)
  if ad.is_undefined_primal(row) or ad.is_undefined_primal(col):
    raise ValueError("Cannot transpose with respect to sparse indices")
  assert ct.shape == shape
  assert row.aval.dtype == col.aval.dtype
  assert ct.dtype == data.aval.dtype
  return _coo_extract(row, col, ct), row, col

ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
xla.translations[coo_todense_p] = xla.lower_fun(
    _coo_todense_impl, multiple_results=False)
if cusparse and cusparse.is_supported:
  xla.backend_specific_translations['gpu'][
      coo_todense_p] = _coo_todense_gpu_translation_rule

#--------------------------------------------------------------------
# coo_fromdense

coo_fromdense_p = core.Primitive('coo_fromdense')
coo_fromdense_p.multiple_results = True

def coo_fromdense(mat, *, nnz, index_dtype=jnp.int32):
  """Create COO-format sparse matrix from a dense matrix.
Example #10
0
def _xlogy_jvp_rhs(g, x, y):
    shape = lax.broadcast_shapes(np.shape(g), np.shape(x))
    g = np.broadcast_to(g, shape)
    x = np.broadcast_to(x, shape)
    x, y = _promote_args_like(osp_special.xlogy, x, y)
    return g * lax._safe_mul(x, np.reciprocal(y))


@custom_transforms
def xlogy(x, y):
    x, y = _promote_args_like(osp_special.xlogy, x, y)
    return lax._safe_mul(x, np.log(y))


ad.defjvp(xlogy.primitive, _xlogy_jvp_lhs, _xlogy_jvp_rhs)


def _xlog1py_jvp_lhs(g, x, y):
    shape = lax.broadcast_shapes(np.shape(g), np.shape(y))
    g = np.broadcast_to(g, shape)
    y = np.broadcast_to(y, shape)
    g, y = _promote_args_like(osp_special.xlog1py, g, y)
    return lax._safe_mul(g, np.log1p(y))


def _xlog1py_jvp_rhs(g, x, y):
    shape = lax.broadcast_shapes(np.shape(g), np.shape(x))
    g = np.broadcast_to(g, shape)
    x = np.broadcast_to(x, shape)
    x, y = _promote_args_like(osp_special.xlog1py, x, y)
Example #11
0
def mybar_impl(w):
    A, _ = pymbar.BAR(w[0], w[1])
    return A


def mybar_jvp(g, w):
    return g * tmbar.dG_dw(w)


def mybar(x):
    return mybar_p.bind(x)


mybar_p = core.Primitive("mybar")
mybar_p.def_impl(mybar_impl)
ad.defjvp(mybar_p, mybar_jvp)


def BAR_leg(insertion_du_dls, deletion_du_dls, lambda_schedule):
    insertion_W = math_utils.trapz(insertion_du_dls, lambda_schedule)
    deletion_W = math_utils.trapz(deletion_du_dls, lambda_schedule)

    return mybar(jnp.stack([insertion_W, deletion_W]))


def BAR_loss(
    complex_insertion_du_dls,  # [C, N]
    complex_deletion_du_dls,  # [C, N]
    solvent_insertion_du_dls,  # [C, N]
    solvent_deletion_du_dls,  # [C, N]
    lambda_schedule,
Example #12
0
def _sin_abstract_eval(x):
  if isinstance(x, AbsArray):
    return AbsArray(x.shape, x._eltTy)
  else:
    return lax.sin_p.abstract_eval(x)

def _sin_typecheck_rule(invar):
  return [invar.aval]
typecheck_rules[sin_p] = _sin_typecheck_rule

def _sin_translation_rule(c, dims, avals, operands):
  (x,), = operands
  return [[xops.Sin(x)]]
translations[sin_p] = _sin_translation_rule

ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))


## cos

def cos(x: Any) -> Any:
  return cos_p.bind(x)
cos_p = core.Primitive('cos_p')

@cos_p.def_abstract_eval
def _cos_abstract_eval(x):
  if isinstance(x, AbsArray):
    return AbsArray(x.shape, x._eltTy)
  else:
    return lax.cos_p.abstract_eval(x)