Exemple #1
0
def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
  # TODO(jakevdp): currently this approach discards all information about
  #   shared data & indices when generating the sparsified jaxpr. The
  #   current approach produces valid sparsified while loops, but they
  #   don't work in corner cases (see associated TODO in sparsify_test.py)
  out_tree = None

  @lu.wrap_init
  def wrapped(*args_flat):
    # TODO(frostig,jakevdp): This closes over `spenv`, which can bring
    # in buffers from the "outer scope" as constants. Is this a
    # problem for primitives like cond and while_loop, which always
    # convert constvars to invars when staging out their subjaxprs?
    nonlocal out_tree
    args = tree_unflatten(in_tree, args_flat)
    spvalues = arrays_to_spvalues(spenv, args)
    result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, spvalues, spenv)
    out = spvalues_to_arrays(spenv, result)
    out_flat, out_tree = tree_flatten(out)
    return out_flat

  args = spvalues_to_arrays(spenv, spvalues)
  args_flat, in_tree = tree_flatten(args)
  avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
  sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
  sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts)
  return sp_jaxpr, out_tree
Exemple #2
0
def _sparsify_jaxpr(spenv, jaxpr, *argspecs):
    # TODO(jakevdp): currently this approach discards all information about
    #   shared data & indices when generating the sparsified jaxpr. The
    #   current approach produces valid sparsified while loops, but they
    #   don't work in corner cases (see associated TODO in sparsify_test.py)
    out_tree = None

    @lu.wrap_init
    def wrapped(*args_flat):
        nonlocal out_tree
        args = tree_unflatten(in_tree, args_flat)
        argspecs = arrays_to_argspecs(spenv, args)
        result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, argspecs, spenv)
        out = argspecs_to_arrays(spenv, result)
        out_flat, out_tree = tree_flatten(out)
        return out_flat

    args = argspecs_to_arrays(spenv, argspecs)
    args_flat, in_tree = tree_flatten(args)
    avals_flat = [
        core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat
    ]
    sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
    sp_jaxpr = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(sp_jaxpr), consts)
    return sp_jaxpr, out_tree
Exemple #3
0
def eval_sparse(
    jaxpr: core.Jaxpr,
    consts: Sequence[Array],  # all consts are dense
    argspecs: Sequence[ArgSpec],  # mix of sparse and dense pointers into spenv
    spenv: SparseEnv,
) -> Sequence[ArgSpec]:
    env: Dict[core.Var, ArgSpec] = {}

    def read(var: core.Var) -> Union[Array, ArgSpec]:
        # all literals are dense
        if isinstance(var, core.Literal):
            return ArgSpec(np.shape(var.val), spenv.push(var.val), None)
        else:
            return env[var]

    def write_buffer(var: core.Var, a: Array) -> None:
        if var is core.dropvar:
            return
        env[var] = ArgSpec(a.shape, spenv.push(a), None)

    def write(var: core.Var, a: ArgSpec) -> None:
        if var is core.dropvar:
            return
        env[var] = a

    # TODO: handle unitvar at all?
    #write_buffer(core.unitvar, core.unit)
    safe_map(write_buffer, jaxpr.constvars, consts)
    safe_map(write, jaxpr.invars, argspecs)

    for eqn in jaxpr.eqns:
        prim = eqn.primitive
        invals = safe_map(read, eqn.invars)

        if any(val.is_sparse() for val in invals):
            if prim not in sparse_rules:
                raise NotImplementedError(f"sparse rule for {prim}")
            out = sparse_rules[prim](spenv, *invals, **eqn.params)
        else:
            if prim is xla.xla_call_p:
                # TODO(vanderplas,frostig): workaround for binding call primitives
                # within a jaxpr interpreter
                params = eqn.params.copy()
                fun = lu.wrap_init(
                    core.jaxpr_as_fun(
                        pe.ClosedJaxpr(params.pop('call_jaxpr'), ())))
                out_bufs = prim.bind(fun, *(val.data(spenv) for val in invals),
                                     **params)
            else:
                out_bufs = prim.bind(*(val.data(spenv) for val in invals),
                                     **eqn.params)
            out_bufs = out_bufs if prim.multiple_results else [out_bufs]
            out = []
            for buf in out_bufs:
                out.append(ArgSpec(buf.shape, spenv.push(buf), None))
        safe_map(write, eqn.outvars, out)

    return safe_map(read, jaxpr.outvars)
Exemple #4
0
def _xla_call_sparse(spenv, *spvalues, call_jaxpr, donated_invars, **params):
  if any(donated_invars):
    raise NotImplementedError("sparse xla_call with donated_invars")
  sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, pe.ClosedJaxpr(call_jaxpr, ()), *spvalues)
  fun = lu.wrap_init(core.jaxpr_as_fun(sp_call_jaxpr))
  args_flat, _ = tree_flatten(spvalues_to_arrays(spenv, spvalues))
  donated_invars = tuple(False for arg in args_flat)
  out_flat = xla.xla_call_p.bind(fun, *args_flat, donated_invars=donated_invars, **params)
  return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat))
Exemple #5
0
def eval_sparse(
    jaxpr: core.Jaxpr,
    consts: Sequence[Array],  # all consts are dense
    spvalues: Sequence[SparsifyValue],  # mix of sparse and dense pointers into spenv
    spenv: SparsifyEnv,
) -> Sequence[SparsifyValue]:
  env : Dict[core.Var, SparsifyValue] = {}

  def read(var: core.Var) -> Union[Array, SparsifyValue]:
    # all literals are dense
    if isinstance(var, core.Literal):
      return spenv.dense(var.val)
    else:
      return env[var]

  def write_buffer(var: core.Var, a: Array) -> None:
    if isinstance(var, core.DropVar):
      return
    env[var] = spenv.dense(a)

  def write(var: core.Var, a: SparsifyValue) -> None:
    if isinstance(var, core.DropVar):
      return
    assert a is not None
    env[var] = a

  safe_map(write_buffer, jaxpr.constvars, consts)
  safe_map(write, jaxpr.invars, spvalues)

  for eqn in jaxpr.eqns:
    prim = eqn.primitive
    invals = safe_map(read, eqn.invars)

    if any(val.is_sparse() for val in invals):
      if prim not in sparse_rules:
        _raise_unimplemented_primitive(prim)
      out = sparse_rules[prim](spenv, *invals, **eqn.params)
    else:
      if prim is xla.xla_call_p:
        # TODO(vanderplas,frostig): workaround for binding call primitives
        # within a jaxpr interpreter
        params = eqn.params.copy()
        fun = lu.wrap_init(core.jaxpr_as_fun(pe.ClosedJaxpr(params.pop('call_jaxpr'), ())))
        out_bufs = prim.bind(fun, *(spenv.data(val) for val in invals), **params)
      else:
        out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params)
      out_bufs = out_bufs if prim.multiple_results else [out_bufs]
      out = []
      for buf, outvar in safe_zip(out_bufs, eqn.outvars):
        if isinstance(outvar, core.DropVar):
          out.append(None)
        else:
          out.append(spenv.dense(buf))
    safe_map(write, eqn.outvars, out)

  return safe_map(read, jaxpr.outvars)