def pp_eqn(eqn: core.JaxprEqn) -> PrettyPrint: lhs = pp_vars(eqn.outvars) pp_lhs = pp(f'{lhs} =') pp_rhs = (pp(eqn.primitive.name) >> core.pp_kv_pairs(sorted(eqn.params.items())) >> pp(' ') >> pp(' '.join(map(str, eqn.invars)))) return pp_lhs >> pp(' ') >> pp_rhs
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext ) -> List[pp.Doc]: printed_params = {k:v for k, v in eqn.params.items() if k == 'call_jaxpr' or k == 'name' or k == 'backend' and v is not None or k == 'device' and v is not None or k == 'donated_invars' and any(v)} return [pp.text(eqn.primitive.name), core.pp_kv_pairs(sorted(printed_params.items()), context), pp.text(" ") + core.pp_vars(eqn.invars, context)]
def _pp_xla_call( eqn: core.JaxprEqn, context: core.JaxprPpContext, settings: core.JaxprPpSettings, ) -> List[pp.Doc]: printed_params = { k: v for k, v in eqn.params.items() if k == 'call_jaxpr' or k == 'name' or k == 'backend' and v is not None or k == 'device' and v is not None or k == 'donated_invars' and any(v) } annotation = (source_info_util.summarize(eqn.source_info) if settings.source_info else None) lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes) rhs = [ pp.text(eqn.primitive.name), core.pp_kv_pairs(sorted(printed_params.items()), context, settings), pp.text(" ") + core.pp_vars(eqn.invars, context) ] return [lhs, pp.text(" = ", annotation=annotation), *rhs]