def test_grad_consistency(self): model, optimizer = self.setUp_init_model( private=True, state_dict=self.original_model.state_dict(), noise_multiplier=0, max_grad_norm=999, ) grad_sample_aggregated = {} for x, y in self.dl: optimizer.zero_grad() logits = model(x) loss = self.criterion(logits, y) loss.backward() # collect all per-sample gradients before we take the step for _, layer in model.named_modules(): if get_layer_type(layer) == "SampleConvNet": continue grad_sample_aggregated[layer] = {} for p in layer.parameters(): if p.requires_grad: grad_sample_aggregated[layer][ p] = get_grad_sample_aggregated(p) optimizer.step() for layer_name, layer in model.named_modules(): if get_layer_type(layer) == "SampleConvNet": continue for p in layer.parameters(): if p.requires_grad: self.assertTrue( torch.allclose( p.grad, grad_sample_aggregated[layer][p], atol=10e-5, rtol=10e-2, ), f"grad_sample doesn't match grad. " f"Layer: {layer_name}, Tensor: {p.shape}", )
def _no_lstm(module: nn.Module) -> bool: r""" Checks if the input module is not LSTM. Args: module: The input module Returns: True if the input module is not LSTM """ is_lstm = True if get_layer_type(module) == "LSTM" else False return not is_lstm
def _no_running_stats_instancenorm_check(module: nn.Module) -> bool: r""" Checks that ``InstanceNorm`` layer has ``track_running_stats`` set to False Args: module: The input module (layer) for which the check is verified. Returns: True if the module is not ``InstanceNorm``, otherwise it returns True if the module (layer) have ``track_running_stats`` set to False, and False otherwise. """ is_instancenorm = get_layer_type(module) in ( "InstanceNorm1d", "InstanceNorm2d", "InstanceNorm3d", ) if is_instancenorm: return not module.track_running_stats return True