def _sparsify_jaxpr(spenv, jaxpr, *spvalues): # TODO(jakevdp): currently this approach discards all information about # shared data & indices when generating the sparsified jaxpr. The # current approach produces valid sparsified while loops, but they # don't work in corner cases (see associated TODO in sparsify_test.py) out_tree = None @lu.wrap_init def wrapped(*args_flat): # TODO(frostig,jakevdp): This closes over `spenv`, which can bring # in buffers from the "outer scope" as constants. Is this a # problem for primitives like cond and while_loop, which always # convert constvars to invars when staging out their subjaxprs? nonlocal out_tree args = tree_unflatten(in_tree, args_flat) spvalues = arrays_to_spvalues(spenv, args) result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, spvalues, spenv) out = spvalues_to_arrays(spenv, result) out_flat, out_tree = tree_flatten(out) return out_flat args = spvalues_to_arrays(spenv, spvalues) args_flat, in_tree = tree_flatten(args) avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat] sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat) sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts) return sp_jaxpr, out_tree
def _sparsify_jaxpr(spenv, jaxpr, *argspecs): # TODO(jakevdp): currently this approach discards all information about # shared data & indices when generating the sparsified jaxpr. The # current approach produces valid sparsified while loops, but they # don't work in corner cases (see associated TODO in sparsify_test.py) out_tree = None @lu.wrap_init def wrapped(*args_flat): nonlocal out_tree args = tree_unflatten(in_tree, args_flat) argspecs = arrays_to_argspecs(spenv, args) result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, argspecs, spenv) out = argspecs_to_arrays(spenv, result) out_flat, out_tree = tree_flatten(out) return out_flat args = argspecs_to_arrays(spenv, argspecs) args_flat, in_tree = tree_flatten(args) avals_flat = [ core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat ] sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat) sp_jaxpr = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(sp_jaxpr), consts) return sp_jaxpr, out_tree
def eval_sparse( jaxpr: core.Jaxpr, consts: Sequence[Array], # all consts are dense argspecs: Sequence[ArgSpec], # mix of sparse and dense pointers into spenv spenv: SparseEnv, ) -> Sequence[ArgSpec]: env: Dict[core.Var, ArgSpec] = {} def read(var: core.Var) -> Union[Array, ArgSpec]: # all literals are dense if isinstance(var, core.Literal): return ArgSpec(np.shape(var.val), spenv.push(var.val), None) else: return env[var] def write_buffer(var: core.Var, a: Array) -> None: if var is core.dropvar: return env[var] = ArgSpec(a.shape, spenv.push(a), None) def write(var: core.Var, a: ArgSpec) -> None: if var is core.dropvar: return env[var] = a # TODO: handle unitvar at all? #write_buffer(core.unitvar, core.unit) safe_map(write_buffer, jaxpr.constvars, consts) safe_map(write, jaxpr.invars, argspecs) for eqn in jaxpr.eqns: prim = eqn.primitive invals = safe_map(read, eqn.invars) if any(val.is_sparse() for val in invals): if prim not in sparse_rules: raise NotImplementedError(f"sparse rule for {prim}") out = sparse_rules[prim](spenv, *invals, **eqn.params) else: if prim is xla.xla_call_p: # TODO(vanderplas,frostig): workaround for binding call primitives # within a jaxpr interpreter params = eqn.params.copy() fun = lu.wrap_init( core.jaxpr_as_fun( pe.ClosedJaxpr(params.pop('call_jaxpr'), ()))) out_bufs = prim.bind(fun, *(val.data(spenv) for val in invals), **params) else: out_bufs = prim.bind(*(val.data(spenv) for val in invals), **eqn.params) out_bufs = out_bufs if prim.multiple_results else [out_bufs] out = [] for buf in out_bufs: out.append(ArgSpec(buf.shape, spenv.push(buf), None)) safe_map(write, eqn.outvars, out) return safe_map(read, jaxpr.outvars)
def _xla_call_sparse(spenv, *spvalues, call_jaxpr, donated_invars, **params): if any(donated_invars): raise NotImplementedError("sparse xla_call with donated_invars") sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, pe.ClosedJaxpr(call_jaxpr, ()), *spvalues) fun = lu.wrap_init(core.jaxpr_as_fun(sp_call_jaxpr)) args_flat, _ = tree_flatten(spvalues_to_arrays(spenv, spvalues)) donated_invars = tuple(False for arg in args_flat) out_flat = xla.xla_call_p.bind(fun, *args_flat, donated_invars=donated_invars, **params) return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat))
def eval_sparse( jaxpr: core.Jaxpr, consts: Sequence[Array], # all consts are dense spvalues: Sequence[SparsifyValue], # mix of sparse and dense pointers into spenv spenv: SparsifyEnv, ) -> Sequence[SparsifyValue]: env : Dict[core.Var, SparsifyValue] = {} def read(var: core.Var) -> Union[Array, SparsifyValue]: # all literals are dense if isinstance(var, core.Literal): return spenv.dense(var.val) else: return env[var] def write_buffer(var: core.Var, a: Array) -> None: if isinstance(var, core.DropVar): return env[var] = spenv.dense(a) def write(var: core.Var, a: SparsifyValue) -> None: if isinstance(var, core.DropVar): return assert a is not None env[var] = a safe_map(write_buffer, jaxpr.constvars, consts) safe_map(write, jaxpr.invars, spvalues) for eqn in jaxpr.eqns: prim = eqn.primitive invals = safe_map(read, eqn.invars) if any(val.is_sparse() for val in invals): if prim not in sparse_rules: _raise_unimplemented_primitive(prim) out = sparse_rules[prim](spenv, *invals, **eqn.params) else: if prim is xla.xla_call_p: # TODO(vanderplas,frostig): workaround for binding call primitives # within a jaxpr interpreter params = eqn.params.copy() fun = lu.wrap_init(core.jaxpr_as_fun(pe.ClosedJaxpr(params.pop('call_jaxpr'), ()))) out_bufs = prim.bind(fun, *(spenv.data(val) for val in invals), **params) else: out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params) out_bufs = out_bufs if prim.multiple_results else [out_bufs] out = [] for buf, outvar in safe_zip(out_bufs, eqn.outvars): if isinstance(outvar, core.DropVar): out.append(None) else: out.append(spenv.dense(buf)) safe_map(write, eqn.outvars, out) return safe_map(read, jaxpr.outvars)