Example #1
0
def jaxpr_collectives(jaxpr):
    """Generates all the collective primitives anywhere inside a Jaxpr."""
    for eqn in jaxpr.eqns:
        if eqn.primitive in _collective_primitives:
            yield eqn.primitive
    for subjaxpr in core.subjaxprs(jaxpr):
        yield from jaxpr_collectives(subjaxpr)
Example #2
0
def var_defs_and_refs(jaxpr: core.Jaxpr):
  defs: Dict[core.Var, MaybeEqn] = {}
  refs: Dict[core.Var, List[MaybeEqn]] = {}

  def read(a: core.Atom, eqn: MaybeEqn):
    if not isinstance(a, core.Literal):
      assert a in defs, a
      assert a in refs, a
      refs[a].append(eqn)

  def write(v: core.Var, eqn: MaybeEqn):
    assert v not in defs, v
    assert v not in refs, v
    if not isinstance(v, core.DropVar):
      defs[v] = eqn
      refs[v] = []

  for v in jaxpr.constvars:
    write(v, None)
  for v in jaxpr.invars:
    write(v, None)

  for eqn in jaxpr.eqns:
    for a in eqn.invars:
      read(a, eqn)
    for v in eqn.outvars:
      write(v, eqn)

  for a in jaxpr.outvars:
    read(a, None)

  res = [(v, defs[v], refs[v]) for v in defs]
  subs = map(var_defs_and_refs, core.subjaxprs(jaxpr))
  return [(jaxpr, res), *subs] if subs else (jaxpr, res)
Example #3
0
def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:
  return (any(type(d) is core.Var for v in jaxpr.invars
              if type(v.aval) is core.DShapedArray for d in v.aval.shape) or
          any(type(d) is core.Var
              for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr))
              for e in j.eqns for v in itertools.chain(e.invars, e.outvars)
              if type(v.aval) is core.DShapedArray for d in v.aval.shape))
Example #4
0
def jaxpr_literals(jaxpr):
  """Generates all the literals inside a jaxpr, including nested subjaxprs."""
  for eqn in jaxpr.eqns:
    for v in eqn.invars:
      if type(v) is core.Literal:
        yield v.val
  for subjaxpr in core.subjaxprs(jaxpr):
    yield from jaxpr_literals(subjaxpr)
Example #5
0
def jaxpr_has_pmap(jaxpr):
  """Whether there is an xla_pmap primitive anywhere inside a Jaxpr."""
  for eqn in jaxpr.eqns:
    if 'xla_pmap' in eqn.primitive.name:
      return True
  for subjaxpr in core.subjaxprs(jaxpr):
    if jaxpr_has_pmap(subjaxpr):
      return True
  return False
Example #6
0
def all_eqns(jaxpr: core.Jaxpr):
  for eqn in jaxpr.eqns:
    yield (jaxpr, eqn)
  for subjaxpr in core.subjaxprs(jaxpr):
    yield from all_eqns(subjaxpr)
Example #7
0
def iter_eqns(jaxpr):
    # TODO(necula): why doesn't this search in params?
    for eqn in jaxpr.eqns:
        yield eqn
    for subjaxpr in core.subjaxprs(jaxpr):
        yield from iter_eqns(subjaxpr)