def test_broadcast_in_dim(self, device, dtype): def _wrapper(a, shape, broadcast_dimensions): return prims.broadcast_in_dim(a, shape, broadcast_dimensions) traced = make_traced(_wrapper) make_arg = partial(make_tensor, device=device, dtype=dtype) # TODO: FIXME: # for executor in ('aten', 'nvfuser'): for executor in ("aten", ): fn = partial(traced, executor=executor) # Same shape shape = (5, 5) a = make_arg(shape) result = fn(a, shape, (0, 1)) self.assertEqual(result.shape, a.shape) self.assertTrue(result.is_contiguous) self.assertEqual(a, result) # Error input: reordering dims with self.assertRaises(Exception): result = fn(a, shape, (1, 0)) # Adding outermost dimensions a = make_arg((5, 5)) target_shape = (3, 3, 5, 5) result = fn(a, target_shape, (2, 3)) self.assertEqual(result.shape, target_shape) self.assertEqual(a.broadcast_to(target_shape), result) # Expands a = make_arg((1, 5, 1)) target_shape = (3, 5, 7) result = fn(a, target_shape, (0, 1, 2)) self.assertEqual(result.shape, target_shape) self.assertEqual(a.expand_as(result), result) # Unsqueezes a = make_arg((1, 2, 3)) target_shape = (1, 2, 1, 3) result = fn(a, target_shape, (0, 1, 3)) self.assertEqual(result.shape, target_shape) self.assertEqual(a.unsqueeze(2), result) # Adds outermost, expands, and unsqueezes a = make_arg((1, 2, 3)) target_shape = (4, 1, 7, 2, 3, 3) result = fn(a, target_shape, (1, 3, 4)) self.assertEqual(result.shape, target_shape) a.unsqueeze_(3) a.unsqueeze_(1) a.unsqueeze_(0) self.assertEqual(a.expand_as(result), result)
def test_broadcast_in_dim(self, device, dtype): # nvfuser is not currently capable of realizing a broadcasted tensor # when the broadcast is the only operation. Another op is needed. def _wrapper(a, b, broadcast_dimensions): a_bc = prims.broadcast_in_dim(a, b.shape, broadcast_dimensions) return prims.add(a_bc, b) traced = make_traced(_wrapper) make_arg = partial(make_tensor, device=device, dtype=dtype) for executor in ('aten', 'nvfuser'): fn = partial(traced, executor=executor) # Same shape shape = (5, 5) a = make_arg(shape) b = make_arg(shape, low=0.0, high=0.0) result = fn(a, b, (0, 1)) self.assertEqual(result.shape, a.shape) self.assertTrue(result.is_contiguous) self.assertEqual(a, result) # Error input: reordering dims with self.assertRaises(Exception): result = fn(a, b, (1, 0)) # Adding outermost dimensions a = make_arg((5, 5)) b = make_arg((3, 3, 5, 5), low=0.0, high=0.0) result = fn(a, b, (2, 3)) self.assertEqual(result.shape, b.shape) self.assertEqual(a.broadcast_to(b.shape), result) # Expands a = make_arg((1, 5, 1)) b = make_arg((3, 5, 7), low=0.0, high=0.0) result = fn(a, b, (0, 1, 2)) self.assertEqual(result.shape, b.shape) self.assertEqual(a.expand_as(result), result) # Unsqueezes a = make_arg((1, 2, 3)) b = make_arg((1, 2, 1, 3), low=0.0, high=0.0) result = fn(a, b, (0, 1, 3)) self.assertEqual(result.shape, b.shape) self.assertEqual(a.unsqueeze(2), result) # FIXME: This test exposes an issue in nvfuser # Adds outermost, expands, and unsqueezes """
def test_var(self, device, dtype, correction): def _wrapper(a): return prims.var(a, [0, 1], correction=correction) traced = make_traced(_wrapper) make_arg = partial(make_tensor, device=device, dtype=dtype) for executor in ('aten', 'nvfuser'): fn = partial(traced, executor=executor) shape = (5, 5) a = make_arg(shape) result = fn(a) self.assertEqual(result.shape, ()) self.assertTrue(result.is_contiguous) self.assertEqual(_wrapper(a), result)
def test_broadcast_in_dim_sum(self, device, dtype): def _wrapper(a): a_sum = prims.sum(a, [0, 1]) a_bc = prims.broadcast_in_dim(a_sum, [], []) return a_bc traced = make_traced(_wrapper) make_arg = partial(make_tensor, device=device, dtype=dtype) for executor in ('aten', 'nvfuser'): fn = partial(traced, executor=executor) shape = (5, 5) a = make_arg(shape) result = fn(a) self.assertEqual(result.shape, ()) self.assertTrue(result.is_contiguous) self.assertEqual(_wrapper(a), result)