Exemplo n.º 1
0
 def test_var_ordering(self):
     newsym = core.gensym()
     a = newsym(core.ShapedArray((), np.dtype('int32')))
     b = newsym(core.ShapedArray((), np.dtype('int32')))
     c = newsym(core.ShapedArray((), np.dtype('int32')))
     for ordering in it.permutations([a, b, c]):
         assert sorted(list(ordering)) == [a, b, c]
Exemplo n.º 2
0
def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable], in_tree,
                                             in_avals, primitive_name: str):
    # When staging the branches of a conditional into jaxprs, constants are
    # extracted from each branch and converted to jaxpr arguments. To use the
    # staged jaxprs as the branches to a conditional *primitive*, we need for
    # their (input) signatures to match. This function "joins" the staged jaxprs:
    # for each one, it makes another that accepts *all* constants, but only uses
    # those that it needs (dropping the rest).

    jaxprs, all_consts, all_out_trees = \
        unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name)
               for fun in funs)

    newvar = core.gensym(jaxprs, suffix='_')
    all_const_avals = [map(_abstractify, consts) for consts in all_consts]
    unused_const_vars = [
        map(newvar, const_avals) for const_avals in all_const_avals
    ]

    def pad_jaxpr_constvars(i, jaxpr):
        prefix = util.concatenate(unused_const_vars[:i])
        suffix = util.concatenate(unused_const_vars[i + 1:])
        constvars = [*prefix, *jaxpr.constvars, *suffix]
        return jaxpr.replace(constvars=constvars)

    consts = util.concatenate(all_consts)
    jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
    closed_jaxprs = [
        core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
        for jaxpr in jaxprs
    ]
    return closed_jaxprs, consts, all_out_trees
Exemplo n.º 3
0
 def test_var_ordering(self):
     newsym = core.gensym()
     a = newsym(core.abstract_unit)
     b = newsym(core.abstract_unit)
     c = newsym(core.abstract_unit)
     for ordering in it.permutations([a, b, c]):
         assert sorted(list(ordering)) == [a, b, c]
Exemplo n.º 4
0
 def test_var_tree_flatten(self):
   newsym = core.gensym()
   a, b, c, d = (
       newsym(core.abstract_unit), newsym(core.abstract_unit),
       newsym(core.abstract_unit), newsym(core.abstract_unit))
   syms = {c: d, a: b}
   assert 'bd' == ''.join(map(str, tree_leaves(syms)))
Exemplo n.º 5
0
def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
                   has_output_token: bool) -> core.Jaxpr:
  """Rewrite a Jaxpr to thread the token, if needed."""
  assert has_input_token or not has_output_token

  if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr):
    return jaxpr

  mk_new_var = core.gensym([jaxpr])

  eqns: List[core.JaxprEqn] = []
  last_token_var = mk_new_var(core.abstract_token)  # store the incoming token
  if has_input_token:
    invars = jaxpr.invars + [last_token_var]
  else:
    invars = jaxpr.invars
    eqns.append(
        core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var],
                           lax.create_token_p, {}, source_info_util.current()))

  for eqn in jaxpr.eqns:
    if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params):
      eqns.append(eqn)
    else:
      output_token_var = mk_new_var(core.abstract_token)
      _rewrite_eqn(eqn, eqns, last_token_var, output_token_var, mk_new_var)
      last_token_var = output_token_var

  outvars = jaxpr.outvars + ([last_token_var] if has_output_token else [])
  new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns)
  return new_jaxpr
Exemplo n.º 6
0
 def test_jaxpr_undefined_eqn_invar(self):
   jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr
   cos = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cos')
   cos.invars[0] = core.gensym([jaxpr], suffix='_test')(cos.invars[0].aval)
   self.assertRaisesRegex(
       core.JaxprTypeError,
       r"Variable '.+_test' not defined\n\nin equation:",
       lambda: core.check_jaxpr(jaxpr))
Exemplo n.º 7
0
 def test_comparing_var(self):
     newsym = core.gensym('')
     a = newsym()
     b = newsym()
     c = newsym()
     assert a < b < c
     assert c > b > a
     assert a != b and b != c and a != c
Exemplo n.º 8
0
 def test_comparing_var(self):
     newsym = core.gensym()
     a = newsym(core.abstract_unit)
     b = newsym(core.abstract_unit)
     c = newsym(core.abstract_unit)
     assert a < b < c
     assert c > b > a
     assert a != b and b != c and a != c
Exemplo n.º 9
0
 def test_comparing_var(self):
     newsym = core.gensym()
     a = newsym(core.ShapedArray((), np.dtype('int32')))
     b = newsym(core.ShapedArray((), np.dtype('int32')))
     c = newsym(core.ShapedArray((), np.dtype('int32')))
     assert a < b < c
     assert c > b > a
     assert a != b and b != c and a != c
Exemplo n.º 10
0
def ignore_errors_jaxpr(jaxpr, error):
    """Constructs a jaxpr which takes two extra args but ignores them."""
    err_aval = core.raise_to_shaped(core.get_aval(error.err))
    code_aval = core.raise_to_shaped(core.get_aval(error.code))
    consts = jaxpr.consts
    jaxpr = jaxpr.jaxpr
    new_vars = core.gensym([jaxpr])
    new_invars = (new_vars(err_aval), new_vars(code_aval), *jaxpr.invars)
    new_jaxpr = core.Jaxpr(jaxpr.constvars, new_invars, jaxpr.outvars,
                           jaxpr.eqns)
    return core.ClosedJaxpr(new_jaxpr, consts)
Exemplo n.º 11
0
def _nonzero_typecheck_rule(invar):
  bound = invar.aval.shape[-1]
  bound = bound if isinstance(bound, int) else bound._bound
  newvar = core.gensym()
  out_dim_var = newvar(AbsArray(invar.aval.shape[:-1], BoundedIntTy(bound)))
  if len(invar.aval.shape) == 1:
    out_val_aval = AbsArray((out_dim_var,), BaseType(np.dtype('int32')))
  else:
    indices = tuple(range(len(out_dim_var.aval.shape)))  # pytype: disable=attribute-error
    expr = DimIndexingExpr(out_dim_var, indices)
    out_val_aval = AbsArray((*invar.aval.shape[:-1], expr),
                              BaseType(np.dtype('int32')))
  return out_dim_var, out_val_aval
Exemplo n.º 12
0
def refresh_names(jaxpr):
    vs = {}
    g = jc.gensym()
    varmap = lambda v: vs[v] if v in vs else vs.setdefault(v, g(v.aval))
    jaxpr_constvars = map(varmap, jaxpr.constvars)
    jaxpr_invars = map(varmap, jaxpr.invars)
    new_eqns = []
    for eqn in jaxpr.eqns:
        invars = [
            v if isinstance(v, Literal) else varmap(v) for v in eqn.invars
        ]
        outvars = map(varmap, eqn.outvars)
        new_eqns.append(
            JaxprEqn(invars, outvars, eqn.primitive, eqn.params,
                     eqn.source_info))
    jaxpr_outvars = map(varmap, jaxpr.outvars)
    return Jaxpr(jaxpr_constvars, jaxpr_invars, jaxpr_outvars, new_eqns)
Exemplo n.º 13
0
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
                                      res_aval_indices_per_jaxpr):
    newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_')
    all_res_vars = _map(newvar, all_res_avals)

    def augment_jaxpr(jaxpr, res_indices):
        num_res = len(res_indices)
        res_vars = jaxpr.jaxpr.invars[:num_res]
        non_res_vars = jaxpr.jaxpr.invars[num_res:]

        aug_res_vars = list(
            util.subvals(all_res_vars, zip(res_indices, res_vars)))
        aug_invars = aug_res_vars + non_res_vars
        jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
                               jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns,
                               jaxpr.jaxpr.effects)
        jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
        return jaxpr_aug

    return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
Exemplo n.º 14
0
def _append_invars(jaxpr, avals):
  newvar = core.gensym([jaxpr])
  return core.Jaxpr(jaxpr.constvars, jaxpr.invars + map(newvar, avals),
                    jaxpr.outvars, jaxpr.eqns)
Exemplo n.º 15
0
 def test_var_compared_by_identity(self):
     a1 = core.gensym()(core.abstract_unit)
     a2 = core.gensym()(core.abstract_unit)
     assert str(a1) == str(a2)
     assert a1 != a2
Exemplo n.º 16
0
 def test_var_compared_by_identity(self):
     a1 = core.gensym()(core.ShapedArray((), np.dtype('int32')))
     a2 = core.gensym()(core.ShapedArray((), np.dtype('int32')))
     assert str(a1) == str(a2)
     assert a1 != a2
Exemplo n.º 17
0
 def test_var_tree_flatten(self):
     newsym = core.gensym('')
     a, b, c, d = newsym(), newsym(), newsym(), newsym()
     syms = {c: d, a: b}
     assert 'bd' == ''.join(map(str, tree_leaves(syms)))
Exemplo n.º 18
0
 def test_var_compared_by_identity(self):
     a1 = core.gensym('')()
     a2 = core.gensym('')()
     assert str(a1) == str(a2)
     assert a1 != a2
Exemplo n.º 19
0
 def test_var_tree_flatten(self):
     newsym = core.gensym()
     aval = core.ShapedArray((), np.dtype('int32'))
     a, b, c, d = (newsym(aval), newsym(aval), newsym(aval), newsym(aval))
     syms = {c: d, a: b}
     assert 'bd' == ''.join(map(str, tree_leaves(syms)))
Exemplo n.º 20
0
def _tracers_to_jaxpr(in_tracers, out_tracers):
    """Constructs Jaxpr given tracers for inputs and outputs.

  Copied from jax.interpreters.partial_eval.tracers_to_jaxpr but modified to
  raise an VariableError when unknown in_tracers are found, rather than the
  default AssertionError.

  Args:
    in_tracers: the tracers that were created for the function inputs
    out_tracers: the tracers that were output by the function.

  Returns:
    a triple of a `Jaxpr`, a list of constant values corresponding to
    the `constvars` in the returned Jaxps, and a list of environment values.
    The vars for the environment values have been pre-pended to the Jaxpr's
    `invars`.

  Raises:
    VariableError: if an unknown input tracer is found
  """
    newvar = jax_core.gensym(None)
    t_to_var = {}

    def getvar(t):
        var = t_to_var.get(id(t))
        if var is None:
            var = newvar(t.pval.get_aval())
            t_to_var[id(t)] = var
        return var

    sorted_tracers = jax_util.toposort(out_tracers)
    invars = safe_map(getvar, in_tracers)
    eqns = []
    env = {}
    consts = {}
    const_to_var = {}

    def getconstvar(c):
        var = const_to_var.get(id(c))
        if var is None:
            var = newvar(jax_core.get_aval(c))
            const_to_var[id(c)] = var
        return var

    processed_eqn_ids = set()
    for t in sorted_tracers:
        recipe = t.recipe
        if isinstance(recipe, pe.JaxprEqnRecipe):
            if recipe.eqn_id not in processed_eqn_ids:
                eqns.append(pe.recipe_to_eqn(getvar, recipe))
                processed_eqn_ids.add(recipe.eqn_id)
        elif isinstance(recipe, pe.LambdaBinding):
            if not any(t is in_tracer for in_tracer in in_tracers):
                raise VariableError('Found unknown input tracer.')
            assert in_tracers, 'Lambda binding with no args'
        elif isinstance(recipe, pe.FreeVar):
            env[getvar(t)] = recipe.val
        elif isinstance(recipe, pe.ConstVar):
            v = t_to_var[id(t)] = getconstvar(recipe.val)
            consts[v] = recipe.val
        elif isinstance(recipe, jax_core.Literal):
            t_to_var[id(t)] = recipe
        elif recipe is jax_core.unit:
            t_to_var[id(t)] = jax_core.unitvar
        else:
            raise TypeError(recipe)

    env_vars, env_vals = jax_util.unzip2(env.items())
    const_vars, const_vals = jax_util.unzip2(consts.items())
    # The env_vars are pre-pended to the invars
    jaxpr = jax_core.Jaxpr(const_vars, list(it.chain(env_vars, invars)),
                           safe_map(getvar, out_tracers), eqns)
    return jaxpr, const_vals, env_vals