def jaxpr_collectives(jaxpr): """Generates all the collective primitives anywhere inside a Jaxpr.""" for eqn in jaxpr.eqns: if eqn.primitive in _collective_primitives: yield eqn.primitive for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_collectives(subjaxpr)
def var_defs_and_refs(jaxpr: core.Jaxpr): defs: Dict[core.Var, MaybeEqn] = {} refs: Dict[core.Var, List[MaybeEqn]] = {} def read(a: core.Atom, eqn: MaybeEqn): if not isinstance(a, core.Literal): assert a in defs, a assert a in refs, a refs[a].append(eqn) def write(v: core.Var, eqn: MaybeEqn): assert v not in defs, v assert v not in refs, v if not isinstance(v, core.DropVar): defs[v] = eqn refs[v] = [] for v in jaxpr.constvars: write(v, None) for v in jaxpr.invars: write(v, None) for eqn in jaxpr.eqns: for a in eqn.invars: read(a, eqn) for v in eqn.outvars: write(v, eqn) for a in jaxpr.outvars: read(a, None) res = [(v, defs[v], refs[v]) for v in defs] subs = map(var_defs_and_refs, core.subjaxprs(jaxpr)) return [(jaxpr, res), *subs] if subs else (jaxpr, res)
def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool: return (any(type(d) is core.Var for v in jaxpr.invars if type(v.aval) is core.DShapedArray for d in v.aval.shape) or any(type(d) is core.Var for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr)) for e in j.eqns for v in itertools.chain(e.invars, e.outvars) if type(v.aval) is core.DShapedArray for d in v.aval.shape))
def jaxpr_literals(jaxpr): """Generates all the literals inside a jaxpr, including nested subjaxprs.""" for eqn in jaxpr.eqns: for v in eqn.invars: if type(v) is core.Literal: yield v.val for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_literals(subjaxpr)
def jaxpr_has_pmap(jaxpr): """Whether there is an xla_pmap primitive anywhere inside a Jaxpr.""" for eqn in jaxpr.eqns: if 'xla_pmap' in eqn.primitive.name: return True for subjaxpr in core.subjaxprs(jaxpr): if jaxpr_has_pmap(subjaxpr): return True return False
def all_eqns(jaxpr: core.Jaxpr): for eqn in jaxpr.eqns: yield (jaxpr, eqn) for subjaxpr in core.subjaxprs(jaxpr): yield from all_eqns(subjaxpr)
def iter_eqns(jaxpr): # TODO(necula): why doesn't this search in params? for eqn in jaxpr.eqns: yield eqn for subjaxpr in core.subjaxprs(jaxpr): yield from iter_eqns(subjaxpr)