Exemplo n.º 1
0
def test_unresample_bilinear_integer_sampling_cpu():

    # Basic Unresample layer
    layer = Unresample(InterpolationType.BILINEAR)

    # Run a forward and backward pass
    output, forward_time, backward_time, gradcheck_res = utils.mapped_resample_test(
        layer,
        input=params.input_3x4().repeat(bs, channels, 1, 1),
        sample_map=params.sample_map6(),
        cuda=False)

    # Manually computed correct result
    correct_output = torch.tensor([[0, 5], [9, 11], [3, 10]]).double()

    # Assert gradient check has passed
    assert gradcheck_res

    # Assert outputs match
    testing.assert_allclose(output, correct_output)
Exemplo n.º 2
0
def test_unresample_bispherical_out_of_bounds_sampling_cuda():

    # Basic Unresample layer
    layer = Unresample(InterpolationType.BISPHERICAL)

    # Run a forward and backward pass
    output, forward_time, backward_time, gradcheck_res = utils.mapped_resample_test(
        layer,
        input=params.input_3x4().repeat(bs, channels, 1, 1),
        sample_map=params.sample_map8(),
        cuda=True)

    # Manually computed correct result
    correct_output = torch.tensor([[0, 0], [0, 0], [4.75, 11]]).double().cuda()

    # Assert gradient check has passed
    assert gradcheck_res

    # Assert outputs match
    testing.assert_allclose(output, correct_output)
Exemplo n.º 3
0
def test_unresample_weighted_sampling_cuda():

    # Basic Unresample layer
    layer = Unresample(InterpolationType.BISPHERICAL)

    # Run a forward and backward pass
    output, forward_time, backward_time, gradcheck_res = utils.mapped_resample_test(
        layer,
        input=params.input_4x7().repeat(bs, channels, 1, 1),
        sample_map=params.sample_map5(),
        interp_weights=params.interp_weights0(),
        cuda=True)

    # Manually computed correct result
    correct_output = torch.tensor([[14, 15.8], [9.6, 19.1]]).double().cuda()

    # Assert gradient check has passed
    assert gradcheck_res

    # Assert outputs match
    testing.assert_allclose(output, correct_output)