def test_with_default_params(self, fake_pmap, fake_jit): with fake.fake_pmap_and_jit(fake_pmap, fake_jit): num_devices = len(jax.devices()) # Default flag specified at definition time def foo(x, y, flag=True): return (x * 2) + y if flag else (x + y) default_foo = jax.pmap(foo, axis_size=num_devices) default_foo = jax.jit(default_foo) inputs = jnp.array([1, 2]) inputs = jnp.broadcast_to(inputs, (num_devices, ) + inputs.shape) expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2)) asserts.assert_tree_all_close(default_foo(inputs, inputs), expected) asserts.assert_tree_all_close(default_foo(x=inputs, y=inputs), expected) # Default overriden by partial to execute other branch overidden_foo = functools.partial(foo, flag=False) overidden_foo = jax.pmap(overidden_foo, axis_size=num_devices) overidden_foo = jax.jit(overidden_foo) expected = jnp.broadcast_to(jnp.array([2, 4]), (num_devices, 2)) asserts.assert_tree_all_close(overidden_foo(inputs, inputs), expected) asserts.assert_tree_all_close(overidden_foo(x=inputs, y=inputs), expected)
def test_pmap_and_jit(self, fake_kwargs, is_pmapped, is_jitted): fn_input = jnp.ones((4, )) def foo(x): return x * 2 # Call with context manager with fake.fake_pmap_and_jit(**fake_kwargs): _assert_pmapped(foo, fn_input, is_pmapped) _assert_jitted(foo, fn_input, is_jitted) # Call with start/stop ctx = fake.fake_pmap_and_jit(**fake_kwargs) ctx.start() _assert_pmapped(foo, fn_input, is_pmapped) _assert_jitted(foo, fn_input, is_jitted) ctx.stop()
def test_with_kwargs(self, fake_pmap, fake_jit): with fake.fake_pmap_and_jit(fake_pmap, fake_jit): num_devices = len(jax.devices()) @functools.partial(jax.pmap, axis_size=num_devices) @jax.jit def foo(x, y): return (x * 2) + y # pmap over all available devices inputs = jnp.array([1, 2]) inputs = jnp.broadcast_to(inputs, (num_devices, ) + inputs.shape) expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2)) asserts.assert_tree_all_close(foo(x=inputs, y=inputs), expected)
def test_with_partial(self, fake_pmap, fake_jit): with fake.fake_pmap_and_jit(fake_pmap, fake_jit): num_devices = len(jax.devices()) # Testing a common use-case where non-parallel arguments are partially # applied before pmapping def foo(x, y, flag): return (x * 2) + y if flag else (x + y) foo = functools.partial(foo, flag=True) foo = jax.pmap(foo, axis_size=num_devices) foo = jax.jit(foo) # pmap over all available devices inputs = jnp.array([1, 2]) inputs = jnp.broadcast_to(inputs, (num_devices, ) + inputs.shape) expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2)) asserts.assert_tree_all_close(foo(inputs, inputs), expected) asserts.assert_tree_all_close(foo(x=inputs, y=inputs), expected)