def _conv_general_dilated_lower( ctx, lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, precision, preferred_element_type, expand_complex_convolutions=False, **unused_kwargs): lhs_aval, rhs_aval = ctx.avals_in aval_out, = ctx.avals_out assert isinstance(dimension_numbers, ConvDimensionNumbers) dtype = lhs_aval.dtype if expand_complex_convolutions and np.issubdtype(dtype, np.complexfloating): if preferred_element_type is not None: # Convert complex dtype to types used for real and imaginary parts assert np.issubdtype(preferred_element_type, np.complexfloating) preferred_element_type = _real_dtype(preferred_element_type) complex_conv = mlir.lower_fun( partial( _complex_mul, partial(conv_general_dilated, window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=precision, preferred_element_type=preferred_element_type)), multiple_results=False) return complex_conv(ctx, lhs, rhs) lhs_spec, rhs_spec, out_spec = dimension_numbers dnums = mhlo.ConvDimensionNumbers.get( input_batch_dimension=lhs_spec[0], input_feature_dimension=lhs_spec[1], input_spatial_dimensions=list(lhs_spec[2:]), kernel_output_feature_dimension=rhs_spec[0], kernel_input_feature_dimension=rhs_spec[1], kernel_spatial_dimensions=list(rhs_spec[2:]), output_batch_dimension=out_spec[0], output_feature_dimension=out_spec[1], output_spatial_dimensions=list(out_spec[2:])) num_spatial_dims = len(rhs_spec) - 2 window_reversal = mlir.dense_bool_elements([False] * num_spatial_dims) return [ mhlo.ConvOp( mlir.aval_to_ir_type(aval_out), lhs, rhs, dimension_numbers=dnums, feature_group_count=mlir.i64_attr(feature_group_count), batch_group_count=mlir.i64_attr(batch_group_count), window_strides=mlir.dense_int_elements(window_strides), padding=mlir.dense_int_elements(padding), lhs_dilation=mlir.dense_int_elements(lhs_dilation), rhs_dilation=mlir.dense_int_elements(rhs_dilation), window_reversal=window_reversal, precision_config=lax.precision_attr(precision)).result ]
def _mlir(c, *mlir_args, **params): lowering = mlir.lower_fun(self.impl, multiple_results=True) return lowering(c, *mlir_args, **params)
@identity_p.def_impl def _identity_impl(mat): return mat @identity_p.def_abstract_eval def _identity_abstract_eval(mat): return AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz) xla.register_translation( identity_p, xla.lower_fun(_identity_impl, multiple_results=False, new_style=True)) mlir.register_lowering( identity_p, mlir.lower_fun(_identity_impl, multiple_results=False)) def split(x): return split_p.bind(x) split_p = core.Primitive('split') split_p.multiple_results = True @split_p.def_impl def _split_impl(mat): return mat, mat @split_p.def_abstract_eval def _split_abstract_eval(mat): m = AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz) return m, m
padding = [(0, 0) for _ in padding] out = _select_and_scatter(operand, select, window_dimensions, window_strides, padding, source, lax._zero(operand), scatter) if expand_padding: start_indices = [lo for (lo, hi) in original_padding] stop_indices = [ lo + d for ((lo, hi), d) in zip(original_padding, operand_shape) ] out = slicing.slice(out, start_indices, stop_indices) return out mlir.register_lowering( select_and_scatter_add_p, mlir.lower_fun(partial(_select_and_scatter_add_impl, expand_padding=False), multiple_results=False)) mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun(partial(_select_and_scatter_add_impl, expand_padding=True), multiple_results=False), platform='cpu') mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun(partial(_select_and_scatter_add_impl, expand_padding=True), multiple_results=False), platform='gpu') def _select_and_gather_add_shape_rule(tangents, operand, *, select_prim, window_dimensions, window_strides, padding, base_dilation, window_dilation):
elif is_gpu_platform: translation_rule = _remat_translation_using_while else: translation_rule = _remat_translation_using_cond else: translation_rule = lambda *args, jaxpr: core.eval_jaxpr( jaxpr, (), *args) return jax.named_call(translation_rule, name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr) for remat_primitive in (pe.remat_call_p, ad_checkpoint.remat_p): # type: ignore mlir.register_lowering(remat_primitive, mlir.lower_fun(remat_impl, multiple_results=True)) mlir.register_lowering(remat_primitive, mlir.lower_fun(partial(remat_impl, is_gpu_platform=True), multiple_results=True), platform="gpu") def _optimization_barrier_abstract_eval(*args): return args def _optimization_barrier_lowering_rule(ctx, *args): barrier_types = _map(mlir.aval_to_ir_types, ctx.avals_in) flat_barrier_types = util.flatten(barrier_types)
def identity(x): return identity_p.bind(x) identity_p = core.Primitive('identity') @identity_p.def_impl def _identity_impl(mat): return mat @identity_p.def_abstract_eval def _identity_abstract_eval(mat): return AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz) mlir.register_lowering( identity_p, mlir.lower_fun(_identity_impl, multiple_results=False)) def split(x): return split_p.bind(x) split_p = core.Primitive('split') split_p.multiple_results = True @split_p.def_impl def _split_impl(mat): return mat, mat @split_p.def_abstract_eval def _split_abstract_eval(mat): m = AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz) return m, m
del params # other params ignored because we're just executing the primal fun return core.jaxpr_as_fun(fun_jaxpr)(*args) def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **params): del args, params if fun_jaxpr.effects: raise NotImplementedError('Effects not supported in `custom_jvp`.') return fun_jaxpr.out_avals, fun_jaxpr.effects custom_jvp_call_jaxpr_p = core.AxisPrimitive('custom_jvp_call_jaxpr') custom_jvp_call_jaxpr_p.multiple_results = True custom_jvp_call_jaxpr_p.def_impl(_custom_jvp_call_jaxpr_impl) custom_jvp_call_jaxpr_p.def_effectful_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval) CustomJVPCallPrimitive.initial_style = custom_jvp_call_jaxpr_p mlir.register_lowering(custom_jvp_call_jaxpr_p, mlir.lower_fun( _custom_jvp_call_jaxpr_impl, multiple_results=True)) def _custom_jvp_call_jaxpr_jvp( primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], num_consts: int): _, args = split_list(primals, [num_consts]) consts_dot, args_dot = split_list(tangents, [num_consts]) if any(type(t) is not Zero for t in consts_dot): raise ad.CustomJVPException() jvp_jaxpr, jvp_consts = jvp_jaxpr_thunk() # consts can be tracers! args_dot = map(ad.instantiate_zeros, args_dot) # Cast float0 to zeros with the primal dtype because custom jvp rules don't # currently handle float0s args_dot = map(ad.replace_float0s, args, args_dot)
def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **params): del args, params return fun_jaxpr.out_avals custom_jvp_call_jaxpr_p = core.AxisPrimitive('custom_jvp_call_jaxpr') custom_jvp_call_jaxpr_p.multiple_results = True custom_jvp_call_jaxpr_p.def_impl(_custom_jvp_call_jaxpr_impl) custom_jvp_call_jaxpr_p.def_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval) CustomJVPCallPrimitive.initial_style = custom_jvp_call_jaxpr_p mlir.register_lowering( custom_jvp_call_jaxpr_p, mlir.lower_fun(_custom_jvp_call_jaxpr_impl, multiple_results=True)) def _custom_jvp_call_jaxpr_jvp(primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], num_consts: int): _, args = split_list(primals, [num_consts]) consts_dot, args_dot = split_list(tangents, [num_consts]) if any(type(t) is not Zero for t in consts_dot): raise ad.CustomJVPException() jvp_jaxpr, jvp_consts = jvp_jaxpr_thunk() # consts can be tracers! args_dot = map(ad.instantiate_zeros, args_dot) # Cast float0 to zeros with the primal dtype because custom jvp rules don't
elif isinstance(obj, BCOO): _, indices = bufs return bcoo_extract(indices, ct), indices elif isinstance(obj, COO): _, row, col = bufs return _coo_extract(row, col, ct), row, col else: raise NotImplementedError(f"todense_transpose for {type(obj)}") def _todense_batching_rule(batched_args, batch_dims, *, tree): return jax.vmap(partial(_todense_impl, tree=tree), batch_dims)(*batched_args), 0 ad.primitive_jvps[todense_p] = _todense_jvp ad.primitive_transposes[todense_p] = _todense_transpose batching.primitive_batchers[todense_p] = _todense_batching_rule mlir.register_lowering(todense_p, mlir.lower_fun( _todense_impl, multiple_results=False)) def empty(shape, dtype=None, index_dtype='int32', sparse_format='bcoo', **kwds): """Create an empty sparse array. Args: shape: sequence of integers giving the array shape. dtype: (optional) dtype of the array. index_dtype: (optional) dtype of the index arrays. format: string specifying the matrix format (e.g. ['bcoo']). **kwds: additional keywords passed to the format-specific _empty constructor. Returns: mat: empty sparse matrix. """ formats = {'bcoo': BCOO, 'coo': COO, 'csr': CSR, 'csc': CSC}
] # Broadcast out b if necessary new_b = [ batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat and d != 0 else x for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat) ] outs = linear_solve_p.bind(*(new_params + new_b), const_lengths=const_lengths, jaxprs=batched_jaxprs) out_dims = [ 0 if batched else batching.not_mapped for batched in solve_x_bat ] return outs, out_dims linear_solve_p = core.AxisPrimitive('custom_linear_solve') linear_solve_p.multiple_results = True linear_solve_p.def_impl(_custom_linear_solve_impl) linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval) ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp xla.register_initial_style_primitive(linear_solve_p) mlir.register_lowering( linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl, multiple_results=True)) ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule pe.partial_eval_jaxpr_custom_rules[linear_solve_p] = \ partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'linear_solve')
assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg)) cts = [ ad_util.zeros_like_aval(ct_aval) if type(ct) is ad_util.Zero else ct for ct, ct_aval in zip(cts, call.out_avals) ] ct_out = tree_unflatten(out_tree, cts) ct_lin = rule(res_arg, ct_out) ct_lin_flat, ct_lin_tree = tree_flatten(ct_lin) check_transpose_rule_trees(rule, lin_tree, ct_lin_tree) return [None] * len(tree_leaves(res_arg)) + ct_lin_flat def custom_transpose_abstract_eval(*in_avals, call, **_): return call.out_avals custom_transpose_p = core.Primitive('custom_transpose_call') custom_transpose_p.multiple_results = True custom_transpose_p.def_impl(custom_transpose_impl) custom_transpose_p.def_abstract_eval(custom_transpose_abstract_eval) ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule xla.register_translation(custom_transpose_p, xla.lower_fun(custom_transpose_impl, new_style=True, multiple_results=True), initial_style=True) mlir.register_lowering( custom_transpose_p, mlir.lower_fun(custom_transpose_impl, multiple_results=True))
else lax._get_min_identity) pads = [(lo, hi, 0) for (lo, hi) in padding] operand = lax.pad(operand, identity(dtype), pads) padding = [(0, 0) for _ in padding] out = _select_and_scatter( operand, select, window_dimensions, window_strides, padding, source, lax._zero(operand), scatter) if expand_padding: start_indices = [lo for (lo, hi) in original_padding] stop_indices = [lo + d for ((lo, hi), d) in zip(original_padding, operand_shape)] out = slicing.slice(out, start_indices, stop_indices) return out mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun( partial(_select_and_scatter_add_impl, expand_padding=False), multiple_results=False)) # TODO(b/161704903): workaround for XLA/CPU crash. mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun( partial(_select_and_scatter_add_impl, expand_padding=True), multiple_results=False), platform='cpu') # TODO(b/182390722): workaround for XLA/GPU crash. mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun( partial(_select_and_scatter_add_impl, expand_padding=True), multiple_results=False), platform='gpu') def _select_and_gather_add_shape_rule( tangents, operand, *, select_prim, window_dimensions, window_strides, padding, base_dilation, window_dilation): if tangents.shape != operand.shape:
mlir.dense_int_elements(range(rank - len(aval.shape), rank))).result return threefry2x32_lowering( (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval))) threefry2x32_p = core.Primitive("threefry2x32") threefry2x32_p.multiple_results = True threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p)) threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval) batching.defbroadcasting(threefry2x32_p) mlir.register_lowering( threefry2x32_p, mlir.lower_fun(partial(_threefry2x32_lowering, use_rolled_loops=False), multiple_results=True)) mlir.register_lowering(threefry2x32_p, mlir.lower_fun(partial(_threefry2x32_lowering, use_rolled_loops=True), multiple_results=True), platform='cpu') if gpu_prng: mlir.register_lowering(threefry2x32_p, partial(_threefry2x32_gpu_lowering, gpu_prng.cuda_threefry2x32), platform='cuda') mlir.register_lowering(threefry2x32_p, partial(_threefry2x32_gpu_lowering, gpu_prng.rocm_threefry2x32), platform='rocm')
def _broadcast(x, aval): return mhlo.BroadcastInDimOp( mlir.aval_to_ir_type(aval_out), x, mlir.dense_int_elements(range(rank - len(aval.shape), rank))).result return threefry2x32_lowering( (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval))) threefry2x32_p = core.Primitive("threefry2x32") threefry2x32_p.multiple_results = True threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p)) threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval) batching.defbroadcasting(threefry2x32_p) mlir.register_lowering(threefry2x32_p, mlir.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=False), multiple_results=True)) mlir.register_lowering(threefry2x32_p, mlir.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=True), multiple_results=True), platform='cpu') if gpu_prng: mlir.register_lowering( threefry2x32_p, partial(_threefry2x32_gpu_lowering, gpu_prng.cuda_threefry2x32), platform='cuda') mlir.register_lowering( threefry2x32_p, partial(_threefry2x32_gpu_lowering, gpu_prng.rocm_threefry2x32), platform='rocm')
def _todense_batching_rule(batched_args, batch_dims, *, tree): return jax.vmap(partial(_todense_impl, tree=tree), batch_dims)(*batched_args), 0 ad.primitive_jvps[todense_p] = _todense_jvp ad.primitive_transposes[todense_p] = _todense_transpose batching.primitive_batchers[todense_p] = _todense_batching_rule xla.register_translation( todense_p, xla.lower_fun(_todense_impl, multiple_results=False, new_style=True)) mlir.register_lowering(todense_p, mlir.lower_fun(_todense_impl, multiple_results=False)) def empty(shape, dtype=None, index_dtype='int32', sparse_format='bcoo', **kwds): """Create an empty sparse array. Args: shape: sequence of integers giving the array shape. dtype: (optional) dtype of the array. index_dtype: (optional) dtype of the index arrays. format: string specifying the matrix format (e.g. ['bcoo']). **kwds: additional keywords passed to the format-specific _empty constructor.
spinfo : COOInfo object containing matrix metadata Returns: mat : array with specified shape and dtype matching ``data`` """ return coo_todense_p.bind(data, row, col, spinfo=spinfo) @coo_todense_p.def_impl def _coo_todense_impl(data, row, col, *, spinfo): return jnp.zeros(spinfo.shape, data.dtype).at[row, col].add(data) @coo_todense_p.def_abstract_eval def _coo_todense_abstract_eval(data, row, col, *, spinfo): return core.ShapedArray(spinfo.shape, data.dtype) _coo_todense_lowering = mlir.lower_fun( _coo_todense_impl, multiple_results=False) def _coo_todense_gpu_lowering(coo_todense_mhlo, ctx, data, row, col, *, spinfo): data_aval, row_aval, _ = ctx.avals_in dtype = data_aval.dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"coo_todense cusparse/hipsparse lowering not available for dtype={dtype}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo) if spinfo.rows_sorted: shape = spinfo.shape transpose = False elif spinfo.cols_sorted: row, col = col, row transpose = True
*primals, *tangents, call=jvp_call, rule=jvp_of_rule_rule, in_tree=jvp_in_tree, out_tree=jvp_out_tree) assert len(outs) % 2 == 0, len(outs) out_primals, out_tangents = util.split_list(outs, [len(outs) // 2]) return out_primals, out_tangents custom_vmap_p = core.Primitive('custom_vmap_call') custom_vmap_p.multiple_results = True custom_vmap_p.def_impl(custom_vmap_impl) custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval) batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp xla.register_initial_style_primitive(custom_vmap_p) mlir.register_lowering(custom_vmap_p, mlir.lower_fun( custom_vmap_impl, multiple_results=True)) # -- custom vmap applications def tree_split(mask, tree): lhs = tree_map(lambda l, x: x if l else None, mask, tree) rhs = tree_map(lambda l, x: None if l else x, mask, tree) return lhs, rhs def tree_merge(mask, lhs_tree, rhs_tree): return tree_map(lambda l, x_l, x_r: x_l if l else x_r, mask, lhs_tree, rhs_tree) def sequential_vmap(f):
res_arg, lin_arg = tree_unflatten(call_in_tree, args) del lin_arg assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg)) cts = [ ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct for ct in cts ] ct_out = tree_unflatten(out_tree, cts) ct_lin = transpose(res_arg, ct_out) check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin)) ct_lin_flat, _ = tree_flatten(tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None), is_leaf=lambda x: x is None) return [None] * len(tree_leaves(res_arg)) + ct_lin_flat def custom_transpose_lowering(*args, call_jaxpr, **params): return core.jaxpr_as_fun(call_jaxpr)(*args) custom_transpose_p = CustomTransposePrimitive('custom_transpose_call') core.custom_typechecks[custom_transpose_p] = custom_transpose_typecheck ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule mlir.register_lowering( custom_transpose_p, mlir.lower_fun(custom_transpose_lowering, multiple_results=True)) xla.register_initial_style_primitive(custom_transpose_p)
in_tree=jvp_in_tree, out_tree=jvp_out_tree) assert len(outs) % 2 == 0, len(outs) out_primals, out_tangents = util.split_list(outs, [len(outs) // 2]) return out_primals, out_tangents custom_vmap_p = core.Primitive('custom_vmap_call') custom_vmap_p.multiple_results = True custom_vmap_p.def_impl(custom_vmap_impl) custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval) batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp xla.register_initial_style_primitive(custom_vmap_p) mlir.register_lowering(custom_vmap_p, mlir.lower_fun(custom_vmap_impl, multiple_results=True)) # -- custom vmap applications def tree_split(mask, tree): lhs = tree_map(lambda l, x: x if l else None, mask, tree) rhs = tree_map(lambda l, x: None if l else x, mask, tree) return lhs, rhs def tree_merge(mask, lhs_tree, rhs_tree): return tree_map(lambda l, x_l, x_r: x_l if l else x_r, mask, lhs_tree, rhs_tree)
def cond(carry): i, _ = carry return i < nsteps def body(carry): i, state = carry i_ = nsteps - i - 1 if reverse else i next_state = core.eval_jaxpr(discharged_jaxpr, consts, i_, *state) return i + 1, next_state _, state = lax.while_loop(cond, body, (jnp.int32(0), list(args))) return state mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True)) for_p.def_impl(partial(xla.apply_primitive, for_p)) def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear): nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] # We need to find out which `Ref`s have nonzero tangents after running the # for loop. Ordinarily we do this with a fixed point on the body jaxpr but # a `for` body jaxpr is stateful and has no outputs. We therefore discharge # the state effect from the jaxpr and we will now have a "symmetric" jaxpr # where the inputs line up with the outputs. We use this discharged jaxpr # for the fixed point. discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) for _ in range(len(nonzero_tangents)): _, out_nonzero_tangents = ad.jvp_jaxpr(core.ClosedJaxpr( discharged_jaxpr, body_consts), [False] + nonzero_tangents,