def test_network_diag_ggn(model_and_input): """Test whether the given module can compute diag_ggn. This test is placed here, because some models are too big to run with PyTorch. Thus, a full diag_ggn comparison with PyTorch is impossible. This test just checks whether it runs on BackPACK without errors. Additionally, it checks whether the forward pass is identical to the original model. Finally, a small number of elements of DiagGGN are compared. Args: model_and_input: module to test Raises: NotImplementedError: if loss_fn is not MSELoss or CrossEntropyLoss """ model_original, x, loss_fn = model_and_input model_original = model_original.eval() output_compare = model_original(x) if isinstance(loss_fn, MSELoss): y = regression_targets(output_compare.shape) elif isinstance(loss_fn, CrossEntropyLoss): y = classification_targets( (output_compare.shape[0], *output_compare.shape[2:]), output_compare.shape[1], ) else: raise NotImplementedError( f"test cannot handle loss_fn = {type(loss_fn)}") num_params = sum(p.numel() for p in model_original.parameters() if p.requires_grad) num_to_compare = 10 idx_to_compare = linspace(0, num_params - 1, num_to_compare, dtype=int32) diag_ggn_exact_to_compare = autograd_diag_ggn_exact(x, y, model_original, loss_fn, idx=idx_to_compare) model_extended = extend(model_original, use_converter=True, debug=True) output = model_extended(x) assert allclose(output, output_compare) loss = extend(loss_fn)(output, y) with backpack(DiagGGNExact()): loss.backward() diag_ggn_exact_vector = cat([ p.diag_ggn_exact.flatten() for p in model_extended.parameters() if p.requires_grad ]) for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare): assert allclose(element, diag_ggn_exact_vector[idx], atol=1e-5)
LOCAL_SETTINGS += [ # RNN settings { "input_fn": lambda: rand(8, 5, 6), "module_fn": lambda: Sequential( RNN(input_size=6, hidden_size=3, batch_first=True), ReduceTuple(index=0), Permute(0, 2, 1), Flatten(), ), "loss_function_fn": lambda: MSELoss(), "target_fn": lambda: regression_targets((8, 3 * 5)), }, { "input_fn": lambda: rand(4, 3, 5), "module_fn": lambda: Sequential( LSTM(input_size=5, hidden_size=4, batch_first=True), ReduceTuple(index=0), Flatten(), ), "loss_function_fn": lambda: CrossEntropyLoss(), "target_fn": lambda: classification_targets((4, ), 4 * 3), },
{ "input_fn": lambda: torch.rand(3, 10), "module_fn": lambda: torch.nn.Sequential( torch.nn.Linear(10, 7), torch.nn.ReLU(), torch.nn.Linear(7, 5) ), "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), }, # Regression { "input_fn": lambda: torch.rand(3, 10), "module_fn": lambda: torch.nn.Sequential( torch.nn.Linear(10, 7), torch.nn.Sigmoid(), torch.nn.Linear(7, 5) ), "loss_function_fn": lambda: torch.nn.MSELoss(reduction="mean"), "target_fn": lambda: regression_targets((3, 5)), }, ] ############################################################################### # test setting: Convolutional Layers # ############################################################################### FIRSTORDER_SETTINGS += [ { "input_fn": lambda: torch.rand(3, 3, 7), "module_fn": lambda: torch.nn.Sequential( torch.nn.Conv1d(3, 2, 2), torch.nn.ReLU(), torch.nn.Flatten(), torch.nn.Linear(12, 5),
"target_fn": lambda: classification_targets(size=(2, ), num_classes=2), }, { "module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), "input_fn": lambda: torch.rand(size=(8, 4)), "target_fn": lambda: classification_targets(size=(8, ), num_classes=2), }, { "module_fn": lambda: torch.nn.CrossEntropyLoss(reduction="none"), "input_fn": lambda: torch.rand(size=(1, 1)), "target_fn": lambda: classification_targets(size=(1, ), num_classes=1), }, { "module_fn": lambda: torch.nn.MSELoss(reduction="mean"), "input_fn": lambda: torch.rand(size=(5, 1)), "target_fn": lambda: regression_targets(size=(5, 1)), }, { "module_fn": lambda: torch.nn.MSELoss(reduction="sum"), "input_fn": lambda: torch.rand(size=(5, 3)), "target_fn": lambda: regression_targets(size=(5, 3)), }, { "module_fn": lambda: torch.nn.MSELoss(reduction="none"), "input_fn": lambda: torch.rand(size=(1, 1)), "target_fn": lambda: regression_targets(size=(1, 1)), }, ] LOSS_FAIL_SETTINGS = [ # non-scalar outputs are not supported
"loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3, ), 5), }, # Regression { "input_fn": lambda: torch.rand(3, 10), "module_fn": lambda: torch.nn.Sequential(torch.nn.Linear(10, 7), torch.nn.Sigmoid(), torch.nn.Linear(7, 5)), "loss_function_fn": lambda: torch.nn.MSELoss(reduction="mean"), "target_fn": lambda: regression_targets((3, 5)), }, ] ############################################################################### # test setting: Convolutional Layers # """ Syntax with default parameters: - `torch.nn.ConvNd(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros)` - `torch.nn.ConvTransposeNd(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros)`
LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS) LOCAL_NOT_SUPPORTED_SETTINGS = [] NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS BATCH_SIZE_1_SETTINGS = [ { "input_fn": lambda: rand(1, 7), "module_fn": lambda: Sequential(Linear(7, 3), ReLU(), Flatten(start_dim=1, end_dim=-1), Linear(3, 1)), "loss_function_fn": lambda: MSELoss(reduction="mean"), "target_fn": lambda: regression_targets((1, 1)), "id_prefix": "one-additional", }, { "input_fn": lambda: rand(3, 10), "module_fn": lambda: Sequential( Linear(10, 5), ReLU(), # skip connection Parallel( Identity(), Linear(5, 5), ),
"loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), "target_fn": lambda: classification_targets((3, ), 5), }, { "input_fn": lambda: rand(3, 10), "module_fn": lambda: Sequential(Linear(10, 7), ReLU(), Linear(7, 5)), "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3, ), 5), }, # Regression { "input_fn": lambda: rand(3, 10), "module_fn": lambda: Sequential(Linear(10, 7), Sigmoid(), Linear(7, 5)), "loss_function_fn": lambda: MSELoss(reduction="mean"), "target_fn": lambda: regression_targets((3, 5)), }, ] ############################################################################### # test setting: Activation Layers # ############################################################################### activations = [ReLU, Sigmoid, Tanh, LeakyReLU, LogSigmoid, ELU, SELU] for act in activations: for bias in [True, False]: SECONDORDER_SETTINGS.append(make_simple_act_setting(act, bias=bias)) ############################################################################### # test setting: Pooling Layers # """
"loss_function_fn": lambda: CrossEntropyLoss(reduction="mean"), "target_fn": lambda: classification_targets((3, ), 5), }, { "input_fn": lambda: rand(3, 10), "module_fn": lambda: Sequential(Linear(10, 7), ReLU(), Linear(7, 5)), "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3, ), 5), }, # regression { "input_fn": lambda: rand(3, 10), "module_fn": lambda: Sequential(Linear(10, 7), Sigmoid(), Linear(7, 5)), "loss_function_fn": lambda: MSELoss(reduction="mean"), "target_fn": lambda: regression_targets((3, 5)), }, ] # linear with additional dimension FIRSTORDER_SETTINGS += [ # regression { "input_fn": lambda: rand(3, 4, 5), "module_fn": lambda: Sequential(Linear(5, 3), Linear(3, 2)), "loss_function_fn": lambda: MSELoss(reduction="mean"), "target_fn": lambda: regression_targets((3, 4, 2)), "id_prefix": "one-additional", }, { "input_fn": lambda: rand(3, 4, 2, 5),