Пример #1
0
def gather_error_check(error, operand, start_indices, *,
                       dimension_numbers, slice_sizes, unique_indices,
                       indices_are_sorted, mode, fill_value):
  out = slicing.gather_p.bind(
      operand, start_indices, dimension_numbers=dimension_numbers,
      slice_sizes=slice_sizes, unique_indices=unique_indices,
      indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)

  # compare to OOB masking logic in lax._gather_translation_rule
  dnums = dimension_numbers
  operand_dims = np.array(operand.shape)

  upper_bound = operand_dims[np.array(dnums.start_index_map)]
  upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
  all_inbounds = jnp.all((start_indices >= 0) & (start_indices <= upper_bound))

  summary = source_info_util.summarize(source_info_util.current())
  msg = f"out-of-bounds indexing at {summary}"
  return out, assert_func(error, all_inbounds, msg)
Пример #2
0
def saved_residuals(f, *args,
                    **kwargs) -> List[Tuple[core.AbstractValue, str]]:
    args, in_tree = tree_flatten((args, kwargs))

    def f_(*args):
        args, kwargs = tree_unflatten(in_tree, args)
        return f(*args, **kwargs)

    jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1])(
        *args).jaxpr
    res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
    res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}

    results = []

    for x in res_lits:
        results.append((x.aval, 'from a literal'))

    for v in jaxpr.constvars:
        if v in res_vars:
            results.append((v.aval, 'from a constant'))

    assert len(jaxpr.invars) == len(args)
    for i, v in enumerate(jaxpr.invars):
        if v in res_vars:
            src = f'from {pe.arg_info_pytree(f, in_tree, True, [i])}'
            results.append((v.aval, src))

    for eqn in jaxpr.eqns:
        src = source_info_util.summarize(eqn.source_info)
        for v in eqn.outvars:
            if v in res_vars:
                if eqn.primitive is name_p:
                    results.append(
                        (v.aval, f"named '{eqn.params['name']}' from {src}"))
                else:
                    results.append((v.aval, f'from {src}'))

    assert len(results) == len(jaxpr.outvars)
    return results
Пример #3
0
def _pp_xla_call(
    eqn: core.JaxprEqn,
    context: core.JaxprPpContext,
    settings: core.JaxprPpSettings,
) -> List[pp.Doc]:
    printed_params = {
        k: v
        for k, v in eqn.params.items()
        if k == 'call_jaxpr' or k == 'name' or k == 'backend' and v is not None
        or k == 'device' and v is not None or k == 'donated_invars' and any(v)
    }
    annotation = (source_info_util.summarize(eqn.source_info)
                  if settings.source_info else None)
    lhs = core.pp_vars(eqn.outvars,
                       context,
                       print_shapes=settings.print_shapes)
    rhs = [
        pp.text(eqn.primitive.name),
        core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
        pp.text(" ") + core.pp_vars(eqn.invars, context)
    ]
    return [lhs, pp.text(" = ", annotation=annotation), *rhs]
Пример #4
0
 def key(eqn):
   return source_info_util.summarize(eqn.source_info)
Пример #5
0
 def key(eqn):
   src = source_info_util.summarize(eqn.source_info)
   return (eqn.primitive.name, src)
Пример #6
0
 def fmt_key(var, eqn):
   if eqn is None:
     return f'{var} <- invar'
   else:
     src = source_info_util.summarize(eqn.source_info)
     return f'{var} <- {eqn.primitive.name} @ {src}'
Пример #7
0
def summary() -> str:
    return str(source_info_util.summarize(source_info_util.current()))
Пример #8
0
def nan_error_check(prim, error, *in_vals, **params):
  out = prim.bind(*in_vals, **params)
  no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
  summary = source_info_util.summarize(source_info_util.current())
  msg = f"nan generated by primitive {prim.name} at {summary}"
  return out, assert_func(error, no_nans, msg)