def test_maxpool2d(self): # TODO merge into test_maxpool2d_strided when backward() is implemented for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: with self.subTest(kernel_size=ksz): helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), lambda x: Tensor.max_pool2d(x, kernel_size=ksz), gpu=self.gpu, forward_only=self.gpu)
def test_maxpool_sizes(self): for sz in [(2, 2), (3, 3), (3, 2), (5, 5), (5, 1)]: helper_test_op( [(32, 2, 110, 28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=sz), lambda x: Tensor.max_pool2d(x, kernel_size=sz), gpu=self.gpu)
def test_maxpool2d(self): for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: with self.subTest(kernel_size=ksz): helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), # TODO: why is this tolerance so high? lambda x: Tensor.max_pool2d(x, kernel_size=ksz), grad_atol=1e-4)
def test_maxpool2d(self): for ksz in [(2, 2), (3, 3), (3, 2), (5, 5), (5, 1)]: with self.subTest(kernel_size=ksz): helper_test_op([(32, 2, 110, 28)], lambda x: torch.nn.functional.max_pool2d( x, kernel_size=ksz), lambda x: Tensor.max_pool2d(x, kernel_size=ksz), device=self.device)
def test_maxpool2d(self): for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: for strd in [(1,1), (2,1), (2,2), (4,2)]: # TODO Grad tolerance for CPU implementation needs to be slightly relaxed; why? with self.subTest(kernel_size=ksz, stride=strd): helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, stride=strd), lambda x: Tensor.max_pool2d(x, kernel_size=ksz, stride=strd), gpu=self.gpu, forward_only=self.gpu, grad_atol=1e-3)
def test_maxpool2d_strided_fwd(self): for ksz in [(2, 2), (3, 3), (3, 2), (5, 5), (5, 1)]: for strd in [(1, 1), (2, 1), (2, 2), (4, 2)]: with self.subTest(kernel_size=ksz, stride=strd): helper_test_op([(32, 2, 110, 28)], lambda x: torch.nn.functional.max_pool2d( x, kernel_size=ksz, stride=strd), lambda x: Tensor.max_pool2d( x, kernel_size=ksz, stride=strd), gpu=self.gpu, forward_only=True)