Пример #1
0
    def test_with_default_params(self, fake_pmap, fake_jit):
        with fake.fake_pmap_and_jit(fake_pmap, fake_jit):
            num_devices = len(jax.devices())

            # Default flag specified at definition time
            def foo(x, y, flag=True):
                return (x * 2) + y if flag else (x + y)

            default_foo = jax.pmap(foo, axis_size=num_devices)
            default_foo = jax.jit(default_foo)

            inputs = jnp.array([1, 2])
            inputs = jnp.broadcast_to(inputs, (num_devices, ) + inputs.shape)
            expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2))
            asserts.assert_tree_all_close(default_foo(inputs, inputs),
                                          expected)
            asserts.assert_tree_all_close(default_foo(x=inputs, y=inputs),
                                          expected)

            # Default overriden by partial to execute other branch
            overidden_foo = functools.partial(foo, flag=False)
            overidden_foo = jax.pmap(overidden_foo, axis_size=num_devices)
            overidden_foo = jax.jit(overidden_foo)

            expected = jnp.broadcast_to(jnp.array([2, 4]), (num_devices, 2))
            asserts.assert_tree_all_close(overidden_foo(inputs, inputs),
                                          expected)
            asserts.assert_tree_all_close(overidden_foo(x=inputs, y=inputs),
                                          expected)
Пример #2
0
    def test_pmap_and_jit(self, fake_kwargs, is_pmapped, is_jitted):
        fn_input = jnp.ones((4, ))

        def foo(x):
            return x * 2

        # Call with context manager
        with fake.fake_pmap_and_jit(**fake_kwargs):
            _assert_pmapped(foo, fn_input, is_pmapped)
            _assert_jitted(foo, fn_input, is_jitted)

        # Call with start/stop
        ctx = fake.fake_pmap_and_jit(**fake_kwargs)
        ctx.start()
        _assert_pmapped(foo, fn_input, is_pmapped)
        _assert_jitted(foo, fn_input, is_jitted)
        ctx.stop()
Пример #3
0
    def test_with_kwargs(self, fake_pmap, fake_jit):
        with fake.fake_pmap_and_jit(fake_pmap, fake_jit):
            num_devices = len(jax.devices())

            @functools.partial(jax.pmap, axis_size=num_devices)
            @jax.jit
            def foo(x, y):
                return (x * 2) + y

            # pmap over all available devices
            inputs = jnp.array([1, 2])
            inputs = jnp.broadcast_to(inputs, (num_devices, ) + inputs.shape)
            expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2))

            asserts.assert_tree_all_close(foo(x=inputs, y=inputs), expected)
Пример #4
0
    def test_with_partial(self, fake_pmap, fake_jit):
        with fake.fake_pmap_and_jit(fake_pmap, fake_jit):
            num_devices = len(jax.devices())

            # Testing a common use-case where non-parallel arguments are partially
            # applied before pmapping
            def foo(x, y, flag):
                return (x * 2) + y if flag else (x + y)

            foo = functools.partial(foo, flag=True)

            foo = jax.pmap(foo, axis_size=num_devices)
            foo = jax.jit(foo)

            # pmap over all available devices
            inputs = jnp.array([1, 2])
            inputs = jnp.broadcast_to(inputs, (num_devices, ) + inputs.shape)
            expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2))

            asserts.assert_tree_all_close(foo(inputs, inputs), expected)
            asserts.assert_tree_all_close(foo(x=inputs, y=inputs), expected)