def __repr__(self): arr_shape = self._shape pp_keys = pp.text('shape = ') + pp.text(str(arr_shape)) pp_impl = pp.text('impl = ') + self.impl.pprint() return str(pp.group( pp.text('PRNGKeyArray:') + pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl)))
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext ) -> List[pp.Doc]: printed_params = {k:v for k, v in eqn.params.items() if k == 'call_jaxpr' or k == 'name' or k == 'backend' and v is not None or k == 'device' and v is not None or k == 'donated_invars' and any(v)} return [pp.text(eqn.primitive.name), core.pp_kv_pairs(sorted(printed_params.items()), context), pp.text(" ") + core.pp_vars(eqn.invars, context)]
def _get_pp_rule(eqn, context, settings): # Pretty prints `a = get x i` as `x[i] <- a` y, = eqn.outvars x, *idx = eqn.invars idx = ','.join(core.pp_var(i, context) for i in idx) lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes) return [ lhs, pp.text(' <- '), pp_ref( pp.concat([ pp.text(core.pp_var(x, context)), pp.text('['), pp.text(idx), pp.text(']') ])) ]
def _pp_xla_call( eqn: core.JaxprEqn, context: core.JaxprPpContext, settings: core.JaxprPpSettings, ) -> List[pp.Doc]: printed_params = { k: v for k, v in eqn.params.items() if k == 'call_jaxpr' or k == 'name' or k == 'backend' and v is not None or k == 'device' and v is not None or k == 'donated_invars' and any(v) } annotation = (source_info_util.summarize(eqn.source_info) if settings.source_info else None) lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes) rhs = [ pp.text(eqn.primitive.name), core.pp_kv_pairs(sorted(printed_params.items()), context, settings), pp.text(" ") + core.pp_vars(eqn.invars, context) ] return [lhs, pp.text(" = ", annotation=annotation), *rhs]
def make_op_metadata(primitive: core.Primitive, params: Dict, *, source_info: source_info_util.SourceInfo, name_stack: str = "", ) -> xc.OpMetadata: eqn_str = str(pp.text(name_stack) + pp_eqn_compact(primitive.name, params, JaxprPpContext())) tracebacks[eqn_str] = source_info.traceback frame = source_info_util.user_frame(source_info) if source_info else None return xc.OpMetadata( op_type=primitive.name, op_name=eqn_str, source_file=_get_canonical_source_file(frame) if frame else None, source_line=frame.line_num if frame else None)
def _source_info_to_location(primitive: core.Primitive, params: Dict, source_info: source_info_util.SourceInfo, name_stack: str = "") -> ir.Location: eqn_str = str( pp.text(name_stack) + core.pp_eqn_compact(primitive.name, params, core.JaxprPpContext())) frame = source_info_util.user_frame(source_info) if frame is None: loc = ir.Location.unknown() else: loc = ir.Location.file(xla._get_canonical_source_file(frame), frame.line_num, 1) loc = ir.Location.name(eqn_str, childLoc=loc) # TODO(phawkins): also include primitive.name as the operator type. return loc
def _addupdate_pp_rule(eqn, context, settings): # pretty-print ` = addupdate x i v` as `x[i] += v` () = eqn.outvars x, v, *idx = eqn.invars idx = ','.join(core.pp_var(i, context) for i in idx) return [ pp_ref( pp.concat([ pp.text(core.pp_var(x, context)), pp.text('['), pp.text(idx), pp.text(']') ])), pp.text(' += '), pp.text(core.pp_var(v, context)) ]
def pprint(self): return (pp.text(f"{self.__class__.__name__}:") + pp.nest( 2, pp.group(pp.brk() + pp.join(pp.brk( ), [pp.text(f"{k} = {v}") for k, v in self._asdict().items()]))))
def _swap_pp_rule(eqn, context, settings): y, = eqn.outvars x, v, *idx = eqn.invars idx = ','.join(core.pp_var(i, context) for i in idx) if type(y) is core.DropVar: # In the case of a set (ignored return value), # pretty print `_ = swap x v i` as `x[i] <- v` del y return [ pp_ref( pp.concat([ pp.text(core.pp_var(x, context)), pp.text('['), pp.text(idx), pp.text(']') ])), pp.text(' <- '), pp.text(core.pp_var(v, context)) ] else: # pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v` x_i = pp.concat([ pp.text(core.pp_var(x, context)), pp.text('['), pp.text(idx), pp.text(']') ]) y = core.pp_vars([y], context, print_shapes=settings.print_shapes) return [ y, pp.text(', '), x_i, pp.text(' <- '), x_i, pp.text(', '), pp.text(core.pp_var(v, context)) ]