コード例 #1
0
ファイル: host_callback.py プロジェクト: yuejiesong1900/jax
 def pp_val(arg) -> ppu.PrettyPrint:
   if isinstance(arg, (tuple, list)):
     return (
         ppu.pp("[ ") >> ppu.vcat([pp_val(e) for e in arg]) >> ppu.pp(" ]"))
   elif isinstance(arg, dict):
     return (ppu.pp("{ ") >> ppu.vcat([
         ppu.pp(f"{k}=") >> pp_val(v) for k, v in sorted(arg.items())
     ]) >> ppu.pp(" }"))
   elif isinstance(arg, np.ndarray):
     return ppu.pp(np.array2string(arg, threshold=threshold))
   else:
     return ppu.pp(str(arg))
コード例 #2
0
def pp_djaxpr(jaxpr: DJaxpr) -> PrettyPrint:
  eqns = map(pp_eqn, jaxpr.eqns)
  in_dim_binders = pp_vars(jaxpr.in_dim_binders)
  in_binders = pp_vars(jaxpr.in_binders)
  out_dims = ', '.join(map(str, jaxpr.out_dims))
  outs = ', '.join(map(str, jaxpr.outs))
  out_dim_types = pp_vars(jaxpr.out_dims)
  outs_type = ', '.join(v.aval.str_short() for v in jaxpr.outs)
  return (pp(f'{{ lambda {in_dim_binders} ; {in_binders} .')
          + (pp('let ') >> vcat(eqns) +
             pp(f'in ( {out_dims} ; {outs} ) '
                f': ( {out_dim_types} ; {outs_type} ) }}')).indent(2))
コード例 #3
0
 def pprint(self):
     return (pp(self.__class__.__name__) >> pp(':')) + vcat(
         [pp(k) >> pp(' = ') >> pp(v)
          for k, v in self._asdict().items()]).indent(2)