Ejemplo n.º 1
0
    def test_with_kwargs(self, fake_pmap, fake_jit):
        with fake.fake_pmap_and_jit(fake_pmap, fake_jit):
            num_devices = len(jax.devices())

            @functools.partial(jax.pmap, axis_size=num_devices)
            @jax.jit
            def foo(x, y):
                return (x * 2) + y

            # pmap over all available devices
            inputs = jnp.array([1, 2])
            inputs = jnp.broadcast_to(inputs, (num_devices, ) + inputs.shape)
            expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2))

            asserts.assert_tree_all_close(foo(x=inputs, y=inputs), expected)
Ejemplo n.º 2
0
  def test_assert_tree_all_close_fails_values_differ(self):
    tree1 = (jnp.array([0.0, 2.0]))
    tree2 = (jnp.array([0.0, 2.1]))
    asserts.assert_tree_all_close(tree1, tree2, atol=0.1)
    with self.assertRaisesRegex(AssertionError,
                                'Values not approximately equal'):
      asserts.assert_tree_all_close(tree1, tree2, atol=0.01)

    asserts.assert_tree_all_close(tree1, tree2, rtol=0.1)
    with self.assertRaisesRegex(AssertionError,
                                'Values not approximately equal'):
      asserts.assert_tree_all_close(tree1, tree2, rtol=0.01)
Ejemplo n.º 3
0
    def test_with_partial(self, fake_pmap, fake_jit):
        with fake.fake_pmap_and_jit(fake_pmap, fake_jit):
            num_devices = len(jax.devices())

            # Testing a common use-case where non-parallel arguments are partially
            # applied before pmapping
            def foo(x, y, flag):
                return (x * 2) + y if flag else (x + y)

            foo = functools.partial(foo, flag=True)

            foo = jax.pmap(foo, axis_size=num_devices)
            foo = jax.jit(foo)

            # pmap over all available devices
            inputs = jnp.array([1, 2])
            inputs = jnp.broadcast_to(inputs, (num_devices, ) + inputs.shape)
            expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2))

            asserts.assert_tree_all_close(foo(inputs, inputs), expected)
            asserts.assert_tree_all_close(foo(x=inputs, y=inputs), expected)
Ejemplo n.º 4
0
 def test_tuple_rev_conversion(self, frozen):
     obj = dummy_dataclass(frozen=frozen)
     asserts.assert_tree_all_close(obj.__class__.from_tuple(obj.to_tuple()),
                                   obj)
Ejemplo n.º 5
0
 def test_dataclass_tree_map(self, frozen):
     factor = 5.
     obj = dummy_dataclass(frozen=frozen)
     target_obj = dummy_dataclass(factor=factor, frozen=frozen)
     asserts.assert_tree_all_close(
         jax.tree_util.tree_map(lambda t: factor * t, obj), target_obj)
Ejemplo n.º 6
0
 def test_assert_tree_all_close_nones(self):
   tree = {'a': [jnp.zeros((1,))], 'b': None}
   asserts.assert_tree_all_close(tree, tree, ignore_nones=True)
   with self.assertRaisesRegex(AssertionError, '`None` detected'):
     asserts.assert_tree_all_close(tree, tree, ignore_nones=False)
Ejemplo n.º 7
0
 def test_assert_tree_all_close_passes_values_close(self):
   tree1 = (jnp.array([1.0, 1.0]),)
   tree2 = (jnp.array([1.0, 1.0 + 1e-9]),)
   asserts.assert_tree_all_close(tree1, tree2)
Ejemplo n.º 8
0
 def test_assert_tree_all_close_passes_values_equal(self):
   tree1 = (jnp.array([0.0, 0.0]),)
   tree2 = (jnp.array([0.0, 0.0]),)
   asserts.assert_tree_all_close(tree1, tree2)
Ejemplo n.º 9
0
 def test_assert_tree_all_close_passes_same_tree(self):
   tree1 = {
       'a': [jnp.zeros((1,))],
       'b': ([0], (0,), 0),
   }
   asserts.assert_tree_all_close(tree1, tree1)