예제 #1
0
파일: prng.py 프로젝트: Jakob-Unfried/jax
 def __repr__(self):
   arr_shape = self._shape
   pp_keys = pp.text('shape = ') + pp.text(str(arr_shape))
   pp_impl = pp.text('impl = ') + self.impl.pprint()
   return str(pp.group(
     pp.text('PRNGKeyArray:') +
     pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl)))
예제 #2
0
파일: xla.py 프로젝트: wayfeng/jax
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext
                 ) -> List[pp.Doc]:
  printed_params = {k:v for k, v in eqn.params.items() if
                    k == 'call_jaxpr' or k == 'name' or
                    k == 'backend' and v is not None or
                    k == 'device' and v is not None or
                    k == 'donated_invars' and any(v)}
  return [pp.text(eqn.primitive.name),
          core.pp_kv_pairs(sorted(printed_params.items()), context),
          pp.text(" ") + core.pp_vars(eqn.invars, context)]
예제 #3
0
def _get_pp_rule(eqn, context, settings):
    # Pretty prints `a = get x i` as `x[i] <- a`
    y, = eqn.outvars
    x, *idx = eqn.invars
    idx = ','.join(core.pp_var(i, context) for i in idx)
    lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
    return [
        lhs,
        pp.text(' <- '),
        pp_ref(
            pp.concat([
                pp.text(core.pp_var(x, context)),
                pp.text('['),
                pp.text(idx),
                pp.text(']')
            ]))
    ]
예제 #4
0
파일: xla.py 프로젝트: romanngg/jax
def _pp_xla_call(
    eqn: core.JaxprEqn,
    context: core.JaxprPpContext,
    settings: core.JaxprPpSettings,
) -> List[pp.Doc]:
    printed_params = {
        k: v
        for k, v in eqn.params.items()
        if k == 'call_jaxpr' or k == 'name' or k == 'backend' and v is not None
        or k == 'device' and v is not None or k == 'donated_invars' and any(v)
    }
    annotation = (source_info_util.summarize(eqn.source_info)
                  if settings.source_info else None)
    lhs = core.pp_vars(eqn.outvars,
                       context,
                       print_shapes=settings.print_shapes)
    rhs = [
        pp.text(eqn.primitive.name),
        core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
        pp.text(" ") + core.pp_vars(eqn.invars, context)
    ]
    return [lhs, pp.text(" = ", annotation=annotation), *rhs]
예제 #5
0
파일: xla.py 프로젝트: wayfeng/jax
def make_op_metadata(primitive: core.Primitive,
                     params: Dict, *,
                     source_info: source_info_util.SourceInfo,
                     name_stack: str = "",
                     ) -> xc.OpMetadata:
  eqn_str = str(pp.text(name_stack) +
                pp_eqn_compact(primitive.name, params, JaxprPpContext()))
  tracebacks[eqn_str] = source_info.traceback
  frame = source_info_util.user_frame(source_info) if source_info else None
  return xc.OpMetadata(
        op_type=primitive.name,
        op_name=eqn_str,
        source_file=_get_canonical_source_file(frame) if frame else None,
        source_line=frame.line_num if frame else None)
예제 #6
0
파일: mlir.py 프로젝트: ahoenselaar/jax
def _source_info_to_location(primitive: core.Primitive,
                             params: Dict,
                             source_info: source_info_util.SourceInfo,
                             name_stack: str = "") -> ir.Location:
    eqn_str = str(
        pp.text(name_stack) +
        core.pp_eqn_compact(primitive.name, params, core.JaxprPpContext()))
    frame = source_info_util.user_frame(source_info)
    if frame is None:
        loc = ir.Location.unknown()
    else:
        loc = ir.Location.file(xla._get_canonical_source_file(frame),
                               frame.line_num, 1)
    loc = ir.Location.name(eqn_str, childLoc=loc)
    # TODO(phawkins): also include primitive.name as the operator type.
    return loc
예제 #7
0
def _addupdate_pp_rule(eqn, context, settings):
    # pretty-print ` = addupdate x i v` as `x[i] += v`
    () = eqn.outvars
    x, v, *idx = eqn.invars
    idx = ','.join(core.pp_var(i, context) for i in idx)
    return [
        pp_ref(
            pp.concat([
                pp.text(core.pp_var(x, context)),
                pp.text('['),
                pp.text(idx),
                pp.text(']')
            ])),
        pp.text(' += '),
        pp.text(core.pp_var(v, context))
    ]
예제 #8
0
 def pprint(self):
     return (pp.text(f"{self.__class__.__name__}:") + pp.nest(
         2,
         pp.group(pp.brk() + pp.join(pp.brk(
         ), [pp.text(f"{k} = {v}") for k, v in self._asdict().items()]))))
예제 #9
0
def _swap_pp_rule(eqn, context, settings):
    y, = eqn.outvars
    x, v, *idx = eqn.invars
    idx = ','.join(core.pp_var(i, context) for i in idx)
    if type(y) is core.DropVar:
        # In the case of a set (ignored return value),
        # pretty print `_ = swap x v i` as `x[i] <- v`
        del y
        return [
            pp_ref(
                pp.concat([
                    pp.text(core.pp_var(x, context)),
                    pp.text('['),
                    pp.text(idx),
                    pp.text(']')
                ])),
            pp.text(' <- '),
            pp.text(core.pp_var(v, context))
        ]
    else:
        # pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
        x_i = pp.concat([
            pp.text(core.pp_var(x, context)),
            pp.text('['),
            pp.text(idx),
            pp.text(']')
        ])
        y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
        return [
            y,
            pp.text(', '), x_i,
            pp.text(' <- '), x_i,
            pp.text(', '),
            pp.text(core.pp_var(v, context))
        ]