Example #1
0
def custom_transpose_transpose_rule(cts, *args, out_types, res_tree, lin_tree,
                                    out_tree, **params):

    if 'transpose_jaxpr_thunk' in params:
        assert 'call_jaxpr' in params
        transpose = make_transpose_from_thunk(params['transpose_jaxpr_thunk'],
                                              lin_tree)
    else:
        assert 'call' in params
        transpose = params['transpose']

    call_in_tree = treedef_tuple((res_tree, lin_tree))

    # TODO(frostig,mattjj): `lin_arg` indicates the inputs with respect
    # to which we are transposing (via `ad.is_undefined_primal`).
    # Consider passing this information to the custom transpose rule?

    res_arg, lin_arg = tree_unflatten(call_in_tree, args)
    del lin_arg
    assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

    cts = [
        ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct
        for ct in cts
    ]
    ct_out = tree_unflatten(out_tree, cts)
    ct_lin = transpose(res_arg, ct_out)
    check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin))
    ct_lin_flat, _ = tree_flatten(tree_broadcast(lin_tree,
                                                 ct_lin,
                                                 is_leaf=lambda x: x is None),
                                  is_leaf=lambda x: x is None)
    return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
Example #2
0
File: coo.py Project: 0x0is1/jax
def _coo_matvec_transpose(ct, data, row, col, v, *, shape, transpose):
    assert not ad.is_undefined_primal(row)
    assert not ad.is_undefined_primal(col)

    if ad.is_undefined_primal(v):
        return data, row, col, coo_matvec(data,
                                          row,
                                          col,
                                          ct,
                                          shape=shape,
                                          transpose=not transpose)
    else:
        v = jnp.asarray(v)
        # The following line does this, but more efficiently:
        # return _coo_extract(row, col, jnp.outer(ct, v)), row, col, v
        return ct[row] * v[col], row, col, v
Example #3
0
def _coo_fromdense_transpose(ct, M, *, nnz, index_dtype):
    data, row, col = ct
    assert len(data) == nnz
    assert row.dtype == col.dtype == index_dtype
    if isinstance(row, ad.Zero) or isinstance(col, ad.Zero):
        raise ValueError("Cannot transpose with respect to sparse indices")
    assert ad.is_undefined_primal(M)
    return coo_todense(data, row, col, shape=M.aval.shape)
Example #4
0
File: api.py Project: jbampton/jax
def _todense_transpose(ct, *bufs, tree):
  assert ad.is_undefined_primal(bufs[0])
  assert not any(ad.is_undefined_primal(buf) for buf in bufs[1:])

  standin = object()
  obj = tree_util.tree_unflatten(tree, [standin] * len(bufs))
  from jax.experimental.sparse import BCOO, bcoo_extract
  if obj is standin:
    return (ct,)
  elif isinstance(obj, BCOO):
    _, indices = bufs
    return bcoo_extract(indices, ct), indices
  elif isinstance(obj, COO):
    _, row, col = bufs
    return _coo_extract(row, col, ct), row, col
  else:
    raise NotImplementedError(f"todense_transpose for {type(obj)}")
Example #5
0
File: csr.py Project: jbampton/jax
def _csr_fromdense_transpose(ct, M, *, nse, index_dtype):
    data, indices, indptr = ct
    assert len(data) == nse
    assert indices.dtype == indptr.dtype == index_dtype
    if isinstance(indices, ad.Zero) or isinstance(indptr, ad.Zero):
        raise ValueError("Cannot transpose with respect to sparse indices")
    assert ad.is_undefined_primal(M)
    return csr_todense(data, indices, indptr, shape=M.aval.shape)
Example #6
0
File: csr.py Project: jbampton/jax
def _csr_matvec_transpose(ct, data, indices, indptr, v, *, shape, transpose):
    assert not ad.is_undefined_primal(indices)
    assert not ad.is_undefined_primal(indptr)

    if ad.is_undefined_primal(v):
        return data, indices, indptr, csr_matvec(data,
                                                 indices,
                                                 indptr,
                                                 ct,
                                                 shape=shape,
                                                 transpose=not transpose)
    else:
        v = jnp.asarray(v)
        # The following lines do this, but more efficiently.
        # return _csr_extract(indices, indptr, jnp.outer(ct, v)), indices, indptr, v
        row, col = _csr_to_coo(indices, indptr)
        return ct[row] * v[col], indices, indptr, v
Example #7
0
def custom_transpose_transpose_rule(cts, *args, call, rule, res_tree, lin_tree,
                                    out_tree):
    call_in_tree = treedef_tuple((res_tree, lin_tree))

    res_arg, lin_arg = tree_unflatten(call_in_tree, args)
    assert all(ad.is_undefined_primal(x) for x in tree_leaves(lin_arg))
    assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

    cts = [
        ad_util.zeros_like_aval(ct_aval) if type(ct) is ad_util.Zero else ct
        for ct, ct_aval in zip(cts, call.out_avals)
    ]
    ct_out = tree_unflatten(out_tree, cts)
    ct_lin = rule(res_arg, ct_out)
    ct_lin_flat, ct_lin_tree = tree_flatten(ct_lin)
    check_transpose_rule_trees(rule, lin_tree, ct_lin_tree)
    return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
Example #8
0
def _dynamic_xla_call_transpose(cts_in, *args, jaxpr, num_consts):
  # TODO make this a dynamic_xla_call_p bind
  del num_consts
  vars_to_vals = dict(
      (d, t) for v, x in zip(jaxpr.in_binders, args)
      if isinstance(v.aval, AbsArray) and not ad.is_undefined_primal(x)
      for d, t in zip(v.aval.shape, x.shape) if isinstance(d, Var))
  dim_args = [vars_to_vals[v] for v in jaxpr.in_dim_binders]
  consts_bar, args_bar = backward_pass(jaxpr, dim_args, args, cts_in)  # type: ignore
  return [*consts_bar, *args_bar]
Example #9
0
def _bcoo_fromdense_transpose(ct, M, *, nse, n_batch, n_dense, index_dtype):
  data, indices = ct
  n_sparse = M.ndim = n_batch - n_dense
  assert data.shape == M.shape[:n_batch] + (nse,) + M.shape[n_batch + n_sparse:]
  assert indices.shape == M.shape[:n_batch] + (n_sparse, nse)
  assert indices.dtype == index_dtype
  if isinstance(indices, ad.Zero):
    raise ValueError("Cannot transpose with respect to sparse indices")
  assert ad.is_undefined_primal(M)
  return bcoo_todense(data, indices, shape=M.aval.shape)
Example #10
0
 def transposed(*args):
   in_primals, out_cts = tree_unflatten(treedef, args)
   in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else
               pe.PartialVal.known(x) for x in in_primals]
   primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
   tangent_jaxpr, _, consts = pe.trace_to_jaxpr(primal_fun, in_pvals, False)
   dummy_args = [ad.UndefinedPrimal(v.aval) for v in tangent_jaxpr.invars]
   in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, False, consts, dummy_args,
                              out_cts)
   in_cts, cell.treedef = tree_flatten(in_cts_)
   return in_cts
Example #11
0
def _linear_call_transpose_rule(cts, *args, callee, transpose,
                                num_callee_consts,
                                num_transpose_consts, num_res):
  f_consts, t_consts, operands_res, operands_lin = split_list(
      args, [num_callee_consts, num_transpose_consts, num_res])
  _, _, cts_avals = split_list(
      transpose.in_avals, [num_transpose_consts, num_res])

  assert all(ad.is_undefined_primal(x)     for x in operands_lin)
  assert all(not ad.is_undefined_primal(x) for x in operands_res)

  cts = [zeros_like_aval(a) if type(ct) is Zero else ct
         for ct, a in zip(cts, cts_avals)]

  cts_out = linear_call_p.bind(*t_consts, *f_consts, *operands_res, *cts,
                               callee=transpose,
                               transpose=callee,
                               num_callee_consts=len(t_consts),
                               num_transpose_consts=len(f_consts),
                               num_res=len(operands_res))

  return [None] * (num_callee_consts + num_transpose_consts + num_res) + cts_out
Example #12
0
def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
  assert not ad.is_undefined_primal(lhs_indices)
  if type(ct) is ad.Zero:
    return ad.Zero
  (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
  lhs_ndim = len(lhs_shape)
  rhs_ndim = rhs.aval.ndim if ad.is_undefined_primal(rhs) else rhs.ndim
  lhs_kept = remaining(range(lhs_ndim), lhs_contract, lhs_batch)
  rhs_kept = remaining(range(rhs_ndim), rhs_contract, rhs_batch)
  ans_batch, ans_lhs, ans_rhs = ranges_like(lhs_batch, lhs_kept, rhs_kept)
  if ad.is_undefined_primal(lhs_data):
    dims = ((ans_rhs, rhs_kept), (ans_batch, rhs_batch))
    lhs_contract_sorted_by_rhs = list(np.take(lhs_contract, np.argsort(rhs_contract)))
    # TODO: extract these sparse indices without constructing the dense matrix.
    out_axes = np.argsort(list(lhs_batch) + lhs_kept + lhs_contract_sorted_by_rhs)
    out_dense = lax.transpose(lax.dot_general(ct, rhs, dimension_numbers=dims), out_axes)
    return bcoo_extract(lhs_indices, out_dense), lhs_indices, rhs
  else:
    dims = ((lhs_kept, ans_lhs), (lhs_batch, ans_batch))
    rhs_contract_sorted_by_lhs = list(np.take(rhs_contract, np.argsort(lhs_contract)))
    out_axes = np.argsort(list(rhs_batch) + rhs_contract_sorted_by_lhs + rhs_kept)
    result = bcoo_dot_general(lhs_data, lhs_indices, ct, lhs_shape=lhs_shape, dimension_numbers=dims)
    return lhs_data, lhs_indices, lax.transpose(result, out_axes)
Example #13
0
def _select_and_gather_add_transpose(
    t, tangents, operand, *, select_prim, window_dimensions, window_strides,
    padding, base_dilation, window_dilation):
  assert select_prim in (lax.le_p, lax.ge_p)
  assert (ad.is_undefined_primal(tangents) and
          not ad.is_undefined_primal(operand))
  if any(d != 1 for d in window_dilation):
    msg = ("VJP not implemented for select_and_gather (MaxPool) with window "
           "dilation, got window_dilation={}.")
    raise NotImplementedError(msg.format(window_dilation))
  if type(t) is ad_util.Zero:
    return [ad_util.Zero(tangents.aval), None]
  has_base_dilation = any(d != 1 for d in base_dilation)
  if has_base_dilation:
    select_identity = (lax._get_max_identity if select_prim is lax.ge_p
                       else lax._get_min_identity)
    operand = lax.pad(operand, select_identity(operand.dtype),
                      tuple((0, 0, d - 1) for d in base_dilation))
  result = _select_and_scatter_add(t, operand, select_prim, window_dimensions,
                                   window_strides, padding)
  if has_base_dilation:
    result = slicing.slice(result, (0,) * len(result.shape), result.shape,
                           base_dilation)
  return [result, None]
Example #14
0
def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs):
    if jaxprs.transpose_solve is None:
        raise TypeError(
            'transpose_solve required for backwards mode automatic '
            'differentiation of custom_linear_solve')

    params, b = _split_linear_solve_args(primals, const_lengths)
    # split off symbolic zeros in the cotangent if present
    x_cotangent, _ = split_list(cotangent, [len(b)])
    assert all(ad.is_undefined_primal(x) for x in b)
    cotangent_b_full = linear_solve_p.bind(
        *(_flatten(params.transpose()) + x_cotangent),
        const_lengths=const_lengths.transpose(),
        jaxprs=jaxprs.transpose())
    # drop aux values in cotangent computation
    cotangent_b, _ = split_list(cotangent_b_full, [len(b)])
    return [None] * sum(const_lengths) + cotangent_b
Example #15
0
def _reduce_window_sum_transpose_rule(cotangent, operand, *, window_dimensions,
                                      window_strides, padding, base_dilation,
                                      window_dilation):
  assert ad.is_undefined_primal(operand)
  input_shape = operand.aval.shape
  pads = convolution._conv_general_vjp_lhs_padding(
      input_shape, window_dimensions, window_strides, cotangent.shape, padding,
      base_dilation, window_dilation)
  ones = [1] * len(input_shape)
  padding_config = [(lo, hi, stride - 1)
                    for (lo, hi), stride in zip(pads, window_strides)]
  pad_cotangent = lax.pad(cotangent, lax._zero(cotangent), padding_config)
  result = _reduce_window_sum(pad_cotangent, window_dimensions, base_dilation,
                              [(0, 0)] * len(input_shape),
                              base_dilation=ones,
                              window_dilation=window_dilation)
  assert result.shape == input_shape, (result.shape, input_shape)
  return [result]
Example #16
0
def _bcoo_extract_transpose(ct, indices, mat):
  assert ad.is_undefined_primal(mat)
  if ad.is_undefined_primal(indices):
    raise ValueError("Cannot transpose with respect to sparse indices")
  assert ct.dtype == mat.aval.dtype
  return indices, bcoo_todense(ct, indices, shape=mat.aval.shape)
Example #17
0
 def write_primal(v, val):
   if not ad.is_undefined_primal(val):
     primal_env[v] = val
Example #18
0
 def abstract(value):
   return raise_to_shaped(value.aval if ad.is_undefined_primal(value) else get_aval(value))