def test_concatdownsample2d(): num_batches = 2 num_chan = 4 scale_factor = 2 x = torch.arange(num_batches * num_chan * 4**2).view( num_batches, num_chan, 4, 4) # Test functional API with pytest.raises(AssertionError): F.concat_downsample2d(x, 3) out = F.concat_downsample2d(x, scale_factor) assert out.shape == (num_batches, num_chan * scale_factor**2, x.shape[2] // scale_factor, x.shape[3] // scale_factor) # Check first and last values assert torch.equal(out[0][0], torch.tensor([[0, 2], [8, 10]])) assert torch.equal(out[0][-num_chan], torch.tensor([[5, 7], [13, 15]])) # Test module mod = downsample.ConcatDownsample2d(scale_factor) assert torch.equal(mod(x), out) # Test JIT module mod = downsample.ConcatDownsample2dJit(scale_factor) assert torch.equal(mod(x), out)
def test_concatdownsample2d(self): num_batches = 2 num_chan = 4 x = torch.rand(num_batches, num_chan, 4, 4) # Test functional API self.assertRaises(AssertionError, F.concat_downsample2d, x, 3) out = F.concat_downsample2d(x, 2) self.assertEqual( out.shape, (num_batches, num_chan * 2**2, x.shape[2] // 2, x.shape[3] // 2)) self.assertTrue( torch.equal( out, torch.stack((x[..., ::2, ::2], x[..., ::2, 1::2], x[..., 1::2, ::2], x[..., 1::2, 1::2]), dim=2).view(num_batches, -1, x.shape[2] // 2, x.shape[3] // 2))) # Test module mod = downsample.ConcatDownsample2d(2) self.assertTrue(torch.equal(mod(x), out))