예제 #1
0
def initial_ildj(incells, outcells, *, jaxpr, num_consts, **_):
    const_cells, incells = jax_util.split_list(incells, [num_consts])
    env, state = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr,
                                     const_cells, incells, outcells)  # pytype: disable=wrong-arg-types
    new_incells = [env.read(invar) for invar in jaxpr.invars]
    new_outcells = [env.read(outvar) for outvar in jaxpr.outvars]
    return const_cells + new_incells, new_outcells, state
예제 #2
0
def initial_ildj(incells, outcells, *, jaxpr, num_consts, **_):
    const_cells, incells = jax_util.split_list(incells, [num_consts])
    env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr,
                              const_cells, incells, outcells)
    new_incells = [env.read(invar) for invar in jaxpr.invars]
    new_outcells = [env.read(outvar) for outvar in jaxpr.outvars]
    return const_cells + new_incells, new_outcells, None
예제 #3
0
def log_prob_jaxpr(jaxpr, constcells, flat_incells, flat_outcells):
    """Runs log_prob propagation on a Jaxpr."""
    def reducer(env, eqn, curr_log_prob, new_log_prob):
        if (isinstance(curr_log_prob, FailedLogProb)
                or isinstance(new_log_prob, FailedLogProb)):
            # If `curr_log_prob` is `None` that means we were unable to compute
            # a log_prob elsewhere, so the propagate failed.
            return failed_log_prob
        if eqn.primitive in log_prob_registry and new_log_prob is None:
            # We are unable to compute a log_prob for this primitive.
            return failed_log_prob
        if new_log_prob is not None:
            cells = [env.read(var) for var in eqn.outvars]
            ildjs = sum([cell.ildj.sum() for cell in cells if cell.top()])
            return curr_log_prob + new_log_prob + ildjs
        return curr_log_prob

    # Re-use the InverseAndILDJ propagation but silently fail instead of
    # erroring when we hit a primitive we can't invert. We accumulate the log
    # probability values using the propagater state.
    _, final_log_prob = propagate.propagate(InverseAndILDJ,
                                            log_prob_rules,
                                            jaxpr,
                                            constcells,
                                            flat_incells,
                                            flat_outcells,
                                            reducer=reducer,
                                            initial_state=0.)
    if final_log_prob is failed_log_prob:
        raise ValueError('Cannot compute log_prob of function.')
    return final_log_prob
예제 #4
0
 def wrapped(*args, **kwargs):
     """Function wrapper that takes in inverse arguments."""
     forward_args = trace_args if len(trace_args) else args
     jaxpr, (in_tree, _) = trace_util.stage(f, dynamic=False)(*forward_args,
                                                              **kwargs)
     flat_forward_args, _ = tree_util.tree_flatten(forward_args)
     flat_args, _ = tree_util.tree_flatten(args)
     flat_constcells = safe_map(InverseAndILDJ.new, jaxpr.literals)
     flat_forward_avals = [
         trace_util.get_shaped_aval(arg) for arg in flat_forward_args
     ]
     flat_incells = [
         InverseAndILDJ.unknown(aval) for aval in flat_forward_avals
     ]
     flat_outcells = safe_map(InverseAndILDJ.new, flat_args)
     env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr,
                               flat_constcells, flat_incells, flat_outcells)
     flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars]
     if any(not flat_incell.top() for flat_incell in flat_incells):
         raise ValueError('Cannot invert function.')
     flat_vals, flat_ildjs = jax_util.unzip2([
         (flat_incell.val, flat_incell.ildj) for flat_incell in flat_incells
     ])
     vals = tree_util.tree_unflatten(in_tree, flat_vals)
     if reduce_ildj:
         ildj_ = sum(np.sum(i) for i in flat_ildjs)
     else:
         ildj_ = tree_util.tree_unflatten(in_tree, flat_ildjs)
     if len(forward_args) == 1:
         vals = vals[0]
         ildj_ = ildj_ if reduce_ildj else ildj_[0]
     return vals, ildj_
예제 #5
0
    def test_should_propagate_accumulated_values_in_chain_function(self):
        def f(x):
            return np.exp(x) + 2.

        jaxpr, _ = trace_util.stage(f)(4.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env = propagate(ILDJ, ildj_rules, jaxpr, list(map(ILDJ.new, consts)),
                        [Inverse.unknown(var.aval) for var in jaxpr.invars],
                        list(map(ILDJ.new, (4., ))))
        inval = env[jaxpr.invars[0]]
        self.assertEqual(inval.ildj, -np.log(2.))
예제 #6
0
    def test_should_propagate_accumulated_values_in_one_op_function(self):
        def f(x):
            return np.exp(x)

        jaxpr, _ = trace_util.stage(f)(2.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env = propagate(ILDJ, ildj_rules, jaxpr, list(map(ILDJ.new, consts)),
                        [unknown] * len(jaxpr.invars),
                        list(map(ILDJ.new, (2., ))))
        inval = env[jaxpr.invars[0]]
        self.assertEqual(inval.ildj, -np.log(2.))
예제 #7
0
    def test_should_propagate_to_invars_for_one_op_function(self):
        def f(x):
            return np.exp(x)

        jaxpr, _ = trace_util.stage(f)(1.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env = propagate(Inverse, inverse_rules, jaxpr,
                        list(map(Inverse.new, consts)),
                        [unknown] * len(jaxpr.invars),
                        list(map(Inverse.new, (1., ))))
        inval = env[jaxpr.invars[0]]
        self.assertEqual(inval.val, 0.)
예제 #8
0
    def test_should_propagate_forward_and_backward(self):
        def f(x, y):
            return x + 1., np.exp(x + 1.) + y

        jaxpr, _ = trace_util.stage(f)(0., 2.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env, _ = propagate(Inverse, inverse_rules, jaxpr,
                           list(map(Inverse.new, consts)),
                           [Inverse.unknown(var.aval) for var in jaxpr.invars],
                           list(map(Inverse.new, (0., 2.))))
        invals = [env[invar].val for invar in jaxpr.invars]
        onp.testing.assert_allclose(invals, (-1., 1.))
예제 #9
0
    def test_propagation_should_not_reach_invars(self):
        def f(x):
            del x
            return 2.

        jaxpr, _ = trace_util.stage(f)(1.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env, _ = propagate(Inverse, inverse_rules, jaxpr,
                           list(map(Inverse.new, consts)),
                           [Inverse.unknown(var.aval) for var in jaxpr.invars],
                           list(map(Inverse.new, (1., ))))
        self.assertTrue(env.read(jaxpr.invars[0]).bottom())
예제 #10
0
    def test_propagate_through_jit(self):
        def f(x):
            return jax.jit(np.exp)(x) + 2.

        jaxpr, _ = trace_util.stage(f)(3.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env, _ = propagate(Inverse, inverse_rules, jaxpr,
                           list(map(Inverse.new, consts)),
                           [Inverse.unknown(var.aval) for var in jaxpr.invars],
                           list(map(Inverse.new, (3., ))))
        inval = env[jaxpr.invars[0]]
        self.assertEqual(inval.val, 0.)
예제 #11
0
    def test_should_propagate_to_invars_for_chain_function(self):
        def f(x):
            return 2. + np.exp(x)

        jaxpr, _ = trace_util.stage(f)(3.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env, _ = propagate(Inverse, inverse_rules, jaxpr,
                           list(map(Inverse.new, consts)),
                           [Inverse.unknown(var.aval) for var in jaxpr.invars],
                           list(map(Inverse.new, (3., ))))
        inval = env[jaxpr.invars[0]]
        self.assertEqual(inval.val, 0.)
예제 #12
0
    def test_correct_inverse_for_identity_function(self):
        def f(x):
            return x

        jaxpr, _ = trace_util.stage(f)(1.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env, _ = propagate(Inverse, inverse_rules, jaxpr,
                           list(map(Inverse.new, consts)),
                           [Inverse.unknown(var.aval) for var in jaxpr.invars],
                           list(map(Inverse.new, (1., ))))
        inval = env[jaxpr.invars[0]]
        self.assertEqual(inval.val, 1.)
예제 #13
0
    def test_propagation_should_not_reach_invars(self):
        def f(x):
            del x
            return 2.

        jaxpr, _ = trace_util.stage(f)(1.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env = propagate(Inverse, inverse_rules, jaxpr,
                        list(map(Inverse.new, consts)),
                        [unknown] * len(jaxpr.invars),
                        list(map(Inverse.new, (1., ))))
        self.assertNotIn(jaxpr.invars[0], env)
예제 #14
0
    def test_propagate_through_jit(self):
        def f(x):
            return jax.jit(np.exp)(x) + 2.

        jaxpr, _ = trace_util.stage(f)(3.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        env = propagate(Inverse, inverse_rules, jaxpr,
                        list(map(Inverse.new, consts)),
                        [unknown] * len(jaxpr.invars),
                        list(map(Inverse.new, (3., ))))
        inval = env[jaxpr.invars[0]]
        self.assertEqual(inval.val, 0.)
        self.assertLen(env.subenvs, 1)
예제 #15
0
    def test_propagate_should_accumulate_state(self):
        def f(x):
            return np.exp(x) + 2.

        jaxpr, _ = trace_util.stage(f)(4.)
        jaxpr, consts = jaxpr.jaxpr, jaxpr.literals
        _, state = propagate(
            ILDJ,
            ildj_rules,
            jaxpr,
            list(map(ILDJ.new, consts)),
            [Inverse.unknown(var.aval) for var in jaxpr.invars],
            list(map(ILDJ.new, (4., ))),
            reducer=lambda env, eqn, count, next: count + next,
            initial_state=0)
        self.assertEqual(state, 2)
예제 #16
0
 def wrapped(sample, *args, **kwargs):
     """Function wrapper that takes in log_prob arguments."""
     # Trace the function using a random seed
     jaxpr, _ = trace_util.stage(f)(random.PRNGKey(0), *args, **kwargs)
     flat_outargs, _ = tree_util.tree_flatten(sample)
     flat_inargs, _ = tree_util.tree_flatten(args)
     constcells = [InverseAndILDJ.new(val) for val in jaxpr.literals]
     flat_incells = [unknown
                     ] + [InverseAndILDJ.new(val) for val in flat_inargs]
     flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs]
     # Re-use the InverseAndILDJ propagation but silently fail instead of
     # erroring when we hit a primitive we can't invert.
     env = propagate.propagate(InverseAndILDJ, log_prob_rules, jaxpr.jaxpr,
                               constcells, flat_incells, flat_outcells)
     # Traverse the resulting environment, looking for primitives that have
     # registered log_probs.
     final_log_prob = _accumulate_log_probs(env)
     return final_log_prob
예제 #17
0
파일: inverse.py 프로젝트: ic/probability
 def wrapped(*args, **kwargs):
   """Function wrapper that takes in inverse arguments."""
   forward_args = trace_args if len(trace_args) else args
   jaxpr, (in_tree, _) = trace_util.stage(f)(*forward_args, **kwargs)
   flat_forward_args, _ = tree_util.tree_flatten(forward_args)
   flat_args, _ = tree_util.tree_flatten(args)
   flat_constcells = safe_map(InverseAndILDJ.new, jaxpr.literals)
   flat_incells = [unknown] * len(flat_forward_args)
   flat_outcells = safe_map(InverseAndILDJ.new, flat_args)
   env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr,
                             flat_constcells, flat_incells, flat_outcells)
   flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars]
   if any(flat_incell.is_unknown() for flat_incell in flat_incells):
     raise ValueError('Cannot invert function.')
   flat_cells, flat_ildjs = jax_util.unzip2([
       (flat_incell.val, flat_incell.ildj) for flat_incell in flat_incells
   ])
   vals = tree_util.tree_unflatten(in_tree, flat_cells)
   ildjs = tree_util.tree_unflatten(in_tree, flat_ildjs)
   if len(trace_args) == 1:
     vals, ildjs = vals[0], ildjs[0]
   return vals, ildjs
예제 #18
0
파일: core.py 프로젝트: seanmb/probability
def initial_ildj(incells, outcells, *, jaxpr, **_):
  env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr, [], incells,
                            outcells)
  new_incells = [env.read(invar) for invar in jaxpr.invars]
  new_outcells = [env.read(outvar) for outvar in jaxpr.outvars]
  return new_incells, new_outcells, None