Пример #1
0
    def test_grad_consistency(self):
        model, optimizer = self.setUp_init_model(
            private=True,
            state_dict=self.original_model.state_dict(),
            noise_multiplier=0,
            max_grad_norm=999,
        )

        grad_sample_aggregated = {}

        for x, y in self.dl:
            optimizer.zero_grad()
            logits = model(x)
            loss = self.criterion(logits, y)
            loss.backward()

            # collect all per-sample gradients before we take the step
            for _, layer in model.named_modules():
                if get_layer_type(layer) == "SampleConvNet":
                    continue

                grad_sample_aggregated[layer] = {}
                for p in layer.parameters():
                    if p.requires_grad:
                        grad_sample_aggregated[layer][
                            p] = get_grad_sample_aggregated(p)

            optimizer.step()

        for layer_name, layer in model.named_modules():
            if get_layer_type(layer) == "SampleConvNet":
                continue

            for p in layer.parameters():
                if p.requires_grad:
                    self.assertTrue(
                        torch.allclose(
                            p.grad,
                            grad_sample_aggregated[layer][p],
                            atol=10e-5,
                            rtol=10e-2,
                        ),
                        f"grad_sample doesn't match grad. "
                        f"Layer: {layer_name}, Tensor: {p.shape}",
                    )
def _no_lstm(module: nn.Module) -> bool:
    r"""
    Checks if the input module is not LSTM.

    Args:
        module: The input module

    Returns:
        True if the input module is not LSTM
    """
    is_lstm = True if get_layer_type(module) == "LSTM" else False

    return not is_lstm
def _no_running_stats_instancenorm_check(module: nn.Module) -> bool:
    r"""
    Checks that ``InstanceNorm`` layer has ``track_running_stats`` set to False

    Args:
        module: The input module (layer) for which the check is verified.

    Returns:
        True if the module is not ``InstanceNorm``, otherwise it returns
        True if the module (layer) have ``track_running_stats`` set to False,
        and False otherwise.

    """
    is_instancenorm = get_layer_type(module) in (
        "InstanceNorm1d",
        "InstanceNorm2d",
        "InstanceNorm3d",
    )

    if is_instancenorm:
        return not module.track_running_stats
    return True