Ejemplo n.º 1
0
def pp_eqn(eqn: core.JaxprEqn) -> PrettyPrint:
  lhs = pp_vars(eqn.outvars)
  pp_lhs = pp(f'{lhs} =')
  pp_rhs = (pp(eqn.primitive.name) >>
            core.pp_kv_pairs(sorted(eqn.params.items())) >> pp(' ') >>
            pp(' '.join(map(str, eqn.invars))))
  return pp_lhs >> pp(' ') >> pp_rhs
Ejemplo n.º 2
0
Archivo: xla.py Proyecto: wayfeng/jax
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext
                 ) -> List[pp.Doc]:
  printed_params = {k:v for k, v in eqn.params.items() if
                    k == 'call_jaxpr' or k == 'name' or
                    k == 'backend' and v is not None or
                    k == 'device' and v is not None or
                    k == 'donated_invars' and any(v)}
  return [pp.text(eqn.primitive.name),
          core.pp_kv_pairs(sorted(printed_params.items()), context),
          pp.text(" ") + core.pp_vars(eqn.invars, context)]
Ejemplo n.º 3
0
Archivo: xla.py Proyecto: romanngg/jax
def _pp_xla_call(
    eqn: core.JaxprEqn,
    context: core.JaxprPpContext,
    settings: core.JaxprPpSettings,
) -> List[pp.Doc]:
    printed_params = {
        k: v
        for k, v in eqn.params.items()
        if k == 'call_jaxpr' or k == 'name' or k == 'backend' and v is not None
        or k == 'device' and v is not None or k == 'donated_invars' and any(v)
    }
    annotation = (source_info_util.summarize(eqn.source_info)
                  if settings.source_info else None)
    lhs = core.pp_vars(eqn.outvars,
                       context,
                       print_shapes=settings.print_shapes)
    rhs = [
        pp.text(eqn.primitive.name),
        core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
        pp.text(" ") + core.pp_vars(eqn.invars, context)
    ]
    return [lhs, pp.text(" = ", annotation=annotation), *rhs]