def test_with_kwargs(self, fake_pmap, fake_jit): with fake.fake_pmap_and_jit(fake_pmap, fake_jit): num_devices = len(jax.devices()) @functools.partial(jax.pmap, axis_size=num_devices) @jax.jit def foo(x, y): return (x * 2) + y # pmap over all available devices inputs = jnp.array([1, 2]) inputs = jnp.broadcast_to(inputs, (num_devices, ) + inputs.shape) expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2)) asserts.assert_tree_all_close(foo(x=inputs, y=inputs), expected)
def test_assert_tree_all_close_fails_values_differ(self): tree1 = (jnp.array([0.0, 2.0])) tree2 = (jnp.array([0.0, 2.1])) asserts.assert_tree_all_close(tree1, tree2, atol=0.1) with self.assertRaisesRegex(AssertionError, 'Values not approximately equal'): asserts.assert_tree_all_close(tree1, tree2, atol=0.01) asserts.assert_tree_all_close(tree1, tree2, rtol=0.1) with self.assertRaisesRegex(AssertionError, 'Values not approximately equal'): asserts.assert_tree_all_close(tree1, tree2, rtol=0.01)
def test_with_partial(self, fake_pmap, fake_jit): with fake.fake_pmap_and_jit(fake_pmap, fake_jit): num_devices = len(jax.devices()) # Testing a common use-case where non-parallel arguments are partially # applied before pmapping def foo(x, y, flag): return (x * 2) + y if flag else (x + y) foo = functools.partial(foo, flag=True) foo = jax.pmap(foo, axis_size=num_devices) foo = jax.jit(foo) # pmap over all available devices inputs = jnp.array([1, 2]) inputs = jnp.broadcast_to(inputs, (num_devices, ) + inputs.shape) expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2)) asserts.assert_tree_all_close(foo(inputs, inputs), expected) asserts.assert_tree_all_close(foo(x=inputs, y=inputs), expected)
def test_tuple_rev_conversion(self, frozen): obj = dummy_dataclass(frozen=frozen) asserts.assert_tree_all_close(obj.__class__.from_tuple(obj.to_tuple()), obj)
def test_dataclass_tree_map(self, frozen): factor = 5. obj = dummy_dataclass(frozen=frozen) target_obj = dummy_dataclass(factor=factor, frozen=frozen) asserts.assert_tree_all_close( jax.tree_util.tree_map(lambda t: factor * t, obj), target_obj)
def test_assert_tree_all_close_nones(self): tree = {'a': [jnp.zeros((1,))], 'b': None} asserts.assert_tree_all_close(tree, tree, ignore_nones=True) with self.assertRaisesRegex(AssertionError, '`None` detected'): asserts.assert_tree_all_close(tree, tree, ignore_nones=False)
def test_assert_tree_all_close_passes_values_close(self): tree1 = (jnp.array([1.0, 1.0]),) tree2 = (jnp.array([1.0, 1.0 + 1e-9]),) asserts.assert_tree_all_close(tree1, tree2)
def test_assert_tree_all_close_passes_values_equal(self): tree1 = (jnp.array([0.0, 0.0]),) tree2 = (jnp.array([0.0, 0.0]),) asserts.assert_tree_all_close(tree1, tree2)
def test_assert_tree_all_close_passes_same_tree(self): tree1 = { 'a': [jnp.zeros((1,))], 'b': ([0], (0,), 0), } asserts.assert_tree_all_close(tree1, tree1)