def gather_error_check(error, operand, start_indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): out = slicing.gather_p.bind( operand, start_indices, dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value) # compare to OOB masking logic in lax._gather_translation_rule dnums = dimension_numbers operand_dims = np.array(operand.shape) upper_bound = operand_dims[np.array(dnums.start_index_map)] upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)] all_inbounds = jnp.all((start_indices >= 0) & (start_indices <= upper_bound)) summary = source_info_util.summarize(source_info_util.current()) msg = f"out-of-bounds indexing at {summary}" return out, assert_func(error, all_inbounds, msg)
def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]: args, in_tree = tree_flatten((args, kwargs)) def f_(*args): args, kwargs = tree_unflatten(in_tree, args) return f(*args, **kwargs) jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1])( *args).jaxpr res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)] res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)} results = [] for x in res_lits: results.append((x.aval, 'from a literal')) for v in jaxpr.constvars: if v in res_vars: results.append((v.aval, 'from a constant')) assert len(jaxpr.invars) == len(args) for i, v in enumerate(jaxpr.invars): if v in res_vars: src = f'from {pe.arg_info_pytree(f, in_tree, True, [i])}' results.append((v.aval, src)) for eqn in jaxpr.eqns: src = source_info_util.summarize(eqn.source_info) for v in eqn.outvars: if v in res_vars: if eqn.primitive is name_p: results.append( (v.aval, f"named '{eqn.params['name']}' from {src}")) else: results.append((v.aval, f'from {src}')) assert len(results) == len(jaxpr.outvars) return results
def _pp_xla_call( eqn: core.JaxprEqn, context: core.JaxprPpContext, settings: core.JaxprPpSettings, ) -> List[pp.Doc]: printed_params = { k: v for k, v in eqn.params.items() if k == 'call_jaxpr' or k == 'name' or k == 'backend' and v is not None or k == 'device' and v is not None or k == 'donated_invars' and any(v) } annotation = (source_info_util.summarize(eqn.source_info) if settings.source_info else None) lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes) rhs = [ pp.text(eqn.primitive.name), core.pp_kv_pairs(sorted(printed_params.items()), context, settings), pp.text(" ") + core.pp_vars(eqn.invars, context) ] return [lhs, pp.text(" = ", annotation=annotation), *rhs]
def key(eqn): return source_info_util.summarize(eqn.source_info)
def key(eqn): src = source_info_util.summarize(eqn.source_info) return (eqn.primitive.name, src)
def fmt_key(var, eqn): if eqn is None: return f'{var} <- invar' else: src = source_info_util.summarize(eqn.source_info) return f'{var} <- {eqn.primitive.name} @ {src}'
def summary() -> str: return str(source_info_util.summarize(source_info_util.current()))
def nan_error_check(prim, error, *in_vals, **params): out = prim.bind(*in_vals, **params) no_nans = jnp.logical_not(jnp.any(jnp.isnan(out))) summary = source_info_util.summarize(source_info_util.current()) msg = f"nan generated by primitive {prim.name} at {summary}" return out, assert_func(error, no_nans, msg)