Example #1
0
def test_concatdownsample2d():

    num_batches = 2
    num_chan = 4
    scale_factor = 2
    x = torch.arange(num_batches * num_chan * 4**2).view(
        num_batches, num_chan, 4, 4)

    # Test functional API
    with pytest.raises(AssertionError):
        F.concat_downsample2d(x, 3)
    out = F.concat_downsample2d(x, scale_factor)
    assert out.shape == (num_batches, num_chan * scale_factor**2,
                         x.shape[2] // scale_factor,
                         x.shape[3] // scale_factor)

    # Check first and last values
    assert torch.equal(out[0][0], torch.tensor([[0, 2], [8, 10]]))
    assert torch.equal(out[0][-num_chan], torch.tensor([[5, 7], [13, 15]]))
    # Test module
    mod = downsample.ConcatDownsample2d(scale_factor)
    assert torch.equal(mod(x), out)
    # Test JIT module
    mod = downsample.ConcatDownsample2dJit(scale_factor)
    assert torch.equal(mod(x), out)
Example #2
0
    def test_concatdownsample2d(self):

        num_batches = 2
        num_chan = 4
        x = torch.rand(num_batches, num_chan, 4, 4)

        # Test functional API
        self.assertRaises(AssertionError, F.concat_downsample2d, x, 3)
        out = F.concat_downsample2d(x, 2)
        self.assertEqual(
            out.shape,
            (num_batches, num_chan * 2**2, x.shape[2] // 2, x.shape[3] // 2))
        self.assertTrue(
            torch.equal(
                out,
                torch.stack((x[..., ::2, ::2], x[..., ::2, 1::2],
                             x[..., 1::2, ::2], x[..., 1::2, 1::2]),
                            dim=2).view(num_batches, -1, x.shape[2] // 2,
                                        x.shape[3] // 2)))
        # Test module
        mod = downsample.ConcatDownsample2d(2)
        self.assertTrue(torch.equal(mod(x), out))