コード例 #1
0
ファイル: convolution.py プロジェクト: xueeinstein/jax
def _conv_general_dilated_lower(
    ctx, lhs, rhs, *, window_strides, padding,
    lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
    batch_group_count, precision, preferred_element_type,
    expand_complex_convolutions=False, **unused_kwargs):
  lhs_aval, rhs_aval = ctx.avals_in
  aval_out, = ctx.avals_out
  assert isinstance(dimension_numbers, ConvDimensionNumbers)
  dtype = lhs_aval.dtype
  if expand_complex_convolutions and np.issubdtype(dtype, np.complexfloating):
    if preferred_element_type is not None:
      # Convert complex dtype to types used for real and imaginary parts
      assert np.issubdtype(preferred_element_type, np.complexfloating)
      preferred_element_type = _real_dtype(preferred_element_type)
    complex_conv = mlir.lower_fun(
      partial(
        _complex_mul,
        partial(conv_general_dilated, window_strides=window_strides,
                padding=padding, lhs_dilation=lhs_dilation,
                rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers,
                feature_group_count=feature_group_count,
                batch_group_count=batch_group_count, precision=precision,
                preferred_element_type=preferred_element_type)),
      multiple_results=False)
    return complex_conv(ctx, lhs, rhs)

  lhs_spec, rhs_spec, out_spec = dimension_numbers
  dnums = mhlo.ConvDimensionNumbers.get(
    input_batch_dimension=lhs_spec[0],
    input_feature_dimension=lhs_spec[1],
    input_spatial_dimensions=list(lhs_spec[2:]),
    kernel_output_feature_dimension=rhs_spec[0],
    kernel_input_feature_dimension=rhs_spec[1],
    kernel_spatial_dimensions=list(rhs_spec[2:]),
    output_batch_dimension=out_spec[0],
    output_feature_dimension=out_spec[1],
    output_spatial_dimensions=list(out_spec[2:]))
  num_spatial_dims = len(rhs_spec) - 2
  window_reversal = mlir.dense_bool_elements([False] * num_spatial_dims)
  return [
      mhlo.ConvOp(
          mlir.aval_to_ir_type(aval_out),
          lhs,
          rhs,
          dimension_numbers=dnums,
          feature_group_count=mlir.i64_attr(feature_group_count),
          batch_group_count=mlir.i64_attr(batch_group_count),
          window_strides=mlir.dense_int_elements(window_strides),
          padding=mlir.dense_int_elements(padding),
          lhs_dilation=mlir.dense_int_elements(lhs_dilation),
          rhs_dilation=mlir.dense_int_elements(rhs_dilation),
          window_reversal=window_reversal,
          precision_config=lax.precision_attr(precision)).result
  ]
コード例 #2
0
ファイル: primitive.py プロジェクト: tensorflow/probability
 def _mlir(c, *mlir_args, **params):
     lowering = mlir.lower_fun(self.impl, multiple_results=True)
     return lowering(c, *mlir_args, **params)
コード例 #3
0
@identity_p.def_impl
def _identity_impl(mat):
  return mat

@identity_p.def_abstract_eval
def _identity_abstract_eval(mat):
  return AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz)

xla.register_translation(
    identity_p, xla.lower_fun(_identity_impl, multiple_results=False,
                              new_style=True))


mlir.register_lowering(
    identity_p, mlir.lower_fun(_identity_impl, multiple_results=False))

def split(x):
  return split_p.bind(x)

split_p = core.Primitive('split')
split_p.multiple_results = True

@split_p.def_impl
def _split_impl(mat):
  return mat, mat

@split_p.def_abstract_eval
def _split_abstract_eval(mat):
  m = AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz)
  return m, m
コード例 #4
0
ファイル: windowed_reductions.py プロジェクト: wayfeng/jax
        padding = [(0, 0) for _ in padding]
    out = _select_and_scatter(operand, select, window_dimensions,
                              window_strides, padding, source,
                              lax._zero(operand), scatter)
    if expand_padding:
        start_indices = [lo for (lo, hi) in original_padding]
        stop_indices = [
            lo + d for ((lo, hi), d) in zip(original_padding, operand_shape)
        ]
        out = slicing.slice(out, start_indices, stop_indices)
    return out


mlir.register_lowering(
    select_and_scatter_add_p,
    mlir.lower_fun(partial(_select_and_scatter_add_impl, expand_padding=False),
                   multiple_results=False))
mlir.register_lowering(select_and_scatter_add_p,
                       mlir.lower_fun(partial(_select_and_scatter_add_impl,
                                              expand_padding=True),
                                      multiple_results=False),
                       platform='cpu')
mlir.register_lowering(select_and_scatter_add_p,
                       mlir.lower_fun(partial(_select_and_scatter_add_impl,
                                              expand_padding=True),
                                      multiple_results=False),
                       platform='gpu')


def _select_and_gather_add_shape_rule(tangents, operand, *, select_prim,
                                      window_dimensions, window_strides,
                                      padding, base_dilation, window_dilation):
コード例 #5
0
ファイル: remat_impl.py プロジェクト: xueeinstein/jax
        elif is_gpu_platform:
            translation_rule = _remat_translation_using_while
        else:
            translation_rule = _remat_translation_using_cond
    else:
        translation_rule = lambda *args, jaxpr: core.eval_jaxpr(
            jaxpr, (), *args)

    return jax.named_call(translation_rule,
                          name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr)


for remat_primitive in (pe.remat_call_p,
                        ad_checkpoint.remat_p):  # type: ignore
    mlir.register_lowering(remat_primitive,
                           mlir.lower_fun(remat_impl, multiple_results=True))
    mlir.register_lowering(remat_primitive,
                           mlir.lower_fun(partial(remat_impl,
                                                  is_gpu_platform=True),
                                          multiple_results=True),
                           platform="gpu")


def _optimization_barrier_abstract_eval(*args):
    return args


def _optimization_barrier_lowering_rule(ctx, *args):
    barrier_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
    flat_barrier_types = util.flatten(barrier_types)
コード例 #6
0
def identity(x):
  return identity_p.bind(x)

identity_p = core.Primitive('identity')

@identity_p.def_impl
def _identity_impl(mat):
  return mat

@identity_p.def_abstract_eval
def _identity_abstract_eval(mat):
  return AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz)

mlir.register_lowering(
    identity_p, mlir.lower_fun(_identity_impl, multiple_results=False))

def split(x):
  return split_p.bind(x)

split_p = core.Primitive('split')
split_p.multiple_results = True

@split_p.def_impl
def _split_impl(mat):
  return mat, mat

@split_p.def_abstract_eval
def _split_abstract_eval(mat):
  m = AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz)
  return m, m
コード例 #7
0
  del params  # other params ignored because we're just executing the primal fun
  return core.jaxpr_as_fun(fun_jaxpr)(*args)

def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **params):
  del args, params
  if fun_jaxpr.effects:
    raise NotImplementedError('Effects not supported in `custom_jvp`.')
  return fun_jaxpr.out_avals, fun_jaxpr.effects

custom_jvp_call_jaxpr_p = core.AxisPrimitive('custom_jvp_call_jaxpr')
custom_jvp_call_jaxpr_p.multiple_results = True
custom_jvp_call_jaxpr_p.def_impl(_custom_jvp_call_jaxpr_impl)
custom_jvp_call_jaxpr_p.def_effectful_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval)
CustomJVPCallPrimitive.initial_style = custom_jvp_call_jaxpr_p

mlir.register_lowering(custom_jvp_call_jaxpr_p, mlir.lower_fun(
    _custom_jvp_call_jaxpr_impl, multiple_results=True))


def _custom_jvp_call_jaxpr_jvp(
    primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
    jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
    num_consts: int):
  _, args = split_list(primals, [num_consts])
  consts_dot, args_dot = split_list(tangents, [num_consts])
  if any(type(t) is not Zero for t in consts_dot):
    raise ad.CustomJVPException()
  jvp_jaxpr, jvp_consts = jvp_jaxpr_thunk()  # consts can be tracers!
  args_dot = map(ad.instantiate_zeros, args_dot)
  # Cast float0 to zeros with the primal dtype because custom jvp rules don't
  # currently handle float0s
  args_dot = map(ad.replace_float0s, args, args_dot)
コード例 #8
0
def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr,
                                         **params):
    del args, params
    return fun_jaxpr.out_avals


custom_jvp_call_jaxpr_p = core.AxisPrimitive('custom_jvp_call_jaxpr')
custom_jvp_call_jaxpr_p.multiple_results = True
custom_jvp_call_jaxpr_p.def_impl(_custom_jvp_call_jaxpr_impl)
custom_jvp_call_jaxpr_p.def_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval)
CustomJVPCallPrimitive.initial_style = custom_jvp_call_jaxpr_p

mlir.register_lowering(
    custom_jvp_call_jaxpr_p,
    mlir.lower_fun(_custom_jvp_call_jaxpr_impl, multiple_results=True))


def _custom_jvp_call_jaxpr_jvp(primals, tangents, *,
                               fun_jaxpr: core.ClosedJaxpr,
                               jvp_jaxpr_thunk: Callable[[],
                                                         Tuple[core.Jaxpr,
                                                               Sequence[Any]]],
                               num_consts: int):
    _, args = split_list(primals, [num_consts])
    consts_dot, args_dot = split_list(tangents, [num_consts])
    if any(type(t) is not Zero for t in consts_dot):
        raise ad.CustomJVPException()
    jvp_jaxpr, jvp_consts = jvp_jaxpr_thunk()  # consts can be tracers!
    args_dot = map(ad.instantiate_zeros, args_dot)
    # Cast float0 to zeros with the primal dtype because custom jvp rules don't
コード例 #9
0
ファイル: api.py プロジェクト: frederikwilde/jax
  elif isinstance(obj, BCOO):
    _, indices = bufs
    return bcoo_extract(indices, ct), indices
  elif isinstance(obj, COO):
    _, row, col = bufs
    return _coo_extract(row, col, ct), row, col
  else:
    raise NotImplementedError(f"todense_transpose for {type(obj)}")

def _todense_batching_rule(batched_args, batch_dims, *, tree):
  return jax.vmap(partial(_todense_impl, tree=tree), batch_dims)(*batched_args), 0

ad.primitive_jvps[todense_p] = _todense_jvp
ad.primitive_transposes[todense_p] = _todense_transpose
batching.primitive_batchers[todense_p] = _todense_batching_rule
mlir.register_lowering(todense_p, mlir.lower_fun(
    _todense_impl, multiple_results=False))


def empty(shape, dtype=None, index_dtype='int32', sparse_format='bcoo', **kwds):
  """Create an empty sparse array.

  Args:
    shape: sequence of integers giving the array shape.
    dtype: (optional) dtype of the array.
    index_dtype: (optional) dtype of the index arrays.
    format: string specifying the matrix format (e.g. ['bcoo']).
    **kwds: additional keywords passed to the format-specific _empty constructor.
  Returns:
    mat: empty sparse matrix.
  """
  formats = {'bcoo': BCOO, 'coo': COO, 'csr': CSR, 'csc': CSC}
コード例 #10
0
ファイル: solves.py プロジェクト: xueeinstein/jax
    ]
    # Broadcast out b if necessary
    new_b = [
        batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else
        batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
        for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
    ]

    outs = linear_solve_p.bind(*(new_params + new_b),
                               const_lengths=const_lengths,
                               jaxprs=batched_jaxprs)
    out_dims = [
        0 if batched else batching.not_mapped for batched in solve_x_bat
    ]
    return outs, out_dims


linear_solve_p = core.AxisPrimitive('custom_linear_solve')
linear_solve_p.multiple_results = True
linear_solve_p.def_impl(_custom_linear_solve_impl)
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
xla.register_initial_style_primitive(linear_solve_p)
mlir.register_lowering(
    linear_solve_p,
    mlir.lower_fun(_custom_linear_solve_impl, multiple_results=True))
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule
pe.partial_eval_jaxpr_custom_rules[linear_solve_p] = \
    partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'linear_solve')
コード例 #11
0
ファイル: custom_transpose.py プロジェクト: wayfeng/jax
    assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

    cts = [
        ad_util.zeros_like_aval(ct_aval) if type(ct) is ad_util.Zero else ct
        for ct, ct_aval in zip(cts, call.out_avals)
    ]
    ct_out = tree_unflatten(out_tree, cts)
    ct_lin = rule(res_arg, ct_out)
    ct_lin_flat, ct_lin_tree = tree_flatten(ct_lin)
    check_transpose_rule_trees(rule, lin_tree, ct_lin_tree)
    return [None] * len(tree_leaves(res_arg)) + ct_lin_flat


def custom_transpose_abstract_eval(*in_avals, call, **_):
    return call.out_avals


custom_transpose_p = core.Primitive('custom_transpose_call')
custom_transpose_p.multiple_results = True
custom_transpose_p.def_impl(custom_transpose_impl)
custom_transpose_p.def_abstract_eval(custom_transpose_abstract_eval)
ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule
xla.register_translation(custom_transpose_p,
                         xla.lower_fun(custom_transpose_impl,
                                       new_style=True,
                                       multiple_results=True),
                         initial_style=True)
mlir.register_lowering(
    custom_transpose_p,
    mlir.lower_fun(custom_transpose_impl, multiple_results=True))
コード例 #12
0
ファイル: windowed_reductions.py プロジェクト: cloudhan/jax
                else lax._get_min_identity)
    pads = [(lo, hi, 0) for (lo, hi) in padding]
    operand = lax.pad(operand, identity(dtype), pads)
    padding = [(0, 0) for _ in padding]
  out = _select_and_scatter(
      operand, select, window_dimensions, window_strides, padding, source,
      lax._zero(operand), scatter)
  if expand_padding:
    start_indices = [lo for (lo, hi) in original_padding]
    stop_indices = [lo + d for ((lo, hi), d) in zip(original_padding,
                                                    operand_shape)]
    out = slicing.slice(out, start_indices, stop_indices)
  return out

mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun(
    partial(_select_and_scatter_add_impl, expand_padding=False),
    multiple_results=False))
# TODO(b/161704903): workaround for XLA/CPU crash.
mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun(
    partial(_select_and_scatter_add_impl, expand_padding=True),
    multiple_results=False), platform='cpu')
# TODO(b/182390722): workaround for XLA/GPU crash.
mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun(
    partial(_select_and_scatter_add_impl, expand_padding=True),
    multiple_results=False), platform='gpu')


def _select_and_gather_add_shape_rule(
    tangents, operand, *, select_prim, window_dimensions, window_strides,
    padding, base_dilation, window_dilation):
  if tangents.shape != operand.shape:
コード例 #13
0
ファイル: prng.py プロジェクト: xueeinstein/jax
            mlir.dense_int_elements(range(rank - len(aval.shape),
                                          rank))).result

    return threefry2x32_lowering(
        (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
        (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))


threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
batching.defbroadcasting(threefry2x32_p)
mlir.register_lowering(
    threefry2x32_p,
    mlir.lower_fun(partial(_threefry2x32_lowering, use_rolled_loops=False),
                   multiple_results=True))
mlir.register_lowering(threefry2x32_p,
                       mlir.lower_fun(partial(_threefry2x32_lowering,
                                              use_rolled_loops=True),
                                      multiple_results=True),
                       platform='cpu')

if gpu_prng:
    mlir.register_lowering(threefry2x32_p,
                           partial(_threefry2x32_gpu_lowering,
                                   gpu_prng.cuda_threefry2x32),
                           platform='cuda')
    mlir.register_lowering(threefry2x32_p,
                           partial(_threefry2x32_gpu_lowering,
                                   gpu_prng.rocm_threefry2x32),
                           platform='rocm')
コード例 #14
0
ファイル: prng.py プロジェクト: cloudhan/jax
  def _broadcast(x, aval):
    return mhlo.BroadcastInDimOp(
        mlir.aval_to_ir_type(aval_out), x,
        mlir.dense_int_elements(range(rank - len(aval.shape), rank))).result
  return threefry2x32_lowering(
          (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
          (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))


threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
batching.defbroadcasting(threefry2x32_p)
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
    partial(_threefry2x32_lowering, use_rolled_loops=False),
    multiple_results=True))
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
    partial(_threefry2x32_lowering, use_rolled_loops=True),
    multiple_results=True), platform='cpu')

if gpu_prng:
  mlir.register_lowering(
      threefry2x32_p,
      partial(_threefry2x32_gpu_lowering, gpu_prng.cuda_threefry2x32),
      platform='cuda')
  mlir.register_lowering(
      threefry2x32_p,
      partial(_threefry2x32_gpu_lowering, gpu_prng.rocm_threefry2x32),
      platform='rocm')
コード例 #15
0

def _todense_batching_rule(batched_args, batch_dims, *, tree):
    return jax.vmap(partial(_todense_impl, tree=tree),
                    batch_dims)(*batched_args), 0


ad.primitive_jvps[todense_p] = _todense_jvp
ad.primitive_transposes[todense_p] = _todense_transpose
batching.primitive_batchers[todense_p] = _todense_batching_rule
xla.register_translation(
    todense_p,
    xla.lower_fun(_todense_impl, multiple_results=False, new_style=True))

mlir.register_lowering(todense_p,
                       mlir.lower_fun(_todense_impl, multiple_results=False))


def empty(shape,
          dtype=None,
          index_dtype='int32',
          sparse_format='bcoo',
          **kwds):
    """Create an empty sparse array.

  Args:
    shape: sequence of integers giving the array shape.
    dtype: (optional) dtype of the array.
    index_dtype: (optional) dtype of the index arrays.
    format: string specifying the matrix format (e.g. ['bcoo']).
    **kwds: additional keywords passed to the format-specific _empty constructor.
コード例 #16
0
ファイル: coo.py プロジェクト: cloudhan/jax
    spinfo : COOInfo object containing matrix metadata

  Returns:
    mat : array with specified shape and dtype matching ``data``
  """
  return coo_todense_p.bind(data, row, col, spinfo=spinfo)

@coo_todense_p.def_impl
def _coo_todense_impl(data, row, col, *, spinfo):
  return jnp.zeros(spinfo.shape, data.dtype).at[row, col].add(data)

@coo_todense_p.def_abstract_eval
def _coo_todense_abstract_eval(data, row, col, *, spinfo):
  return core.ShapedArray(spinfo.shape, data.dtype)

_coo_todense_lowering = mlir.lower_fun(
    _coo_todense_impl, multiple_results=False)

def _coo_todense_gpu_lowering(coo_todense_mhlo, ctx, data, row, col, *, spinfo):
  data_aval, row_aval, _ = ctx.avals_in
  dtype = data_aval.dtype
  if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
    warnings.warn(f"coo_todense cusparse/hipsparse lowering not available for dtype={dtype}. "
                  "Falling back to default implementation.", CuSparseEfficiencyWarning)
    return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)

  if spinfo.rows_sorted:
    shape = spinfo.shape
    transpose = False
  elif spinfo.cols_sorted:
    row, col = col, row
    transpose = True
コード例 #17
0
ファイル: custom_batching.py プロジェクト: xueeinstein/jax
      *primals, *tangents,
      call=jvp_call, rule=jvp_of_rule_rule,
      in_tree=jvp_in_tree, out_tree=jvp_out_tree)
  assert len(outs) % 2 == 0, len(outs)
  out_primals, out_tangents = util.split_list(outs, [len(outs) // 2])
  return out_primals, out_tangents


custom_vmap_p = core.Primitive('custom_vmap_call')
custom_vmap_p.multiple_results = True
custom_vmap_p.def_impl(custom_vmap_impl)
custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval)
batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching
ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp
xla.register_initial_style_primitive(custom_vmap_p)
mlir.register_lowering(custom_vmap_p, mlir.lower_fun(
    custom_vmap_impl, multiple_results=True))


# -- custom vmap applications


def tree_split(mask, tree):
  lhs = tree_map(lambda l, x: x if l else None, mask, tree)
  rhs = tree_map(lambda l, x: None if l else x, mask, tree)
  return lhs, rhs

def tree_merge(mask, lhs_tree, rhs_tree):
  return tree_map(lambda l, x_l, x_r: x_l if l else x_r,
                  mask, lhs_tree, rhs_tree)

def sequential_vmap(f):
コード例 #18
0
    res_arg, lin_arg = tree_unflatten(call_in_tree, args)
    del lin_arg
    assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

    cts = [
        ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct
        for ct in cts
    ]
    ct_out = tree_unflatten(out_tree, cts)
    ct_lin = transpose(res_arg, ct_out)
    check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin))
    ct_lin_flat, _ = tree_flatten(tree_broadcast(lin_tree,
                                                 ct_lin,
                                                 is_leaf=lambda x: x is None),
                                  is_leaf=lambda x: x is None)
    return [None] * len(tree_leaves(res_arg)) + ct_lin_flat


def custom_transpose_lowering(*args, call_jaxpr, **params):
    return core.jaxpr_as_fun(call_jaxpr)(*args)


custom_transpose_p = CustomTransposePrimitive('custom_transpose_call')
core.custom_typechecks[custom_transpose_p] = custom_transpose_typecheck
ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule
mlir.register_lowering(
    custom_transpose_p,
    mlir.lower_fun(custom_transpose_lowering, multiple_results=True))
xla.register_initial_style_primitive(custom_transpose_p)
コード例 #19
0
                              in_tree=jvp_in_tree,
                              out_tree=jvp_out_tree)
    assert len(outs) % 2 == 0, len(outs)
    out_primals, out_tangents = util.split_list(outs, [len(outs) // 2])
    return out_primals, out_tangents


custom_vmap_p = core.Primitive('custom_vmap_call')
custom_vmap_p.multiple_results = True
custom_vmap_p.def_impl(custom_vmap_impl)
custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval)
batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching
ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp
xla.register_initial_style_primitive(custom_vmap_p)
mlir.register_lowering(custom_vmap_p,
                       mlir.lower_fun(custom_vmap_impl, multiple_results=True))

# -- custom vmap applications


def tree_split(mask, tree):
    lhs = tree_map(lambda l, x: x if l else None, mask, tree)
    rhs = tree_map(lambda l, x: None if l else x, mask, tree)
    return lhs, rhs


def tree_merge(mask, lhs_tree, rhs_tree):
    return tree_map(lambda l, x_l, x_r: x_l
                    if l else x_r, mask, lhs_tree, rhs_tree)

コード例 #20
0
    def cond(carry):
        i, _ = carry
        return i < nsteps

    def body(carry):
        i, state = carry
        i_ = nsteps - i - 1 if reverse else i
        next_state = core.eval_jaxpr(discharged_jaxpr, consts, i_, *state)
        return i + 1, next_state

    _, state = lax.while_loop(cond, body, (jnp.int32(0), list(args)))
    return state


mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True))
for_p.def_impl(partial(xla.apply_primitive, for_p))


def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear):
    nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
    # We need to find out which `Ref`s have nonzero tangents after running the
    # for loop. Ordinarily we do this with a fixed point on the body jaxpr but
    # a `for` body jaxpr is stateful and has no outputs. We therefore discharge
    # the state effect from the jaxpr and we will now have a "symmetric" jaxpr
    # where the inputs line up with the outputs. We use this discharged jaxpr
    # for the fixed point.
    discharged_jaxpr, body_consts = discharge_state(jaxpr, ())
    for _ in range(len(nonzero_tangents)):
        _, out_nonzero_tangents = ad.jvp_jaxpr(core.ClosedJaxpr(
            discharged_jaxpr, body_consts), [False] + nonzero_tangents,