def testVMapShardingConstraint(self): f = pjit(lambda x: with_sharding_constraint(x, P('x')), in_axis_resources=P(), out_axis_resources=P('x')) x = jnp.arange(5*4).reshape((5, 4)) jaxpr = jax.make_jaxpr(jax.vmap(f))(x) pjit_eqn, = jaxpr.eqns constraint_eqn, = pjit_eqn.params['jaxpr'].eqns self.assertEqual(constraint_eqn.params['axis_resources'].partitions, ((), ('x',))) self.assertEqual(constraint_eqn.params['axis_resources'].sync, SpecSync.DIM_PERMUTE)
def test_check_jaxpr_cond_invalid(self): jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))( 1.).jaxpr cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond') cond.params['branches'][0].jaxpr.invars = () self.assertRaisesRegex( core.JaxprTypeError, 'cond branch 0 takes 0 inputs, branch 1 takes 1', lambda: core.check_jaxpr(jaxpr))
def test_nested_name_stack(self): @extend_name_stack('foo') def f(x): with extend_name_stack('bar'): return x + 1 jaxpr = jax.make_jaxpr(f)(2).jaxpr for eqn in jaxpr.eqns: self.assertEqual(str(eqn.source_info.name_stack), 'foo/bar')
def testEvalJaxpr(self): x, y = jnp.arange(4), jnp.arange(5) f = pjit(lambda x, y: x.sum() + jnp.sin(y), in_axis_resources=(P('x'), P('y')), out_axis_resources=P('y')) f_jaxpr = jax.make_jaxpr(f)(x, y) f_eval = jax.core.jaxpr_as_fun(f_jaxpr) r, = f_eval(x, y) self.assertAllClose(r, x.sum() + jnp.sin(y))
def testLowerToNothing(self): empty = Empty(AbstractEmpty()) jaxpr = make_jaxpr(jit(lambda e: e))(empty).jaxpr core.check_jaxpr(jaxpr) # cannot return a unit, because CompileAndCheck assumes array output. testfunc = lambda e: None args_maker = lambda: [empty] self._CompileAndCheck(testfunc, args_maker)
def test_vmap_should_transform_name_stack(self): @jax.vmap def f(x): with extend_name_stack('foo'): return x + 1 jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(foo)')
def _invertible_jaxpr_and_constants(fun): """Returns a transformation from function invocation to invertible jaxpr.""" jaxpr_maker = jax.make_jaxpr(fun) @jax.api.wraps(fun) # pylint: disable=no-value-for-parameter def jaxpr_const_maker(*args, **kwargs): typed_jaxpr = jaxpr_maker(*args, **kwargs) return typed_jaxpr.jaxpr, typed_jaxpr.literals return jaxpr_const_maker
def _check_numerical_stability(fn): """Logs a warning if numerically unstable operations are requested.""" jaxpr = jax.make_jaxpr(fn)(0.0).jaxpr for eqn in jaxpr.eqns: if eqn.primitive in _potentially_unstable_primitives: logging.warn("[Distrax]: the '%s' primitive can exhibit unstable " "numerical behavior under certain circumstances. Consider " "using the %s bijector instead if possible.", eqn.primitive, _potentially_unstable_primitives[eqn.primitive])
def test_jaxpr_undefined_eqn_invar(self): jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr cos = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cos') cos.invars[0] = core.gensym([jaxpr], suffix='_test')(cos.invars[0].aval) self.assertRaisesRegex( core.JaxprTypeError, r"Variable '.+_test' not defined\n\nin equation:", lambda: core.check_jaxpr(jaxpr))
def test_jaxpr_dropvar_from_cond(self): def f(x): _, y = lax.cond(x < 0., lambda x: (jnp.sin(x), x + 1.), lambda x: (jnp.cos(x), x + 2.), x) return y jaxpr = make_jaxpr(f)(1.).jaxpr assert jaxpr.eqns[-1].outvars[0] is core.dropvar core.check_jaxpr(jaxpr)
def test_jaxpr_dropvar_from_loop(self): def f(x): _, y = lax.while_loop(lambda s: s[0] < 0., lambda s: (jnp.sin(s[0]), jnp.cos(s[1])), (x, x)) return y + 1. jaxpr = make_jaxpr(f)(1.).jaxpr assert jaxpr.eqns[0].outvars[0] is core.dropvar core.check_jaxpr(jaxpr)
def test_xmap_inherits_effects(self): def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x f = maps.xmap(f, in_axes=['a'], out_axes=['a']) jaxpr = jax.make_jaxpr(f)(jnp.arange(jax.local_device_count())) self.assertSetEqual(jaxpr.effects, {"foo", "bar"})
def test_prune_jit_args(self): def f(*args): return args[0] closed_jaxpr = jax.make_jaxpr(f)(*range(10)) pruned_jaxpr, kept_const_idx, kept_var_idx = xla._prune_unused_inputs( closed_jaxpr.jaxpr) assert len(pruned_jaxpr.invars) == 1 assert kept_const_idx == set() assert kept_var_idx == {0}
def test_jaxpr_doesnt_include_trivial_operations(self): @partial(mask, in_shapes=['n'], out_shape='') def foo(x): return np.sum(x) padded_x = np.array([0, 1, 2, 3, 999, 999]) jaxpr = make_jaxpr(foo)([padded_x], dict(n=3)) self.assertNotIn('mul', str(jaxpr)) self.assertNotIn('add', str(jaxpr))
def test_readme_example(self): """Some of the examples from the README.""" def image_mask_jax(images, mask): # images: f32[B, W, W] and mask: f32[W, W] return images * mask print(jax.make_jaxpr(image_mask_jax)(np.ones((1024, 28, 28)), np.ones((28, 28)))) # will invoke broadcast_in_dim with shape=(1, w, w) jax2tf.convert(image_mask_jax, polymorphic_shapes=["(b, w, w)", "(w, w)"])
def test_different_effects_in_jaxpr(self): def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x + 1. jaxpr = jax.make_jaxpr(f)(2.) self.assertEqual({'foo'}, jaxpr.jaxpr.eqns[0].effects) self.assertEqual({'bar'}, jaxpr.jaxpr.eqns[1].effects) self.assertEqual({'foo', 'bar'}, jaxpr.effects)
def test_primitives_by_source(self): def f(x, y): s = jnp.sin(x) return jnp.sin(s) + jnp.cos(y) hist = jaxpr_util.primitives_by_source(make_jaxpr(f)(1., 1.).jaxpr) sin_keys = [k for k in hist.keys() if k.startswith('sin @ ')] self.assertEqual(len(sin_keys), 2) self.assertTrue(all(count == 1 for count in hist.values()))
def test_check_jaxpr_scan_correct(self): def f(c, x): b = jnp.cos(jnp.sum(jnp.sin(x)) + jnp.sum(jnp.cos(c))) c = jnp.sin(c * b) return c, b xs = jnp.ones((5, 3)) c = jnp.ones(4) jaxpr = make_jaxpr(partial(lax.scan, f))(c, xs).jaxpr core.check_jaxpr(jaxpr)
def test_jvp_should_transform_stacks(self): def f(x): with jax.named_scope('bar'): with jax.named_scope('baz'): return jnp.square(x) g = jax.named_scope('foo')(lambda x, t: jax.jvp(f, (x, ), (t, ))) jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo/jvp(bar)/baz')
def test_jaxpr_dropvar_from_jit_call(self): def inner(x): return x + 1, x + 2 def f(x): _, y = jit(inner)(x) return y + 3 jaxpr = make_jaxpr(f)(1).jaxpr assert isinstance(jaxpr.eqns[0].outvars[0], core.DropVar) core.check_jaxpr(jaxpr)
def test_multiple_name_stack(self): def f(x): with jax.named_scope('foo'): y = x + 1 with jax.named_scope('bar'): with jax.named_scope('baz'): return y + 1 jaxpr = jax.make_jaxpr(f)(2).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'bar/baz')
def test_vmap_should_transform_inner_name_stacks(self): @jax.named_scope('foo') @jax.vmap def f(x): with jax.named_scope('bar'): with jax.named_scope('baz'): return x + 1 jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo/vmap(bar)/baz')
def test_primitives(self): def f(x, y): s = jit(jnp.sin)(x) return jnp.sin(s) + jnp.cos(y) hist = jaxpr_util.primitives(make_jaxpr(f)(1., 1.).jaxpr) for k in ['add', 'sin', 'cos', 'xla_call']: assert k in hist, k self.assertEqual(hist['sin'], 2) self.assertTrue( all(count == 1 for k, count in hist.items() if k != 'sin'))
def test_primitives_by_source(self): def f(x, y): s = jnp.sin(x) return jnp.sin(s) + jnp.cos(y) hist = jaxpr_util.primitives_by_source(make_jaxpr(f)(1., 1.).jaxpr) sin_keys = [k for k in hist.keys() if k.startswith('sin @ ')] rem_keys = [k for k in hist.keys() if not k.startswith('sin @ ')] self.assertEqual(sum(hist[k] for k in sin_keys), 2) self.assertTrue(all(hist[k] == 1 for k in rem_keys))
def test_pmap_call_primitive_jaxpr_should_not_store_outer_name_stack(self): @jax.named_scope('foo') @jax.pmap def f(x): with jax.named_scope('bar'): return x + 1 jaxpr = jax.make_jaxpr(f)(jnp.ones(1)).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') self.assertEqual( str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info. name_stack), 'bar')
def test_scan_partial_eval(self): def f_jax(xs, ys): body_const = np.ones((2, ), dtype=np.float32) # Test constant capture def body(res0, inputs): x, y = inputs return res0 + x * y, body_const c_out, _ = lax.scan(body, 0., (xs, ys)) return c_out arg = np.arange(10, dtype=np.float32) print(jax.make_jaxpr(jax.grad(f_jax))(arg, arg)) self.ConvertAndCompare(jax.grad(f_jax), arg, arg)
def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype, dimension_numbers, bdims): rng = jtu.rand_small(self.rng()) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) # Checks that batching didn't introduce any transposes or broadcasts. jaxpr = jax.make_jaxpr(dot)(np.zeros(lhs_shape, dtype), np.zeros(rhs_shape, dtype)) for eqn in jtu.iter_eqns(jaxpr.jaxpr): self.assertFalse(eqn.primitive in ["transpose", "broadcast"])
def test_recurrent01(self): theta = NestedMap(proj=np.random.uniform(size=[3, 4])) inputs = NestedMap(x=np.random.uniform(size=[5, 3])) state0 = NestedMap(y=np.zeros([4])) prng_key = jnp.array([21230, 90230], dtype=jnp.uint32) global_step = jnp.array(0, dtype=jnp.uint64) def cell_fn(theta, state0, inputs_t): del state0 y = jnp.einsum('x,xy->y', inputs_t.x, theta.proj) return NestedMap(y=y) def comp01(theta, state0, inputs): with base_layer.JaxContext.new_context(prng_key=prng_key, global_step=global_step): final_state, cum_states = recurrent.recurrent_func( theta, state0, inputs, cell_fn) loss = jnp.sum(final_state.y) + jnp.sum(cum_states.y) return loss def comp02(theta, state0, inputs): with base_layer.JaxContext.new_context(prng_key=prng_key, global_step=global_step): final_state, cum_states = recurrent.recurrent_static( theta, state0, inputs, cell_fn) loss = jnp.sum(final_state.y) + jnp.sum(cum_states.y) return loss logging.info('comp01_jaxpr: %s', jax.make_jaxpr(comp01)(theta, state0, inputs)) logging.info('comp02_jaxpr: %s', jax.make_jaxpr(comp02)(theta, state0, inputs)) loss1 = comp01(theta, state0, inputs) loss2 = comp02(theta, state0, inputs) def to_np(x): return np.asarray(x, dtype=np.float32) self.assertAllClose(to_np(loss1), to_np(loss2))
def bound_propagation( prop_alg: PropagationAlgorithm[Repr], function: Callable[..., Nest[Tensor]], *bounds: Nest[GraphInput], graph_simplifier=synthetic_primitives.default_simplifier, ) -> Tuple[Nest[Union[Repr, Tensor]], Dict[jax.core.Var, Union[Repr, Tensor, Bound]]]: """Performs Bound Propagation on the model implemented by `function`. Args: prop_alg: Algorithm specifying how to traverse the graph and how to transform each node. function: Pure function inputs -> outputs. If the function to propagate through has a more complex signature, the use of `functools.partial` can solve that problem. *bounds: Nest of `IntervalBound` objects containing the lower and upper bounds on all the inputs, or `Tensor`s containing known inputs directly. graph_simplifier: Function transforming the JaxPR graph into a simpler graph. Default value is a function identifying specific activation functions, followed by grouping of linear sequences and quadratic forms. Returns: bounds: Bounds over all the outputs of the function, with the same structure as the output of `function` env: Mapping from the node of the computations to their representation. """ # Replace all the jittable bounds by standard bound object. bounds = unjit_inputs(*bounds) # Parse the computation graph. placeholder_inputs = jax.tree_util.tree_map( lambda b: b.lower if isinstance(b, Bound) else b, bounds) jaxpr_maker = jax.make_jaxpr(function) parsed = jaxpr_maker(*placeholder_inputs) output_shapes = jax.eval_shape(function, *placeholder_inputs) flat_is_bound, _ = jax.tree_util.tree_flatten( jax.tree_util.tree_map(lambda b: isinstance(b, Bound), bounds)) inp_is_bound = { var: is_bound for var, is_bound in zip(parsed.jaxpr.invars, flat_is_bound) } simplified_graph = synthetic_primitives.simplify_graph( graph_simplifier, parsed.jaxpr, inp_is_bound) graph = PropagationGraph(simplified_graph, parsed.literals) outvals, env = prop_alg.propagate(graph, bounds) # Make outvals into the same tree structure than the output of the function. tree_structure = jax.tree_util.tree_structure(output_shapes) outvals = jax.tree_util.tree_unflatten(tree_structure, outvals) return outvals, env
def test_jaxpr_typecheck_should_verify_eqn_effects_are_subset(self): def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x + 1. jaxpr = jax.make_jaxpr(f)(2.).jaxpr # Edit jaxpr to make its type wrong jaxpr = jaxpr.replace(effects={'foo'}) with self.assertRaisesRegex(core.JaxprTypeError, 'Equation effects are not subset of Jaxpr effects.'): core.check_jaxpr(jaxpr)