예제 #1
0
    def test_primitives_by_shape(self):
        def f(x, y):
            def sub(x, y):
                return jnp.sum(jnp.array([x, y])), y

            s, _ = jit(sub)(x, y)
            return jnp.sin(s) + jnp.cos(y)

        hist = jaxpr_util.primitives_by_shape(make_jaxpr(f)(1., 1.).jaxpr)

        shapes = [
            'add :: float32[]',
            'sin :: float32[]',
            'cos :: float32[]',
            'reduce_sum :: float32[]',
            'concatenate :: float32[2]',
            'xla_call :: float32[] *',
        ]
        for k in shapes:
            self.assertEqual(hist[k], 1)
예제 #2
0
    def test_primitives_by_shape(self):
        def f(x, y):
            def sub(x, y):
                return jnp.sum(jnp.array([x, y])), y

            s, _ = jit(sub)(x, y)
            return jnp.sin(s) + jnp.cos(y)

        hist = jaxpr_util.primitives_by_shape(make_jaxpr(f)(1., 1.).jaxpr)

        t = '64' if FLAGS.jax_enable_x64 else '32'
        shapes = [
            f'add :: float{t}[]',
            f'sin :: float{t}[]',
            f'cos :: float{t}[]',
            f'reduce_sum :: float{t}[]',
            f'concatenate :: float{t}[2]',
            f'xla_call :: float{t}[] *',
        ]
        for k in shapes:
            self.assertEqual(hist[k], 1)