Esempio n. 1
0
def test_forward_passes():
    channels = [8, 16, 32]
    batch_size = 32

    cnn_enc = CNNEncoder(
        input_channels=3,
        channels=channels,
        kernel_size=3,
        conv_dim=2,
        pooling=torch.nn.MaxPool2d(2),
        stride=1,
        padding=1,
    )

    _in = torch.rand((batch_size, 3, 32, 32))
    out = cnn_enc(_in)

    assert out.shape == torch.Size([batch_size, channels[-1], 4, 4])

    cnn_enc = CNNEncoder(
        input_channels=3,
        channels=channels,
        kernel_size=3,
        conv_dim=2,
        pooling=None,
        stride=1,
        padding=1,
    )

    _in = torch.rand((batch_size, 3, 32, 32))
    out = cnn_enc(_in)

    assert out.shape == torch.Size([batch_size, channels[-1], 32, 32])
Esempio n. 2
0
def test_kernel_sizes4():
    channels = [10]
    kernels = [(10, 10)]

    with pytest.raises(ValueError):
        cnn_enc = CNNEncoder(input_channels=3,
                             channels=channels,
                             kernel_size=kernels,
                             conv_dim=3)

    kernels = [(10, 10, 10, 10)]

    with pytest.raises(ValueError):
        cnn_enc = CNNEncoder(input_channels=3,
                             channels=channels,
                             kernel_size=kernels,
                             conv_dim=3)
Esempio n. 3
0
def test_kernel_sizes3():
    channels = [random.randint(5, 32) for _ in range(10)]
    kernels = [random.randint(5, 32) for _ in range(8)]

    with pytest.raises(ValueError):
        cnn_enc = CNNEncoder(input_channels=3,
                             channels=channels,
                             kernel_size=kernels,
                             conv_dim=3)
Esempio n. 4
0
def test_cnn_blocks_filters():
    aux = [random.randint(5, 32) for _ in range(10)]

    cnn_enc = CNNEncoder(input_channels=aux[0], channels=aux[1:])

    for i, l in enumerate(cnn_enc.cnn):
        conv = l[0]
        assert conv.in_channels == aux[i]
        assert conv.out_channels == aux[i + 1]
Esempio n. 5
0
def test_kernel_sizes2():
    channels = [random.randint(5, 32) for _ in range(10)]
    kernels = [random.randint(5, 32) for _ in range(10)]

    cnn_enc = CNNEncoder(input_channels=3,
                         channels=channels,
                         kernel_size=kernels,
                         conv_dim=3)
    for i, b in enumerate(cnn_enc.cnn):
        conv = b[0]
        assert conv.kernel_size == (kernels[i], ) * 3
Esempio n. 6
0
def test_kernel_sizes():
    channels = [random.randint(5, 32) for _ in range(10)]
    kernels = [random.randint(5, 32) for _ in range(10)]

    cnn_enc = CNNEncoder(input_channels=3,
                         channels=channels,
                         kernel_size=5,
                         conv_dim=3)
    for i in cnn_enc.cnn:
        conv = i[0]
        assert conv.kernel_size == (5, 5, 5)
Esempio n. 7
0
def test_invalid_conv_dim():
    with pytest.raises(ValueError):
        cnn_enc = CNNEncoder(input_channels=3, channels=[8, 8, 8], conv_dim=4)
Esempio n. 8
0
def test_cnn_blocks_size():
    for i in range(5):
        cnn_enc = CNNEncoder(input_channels=3, channels=list(range(8, 8 + i)))

        assert len(cnn_enc.cnn) == i
Esempio n. 9
0
def test_invalid_channels_kernels():
    with pytest.raises(ValueError):
        cnn_enc = CNNEncoder(input_channels=3,
                             channels=[8, 8, 8],
                             kernel_size=[3, 3],
                             conv_dim=0)