def test_sum(self): helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=3), lambda x: Tensor.sum(x, axis=3)) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,3)), lambda x: Tensor.sum(x, axis=(1,3))) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2))) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2))) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1))
def test_sum(self): helper_test_op([(45, 3)], lambda x: x.sum(), Tensor.sum, device=self.device) helper_test_op([(3, 4, 5, 6)], lambda x: x.sum(axis=(1, 2)), lambda x: Tensor.sum(x, axis=(1, 2)), device=self.device) helper_test_op([(3, 4, 5, 6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1), device=self.device)
def test_sum_axis(self): helper_test_op([(3, 4, 5, 6)], lambda x: x.sum(axis=(1, 2)), lambda x: Tensor.sum(x, axis=(1, 2)), device=self.device)