Example #1
0
 def testVMapShardingConstraint(self):
   f = pjit(lambda x: with_sharding_constraint(x, P('x')),
            in_axis_resources=P(), out_axis_resources=P('x'))
   x = jnp.arange(5*4).reshape((5, 4))
   jaxpr = jax.make_jaxpr(jax.vmap(f))(x)
   pjit_eqn, = jaxpr.eqns
   constraint_eqn, = pjit_eqn.params['jaxpr'].eqns
   self.assertEqual(constraint_eqn.params['axis_resources'].partitions, ((), ('x',)))
   self.assertEqual(constraint_eqn.params['axis_resources'].sync, SpecSync.DIM_PERMUTE)
Example #2
0
 def test_check_jaxpr_cond_invalid(self):
     jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(
         1.).jaxpr
     cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond')
     cond.params['branches'][0].jaxpr.invars = ()
     self.assertRaisesRegex(
         core.JaxprTypeError,
         'cond branch 0 takes 0 inputs, branch 1 takes 1',
         lambda: core.check_jaxpr(jaxpr))
Example #3
0
    def test_nested_name_stack(self):
        @extend_name_stack('foo')
        def f(x):
            with extend_name_stack('bar'):
                return x + 1

        jaxpr = jax.make_jaxpr(f)(2).jaxpr
        for eqn in jaxpr.eqns:
            self.assertEqual(str(eqn.source_info.name_stack), 'foo/bar')
Example #4
0
 def testEvalJaxpr(self):
   x, y = jnp.arange(4), jnp.arange(5)
   f = pjit(lambda x, y: x.sum() + jnp.sin(y),
            in_axis_resources=(P('x'), P('y')),
            out_axis_resources=P('y'))
   f_jaxpr = jax.make_jaxpr(f)(x, y)
   f_eval = jax.core.jaxpr_as_fun(f_jaxpr)
   r, = f_eval(x, y)
   self.assertAllClose(r, x.sum() + jnp.sin(y))
Example #5
0
  def testLowerToNothing(self):
    empty = Empty(AbstractEmpty())
    jaxpr = make_jaxpr(jit(lambda e: e))(empty).jaxpr
    core.check_jaxpr(jaxpr)

    # cannot return a unit, because CompileAndCheck assumes array output.
    testfunc = lambda e: None
    args_maker = lambda: [empty]
    self._CompileAndCheck(testfunc, args_maker)
Example #6
0
    def test_vmap_should_transform_name_stack(self):
        @jax.vmap
        def f(x):
            with extend_name_stack('foo'):
                return x + 1

        jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr
        self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack),
                         'vmap(foo)')
Example #7
0
def _invertible_jaxpr_and_constants(fun):
  """Returns a transformation from function invocation to invertible jaxpr."""
  jaxpr_maker = jax.make_jaxpr(fun)

  @jax.api.wraps(fun)  # pylint: disable=no-value-for-parameter
  def jaxpr_const_maker(*args, **kwargs):
    typed_jaxpr = jaxpr_maker(*args, **kwargs)
    return typed_jaxpr.jaxpr, typed_jaxpr.literals
  return jaxpr_const_maker
Example #8
0
def _check_numerical_stability(fn):
  """Logs a warning if numerically unstable operations are requested."""
  jaxpr = jax.make_jaxpr(fn)(0.0).jaxpr
  for eqn in jaxpr.eqns:
    if eqn.primitive in _potentially_unstable_primitives:
      logging.warn("[Distrax]: the '%s' primitive can exhibit unstable "
                   "numerical behavior under certain circumstances. Consider "
                   "using the %s bijector instead if possible.", eqn.primitive,
                   _potentially_unstable_primitives[eqn.primitive])
Example #9
0
 def test_jaxpr_undefined_eqn_invar(self):
     jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr
     cos = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cos')
     cos.invars[0] = core.gensym([jaxpr],
                                 suffix='_test')(cos.invars[0].aval)
     self.assertRaisesRegex(
         core.JaxprTypeError,
         r"Variable '.+_test' not defined\n\nin equation:",
         lambda: core.check_jaxpr(jaxpr))
Example #10
0
    def test_jaxpr_dropvar_from_cond(self):
        def f(x):
            _, y = lax.cond(x < 0., lambda x: (jnp.sin(x), x + 1.), lambda x:
                            (jnp.cos(x), x + 2.), x)
            return y

        jaxpr = make_jaxpr(f)(1.).jaxpr
        assert jaxpr.eqns[-1].outvars[0] is core.dropvar
        core.check_jaxpr(jaxpr)
Example #11
0
    def test_jaxpr_dropvar_from_loop(self):
        def f(x):
            _, y = lax.while_loop(lambda s: s[0] < 0., lambda s:
                                  (jnp.sin(s[0]), jnp.cos(s[1])), (x, x))
            return y + 1.

        jaxpr = make_jaxpr(f)(1.).jaxpr
        assert jaxpr.eqns[0].outvars[0] is core.dropvar
        core.check_jaxpr(jaxpr)
Example #12
0
    def test_xmap_inherits_effects(self):
        def f(x):
            effect_p.bind(effect='foo')
            effect_p.bind(effect='bar')
            return x

        f = maps.xmap(f, in_axes=['a'], out_axes=['a'])
        jaxpr = jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
        self.assertSetEqual(jaxpr.effects, {"foo", "bar"})
Example #13
0
    def test_prune_jit_args(self):
        def f(*args):
            return args[0]

        closed_jaxpr = jax.make_jaxpr(f)(*range(10))
        pruned_jaxpr, kept_const_idx, kept_var_idx = xla._prune_unused_inputs(
            closed_jaxpr.jaxpr)
        assert len(pruned_jaxpr.invars) == 1
        assert kept_const_idx == set()
        assert kept_var_idx == {0}
Example #14
0
  def test_jaxpr_doesnt_include_trivial_operations(self):
    @partial(mask, in_shapes=['n'], out_shape='')
    def foo(x):
      return np.sum(x)

    padded_x = np.array([0, 1, 2, 3, 999, 999])

    jaxpr = make_jaxpr(foo)([padded_x], dict(n=3))
    self.assertNotIn('mul', str(jaxpr))
    self.assertNotIn('add', str(jaxpr))
Example #15
0
  def test_readme_example(self):
    """Some of the examples from the README."""
    def image_mask_jax(images, mask):
      # images: f32[B, W, W]  and mask: f32[W, W]
      return images * mask

    print(jax.make_jaxpr(image_mask_jax)(np.ones((1024, 28, 28)), np.ones((28, 28))))

    # will invoke broadcast_in_dim with shape=(1, w, w)
    jax2tf.convert(image_mask_jax, polymorphic_shapes=["(b, w, w)", "(w, w)"])
Example #16
0
    def test_different_effects_in_jaxpr(self):
        def f(x):
            effect_p.bind(effect='foo')
            effect_p.bind(effect='bar')
            return x + 1.

        jaxpr = jax.make_jaxpr(f)(2.)
        self.assertEqual({'foo'}, jaxpr.jaxpr.eqns[0].effects)
        self.assertEqual({'bar'}, jaxpr.jaxpr.eqns[1].effects)
        self.assertEqual({'foo', 'bar'}, jaxpr.effects)
Example #17
0
    def test_primitives_by_source(self):
        def f(x, y):
            s = jnp.sin(x)
            return jnp.sin(s) + jnp.cos(y)

        hist = jaxpr_util.primitives_by_source(make_jaxpr(f)(1., 1.).jaxpr)

        sin_keys = [k for k in hist.keys() if k.startswith('sin @ ')]
        self.assertEqual(len(sin_keys), 2)
        self.assertTrue(all(count == 1 for count in hist.values()))
Example #18
0
    def test_check_jaxpr_scan_correct(self):
        def f(c, x):
            b = jnp.cos(jnp.sum(jnp.sin(x)) + jnp.sum(jnp.cos(c)))
            c = jnp.sin(c * b)
            return c, b

        xs = jnp.ones((5, 3))
        c = jnp.ones(4)
        jaxpr = make_jaxpr(partial(lax.scan, f))(c, xs).jaxpr
        core.check_jaxpr(jaxpr)
Example #19
0
    def test_jvp_should_transform_stacks(self):
        def f(x):
            with jax.named_scope('bar'):
                with jax.named_scope('baz'):
                    return jnp.square(x)

        g = jax.named_scope('foo')(lambda x, t: jax.jvp(f, (x, ), (t, )))
        jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr
        self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack),
                         'foo/jvp(bar)/baz')
Example #20
0
  def test_jaxpr_dropvar_from_jit_call(self):
    def inner(x):
      return x + 1, x + 2

    def f(x):
      _, y = jit(inner)(x)
      return y + 3

    jaxpr = make_jaxpr(f)(1).jaxpr
    assert isinstance(jaxpr.eqns[0].outvars[0], core.DropVar)
    core.check_jaxpr(jaxpr)
Example #21
0
    def test_multiple_name_stack(self):
        def f(x):
            with jax.named_scope('foo'):
                y = x + 1
            with jax.named_scope('bar'):
                with jax.named_scope('baz'):
                    return y + 1

        jaxpr = jax.make_jaxpr(f)(2).jaxpr
        self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
        self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'bar/baz')
Example #22
0
    def test_vmap_should_transform_inner_name_stacks(self):
        @jax.named_scope('foo')
        @jax.vmap
        def f(x):
            with jax.named_scope('bar'):
                with jax.named_scope('baz'):
                    return x + 1

        jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr
        self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack),
                         'foo/vmap(bar)/baz')
Example #23
0
    def test_primitives(self):
        def f(x, y):
            s = jit(jnp.sin)(x)
            return jnp.sin(s) + jnp.cos(y)

        hist = jaxpr_util.primitives(make_jaxpr(f)(1., 1.).jaxpr)

        for k in ['add', 'sin', 'cos', 'xla_call']:
            assert k in hist, k
        self.assertEqual(hist['sin'], 2)
        self.assertTrue(
            all(count == 1 for k, count in hist.items() if k != 'sin'))
Example #24
0
    def test_primitives_by_source(self):
        def f(x, y):
            s = jnp.sin(x)
            return jnp.sin(s) + jnp.cos(y)

        hist = jaxpr_util.primitives_by_source(make_jaxpr(f)(1., 1.).jaxpr)

        sin_keys = [k for k in hist.keys() if k.startswith('sin @ ')]
        rem_keys = [k for k in hist.keys() if not k.startswith('sin @ ')]

        self.assertEqual(sum(hist[k] for k in sin_keys), 2)
        self.assertTrue(all(hist[k] == 1 for k in rem_keys))
Example #25
0
    def test_pmap_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
        @jax.named_scope('foo')
        @jax.pmap
        def f(x):
            with jax.named_scope('bar'):
                return x + 1

        jaxpr = jax.make_jaxpr(f)(jnp.ones(1)).jaxpr
        self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
        self.assertEqual(
            str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.
                name_stack), 'bar')
Example #26
0
  def test_scan_partial_eval(self):
    def f_jax(xs, ys):
      body_const = np.ones((2, ), dtype=np.float32)  # Test constant capture
      def body(res0, inputs):
        x, y = inputs
        return res0 + x * y, body_const
      c_out, _ = lax.scan(body, 0., (xs, ys))
      return c_out

    arg = np.arange(10, dtype=np.float32)
    print(jax.make_jaxpr(jax.grad(f_jax))(arg, arg))
    self.ConvertAndCompare(jax.grad(f_jax), arg, arg)
Example #27
0
    def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
                                       dimension_numbers, bdims):
        rng = jtu.rand_small(self.rng())
        dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
        self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape),
                            (dtype, dtype), rng)

        # Checks that batching didn't introduce any transposes or broadcasts.
        jaxpr = jax.make_jaxpr(dot)(np.zeros(lhs_shape, dtype),
                                    np.zeros(rhs_shape, dtype))
        for eqn in jtu.iter_eqns(jaxpr.jaxpr):
            self.assertFalse(eqn.primitive in ["transpose", "broadcast"])
Example #28
0
    def test_recurrent01(self):
        theta = NestedMap(proj=np.random.uniform(size=[3, 4]))
        inputs = NestedMap(x=np.random.uniform(size=[5, 3]))
        state0 = NestedMap(y=np.zeros([4]))

        prng_key = jnp.array([21230, 90230], dtype=jnp.uint32)
        global_step = jnp.array(0, dtype=jnp.uint64)

        def cell_fn(theta, state0, inputs_t):
            del state0
            y = jnp.einsum('x,xy->y', inputs_t.x, theta.proj)
            return NestedMap(y=y)

        def comp01(theta, state0, inputs):
            with base_layer.JaxContext.new_context(prng_key=prng_key,
                                                   global_step=global_step):
                final_state, cum_states = recurrent.recurrent_func(
                    theta, state0, inputs, cell_fn)
                loss = jnp.sum(final_state.y) + jnp.sum(cum_states.y)
                return loss

        def comp02(theta, state0, inputs):
            with base_layer.JaxContext.new_context(prng_key=prng_key,
                                                   global_step=global_step):
                final_state, cum_states = recurrent.recurrent_static(
                    theta, state0, inputs, cell_fn)
                loss = jnp.sum(final_state.y) + jnp.sum(cum_states.y)
                return loss

        logging.info('comp01_jaxpr: %s',
                     jax.make_jaxpr(comp01)(theta, state0, inputs))
        logging.info('comp02_jaxpr: %s',
                     jax.make_jaxpr(comp02)(theta, state0, inputs))
        loss1 = comp01(theta, state0, inputs)
        loss2 = comp02(theta, state0, inputs)

        def to_np(x):
            return np.asarray(x, dtype=np.float32)

        self.assertAllClose(to_np(loss1), to_np(loss2))
Example #29
0
def bound_propagation(
    prop_alg: PropagationAlgorithm[Repr],
    function: Callable[..., Nest[Tensor]],
    *bounds: Nest[GraphInput],
    graph_simplifier=synthetic_primitives.default_simplifier,
) -> Tuple[Nest[Union[Repr, Tensor]], Dict[jax.core.Var, Union[Repr, Tensor,
                                                               Bound]]]:
    """Performs Bound Propagation on the model implemented by `function`.

  Args:
    prop_alg: Algorithm specifying how to traverse the graph and how to
      transform each node.
    function: Pure function inputs -> outputs. If the function to propagate
      through has a more complex signature, the use of `functools.partial` can
      solve that problem.
    *bounds: Nest of `IntervalBound` objects containing the lower and upper
      bounds on all the inputs, or `Tensor`s containing known inputs directly.
    graph_simplifier: Function transforming the JaxPR graph into a simpler
      graph. Default value is a function identifying specific activation
      functions, followed by grouping of linear sequences and quadratic forms.
  Returns:
    bounds: Bounds over all the outputs of the function, with the same structure
      as the output of `function`
    env: Mapping from the node of the computations to their representation.
  """
    # Replace all the jittable bounds by standard bound object.
    bounds = unjit_inputs(*bounds)

    # Parse the computation graph.
    placeholder_inputs = jax.tree_util.tree_map(
        lambda b: b.lower if isinstance(b, Bound) else b, bounds)
    jaxpr_maker = jax.make_jaxpr(function)
    parsed = jaxpr_maker(*placeholder_inputs)
    output_shapes = jax.eval_shape(function, *placeholder_inputs)

    flat_is_bound, _ = jax.tree_util.tree_flatten(
        jax.tree_util.tree_map(lambda b: isinstance(b, Bound), bounds))
    inp_is_bound = {
        var: is_bound
        for var, is_bound in zip(parsed.jaxpr.invars, flat_is_bound)
    }
    simplified_graph = synthetic_primitives.simplify_graph(
        graph_simplifier, parsed.jaxpr, inp_is_bound)
    graph = PropagationGraph(simplified_graph, parsed.literals)

    outvals, env = prop_alg.propagate(graph, bounds)

    # Make outvals into the same tree structure than the output of the function.
    tree_structure = jax.tree_util.tree_structure(output_shapes)
    outvals = jax.tree_util.tree_unflatten(tree_structure, outvals)

    return outvals, env
Example #30
0
  def test_jaxpr_typecheck_should_verify_eqn_effects_are_subset(self):
    def f(x):
      effect_p.bind(effect='foo')
      effect_p.bind(effect='bar')
      return x + 1.
    jaxpr = jax.make_jaxpr(f)(2.).jaxpr

    # Edit jaxpr to make its type wrong
    jaxpr = jaxpr.replace(effects={'foo'})

    with self.assertRaisesRegex(core.JaxprTypeError,
        'Equation effects are not subset of Jaxpr effects.'):
      core.check_jaxpr(jaxpr)