コード例 #1
0
ファイル: test_converter.py プロジェクト: f-dangel/backpack
def test_network_diag_ggn(model_and_input):
    """Test whether the given module can compute diag_ggn.

    This test is placed here, because some models are too big to run with PyTorch.
    Thus, a full diag_ggn comparison with PyTorch is impossible.
    This test just checks whether it runs on BackPACK without errors.
    Additionally, it checks whether the forward pass is identical to the original model.
    Finally, a small number of elements of DiagGGN are compared.

    Args:
        model_and_input: module to test

    Raises:
        NotImplementedError: if loss_fn is not MSELoss or CrossEntropyLoss
    """
    model_original, x, loss_fn = model_and_input
    model_original = model_original.eval()
    output_compare = model_original(x)
    if isinstance(loss_fn, MSELoss):
        y = regression_targets(output_compare.shape)
    elif isinstance(loss_fn, CrossEntropyLoss):
        y = classification_targets(
            (output_compare.shape[0], *output_compare.shape[2:]),
            output_compare.shape[1],
        )
    else:
        raise NotImplementedError(
            f"test cannot handle loss_fn = {type(loss_fn)}")

    num_params = sum(p.numel() for p in model_original.parameters()
                     if p.requires_grad)
    num_to_compare = 10
    idx_to_compare = linspace(0, num_params - 1, num_to_compare, dtype=int32)
    diag_ggn_exact_to_compare = autograd_diag_ggn_exact(x,
                                                        y,
                                                        model_original,
                                                        loss_fn,
                                                        idx=idx_to_compare)

    model_extended = extend(model_original, use_converter=True, debug=True)
    output = model_extended(x)

    assert allclose(output, output_compare)

    loss = extend(loss_fn)(output, y)

    with backpack(DiagGGNExact()):
        loss.backward()

    diag_ggn_exact_vector = cat([
        p.diag_ggn_exact.flatten() for p in model_extended.parameters()
        if p.requires_grad
    ])

    for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare):
        assert allclose(element, diag_ggn_exact_vector[idx], atol=1e-5)
コード例 #2
0
LOCAL_SETTINGS += [
    # RNN settings
    {
        "input_fn":
        lambda: rand(8, 5, 6),
        "module_fn":
        lambda: Sequential(
            RNN(input_size=6, hidden_size=3, batch_first=True),
            ReduceTuple(index=0),
            Permute(0, 2, 1),
            Flatten(),
        ),
        "loss_function_fn":
        lambda: MSELoss(),
        "target_fn":
        lambda: regression_targets((8, 3 * 5)),
    },
    {
        "input_fn":
        lambda: rand(4, 3, 5),
        "module_fn":
        lambda: Sequential(
            LSTM(input_size=5, hidden_size=4, batch_first=True),
            ReduceTuple(index=0),
            Flatten(),
        ),
        "loss_function_fn":
        lambda: CrossEntropyLoss(),
        "target_fn":
        lambda: classification_targets((4, ), 4 * 3),
    },
コード例 #3
0
    {
        "input_fn": lambda: torch.rand(3, 10),
        "module_fn": lambda: torch.nn.Sequential(
            torch.nn.Linear(10, 7), torch.nn.ReLU(), torch.nn.Linear(7, 5)
        ),
        "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
        "target_fn": lambda: classification_targets((3,), 5),
    },
    # Regression
    {
        "input_fn": lambda: torch.rand(3, 10),
        "module_fn": lambda: torch.nn.Sequential(
            torch.nn.Linear(10, 7), torch.nn.Sigmoid(), torch.nn.Linear(7, 5)
        ),
        "loss_function_fn": lambda: torch.nn.MSELoss(reduction="mean"),
        "target_fn": lambda: regression_targets((3, 5)),
    },
]

###############################################################################
#                         test setting: Convolutional Layers                  #
###############################################################################

FIRSTORDER_SETTINGS += [
    {
        "input_fn": lambda: torch.rand(3, 3, 7),
        "module_fn": lambda: torch.nn.Sequential(
            torch.nn.Conv1d(3, 2, 2),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.Linear(12, 5),
コード例 #4
0
        "target_fn": lambda: classification_targets(size=(2, ), num_classes=2),
    },
    {
        "module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
        "input_fn": lambda: torch.rand(size=(8, 4)),
        "target_fn": lambda: classification_targets(size=(8, ), num_classes=2),
    },
    {
        "module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="none"),
        "input_fn": lambda: torch.rand(size=(1, 1)),
        "target_fn": lambda: classification_targets(size=(1, ), num_classes=1),
    },
    {
        "module_fn": lambda: torch.nn.MSELoss(reduction="mean"),
        "input_fn": lambda: torch.rand(size=(5, 1)),
        "target_fn": lambda: regression_targets(size=(5, 1)),
    },
    {
        "module_fn": lambda: torch.nn.MSELoss(reduction="sum"),
        "input_fn": lambda: torch.rand(size=(5, 3)),
        "target_fn": lambda: regression_targets(size=(5, 3)),
    },
    {
        "module_fn": lambda: torch.nn.MSELoss(reduction="none"),
        "input_fn": lambda: torch.rand(size=(1, 1)),
        "target_fn": lambda: regression_targets(size=(1, 1)),
    },
]

LOSS_FAIL_SETTINGS = [
    # non-scalar outputs are not supported
コード例 #5
0
        "loss_function_fn":
        lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
        "target_fn":
        lambda: classification_targets((3, ), 5),
    },
    # Regression
    {
        "input_fn":
        lambda: torch.rand(3, 10),
        "module_fn":
        lambda: torch.nn.Sequential(torch.nn.Linear(10, 7), torch.nn.Sigmoid(),
                                    torch.nn.Linear(7, 5)),
        "loss_function_fn":
        lambda: torch.nn.MSELoss(reduction="mean"),
        "target_fn":
        lambda: regression_targets((3, 5)),
    },
]

###############################################################################
#                         test setting: Convolutional Layers                  #
"""
Syntax with default parameters: 
 - `torch.nn.ConvNd(in_channels, out_channels, 
    kernel_size, stride=1, padding=0, dilation=1, 
    groups=1, bias=True, padding_mode='zeros)`    

 - `torch.nn.ConvTransposeNd(in_channels, out_channels, 
    kernel_size, stride=1, padding=0, output_padding=0, 
    groups=1, bias=True, dilation=1, padding_mode='zeros)`
コード例 #6
0
                                 LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS)
LOCAL_NOT_SUPPORTED_SETTINGS = []

NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS

BATCH_SIZE_1_SETTINGS = [
    {
        "input_fn":
        lambda: rand(1, 7),
        "module_fn":
        lambda: Sequential(Linear(7, 3), ReLU(),
                           Flatten(start_dim=1, end_dim=-1), Linear(3, 1)),
        "loss_function_fn":
        lambda: MSELoss(reduction="mean"),
        "target_fn":
        lambda: regression_targets((1, 1)),
        "id_prefix":
        "one-additional",
    },
    {
        "input_fn":
        lambda: rand(3, 10),
        "module_fn":
        lambda: Sequential(
            Linear(10, 5),
            ReLU(),
            # skip connection
            Parallel(
                Identity(),
                Linear(5, 5),
            ),
コード例 #7
0
        "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
        "target_fn": lambda: classification_targets((3, ), 5),
    },
    {
        "input_fn": lambda: rand(3, 10),
        "module_fn": lambda: Sequential(Linear(10, 7), ReLU(), Linear(7, 5)),
        "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"),
        "target_fn": lambda: classification_targets((3, ), 5),
    },
    # Regression
    {
        "input_fn": lambda: rand(3, 10),
        "module_fn":
        lambda: Sequential(Linear(10, 7), Sigmoid(), Linear(7, 5)),
        "loss_function_fn": lambda: MSELoss(reduction="mean"),
        "target_fn": lambda: regression_targets((3, 5)),
    },
]

###############################################################################
#                         test setting: Activation Layers                     #
###############################################################################
activations = [ReLU, Sigmoid, Tanh, LeakyReLU, LogSigmoid, ELU, SELU]

for act in activations:
    for bias in [True, False]:
        SECONDORDER_SETTINGS.append(make_simple_act_setting(act, bias=bias))

###############################################################################
#                         test setting: Pooling Layers                       #
"""
コード例 #8
0
        "loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"),
        "target_fn": lambda: classification_targets((3, ), 5),
    },
    {
        "input_fn": lambda: rand(3, 10),
        "module_fn": lambda: Sequential(Linear(10, 7), ReLU(), Linear(7, 5)),
        "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"),
        "target_fn": lambda: classification_targets((3, ), 5),
    },
    # regression
    {
        "input_fn": lambda: rand(3, 10),
        "module_fn":
        lambda: Sequential(Linear(10, 7), Sigmoid(), Linear(7, 5)),
        "loss_function_fn": lambda: MSELoss(reduction="mean"),
        "target_fn": lambda: regression_targets((3, 5)),
    },
]

# linear with additional dimension
FIRSTORDER_SETTINGS += [
    # regression
    {
        "input_fn": lambda: rand(3, 4, 5),
        "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2)),
        "loss_function_fn": lambda: MSELoss(reduction="mean"),
        "target_fn": lambda: regression_targets((3, 4, 2)),
        "id_prefix": "one-additional",
    },
    {
        "input_fn": lambda: rand(3, 4, 2, 5),