Esempio n. 1
0
    def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_free_ms=0):
        super().__init__(group, wrapper_config)
        self.group = group
        self.delay_before_free_ms = delay_before_free_ms

        # "expert" params are different on each rank
        torch.manual_seed(42 + group.rank())
        d_expert = 23
        d_shared = 12
        d_input = 8
        expert = nn.Linear(d_expert, d_shared)

        self.num_expert_params = sum([p.numel() for p in expert.parameters()])
        for p in expert.parameters():
            p.expert = True

        # everything else is shared
        torch.manual_seed(0)

        shared = nn.Linear(d_shared, d_expert)

        if checkpoint_act:
            expert = checkpoint_wrapper(expert)
            shared = checkpoint_wrapper(shared)

        if wrapper_config is not None:
            # we create a process group of size 1 for the expert params
            expert_group = torch.distributed.new_group([group.rank()])  # world size 1 means no shard
            expert = FullyShardedDataParallel(expert, expert_group, **wrapper_config)

            shared = FullyShardedDataParallel(shared, group, **wrapper_config)

        self.module = nn.Sequential(nn.Linear(d_input, d_shared), shared, expert, nn.Linear(d_shared, d_input))
def test_checkpoint_disabling():
    """Test to check new disable_checkpoint() API added to checkpoint_wrapper."""
    class TestModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.cnt = 0
            self.linear = nn.Linear(2, 2)

        def forward(self, x):
            self.cnt += 1
            y = []
            for i in x:
                y.append(self.linear(i))
            return y

    x = torch.rand(4, 2)
    model1 = checkpoint_wrapper(TestModel())
    model2 = checkpoint_wrapper(TestModel())

    # Forward. cnt += 1
    y = model1(x)
    y = sum(i.sum() for i in y)
    # Backward. cnt += 1
    y.backward()
    assert model1.cnt == 2

    with disable_checkpointing():
        # Forward. cnt += 1
        y = model2(x)
        y = sum(i.sum() for i in y)
        # Backward. cnt remains same as checkpointing is disabled
        y.backward()
    assert model2.cnt == 1
Esempio n. 3
0
    def __init__(self, multiout=False, checkpoint_config=0):
        super().__init__()
        torch.manual_seed(0)  # make sure weights are deterministic.

        self.multiout = multiout

        self.conv1 = nn.Sequential(nn.Conv2d(1, 5, 3), nn.ReLU(), nn.Conv2d(5, 5, 3))
        self.conv2 = nn.Sequential(nn.Conv2d(3, 5, 3), nn.ReLU(), nn.Conv2d(5, 5, 3))

        assert 0 <= checkpoint_config <= 3
        if checkpoint_config & 1:
            self.conv1 = checkpoint_wrapper(self.conv1)
        if checkpoint_config & (1 << 1):
            self.conv2 = checkpoint_wrapper(self.conv2)
    def __init__(self, flatten, mixed_precision, fsdp_wrap_ckpt):
        super().__init__()
        if fsdp_wrap_ckpt:
            middle_module = FSDP(checkpoint_wrapper(nn.Linear(3, 3)),
                                 flatten_parameters=flatten,
                                 mixed_precision=mixed_precision)
        else:
            middle_module = checkpoint_wrapper(
                FSDP(nn.Linear(3, 3),
                     flatten_parameters=flatten,
                     mixed_precision=mixed_precision))

        self.ffn = nn.Sequential(nn.Linear(3, 3), middle_module,
                                 nn.Linear(3, 3))
Esempio n. 5
0
def test_list_input():
    """ Test checkpointing with input argument type being a list.

    Note: Testing shows that PyTorch's torch.utils.checkpoint function does not pass this test.
    """
    count = 0

    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Linear(2, 2)

        def forward(self, x):
            nonlocal count
            count += 1
            y = []
            for i in x:
                y.append(self.conv(i))
            return y

    model = nn.Sequential(checkpoint_wrapper(Model()), Model()).cuda()
    in_data1 = torch.rand(4, 2).cuda()
    in_data2 = torch.rand(4, 2).cuda()

    # Forward. Count should be 2 for 2 modules.
    out = model([in_data1, in_data2])
    loss = sum(x.sum() for x in out)
    assert count == 2, f"Incorrect count {count}"

    # Backward. Adds 1 more forward call due to checkpoint.
    loss.backward()
    assert count == 3, f"Incorrect count {count}"
Esempio n. 6
0
    def __init__(self,
                 group,
                 wrapper_config,
                 wrap_everything=False,
                 checkpoint=False):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        self.wrapper_config = wrapper_config

        def _maybe_wrap(layer):
            if wrapper_config is not None:
                return FullyShardedDataParallel(layer, group, **wrapper_config)
            return layer

        torch.manual_seed(0)  # keep everything deterministic
        self.module = nn.Sequential(
            nn.Linear(8, 4),
            _maybe_wrap(
                nn.Sequential(
                    _maybe_wrap(nn.Linear(4, 16)),
                    nn.Linear(16, 16),
                )),
            _maybe_wrap(nn.Linear(16, 4)),
            nn.Linear(4, 8),
        )

        # Wrap all modules triggers a corner case where root FSDP doesn't have any params.
        # Test it with checkpoint_wrapper as well to validate final backward callback
        # is queued correctly when root FSDP does not have any params and every layer is
        # wrapped as FSDP(checkpoint(module)).
        if wrap_everything:
            if checkpoint:
                self.module = nn.Sequential(
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(8, 4))),
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(4, 16))),
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(16, 4))),
                    _maybe_wrap(checkpoint_wrapper(nn.Linear(4, 8))),
                )
            else:
                self.module = nn.Sequential(
                    _maybe_wrap(nn.Linear(8, 4)),
                    _maybe_wrap(nn.Linear(4, 16)),
                    _maybe_wrap(nn.Linear(16, 4)),
                    _maybe_wrap(nn.Linear(4, 8)),
                )
Esempio n. 7
0
 def __init__(self):
     super().__init__()
     self.ffn = nn.Sequential(
         nn.Linear(3, 3),
         FullyShardedDataParallel(
             checkpoint_wrapper(nn.Linear(3, 3),
                                maintain_forward_counter=True)),
         nn.Linear(3, 3),
     )
Esempio n. 8
0
def test_deprecated_path():

    # Check if import works as before.
    # from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
    from fairscale.nn import checkpoint_wrapper

    ffn = nn.Sequential(nn.Linear(32, 128), nn.Dropout(p=0.5), nn.Linear(128, 32),)
    ffn = checkpoint_wrapper(ffn, {})

    # Check if direct import works as before.
    ffn = nn.Sequential(nn.Linear(32, 128), nn.Dropout(p=0.5), nn.Linear(128, 32),)
    ffn = deprecated_checkpoint_wrapper(ffn, {})
Esempio n. 9
0
 def __init__(self, use_pytorch_checkpoint=False, use_fairscale_checkpoint=False, **kwargs):
     super().__init__()
     torch.manual_seed(0)  # make sure weights are deterministic.
     assert not (
         use_pytorch_checkpoint and use_fairscale_checkpoint
     ), "Cannot use both pytorch and fairscale checkpointing mechanisms."
     self.use_pytorch_checkpoint = use_pytorch_checkpoint
     self.ffn = nn.Sequential(
         nn.Linear(32, 128),
         # add a Dropout layer to test RNG save/restore
         nn.Dropout(p=0.5),
         nn.Linear(128, 32),
     )
     if use_fairscale_checkpoint:
         self.ffn = checkpoint_wrapper(self.ffn, **kwargs)
     self.out = nn.Linear(32, 1)
Esempio n. 10
0
    def __init__(self, enable_checkpoint=False, cpu_offload=False):
        super().__init__()

        torch.manual_seed(0)  # make sure weights are deterministic.

        # These numbers are picked to show cpu_offload memory saving.
        # Inner (recomputed) activation sizes need to be just right
        # to show the benefit.
        self.layers = nn.Sequential(
            nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 4), nn.Linear(4, 8)),
            nn.Sequential(nn.Linear(8, 4), nn.Linear(4, 4), nn.Linear(4, 4)),
            nn.Sequential(nn.Linear(4, 6), nn.Linear(6, 8), nn.Linear(8, 2)),
        )

        if enable_checkpoint:
            for i, layer in enumerate(self.layers):
                # Only middle layer needs to have offloading
                self.layers[i] = checkpoint_wrapper(layer, cpu_offload if i == 1 else False)
def get_model(norm_type, checkpointed, mixed_precision):
    assert norm_type in NORM_TYPES, norm_type
    assert checkpointed in [True, False], checkpointed
    assert mixed_precision in MP_TYPES

    model = Sequential(Linear(3, 2), norm_type(2))

    if mixed_precision == "fp16":
        # Set param.data and buffers as fp16
        for p in model.parameters():
            p.data = p.data.half()
        for m in model:
            for n, b in m.named_buffers():
                setattr(m, n, b.half())
    elif mixed_precision == "call_half":
        model.half()

    if checkpointed:
        model = checkpoint_wrapper(model)

    return model
def test_checkpoint_requires_grad():
    """Test to check checkpointing when outputs do not require gradient."""
    class TestModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.cnt = 0
            self.linear = nn.Linear(2, 2)

        def forward(self, x):
            self.cnt += 1
            return self.linear(x)

    x = torch.rand(4, 2)
    model = nn.Sequential(
        checkpoint_wrapper(TestModel()),
        checkpoint_wrapper(TestModel()),
        checkpoint_wrapper(TestModel()),
        checkpoint_wrapper(TestModel()),
    )
    model[0].requires_grad_(False)
    model[1].requires_grad_(False)
    model[2].requires_grad_(False)

    y = model(x)
    y = y.sum()
    y.backward()

    # Since only last model needs grad, we only run forward twice for it
    assert model[0].cnt == 1
    assert model[1].cnt == 1
    assert model[2].cnt == 1
    assert model[3].cnt == 2

    # Now test with first model needing grad
    model = nn.Sequential(
        checkpoint_wrapper(TestModel()),
        checkpoint_wrapper(TestModel()),
        checkpoint_wrapper(TestModel()),
        checkpoint_wrapper(TestModel()),
    )
    model[0].requires_grad_(True)
    model[1].requires_grad_(False)
    model[2].requires_grad_(False)

    y = model(x)
    y = y.sum()
    y.backward()

    # Since first model needs grad, all models need grad, so we run forward twice for all
    assert model[0].cnt == 2
    assert model[1].cnt == 2
    assert model[2].cnt == 2
    assert model[3].cnt == 2

    # Stress test with multiple inputs/outputs, of which some are not Tensor
    class TestModel2(nn.Module):
        def __init__(self):
            super().__init__()
            self.cnt = 0
            self.linear = nn.Linear(2, 2)

        def forward(self, x, y, z):
            self.cnt += 1
            z = z + [self.cnt]
            return self.linear(x + y), z, ["hi"]

    model1 = checkpoint_wrapper(TestModel())
    model2 = checkpoint_wrapper(TestModel())
    model3 = checkpoint_wrapper(TestModel2())
    model4 = checkpoint_wrapper(TestModel())
    model1.requires_grad_(False)
    model2.requires_grad_(False)

    y = model4(model3(model1(x), model2(x), ["bye"])[0])
    y = y.sum()
    y.backward()

    assert model1.cnt == 1
    assert model2.cnt == 1
    assert model3.cnt == 2
    assert model4.cnt == 2

    model1 = checkpoint_wrapper(TestModel())
    model2 = checkpoint_wrapper(TestModel())
    model3 = checkpoint_wrapper(TestModel2())
    model4 = checkpoint_wrapper(TestModel())
    model2.requires_grad_(False)

    y = model4(model3(model1(x), model2(x), ["bye"])[0])
    y = y.sum()
    y.backward()

    assert model1.cnt == 2
    assert model2.cnt == 1
    assert model3.cnt == 2
    assert model4.cnt == 2

    # Test flattened pararameters
    model = nn.Sequential(
        FlattenParamsWrapper(checkpoint_wrapper(TestModel())),
        FlattenParamsWrapper(checkpoint_wrapper(TestModel())),
        FlattenParamsWrapper(checkpoint_wrapper(TestModel())),
        FlattenParamsWrapper(checkpoint_wrapper(TestModel())),
    )
    model[0].requires_grad_(False)
    model[1].requires_grad_(False)

    y = model(x)
    y = y.sum()
    y.backward()

    assert model[0].cnt == 1
    assert model[1].cnt == 1
    assert model[2].cnt == 2
    assert model[3].cnt == 2