def test_freeze_unfreeze_function(tmpdir): """Test freeze properly sets requires_grad on the modules""" seed_everything(42) class FreezeModel(LightningModule): def __init__(self): super().__init__() self.backbone = nn.Sequential(nn.Linear(32, 32), nn.BatchNorm1d(32), nn.ReLU(), nn.Linear(32, 2)) model = FreezeModel() BaseFinetuning.freeze(model, train_bn=True) assert not model.backbone[0].weight.requires_grad assert model.backbone[1].weight.requires_grad assert not model.backbone[3].weight.requires_grad BaseFinetuning.freeze(model, train_bn=False) assert not model.backbone[0].weight.requires_grad assert not model.backbone[1].weight.requires_grad assert not model.backbone[3].weight.requires_grad BaseFinetuning.make_trainable(model) assert model.backbone[0].weight.requires_grad assert model.backbone[1].weight.requires_grad assert model.backbone[3].weight.requires_grad BaseFinetuning.freeze(model.backbone[0], train_bn=False) assert not model.backbone[0].weight.requires_grad BaseFinetuning.freeze(([(model.backbone[1]), [model.backbone[3]]]), train_bn=True) assert model.backbone[1].weight.requires_grad assert not model.backbone[3].weight.requires_grad
def _unfreeze_milestones_function( self, pl_module: Union[Module, Iterable[Union[Module, Iterable]]], epoch: int, optimizer: Optimizer, opt_idx: int, strategy_metadata: Tuple[Tuple[int, int], int], ): unfreeze_milestones: Tuple[int, int] = strategy_metadata[0] num_layers: int = strategy_metadata[1] modules = self._get_modules_to_freeze(pl_module=pl_module) if modules is not None: if epoch == unfreeze_milestones[0]: # unfreeze num_layers last layers backbone_modules = BaseFinetuning.flatten_modules( modules=modules)[-num_layers:] self.unfreeze_and_add_param_group( modules=backbone_modules, optimizer=optimizer, train_bn=self.train_bn, ) elif epoch == unfreeze_milestones[1]: # unfreeze remaining layers backbone_modules = BaseFinetuning.flatten_modules( modules=modules)[:-num_layers] self.unfreeze_and_add_param_group( modules=backbone_modules, optimizer=optimizer, train_bn=self.train_bn, )
def test_complex_nested_model(): """Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters directly themselves rather than exclusively their submodules containing parameters.""" model = nn.Sequential( OrderedDict([("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), ("decoder", ConvBlock(128, 10))])) # There are 10 leaf modules or parent modules w/ parameters in the test model assert len(BaseFinetuning.flatten_modules(model)) == 10 BaseFinetuning.freeze(model.encoder, train_bn=True) assert not model.encoder[0].module_dict[ "conv"].weight.requires_grad # Validate a leaf module parameter is frozen assert not model.encoder[ 0].parent_param.requires_grad # Validate the parent module parameter is frozen assert model.encoder[0].bn.weight.requires_grad BaseFinetuning.make_trainable(model) encoder_params = list( BaseFinetuning.filter_params(model.encoder, train_bn=True)) # The 9 parameters of the encoder are: # conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param # conv1.weight, conv1.bias, bn1.weight, bn1.bias assert len(encoder_params) == 9
def test_deep_nested_model(): class ConvBlock(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, 3) self.act = nn.ReLU() self.bn = nn.BatchNorm2d(out_channels) def forward(self, x): x = self.conv(x) x = self.act(x) return self.bn(x) model = nn.Sequential( OrderedDict([ ("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))), ("decoder", ConvBlock(128, 10)), ]) ) # There's 9 leaf layers in that model assert len(BaseFinetuning.flatten_modules(model)) == 9 BaseFinetuning.freeze(model.encoder, train_bn=True) assert not model.encoder[0].conv.weight.requires_grad assert model.encoder[0].bn.weight.requires_grad BaseFinetuning.make_trainable(model) encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True)) # The 8 parameters of the encoder are: # conv0.weight, conv0.bias, bn0.weight, bn0.bias # conv1.weight, conv1.bias, bn1.weight, bn1.bias assert len(encoder_params) == 8
def test_complex_nested_model(): """ Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters directly themselves rather than exclusively their submodules containing parameters. """ class ConvBlock(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, 3) self.act = nn.ReLU() self.bn = nn.BatchNorm2d(out_channels) def forward(self, x): x = self.conv(x) x = self.act(x) return self.bn(x) class ConvBlockParam(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.module_dict = nn.ModuleDict({ "conv": nn.Conv2d(in_channels, out_channels, 3), "act": nn.ReLU(), }) # add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling self.parent_param = nn.Parameter( torch.zeros((1), dtype=torch.float)) self.bn = nn.BatchNorm2d(out_channels) def forward(self, x): x = self.module_dict["conv"](x) x = self.module_dict["act"](x) return self.bn(x) model = nn.Sequential( OrderedDict([ ("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), ("decoder", ConvBlock(128, 10)), ])) # There are 10 leaf modules or parent modules w/ parameters in the test model assert len(BaseFinetuning.flatten_modules(model)) == 10 BaseFinetuning.freeze(model.encoder, train_bn=True) assert not model.encoder[0].module_dict[ "conv"].weight.requires_grad # Validate a leaf module parameter is frozen assert not model.encoder[ 0].parent_param.requires_grad # Validate the parent module parameter is frozen assert model.encoder[0].bn.weight.requires_grad BaseFinetuning.make_trainable(model) encoder_params = list( BaseFinetuning.filter_params(model.encoder, train_bn=True)) # The 9 parameters of the encoder are: # conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param # conv1.weight, conv1.bias, bn1.weight, bn1.bias assert len(encoder_params) == 9
def test_unfreeze_and_add_param_group_function(tmpdir): """Test unfreeze_and_add_param_group properly unfreeze parameters and add to the correct param_group""" seed_everything(42) class FreezeModel(LightningModule): def __init__(self): super().__init__() self.backbone = nn.Sequential( nn.Linear(32, 32, bias=False), nn.Linear(32, 32, bias=False), nn.Linear(32, 32, bias=False), nn.Linear(32, 32, bias=False), nn.Linear(32, 32, bias=False), nn.BatchNorm1d(32), ) model = FreezeModel() optimizer = SGD(model.backbone[0].parameters(), lr=0.01) with pytest.warns(UserWarning, match="The provided params to be freezed already"): BaseFinetuning.unfreeze_and_add_param_group(model.backbone[0], optimizer=optimizer) assert optimizer.param_groups[0]["lr"] == 0.01 model.backbone[1].weight.requires_grad = False BaseFinetuning.unfreeze_and_add_param_group(model.backbone[1], optimizer=optimizer) assert len(optimizer.param_groups) == 2 assert optimizer.param_groups[1]["lr"] == 0.001 assert torch.equal(optimizer.param_groups[1]["params"][0], model.backbone[1].weight) assert model.backbone[1].weight.requires_grad with pytest.warns(UserWarning, match="The provided params to be freezed already"): BaseFinetuning.unfreeze_and_add_param_group(model, optimizer=optimizer, lr=100, train_bn=False) assert len(optimizer.param_groups) == 3 assert optimizer.param_groups[2]["lr"] == 100 assert len(optimizer.param_groups[2]["params"]) == 3 for group_idx, group in enumerate(optimizer.param_groups): if group_idx == 0: assert torch.equal(optimizer.param_groups[0]["params"][0], model.backbone[0].weight) if group_idx == 2: assert torch.equal(optimizer.param_groups[2]["params"][0], model.backbone[2].weight) assert torch.equal(optimizer.param_groups[2]["params"][1], model.backbone[3].weight) assert torch.equal(optimizer.param_groups[2]["params"][2], model.backbone[4].weight)