def test_output_shape_gradient(self): """ Compares the output shape and gradients of the converting_relu operation to the output of the torch implementation for different input dimensions. """ test_inputs = { "1-d signed": ConvertingReLUInput(torch.arange(-63., 65.).requires_grad_()), "1-d": ConvertingReLUInput(rand_full((128, ), 20.)), "2-d": ConvertingReLUInput(rand_full((3, 128), 20.)), "3-d": ConvertingReLUInput(rand_full((2, 3, 128), 20.)), "2-d non-contiguous input": ConvertingReLUInput( rand_full((128, 3), 20.).data.t().requires_grad_()), } for mode, converting_relu_input in test_inputs.items(): with self.subTest(mode=mode): result = self.converting_relu( **converting_relu_input._asdict()) converting_relu_input_torch = converting_relu_input.duplicate() result_torch = torch.div( torch.relu(**converting_relu_input_torch._asdict()), 4) self.assertEqual(result.size(), result_torch.size()) # compute gradients result.backward(torch.ones_like(result)) result_torch.backward(torch.ones_like(result_torch)) self.assertTrue( torch.allclose(result, result_torch, atol=1.0), f"Result does not match:\n" f"{result}\n!=\n{result_torch}") self.assertTrue(torch.all(result <= 31.), f"Result not smaller equal 31:\n" f"{result}") for name, arg in converting_relu_input._asdict().items(): if hasattr(arg, "grad"): grad = arg.grad grad_torch = getattr(converting_relu_input_torch, name).grad self.assertTrue( torch.allclose(grad, grad_torch, rtol=0.1), f"{name.capitalize()} gradient does not match:\n" f"{grad}\n!=\n{grad_torch}")
def test_output_shape_gradient(self): """ Compares the output shape and gradients of the add operation to the output of the torch implementation for different input dimensions. """ test_inputs = { "1-d broadcast": AddInput(rand_full((3, 128), 20.), rand_full((128,), 30.)), "1-d": AddInput(rand_full((128,), 20.), rand_full((128,), 30.)), } for mode, add_input in test_inputs.items(): with self.subTest(mode=mode): result = self.add(**add_input._asdict()) add_input_torch = add_input.duplicate() result_torch = self.torch_add(**add_input_torch._asdict()) self.assertEqual(result.size(), result_torch.size()) # compute gradients result.backward(torch.ones_like(result)) result_torch.backward(torch.ones_like(result_torch)) self.assertTrue( torch.allclose(result, result_torch, atol=2.0), f"Result does not match:\n" f"{result}\n!=\n{result_torch}") for name, arg in add_input._asdict().items(): if hasattr(arg, "grad"): grad = arg.grad grad_torch = getattr(add_input_torch, name).grad self.assertTrue( torch.allclose(grad, grad_torch, rtol=0.1), f"{name.capitalize()} gradient does not match:\n" f"{grad}\n!=\n{grad_torch}")
class TestConv1d(TestConv): """ Tests the conv1d operation. """ conv = torch.conv1d torch_conv = torch.conv1d test_inputs = { "batch1_outchannels1_inchannels1_kernel_larger_stride": ConvInput(rand_full((3, 1, 30), 25.), rand_full((1, 1, 5), 50.), stride=7), "expanded_full_synram": ConvInput(rand_full((2, 1, 128), 10.), rand_full((14, 1, 43), 15.), bias=torch.full((14, ), 1.).requires_grad_(), stride=5), "expanded_overfull_synram": ConvInput(rand_full((2, 1, 138), 10.), rand_full((14, 1, 43), 15.), bias=torch.full((14, ), 1.).requires_grad_(), stride=5), } kernel_size = 5 for n_batches in [2, 4]: for n_input_channels in [1, 3, 5]: for n_output_channels in [1, 4]: for stride in [7, 4, 2]: test_inputs.update({ f"batch{n_batches}_outchannels{n_output_channels}_" + f"inchannels{n_input_channels}_kernel{kernel_size}_" + f"stride{stride}": ConvInput(rand_full((n_batches, n_input_channels, 30), 10.), rand_full((n_output_channels, n_input_channels, kernel_size), 50.), stride=stride) })
def test_output_shape_gradient(self): """ Compares the output shape and gradients of the matmul operation to the output of the torch implementation for different input dimensions. """ test_inputs = { "1-d x 1-d": MatmulInput(rand_full((128, ), 12.), rand_full((128, ), 15.)), "1-d x 2-d": MatmulInput(rand_full((128, ), 12.), rand_full((128, 5), 15.)), # TODO: implement > 2D weights # "1-d x 3-d": # MatmulInput(rand_full((128,), 12.), rand_full((2, 128, 5), 15.)), # "1-d x 4-d": # MatmulInput(rand_full((128,), 12.), rand_full((4, 2, 128, 5), 15.)), "2-d x 1-d": MatmulInput(rand_full((3, 128), 12.), rand_full((128, ), 15.)), "2-d x 2-d": MatmulInput(rand_full((3, 128), 12.), rand_full((128, 5), 15.)), # "2-d x 3-d": # MatmulInput(rand_full((3, 128), 12.), rand_full((2, 128, 5), 15.)), # "2-d x 4-d": # MatmulInput(rand_full((3, 128), 12.), rand_full((4, 2, 128, 5), 15.)), "3-d x 1-d": MatmulInput(rand_full((2, 3, 128), 12.), rand_full((128, ), 15.)), "3-d x 2-d": MatmulInput(rand_full((2, 3, 128), 12.), rand_full((128, 5), 15.)), # TODO: implement batched mode # "3-d x 3-d": # MatmulInput(rand_full((2, 3, 128), 12.), rand_full((2, 128, 5), 15.)), # "3-d x 4-d": # MatmulInput(rand_full((2, 3, 128), 12.), rand_full((4, 2, 128, 5), 15.)), "2-d x 2-d non-contiguous input": MatmulInput( rand_full((128, 3), 12.).data.t().requires_grad_(), rand_full((128, 5), 15.)), "2-d x 2-d non-contiguous other": MatmulInput(rand_full((3, 128), 12.), rand_full((5, 128), 15.).data.t().requires_grad_()) } for mode, matmul_input in test_inputs.items(): with self.subTest(mode=mode): result = self.matmul(**matmul_input._asdict()) self.assertTrue(result.is_contiguous()) matmul_input_torch = matmul_input.duplicate() result_torch = torch.matmul(**matmul_input_torch._asdict()) self.assertEqual(result.size(), result_torch.size()) # compute gradients result.backward(torch.ones_like(result)) result_torch.backward(torch.ones_like(result_torch)) for name, arg in matmul_input._asdict().items(): if hasattr(arg, "grad"): grad = arg.grad grad_torch = getattr(matmul_input_torch, name).grad \ * self.gain self.assertTrue( torch.allclose(grad, grad_torch, rtol=.001), f"{name.capitalize()} gradient does not match:\n" f"{grad}\n!=\n{grad_torch}")
def test_output_shape_value(self): """ Compares the output shape and value of the argmax operation to the output of the torch implementation for different inputs. """ test_inputs = { "1d dim=None": ArgMaxInput(rand_full((123), 20.).round(), dim=None, keepdim=False), "1d dim=0": ArgMaxInput(rand_full((123), 20.).round(), dim=0, keepdim=False), "1d dim=0 keepdim=True": ArgMaxInput(rand_full((123), 20.).round(), dim=0, keepdim=True), "2d dim=None": ArgMaxInput(rand_full((123, 456), 20.).round(), dim=None, keepdim=False), "2d dim=1": ArgMaxInput(rand_full((123, 456), 20.).round(), dim=1, keepdim=False), "2d dim=0": ArgMaxInput(rand_full((123, 456), 20.).round(), dim=0, keepdim=False), "2d dim=1 keepdim=True": ArgMaxInput(rand_full((123, 456), 20.).round(), dim=1, keepdim=True), "2d dim=0 keepdim=True": ArgMaxInput(rand_full((123, 456), 20.).round(), dim=0, keepdim=True), "3d dim=None": ArgMaxInput(rand_full((12, 45, 78), 20.).round(), dim=None, keepdim=False), "3d dim=2": ArgMaxInput(rand_full((12, 45, 78), 20.).round(), dim=2, keepdim=False), "3d dim=1": ArgMaxInput(rand_full((12, 45, 78), 20.).round(), dim=1, keepdim=False), "3d dim=0": ArgMaxInput(rand_full((12, 45, 78), 20.).round(), dim=0, keepdim=False), "3d dim=2 keepdim=True": ArgMaxInput(rand_full((12, 45, 78), 20.).round(), dim=2, keepdim=True), "3d dim=1 keepdim=True": ArgMaxInput(rand_full((12, 45, 78), 20.).round(), dim=1, keepdim=True), "3d dim=0 keepdim=True": ArgMaxInput(rand_full((12, 45, 78), 20.).round(), dim=0, keepdim=True), } for mode, argmax_input in test_inputs.items(): with self.subTest(mode=mode): result = self.argmax(**argmax_input._asdict()) argmax_input_torch = argmax_input.duplicate() result_torch = torch.argmax(**argmax_input_torch._asdict()) self.assertTrue( torch.equal(result, result_torch), f"Result does not match:\n" f"{result}\n!=\n{result_torch}")
class TestConv2d(TestConv): """ Tests the conv2d operation. """ conv = torch.conv2d torch_conv = torch.conv2d test_inputs = { "batch1_outchannels1_inchannels1_kernel_larger_stride": ConvInput(rand_full((1, 1, 30, 60), 25.), rand_full((1, 1, 5, 10), 20), stride=(7, 14)), "batch2_outchannels1_inchannels3_kernel_larger_stride": ConvInput(rand_full((2, 3, 30, 60), 10.), rand_full((1, 3, 5, 10), 20), stride=(7, 14)), "batch2_outchannels4_inchannels3_kernel_larger_stride": ConvInput(rand_full((2, 3, 30, 60), 10.), rand_full((4, 3, 5, 10), 20), stride=(7, 14)), "batch1_outchannels1_inchannels1_kernel_smaller_stride": ConvInput(rand_full((1, 1, 30, 60), 25.), rand_full((1, 1, 5, 10), 20), stride=(4, 8)), "batch2_outchannels1_inchannels3_kernel_smaller_stride": ConvInput(rand_full((2, 3, 30, 60), 10.), rand_full((1, 3, 5, 10), 20), stride=(4, 8)), "batch2_outchannels4_inchannels3_kernel_smaller_stride": ConvInput(rand_full((2, 3, 30, 60), 10.), rand_full((4, 3, 5, 10), 20), stride=(4, 8)), "batch2_outchannels4_inchannels3_kernel_smaller_stride": ConvInput(rand_full((2, 3, 30, 60), 10.), rand_full((4, 3, 5, 10), 20), bias=torch.full((4, ), 0.).requires_grad_(), stride=(4, 8)) }