Exemplo n.º 1
0
def test_res_gca_decoder():
    """Test resnet decoder with shortcut and guided contextual attention."""
    with pytest.raises(NotImplementedError):
        ResGCADecoder('UnknowBlock', [2, 3, 3, 2], 512)

    model = ResGCADecoder('BasicBlockDec', [2, 3, 3, 2], 512)
    model.init_weights()
    model.train()

    encoder = ResGCAEncoder('BasicBlock', [2, 4, 4, 2], 6)
    img = _demo_inputs((2, 6, 32, 32))
    outputs = encoder(img)
    prediction = model(outputs)
    assert_tensor_with_shape(prediction, torch.Size([2, 1, 32, 32]))

    # test forward with gpu
    if torch.cuda.is_available():
        model = ResGCADecoder('BasicBlockDec', [2, 3, 3, 2], 512)
        model.init_weights()
        model.train()
        model.cuda()
        encoder = ResGCAEncoder('BasicBlock', [2, 4, 4, 2], 6)
        encoder.cuda()
        img = _demo_inputs((2, 6, 32, 32)).cuda()
        outputs = encoder(img)
        prediction = model(outputs)
        assert_tensor_with_shape(prediction, torch.Size([2, 1, 32, 32]))
Exemplo n.º 2
0
def test_res_gca_encoder():
    """Test resnet encoder with shortcut and guided contextual attention."""
    with pytest.raises(NotImplementedError):
        ResGCAEncoder('UnknownBlock', [3, 4, 4, 2], 3)

    target_shape = [(2, 32, 64, 64), (2, 32, 32, 32), (2, 64, 16, 16),
                    (2, 128, 8, 8), (2, 256, 4, 4)]
    # target shape for model with late downsample
    target_late_ds = [(2, 32, 64, 64), (2, 64, 32, 32), (2, 64, 16, 16),
                      (2, 128, 8, 8), (2, 256, 4, 4)]

    model = ResGCAEncoder('BasicBlock', [3, 4, 4, 2], 4)
    model.init_weights()
    model.train()
    # trimap has 1 channels
    img = _demo_inputs((2, 4, 64, 64))
    outputs = model(img)
    assert_tensor_with_shape(outputs['out'], (2, 512, 2, 2))
    assert_tensor_with_shape(outputs['img_feat'], (2, 128, 8, 8))
    assert_tensor_with_shape(outputs['unknown'], (2, 1, 8, 8))
    for i in range(5):
        assert_tensor_with_shape(outputs[f'feat{i+1}'], target_shape[i])

    model = ResGCAEncoder('BasicBlock', [3, 4, 4, 2], 6)
    model.init_weights()
    model.train()
    # both image and trimap has 3 channels
    img = _demo_inputs((2, 6, 64, 64))
    outputs = model(img)
    assert_tensor_with_shape(outputs['out'], (2, 512, 2, 2))
    assert_tensor_with_shape(outputs['img_feat'], (2, 128, 8, 8))
    assert_tensor_with_shape(outputs['unknown'], (2, 1, 8, 8))
    for i in range(5):
        assert_tensor_with_shape(outputs[f'feat{i+1}'], target_shape[i])

    # test resnet shortcut encoder with late downsample
    model = ResGCAEncoder('BasicBlock', [3, 4, 4, 2], 6, late_downsample=True)
    model.init_weights()
    model.train()
    # both image and trimap has 3 channels
    img = _demo_inputs((2, 6, 64, 64))
    outputs = model(img)
    assert_tensor_with_shape(outputs['out'], (2, 512, 2, 2))
    assert_tensor_with_shape(outputs['img_feat'], (2, 128, 8, 8))
    assert_tensor_with_shape(outputs['unknown'], (2, 1, 8, 8))
    for i in range(5):
        assert_tensor_with_shape(outputs[f'feat{i+1}'], target_late_ds[i])

    if torch.cuda.is_available():
        # repeat above code again
        model = ResGCAEncoder('BasicBlock', [3, 4, 4, 2], 4)
        model.init_weights()
        model.train()
        model.cuda()
        # trimap has 1 channels
        img = _demo_inputs((2, 4, 64, 64)).cuda()
        outputs = model(img)
        assert_tensor_with_shape(outputs['out'], (2, 512, 2, 2))
        assert_tensor_with_shape(outputs['img_feat'], (2, 128, 8, 8))
        assert_tensor_with_shape(outputs['unknown'], (2, 1, 8, 8))
        for i in range(5):
            assert_tensor_with_shape(outputs[f'feat{i+1}'], target_shape[i])

        model = ResGCAEncoder('BasicBlock', [3, 4, 4, 2], 6)
        model.init_weights()
        model.train()
        model.cuda()
        # both image and trimap has 3 channels
        img = _demo_inputs((2, 6, 64, 64)).cuda()
        outputs = model(img)
        assert_tensor_with_shape(outputs['out'], (2, 512, 2, 2))
        assert_tensor_with_shape(outputs['img_feat'], (2, 128, 8, 8))
        assert_tensor_with_shape(outputs['unknown'], (2, 1, 8, 8))
        for i in range(5):
            assert_tensor_with_shape(outputs[f'feat{i+1}'], target_shape[i])

        # test resnet shortcut encoder with late downsample
        model = ResGCAEncoder(
            'BasicBlock', [3, 4, 4, 2], 6, late_downsample=True)
        model.init_weights()
        model.train()
        model.cuda()
        # both image and trimap has 3 channels
        img = _demo_inputs((2, 6, 64, 64)).cuda()
        outputs = model(img)
        assert_tensor_with_shape(outputs['out'], (2, 512, 2, 2))
        assert_tensor_with_shape(outputs['img_feat'], (2, 128, 8, 8))
        assert_tensor_with_shape(outputs['unknown'], (2, 1, 8, 8))
        for i in range(5):
            assert_tensor_with_shape(outputs[f'feat{i+1}'], target_late_ds[i])