def fn_static(arg_0, arg_1): self.assertIsInstance(arg_0, arg_0_type) self.assertNotIsInstance(arg_1, arg_1_type) asserts.assert_shape(arg_0, [n_copies]) asserts.assert_shape(arg_1, [n_copies]) res = arg_1 - arg_0 return jax.lax.psum(res, axis_name='i')
def fn(arg_0, arg_1): self.assertNotIsInstance(arg_0, arg_0_type) self.assertNotIsInstance(arg_1, arg_1_type) asserts.assert_shape(arg_0, [n_copies]) asserts.assert_shape(arg_1, [n_devices, n_copies]) res = arg_1 - arg_0 psum_res = jax.lax.psum(res, axis_name='i') return psum_res
def fn_static(arg_0, arg_1): self.assertNotIsInstance(arg_0, arg_0_type) self.assertIsInstance(arg_1, arg_1_type) asserts.assert_shape(arg_0, [n_copies]) asserts.assert_shape(arg_1, [n_copies]) arg_1 = np.array(arg_1) # don't stage out operations on arg_1 psum_arg_1 = np.sum(jax.lax.psum(arg_1, axis_name='j')) self.assertEqual(psum_arg_1, arg_1[0] * (n_copies * n_devices)) res = arg_1 - arg_0 psum_res = jax.lax.psum(res, axis_name='j') return psum_res
def vmapped_fn(arg_0, arg_1): self.assertIsInstance(arg_0, ArrayBatched) self.assertIsInstance(arg_1, ArrayBatched) asserts.assert_shape(arg_0, actual_shape[1:]) asserts.assert_shape(arg_1, actual_shape[1:]) return arg_1 - arg_0
def test_pytypes_pass(self): arrays = as_arrays([[[1, 2], [3, 4]], [[1], [3]]]) asserts.assert_shape(arrays, (2, None))
def test_shape_should_pass(self, arrays, shapes): arrays = as_arrays(arrays) asserts.assert_shape(arrays, shapes)
def test_shape_should_fail_wrong_length(self, arrays, shapes): arrays = as_arrays(arrays) with self.assertRaisesRegex( AssertionError, 'Length of `inputs` and `expected_shapes` must match'): asserts.assert_shape(arrays, shapes)
def test_shape_should_fail(self, arrays, shapes): arrays = as_arrays(arrays) with self.assertRaisesRegex(AssertionError, 'input .+ has shape .+ but expected .+'): asserts.assert_shape(arrays, shapes)