Ejemplo n.º 1
0
    def test_should_not_inline_calls_without_variables(self):
        def inline_call(x):
            x = call(lambda x: x + 1)(x)
            x = variable(x, name='x')
            return call(lambda x: x + 1)(x)

        init, apply = unzip_variable(inline_call)(1.)
        self.assertDictEqual(init(1.), {'x': 2.})
        init_jaxpr = trace_util.stage(init)(1.)[0]
        self.assertIn(call_p, {eqn.primitive for eqn in init_jaxpr.jaxpr.eqns})
        apply_jaxpr = trace_util.stage(apply)(init(1.))[0]
        self.assertIn(call_p,
                      {eqn.primitive
                       for eqn in apply_jaxpr.jaxpr.eqns})
Ejemplo n.º 2
0
 def wrapped(*args, **kwargs):
     """Function wrapper that takes in inverse arguments."""
     forward_args = trace_args if len(trace_args) else args
     jaxpr, (in_tree, _) = trace_util.stage(f, dynamic=False)(*forward_args,
                                                              **kwargs)
     flat_forward_args, _ = tree_util.tree_flatten(forward_args)
     flat_args, _ = tree_util.tree_flatten(args)
     flat_constcells = safe_map(InverseAndILDJ.new, jaxpr.literals)
     flat_forward_avals = [
         trace_util.get_shaped_aval(arg) for arg in flat_forward_args
     ]
     flat_incells = [
         InverseAndILDJ.unknown(aval) for aval in flat_forward_avals
     ]
     flat_outcells = safe_map(InverseAndILDJ.new, flat_args)
     env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr,
                               flat_constcells, flat_incells, flat_outcells)
     flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars]
     if any(not flat_incell.top() for flat_incell in flat_incells):
         raise ValueError('Cannot invert function.')
     flat_vals, flat_ildjs = jax_util.unzip2([
         (flat_incell.val, flat_incell.ildj) for flat_incell in flat_incells
     ])
     vals = tree_util.tree_unflatten(in_tree, flat_vals)
     if reduce_ildj:
         ildj_ = sum(np.sum(i) for i in flat_ildjs)
     else:
         ildj_ = tree_util.tree_unflatten(in_tree, flat_ildjs)
     if len(forward_args) == 1:
         vals = vals[0]
         ildj_ = ildj_ if reduce_ildj else ildj_[0]
     return vals, ildj_
 def wrapped(*args, **kwargs):
     closed_jaxpr, (in_tree,
                    out_tree) = trace_util.stage(f, dynamic=False)(*args,
                                                                   **kwargs)
     jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.literals
     expressions = jaxpr_to_expressions(jaxpr)
     return (BoundExpression(expressions,
                             dict(zip(map(str, jaxpr.constvars), consts))),
             tuple(map(str, jaxpr.invars)), (in_tree, out_tree))
Ejemplo n.º 4
0
 def wrapped(*args, **kwargs):
   """Runs a function and binds it to a call primitive."""
   jaxpr, (in_tree, out_tree) = trace_util.stage(f, dynamic=True)(
       *args, **kwargs)
   flat_args = tree_util.tree_leaves(args)
   outs = prim.bind(*it.chain(jaxpr.literals, flat_args),
                    jaxpr=jaxpr.jaxpr, in_tree=in_tree, out_tree=out_tree,
                    num_consts=len(jaxpr.literals), **params)
   return tree_util.tree_unflatten(out_tree, outs)
Ejemplo n.º 5
0
 def forward(state: Value, *args: Value, **kwargs) -> Tuple[Value, Value]:
   # First, trace the function into a JAXpr.
   closed_jaxpr, (_, out_tree) = trace_util.stage(
       f, dynamic=True)(*args, **kwargs)
   flat_args = tree_util.tree_leaves(args)
   # Interpret the JAXpr according to `handlers`.
   out, state = eval_jaxpr_with_state(closed_jaxpr.jaxpr, handlers,
                                      closed_jaxpr.literals, state,
                                      *flat_args)
   return tree_util.tree_unflatten(out_tree, out), state
Ejemplo n.º 6
0
    def test_should_propagate_accumulated_values_in_one_op_function(self):
        def f(x):
            return np.exp(x)

        jaxpr, _ = trace_util.stage(f)(2.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env = propagate(ILDJ, ildj_rules, jaxpr, list(map(ILDJ.new, consts)),
                        [unknown] * len(jaxpr.invars),
                        list(map(ILDJ.new, (2., ))))
        inval = env[jaxpr.invars[0]]
        self.assertEqual(inval.ildj, -np.log(2.))
Ejemplo n.º 7
0
    def test_should_propagate_accumulated_values_in_chain_function(self):
        def f(x):
            return np.exp(x) + 2.

        jaxpr, _ = trace_util.stage(f)(4.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env = propagate(ILDJ, ildj_rules, jaxpr, list(map(ILDJ.new, consts)),
                        [Inverse.unknown(var.aval) for var in jaxpr.invars],
                        list(map(ILDJ.new, (4., ))))
        inval = env[jaxpr.invars[0]]
        self.assertEqual(inval.ildj, -np.log(2.))
Ejemplo n.º 8
0
    def test_should_propagate_to_invars_for_one_op_function(self):
        def f(x):
            return np.exp(x)

        jaxpr, _ = trace_util.stage(f)(1.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env = propagate(Inverse, inverse_rules, jaxpr,
                        list(map(Inverse.new, consts)),
                        [unknown] * len(jaxpr.invars),
                        list(map(Inverse.new, (1., ))))
        inval = env[jaxpr.invars[0]]
        self.assertEqual(inval.val, 0.)
Ejemplo n.º 9
0
    def test_should_propagate_forward_and_backward(self):
        def f(x, y):
            return x + 1., np.exp(x + 1.) + y

        jaxpr, _ = trace_util.stage(f)(0., 2.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env, _ = propagate(Inverse, inverse_rules, jaxpr,
                           list(map(Inverse.new, consts)),
                           [Inverse.unknown(var.aval) for var in jaxpr.invars],
                           list(map(Inverse.new, (0., 2.))))
        invals = [env[invar].val for invar in jaxpr.invars]
        onp.testing.assert_allclose(invals, (-1., 1.))
Ejemplo n.º 10
0
    def test_propagation_should_not_reach_invars(self):
        def f(x):
            del x
            return 2.

        jaxpr, _ = trace_util.stage(f)(1.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env, _ = propagate(Inverse, inverse_rules, jaxpr,
                           list(map(Inverse.new, consts)),
                           [Inverse.unknown(var.aval) for var in jaxpr.invars],
                           list(map(Inverse.new, (1., ))))
        self.assertTrue(env.read(jaxpr.invars[0]).bottom())
Ejemplo n.º 11
0
    def test_propagate_through_jit(self):
        def f(x):
            return jax.jit(np.exp)(x) + 2.

        jaxpr, _ = trace_util.stage(f)(3.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env, _ = propagate(Inverse, inverse_rules, jaxpr,
                           list(map(Inverse.new, consts)),
                           [Inverse.unknown(var.aval) for var in jaxpr.invars],
                           list(map(Inverse.new, (3., ))))
        inval = env[jaxpr.invars[0]]
        self.assertEqual(inval.val, 0.)
Ejemplo n.º 12
0
    def test_should_propagate_to_invars_for_chain_function(self):
        def f(x):
            return 2. + np.exp(x)

        jaxpr, _ = trace_util.stage(f)(3.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env, _ = propagate(Inverse, inverse_rules, jaxpr,
                           list(map(Inverse.new, consts)),
                           [Inverse.unknown(var.aval) for var in jaxpr.invars],
                           list(map(Inverse.new, (3., ))))
        inval = env[jaxpr.invars[0]]
        self.assertEqual(inval.val, 0.)
Ejemplo n.º 13
0
    def test_correct_inverse_for_identity_function(self):
        def f(x):
            return x

        jaxpr, _ = trace_util.stage(f)(1.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env, _ = propagate(Inverse, inverse_rules, jaxpr,
                           list(map(Inverse.new, consts)),
                           [Inverse.unknown(var.aval) for var in jaxpr.invars],
                           list(map(Inverse.new, (1., ))))
        inval = env[jaxpr.invars[0]]
        self.assertEqual(inval.val, 1.)
Ejemplo n.º 14
0
    def test_propagation_should_not_reach_invars(self):
        def f(x):
            del x
            return 2.

        jaxpr, _ = trace_util.stage(f)(1.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env = propagate(Inverse, inverse_rules, jaxpr,
                        list(map(Inverse.new, consts)),
                        [unknown] * len(jaxpr.invars),
                        list(map(Inverse.new, (1., ))))
        self.assertNotIn(jaxpr.invars[0], env)
Ejemplo n.º 15
0
    def test_propagate_through_jit(self):
        def f(x):
            return jax.jit(np.exp)(x) + 2.

        jaxpr, _ = trace_util.stage(f)(3.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env = propagate(Inverse, inverse_rules, jaxpr,
                        list(map(Inverse.new, consts)),
                        [unknown] * len(jaxpr.invars),
                        list(map(Inverse.new, (3., ))))
        inval = env[jaxpr.invars[0]]
        self.assertEqual(inval.val, 0.)
        self.assertLen(env.subenvs, 1)
Ejemplo n.º 16
0
 def wrapped(sample, *args, **kwargs):
   """Function wrapper that takes in log_prob arguments."""
   # Trace the function using a random seed
   dummy_seed = random.PRNGKey(0)
   jaxpr, _ = trace_util.stage(f, dynamic=False)(dummy_seed, *args, **kwargs)
   flat_outargs, _ = tree_util.tree_flatten(sample)
   flat_inargs, _ = tree_util.tree_flatten(args)
   constcells = [InverseAndILDJ.new(val) for val in jaxpr.literals]
   flat_incells = [
       InverseAndILDJ.unknown(trace_util.get_shaped_aval(dummy_seed))
   ] + [InverseAndILDJ.new(val) for val in flat_inargs]
   flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs]
   return log_prob_jaxpr(jaxpr.jaxpr, constcells, flat_incells, flat_outcells)
Ejemplo n.º 17
0
    def wrapped(init_key: Key, *args, **kwargs) -> FunctionModule:
        has_init_key = kwargs_util.check_in_kwargs(f, init_keyword)
        if not has_init_key:

            def init_f(init_key, *args, **kwargs):
                del init_key, args, kwargs
                return {}

            def cau_f(variables, *args, **kwargs):
                return f(*args, **kwargs), variables
        else:
            with unzip.new_custom_rules(custom_unzip_rules):

                def fun(init_key, *args, **kwargs):
                    kwargs = {**kwargs, init_keyword: init_key}
                    return f(*args, **kwargs)

                init_f, apply_f = unzip.unzip(fun,
                                              tag=module.VARIABLE)(init_key,
                                                                   *args,
                                                                   **kwargs)
            cau_f = functools.partial(
                harvest.harvest(apply_f, tag=module.ASSIGN), {})
        if name is not None:
            init_f = harvest.nest(init_f, scope=name)
            cau_f = harvest.nest(cau_f, scope=name)
        variables = init_f(init_key)
        cau_jaxpr, (in_tree,
                    out_tree) = trace_util.stage(cau_f,
                                                 dynamic=True)(variables,
                                                               *args, **kwargs)
        if name is None:
            variables = {
                k: module.variable(val, name=k, key=init_key)
                for k, val in variables.items()
            }
            return FunctionModule(variables,
                                  cau_jaxpr,
                                  in_tree,
                                  out_tree,
                                  name=name)
        else:
            return module.variable(FunctionModule(variables,
                                                  cau_jaxpr,
                                                  in_tree,
                                                  out_tree,
                                                  name=name),
                                   name=name,
                                   key=init_key)
Ejemplo n.º 18
0
    def test_propagate_should_accumulate_state(self):
        def f(x):
            return np.exp(x) + 2.

        jaxpr, _ = trace_util.stage(f)(4.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        _, state = propagate(
            ILDJ,
            ildj_rules,
            jaxpr,
            list(map(ILDJ.new, consts)),
            [Inverse.unknown(var.aval) for var in jaxpr.invars],
            list(map(ILDJ.new, (4., ))),
            reducer=lambda env, eqn, count, next: count + next,
            initial_state=0)
        self.assertEqual(state, 2)
Ejemplo n.º 19
0
 def wrapped(sample, *args, **kwargs):
     """Function wrapper that takes in log_prob arguments."""
     # Trace the function using a random seed
     jaxpr, _ = trace_util.stage(f)(random.PRNGKey(0), *args, **kwargs)
     flat_outargs, _ = tree_util.tree_flatten(sample)
     flat_inargs, _ = tree_util.tree_flatten(args)
     constcells = [InverseAndILDJ.new(val) for val in jaxpr.literals]
     flat_incells = [unknown
                     ] + [InverseAndILDJ.new(val) for val in flat_inargs]
     flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs]
     # Re-use the InverseAndILDJ propagation but silently fail instead of
     # erroring when we hit a primitive we can't invert.
     env = propagate.propagate(InverseAndILDJ, log_prob_rules, jaxpr.jaxpr,
                               constcells, flat_incells, flat_outcells)
     # Traverse the resulting environment, looking for primitives that have
     # registered log_probs.
     final_log_prob = _accumulate_log_probs(env)
     return final_log_prob
def random_variable_batching_rule(args, dims, *, num_consts, batch_ndims,
                                  jaxpr, **params):
  """Batching (vmap) rule for the `random_variable` primitive."""
  old_consts = args[:num_consts]
  args, dims = args[num_consts:], dims[num_consts:]
  def _run(*args):
    return random_variable_p.impl(*it.chain(old_consts, args),
                                  num_consts=len(old_consts),
                                  jaxpr=jaxpr,
                                  batch_ndims=batch_ndims,
                                  **params)
  run = jax.vmap(_run, in_axes=dims, out_axes=0)
  closed_jaxpr, _ = trace_util.stage(run, dynamic=True)(*args)
  new_jaxpr, new_consts = closed_jaxpr.jaxpr, closed_jaxpr.literals
  result = random_variable_p.bind(*it.chain(new_consts, args),
                                  num_consts=len(new_consts),
                                  jaxpr=new_jaxpr,
                                  batch_ndims=batch_ndims + 1,
                                  **params)
  return result, (0,) * len(result)
Ejemplo n.º 21
0
 def wrapped(*args, **kwargs):
   """Function wrapper that takes in inverse arguments."""
   forward_args = trace_args if len(trace_args) else args
   jaxpr, (in_tree, _) = trace_util.stage(f)(*forward_args, **kwargs)
   flat_forward_args, _ = tree_util.tree_flatten(forward_args)
   flat_args, _ = tree_util.tree_flatten(args)
   flat_constcells = safe_map(InverseAndILDJ.new, jaxpr.literals)
   flat_incells = [unknown] * len(flat_forward_args)
   flat_outcells = safe_map(InverseAndILDJ.new, flat_args)
   env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr,
                             flat_constcells, flat_incells, flat_outcells)
   flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars]
   if any(flat_incell.is_unknown() for flat_incell in flat_incells):
     raise ValueError('Cannot invert function.')
   flat_cells, flat_ildjs = jax_util.unzip2([
       (flat_incell.val, flat_incell.ildj) for flat_incell in flat_incells
   ])
   vals = tree_util.tree_unflatten(in_tree, flat_cells)
   ildjs = tree_util.tree_unflatten(in_tree, flat_ildjs)
   if len(trace_args) == 1:
     vals, ildjs = vals[0], ildjs[0]
   return vals, ildjs
Ejemplo n.º 22
0
  def wrapped(init_key: Key, *args, **kwargs) -> FunctionModule:
    has_init_key = kwargs_util.check_in_kwargs(f, init_keyword)
    if not has_init_key:

      def init_f(init_key, *args, **kwargs):
        del init_key, args, kwargs
        return {}

      def cau_f(variables, *args, **kwargs):
        return f(*args, **kwargs), variables
    else:

      def f_(init_key, *args, **kwargs):
        return f(*args, **kwargs, init_key=init_key)

      def init_f(init_key, *args, **kwargs):
        return harvest.reap(
            f_, tag=module.VARIABLE, exclusive=True)(init_key, *args, **kwargs)

      def apply_f(variables, *args, **kwargs):
        return harvest.plant(
            f_, tag=module.VARIABLE)(variables, random.PRNGKey(0), *args,
                                     **kwargs)

      cau_f = functools.partial(harvest.harvest(apply_f, tag=module.ASSIGN), {})
    variables = init_f(init_key, *args, **kwargs)
    cau_jaxpr, (in_tree, out_tree) = trace_util.stage(
        cau_f, dynamic=True)(variables, *args, **kwargs)
    if name is None:
      variables = {
          k: module.variable(val, name=k, key=init_key)
          for k, val in variables.items()
      }
      return FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name)
    else:
      mod = FunctionModule(variables, cau_jaxpr, in_tree, out_tree, name=name)
      if variables:
        return module.variable(mod, name=name, key=init_key)
      return mod