Beispiel #1
0
        def fn_static(arg_0, arg_1):
            self.assertIsInstance(arg_0, arg_0_type)
            self.assertNotIsInstance(arg_1, arg_1_type)
            asserts.assert_shape(arg_0, [n_copies])
            asserts.assert_shape(arg_1, [n_copies])

            res = arg_1 - arg_0
            return jax.lax.psum(res, axis_name='i')
Beispiel #2
0
        def fn(arg_0, arg_1):
            self.assertNotIsInstance(arg_0, arg_0_type)
            self.assertNotIsInstance(arg_1, arg_1_type)
            asserts.assert_shape(arg_0, [n_copies])
            asserts.assert_shape(arg_1, [n_devices, n_copies])

            res = arg_1 - arg_0
            psum_res = jax.lax.psum(res, axis_name='i')
            return psum_res
Beispiel #3
0
        def fn_static(arg_0, arg_1):
            self.assertNotIsInstance(arg_0, arg_0_type)
            self.assertIsInstance(arg_1, arg_1_type)
            asserts.assert_shape(arg_0, [n_copies])
            asserts.assert_shape(arg_1, [n_copies])

            arg_1 = np.array(arg_1)  # don't stage out operations on arg_1
            psum_arg_1 = np.sum(jax.lax.psum(arg_1, axis_name='j'))
            self.assertEqual(psum_arg_1, arg_1[0] * (n_copies * n_devices))
            res = arg_1 - arg_0
            psum_res = jax.lax.psum(res, axis_name='j')
            return psum_res
Beispiel #4
0
 def vmapped_fn(arg_0, arg_1):
     self.assertIsInstance(arg_0, ArrayBatched)
     self.assertIsInstance(arg_1, ArrayBatched)
     asserts.assert_shape(arg_0, actual_shape[1:])
     asserts.assert_shape(arg_1, actual_shape[1:])
     return arg_1 - arg_0
Beispiel #5
0
 def test_pytypes_pass(self):
   arrays = as_arrays([[[1, 2], [3, 4]], [[1], [3]]])
   asserts.assert_shape(arrays, (2, None))
Beispiel #6
0
 def test_shape_should_pass(self, arrays, shapes):
   arrays = as_arrays(arrays)
   asserts.assert_shape(arrays, shapes)
Beispiel #7
0
 def test_shape_should_fail_wrong_length(self, arrays, shapes):
   arrays = as_arrays(arrays)
   with self.assertRaisesRegex(
       AssertionError, 'Length of `inputs` and `expected_shapes` must match'):
     asserts.assert_shape(arrays, shapes)
Beispiel #8
0
 def test_shape_should_fail(self, arrays, shapes):
   arrays = as_arrays(arrays)
   with self.assertRaisesRegex(AssertionError,
                               'input .+ has shape .+ but expected .+'):
     asserts.assert_shape(arrays, shapes)