def _assert_pmapped(fn, fn_input, is_pmapped, should_jit=False): """Asserts whether a function can be pmapped or not. Args: fn: The function to be tested fn_input: Input to pass to the function is_pmapped: Assert that the function can be pmapped with jax.pmap (True) or cannot be pmapped (False), i.e. the fake pmap is working correctly. should_jit: if True, asserts that the function is jitted, regardless of it being pmapped or not. """ num_devices = len(jax.devices()) if should_jit: asserts.clear_trace_counter() fn = asserts.assert_max_traces(fn, n=1) wrapped_fn = jax.pmap(fn, axis_size=num_devices) fn_input = jnp.broadcast_to(fn_input, (num_devices, ) + fn_input.shape) output = wrapped_fn(fn_input) # We test whether the function has been pmapped by inspecting the type of # the function output, if it is a sharded array type then the function has # been pmapped if not is_pmapped and hasattr(jax.interpreters.xla, 'type_is_device_array'): expected_type = 'DeviceArray' assert_message = f'Output is type {type(output)}, expected {expected_type}' assert jax.interpreters.xla.type_is_device_array( output), assert_message else: expected_type = ArraySharded if is_pmapped else jnp.DeviceArray assert_message = f'Output is type {type(output)}, expected {expected_type}' # We want to check exact types here assert type(output) == expected_type, assert_message # pylint: disable=unidiomatic-typecheck
def test_redefined_traced_function(self): def outer_fn(x): @jax.jit @asserts.assert_max_traces(3) def inner_fn(y): return y.sum() return inner_fn(2 * x) self.assertEqual(outer_fn(1), 2) self.assertEqual(outer_fn(2), 4) self.assertEqual(outer_fn(3), 6) # Fails since the traced inner function is redefined at each call. with self.assertRaisesRegex(AssertionError, 'fn.* is traced > .* times!'): outer_fn(4) asserts.clear_trace_counter() for i in range(10): if i > 2: with self.assertRaisesRegex(AssertionError, 'fn.* is traced > .* times!'): outer_fn(1) else: outer_fn(1)
def _assert_jitted(fn, fn_input, is_jitted): """Asserts that a function can be jitted or not. Args: fn: The function to be tested fn_input: Input to pass to the function is_jitted: Assert that the function can be jitted with jax.jit (True) or cannot be jitted (False), i.e. the fake jit is working correctly. """ asserts.clear_trace_counter() max_traces = 1 if is_jitted else 0 wrapped_fn = jax.jit(asserts.assert_max_traces(fn, max_traces)) wrapped_fn(fn_input)
def setUp(self): super().setUp() asserts.clear_trace_counter()