def test_freeze_unfreeze_function(tmpdir):
    """Test freeze properly sets requires_grad on the modules"""

    seed_everything(42)

    class FreezeModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.backbone = nn.Sequential(nn.Linear(32, 32),
                                          nn.BatchNorm1d(32), nn.ReLU(),
                                          nn.Linear(32, 2))

    model = FreezeModel()
    BaseFinetuning.freeze(model, train_bn=True)
    assert not model.backbone[0].weight.requires_grad
    assert model.backbone[1].weight.requires_grad
    assert not model.backbone[3].weight.requires_grad

    BaseFinetuning.freeze(model, train_bn=False)
    assert not model.backbone[0].weight.requires_grad
    assert not model.backbone[1].weight.requires_grad
    assert not model.backbone[3].weight.requires_grad

    BaseFinetuning.make_trainable(model)
    assert model.backbone[0].weight.requires_grad
    assert model.backbone[1].weight.requires_grad
    assert model.backbone[3].weight.requires_grad

    BaseFinetuning.freeze(model.backbone[0], train_bn=False)
    assert not model.backbone[0].weight.requires_grad

    BaseFinetuning.freeze(([(model.backbone[1]), [model.backbone[3]]]),
                          train_bn=True)
    assert model.backbone[1].weight.requires_grad
    assert not model.backbone[3].weight.requires_grad
Exemplo n.º 2
0
    def _unfreeze_milestones_function(
        self,
        pl_module: Union[Module, Iterable[Union[Module, Iterable]]],
        epoch: int,
        optimizer: Optimizer,
        opt_idx: int,
        strategy_metadata: Tuple[Tuple[int, int], int],
    ):
        unfreeze_milestones: Tuple[int, int] = strategy_metadata[0]
        num_layers: int = strategy_metadata[1]

        modules = self._get_modules_to_freeze(pl_module=pl_module)
        if modules is not None:
            if epoch == unfreeze_milestones[0]:
                # unfreeze num_layers last layers

                backbone_modules = BaseFinetuning.flatten_modules(
                    modules=modules)[-num_layers:]
                self.unfreeze_and_add_param_group(
                    modules=backbone_modules,
                    optimizer=optimizer,
                    train_bn=self.train_bn,
                )
            elif epoch == unfreeze_milestones[1]:
                # unfreeze remaining layers
                backbone_modules = BaseFinetuning.flatten_modules(
                    modules=modules)[:-num_layers]
                self.unfreeze_and_add_param_group(
                    modules=backbone_modules,
                    optimizer=optimizer,
                    train_bn=self.train_bn,
                )
Exemplo n.º 3
0
def test_complex_nested_model():
    """Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
    directly themselves rather than exclusively their submodules containing parameters."""

    model = nn.Sequential(
        OrderedDict([("encoder",
                      nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64,
                                                                     128))),
                     ("decoder", ConvBlock(128, 10))]))

    # There are 10 leaf modules or parent modules w/ parameters in the test model
    assert len(BaseFinetuning.flatten_modules(model)) == 10

    BaseFinetuning.freeze(model.encoder, train_bn=True)
    assert not model.encoder[0].module_dict[
        "conv"].weight.requires_grad  # Validate a leaf module parameter is frozen
    assert not model.encoder[
        0].parent_param.requires_grad  # Validate the parent module parameter is frozen
    assert model.encoder[0].bn.weight.requires_grad

    BaseFinetuning.make_trainable(model)
    encoder_params = list(
        BaseFinetuning.filter_params(model.encoder, train_bn=True))
    # The 9 parameters of the encoder are:
    # conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param
    # conv1.weight, conv1.bias, bn1.weight, bn1.bias
    assert len(encoder_params) == 9
def test_deep_nested_model():

    class ConvBlock(nn.Module):

        def __init__(self, in_channels, out_channels):
            super().__init__()
            self.conv = nn.Conv2d(in_channels, out_channels, 3)
            self.act = nn.ReLU()
            self.bn = nn.BatchNorm2d(out_channels)

        def forward(self, x):
            x = self.conv(x)
            x = self.act(x)
            return self.bn(x)

    model = nn.Sequential(
        OrderedDict([
            ("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))),
            ("decoder", ConvBlock(128, 10)),
        ])
    )

    # There's 9 leaf layers in that model
    assert len(BaseFinetuning.flatten_modules(model)) == 9

    BaseFinetuning.freeze(model.encoder, train_bn=True)
    assert not model.encoder[0].conv.weight.requires_grad
    assert model.encoder[0].bn.weight.requires_grad

    BaseFinetuning.make_trainable(model)
    encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True))
    # The 8 parameters of the encoder are:
    # conv0.weight, conv0.bias, bn0.weight, bn0.bias
    # conv1.weight, conv1.bias, bn1.weight, bn1.bias
    assert len(encoder_params) == 8
def test_complex_nested_model():
    """
    Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
    directly themselves rather than exclusively their submodules containing parameters.
    """
    class ConvBlock(nn.Module):
        def __init__(self, in_channels, out_channels):
            super().__init__()
            self.conv = nn.Conv2d(in_channels, out_channels, 3)
            self.act = nn.ReLU()
            self.bn = nn.BatchNorm2d(out_channels)

        def forward(self, x):
            x = self.conv(x)
            x = self.act(x)
            return self.bn(x)

    class ConvBlockParam(nn.Module):
        def __init__(self, in_channels, out_channels):
            super().__init__()
            self.module_dict = nn.ModuleDict({
                "conv":
                nn.Conv2d(in_channels, out_channels, 3),
                "act":
                nn.ReLU(),
            })
            # add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
            self.parent_param = nn.Parameter(
                torch.zeros((1), dtype=torch.float))
            self.bn = nn.BatchNorm2d(out_channels)

        def forward(self, x):
            x = self.module_dict["conv"](x)
            x = self.module_dict["act"](x)
            return self.bn(x)

    model = nn.Sequential(
        OrderedDict([
            ("encoder", nn.Sequential(ConvBlockParam(3, 64),
                                      ConvBlock(64, 128))),
            ("decoder", ConvBlock(128, 10)),
        ]))

    # There are 10 leaf modules or parent modules w/ parameters in the test model
    assert len(BaseFinetuning.flatten_modules(model)) == 10

    BaseFinetuning.freeze(model.encoder, train_bn=True)
    assert not model.encoder[0].module_dict[
        "conv"].weight.requires_grad  # Validate a leaf module parameter is frozen
    assert not model.encoder[
        0].parent_param.requires_grad  # Validate the parent module parameter is frozen
    assert model.encoder[0].bn.weight.requires_grad

    BaseFinetuning.make_trainable(model)
    encoder_params = list(
        BaseFinetuning.filter_params(model.encoder, train_bn=True))
    # The 9 parameters of the encoder are:
    # conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param
    # conv1.weight, conv1.bias, bn1.weight, bn1.bias
    assert len(encoder_params) == 9
def test_unfreeze_and_add_param_group_function(tmpdir):
    """Test unfreeze_and_add_param_group properly unfreeze parameters and add to the correct param_group"""

    seed_everything(42)

    class FreezeModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.backbone = nn.Sequential(
                nn.Linear(32, 32, bias=False),
                nn.Linear(32, 32, bias=False),
                nn.Linear(32, 32, bias=False),
                nn.Linear(32, 32, bias=False),
                nn.Linear(32, 32, bias=False),
                nn.BatchNorm1d(32),
            )

    model = FreezeModel()
    optimizer = SGD(model.backbone[0].parameters(), lr=0.01)

    with pytest.warns(UserWarning,
                      match="The provided params to be freezed already"):
        BaseFinetuning.unfreeze_and_add_param_group(model.backbone[0],
                                                    optimizer=optimizer)
    assert optimizer.param_groups[0]["lr"] == 0.01

    model.backbone[1].weight.requires_grad = False
    BaseFinetuning.unfreeze_and_add_param_group(model.backbone[1],
                                                optimizer=optimizer)
    assert len(optimizer.param_groups) == 2
    assert optimizer.param_groups[1]["lr"] == 0.001
    assert torch.equal(optimizer.param_groups[1]["params"][0],
                       model.backbone[1].weight)
    assert model.backbone[1].weight.requires_grad

    with pytest.warns(UserWarning,
                      match="The provided params to be freezed already"):
        BaseFinetuning.unfreeze_and_add_param_group(model,
                                                    optimizer=optimizer,
                                                    lr=100,
                                                    train_bn=False)
    assert len(optimizer.param_groups) == 3
    assert optimizer.param_groups[2]["lr"] == 100
    assert len(optimizer.param_groups[2]["params"]) == 3
    for group_idx, group in enumerate(optimizer.param_groups):
        if group_idx == 0:
            assert torch.equal(optimizer.param_groups[0]["params"][0],
                               model.backbone[0].weight)
        if group_idx == 2:
            assert torch.equal(optimizer.param_groups[2]["params"][0],
                               model.backbone[2].weight)
            assert torch.equal(optimizer.param_groups[2]["params"][1],
                               model.backbone[3].weight)
            assert torch.equal(optimizer.param_groups[2]["params"][2],
                               model.backbone[4].weight)