def _swap_pp_rule(eqn, context, settings): y, = eqn.outvars x, v, *idx = eqn.invars idx = ','.join(core.pp_var(i, context) for i in idx) if type(y) is core.DropVar: # In the case of a set (ignored return value), # pretty print `_ = swap x v i` as `x[i] <- v` del y return [ pp_ref( pp.concat([ pp.text(core.pp_var(x, context)), pp.text('['), pp.text(idx), pp.text(']') ])), pp.text(' <- '), pp.text(core.pp_var(v, context)) ] else: # pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v` x_i = pp.concat([ pp.text(core.pp_var(x, context)), pp.text('['), pp.text(idx), pp.text(']') ]) y = core.pp_vars([y], context, print_shapes=settings.print_shapes) return [ y, pp.text(', '), x_i, pp.text(' <- '), x_i, pp.text(', '), pp.text(core.pp_var(v, context)) ]
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext ) -> List[pp.Doc]: printed_params = {k:v for k, v in eqn.params.items() if k == 'call_jaxpr' or k == 'name' or k == 'backend' and v is not None or k == 'device' and v is not None or k == 'donated_invars' and any(v)} return [pp.text(eqn.primitive.name), core.pp_kv_pairs(sorted(printed_params.items()), context), pp.text(" ") + core.pp_vars(eqn.invars, context)]
def _pp_xla_call( eqn: core.JaxprEqn, context: core.JaxprPpContext, settings: core.JaxprPpSettings, ) -> List[pp.Doc]: printed_params = { k: v for k, v in eqn.params.items() if k == 'call_jaxpr' or k == 'name' or k == 'backend' and v is not None or k == 'device' and v is not None or k == 'donated_invars' and any(v) } annotation = (source_info_util.summarize(eqn.source_info) if settings.source_info else None) lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes) rhs = [ pp.text(eqn.primitive.name), core.pp_kv_pairs(sorted(printed_params.items()), context, settings), pp.text(" ") + core.pp_vars(eqn.invars, context) ] return [lhs, pp.text(" = ", annotation=annotation), *rhs]
def _get_pp_rule(eqn, context, settings): # Pretty prints `a = get x i` as `x[i] <- a` y, = eqn.outvars x, *idx = eqn.invars idx = ','.join(core.pp_var(i, context) for i in idx) lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes) return [ lhs, pp.text(' <- '), pp_ref( pp.concat([ pp.text(core.pp_var(x, context)), pp.text('['), pp.text(idx), pp.text(']') ])) ]