Exemplo n.º 1
0
def _assert_pmapped(fn, fn_input, is_pmapped, should_jit=False):
    """Asserts whether a function can be pmapped or not.

  Args:
    fn: The function to be tested
    fn_input: Input to pass to the function
    is_pmapped: Assert that the function can be pmapped with jax.pmap (True) or
    cannot be pmapped (False), i.e. the fake pmap is working correctly.
    should_jit: if True, asserts that the function is jitted, regardless of it
    being pmapped or not.
  """
    num_devices = len(jax.devices())
    if should_jit:
        asserts.clear_trace_counter()
        fn = asserts.assert_max_traces(fn, n=1)
    wrapped_fn = jax.pmap(fn, axis_size=num_devices)

    fn_input = jnp.broadcast_to(fn_input, (num_devices, ) + fn_input.shape)
    output = wrapped_fn(fn_input)

    # We test whether the function has been pmapped by inspecting the type of
    # the function output, if it is a sharded array type then the function has
    # been pmapped
    if not is_pmapped and hasattr(jax.interpreters.xla,
                                  'type_is_device_array'):
        expected_type = 'DeviceArray'
        assert_message = f'Output is type {type(output)}, expected {expected_type}'
        assert jax.interpreters.xla.type_is_device_array(
            output), assert_message
    else:
        expected_type = ArraySharded if is_pmapped else jnp.DeviceArray
        assert_message = f'Output is type {type(output)}, expected {expected_type}'
        # We want to check exact types here
        assert type(output) == expected_type, assert_message  # pylint: disable=unidiomatic-typecheck
Exemplo n.º 2
0
  def test_redefined_traced_function(self):

    def outer_fn(x):

      @jax.jit
      @asserts.assert_max_traces(3)
      def inner_fn(y):
        return y.sum()

      return inner_fn(2 * x)

    self.assertEqual(outer_fn(1), 2)
    self.assertEqual(outer_fn(2), 4)
    self.assertEqual(outer_fn(3), 6)

    # Fails since the traced inner function is redefined at each call.
    with self.assertRaisesRegex(AssertionError, 'fn.* is traced > .* times!'):
      outer_fn(4)

    asserts.clear_trace_counter()
    for i in range(10):
      if i > 2:
        with self.assertRaisesRegex(AssertionError,
                                    'fn.* is traced > .* times!'):
          outer_fn(1)
      else:
        outer_fn(1)
Exemplo n.º 3
0
def _assert_jitted(fn, fn_input, is_jitted):
    """Asserts that a function can be jitted or not.

  Args:
    fn: The function to be tested
    fn_input: Input to pass to the function
    is_jitted: Assert that the function can be jitted with jax.jit (True) or
    cannot be jitted (False), i.e. the fake jit is working correctly.
  """
    asserts.clear_trace_counter()
    max_traces = 1 if is_jitted else 0
    wrapped_fn = jax.jit(asserts.assert_max_traces(fn, max_traces))
    wrapped_fn(fn_input)
Exemplo n.º 4
0
 def setUp(self):
   super().setUp()
   asserts.clear_trace_counter()