Example #1
0
def _swap_pp_rule(eqn, context, settings):
    y, = eqn.outvars
    x, v, *idx = eqn.invars
    idx = ','.join(core.pp_var(i, context) for i in idx)
    if type(y) is core.DropVar:
        # In the case of a set (ignored return value),
        # pretty print `_ = swap x v i` as `x[i] <- v`
        del y
        return [
            pp_ref(
                pp.concat([
                    pp.text(core.pp_var(x, context)),
                    pp.text('['),
                    pp.text(idx),
                    pp.text(']')
                ])),
            pp.text(' <- '),
            pp.text(core.pp_var(v, context))
        ]
    else:
        # pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
        x_i = pp.concat([
            pp.text(core.pp_var(x, context)),
            pp.text('['),
            pp.text(idx),
            pp.text(']')
        ])
        y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
        return [
            y,
            pp.text(', '), x_i,
            pp.text(' <- '), x_i,
            pp.text(', '),
            pp.text(core.pp_var(v, context))
        ]
Example #2
0
File: xla.py Project: wayfeng/jax
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext
                 ) -> List[pp.Doc]:
  printed_params = {k:v for k, v in eqn.params.items() if
                    k == 'call_jaxpr' or k == 'name' or
                    k == 'backend' and v is not None or
                    k == 'device' and v is not None or
                    k == 'donated_invars' and any(v)}
  return [pp.text(eqn.primitive.name),
          core.pp_kv_pairs(sorted(printed_params.items()), context),
          pp.text(" ") + core.pp_vars(eqn.invars, context)]
Example #3
0
File: xla.py Project: romanngg/jax
def _pp_xla_call(
    eqn: core.JaxprEqn,
    context: core.JaxprPpContext,
    settings: core.JaxprPpSettings,
) -> List[pp.Doc]:
    printed_params = {
        k: v
        for k, v in eqn.params.items()
        if k == 'call_jaxpr' or k == 'name' or k == 'backend' and v is not None
        or k == 'device' and v is not None or k == 'donated_invars' and any(v)
    }
    annotation = (source_info_util.summarize(eqn.source_info)
                  if settings.source_info else None)
    lhs = core.pp_vars(eqn.outvars,
                       context,
                       print_shapes=settings.print_shapes)
    rhs = [
        pp.text(eqn.primitive.name),
        core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
        pp.text(" ") + core.pp_vars(eqn.invars, context)
    ]
    return [lhs, pp.text(" = ", annotation=annotation), *rhs]
Example #4
0
def _get_pp_rule(eqn, context, settings):
    # Pretty prints `a = get x i` as `x[i] <- a`
    y, = eqn.outvars
    x, *idx = eqn.invars
    idx = ','.join(core.pp_var(i, context) for i in idx)
    lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
    return [
        lhs,
        pp.text(' <- '),
        pp_ref(
            pp.concat([
                pp.text(core.pp_var(x, context)),
                pp.text('['),
                pp.text(idx),
                pp.text(']')
            ]))
    ]