def test_grad_matches_original_per_layer_clipping(self):
        original_model, orignial_optimizer = self.setUp_init_model()
        private_model, private_optimizer = self.setUp_init_model(
            private=True,
            state_dict=original_model.state_dict(),
            noise_multiplier=0,
            max_grad_norm=[999] * 18,
            clip_per_layer=True,
        )

        for _ in range(3):
            self.setUp_model_step(original_model, orignial_optimizer)
            self.setUp_model_step(private_model, private_optimizer)

        for layer_name, private_layer in private_model.named_children():
            if not requires_grad(private_layer):
                continue

            original_layer = getattr(original_model, layer_name)

            for layer, private_layer in zip(
                [p.grad for p in original_layer.parameters() if p.requires_grad],
                [p.grad for p in private_layer.parameters() if p.requires_grad],
            ):
                self.assertTrue(
                    torch.allclose(layer, private_layer, atol=10e-4, rtol=10e-2),
                    f"Layer: {layer_name}. Private gradients with noise 0 doesn't match original",
                )
    def capture_activations_hook(
        self,
        module: nn.Module,
        forward_input: List[torch.Tensor],
        _forward_output: torch.Tensor,
    ):
        if not requires_grad(module) or not module.training:
            return

        if not self.hooks_enabled:
            return

        if not hasattr(module, "activations"):
            module.activations = []
        module.activations.append(forward_input[0].detach())  # pyre-ignore