def test_concatdownsample2d(self): num_batches = 2 num_chan = 4 scale_factor = 2 x = torch.arange(num_batches * num_chan * 4**2).view( num_batches, num_chan, 4, 4) # Test functional API self.assertRaises(AssertionError, F.concat_downsample2d, x, 3) out = F.concat_downsample2d(x, scale_factor) self.assertEqual( out.shape, (num_batches, num_chan * scale_factor**2, x.shape[2] // scale_factor, x.shape[3] // scale_factor)) # Check first and last values self.assertTrue(torch.equal(out[0][0], torch.tensor([[0, 2], [8, 10]]))) self.assertTrue( torch.equal(out[0][-num_chan], torch.tensor([[5, 7], [13, 15]]))) # Test module mod = downsample.ConcatDownsample2d(scale_factor) self.assertTrue(torch.equal(mod(x), out)) # Test JIT module mod = downsample.ConcatDownsample2dJit(scale_factor) self.assertTrue(torch.equal(mod(x), out))
def test_concatdownsample2d(self): num_batches = 2 num_chan = 4 x = torch.rand(num_batches, num_chan, 4, 4) # Test functional API self.assertRaises(AssertionError, F.concat_downsample2d, x, 3) out = F.concat_downsample2d(x, 2) self.assertEqual( out.shape, (num_batches, num_chan * 2**2, x.shape[2] // 2, x.shape[3] // 2)) self.assertTrue( torch.equal( out, torch.stack((x[..., ::2, ::2], x[..., ::2, 1::2], x[..., 1::2, ::2], x[..., 1::2, 1::2]), dim=2).view(num_batches, -1, x.shape[2] // 2, x.shape[3] // 2))) # Test module mod = downsample.ConcatDownsample2d(2) self.assertTrue(torch.equal(mod(x), out))