예제 #1
0
    def __init__(self, group, wrapper_config, checkpoint_act=False):
        super().__init__(group, wrapper_config)
        self.group = group

        # "expert" params are different on each rank
        torch.manual_seed(42 + group.rank())
        expert = nn.Linear(16, 4)
        for p in expert.parameters():
            p.expert = True

        # everything else is shared
        torch.manual_seed(0)
        shared = nn.Linear(4, 16)

        if checkpoint_act:
            expert = checkpoint_wrapper(expert)
            shared = checkpoint_wrapper(shared)

        if wrapper_config is not None:
            # we create a process group of size 1 for the expert params
            expert_group = torch.distributed.new_group([group.rank()])
            expert = FullyShardedDataParallel(expert, expert_group,
                                              **wrapper_config)

            shared = FullyShardedDataParallel(shared, group, **wrapper_config)

        self.module = nn.Sequential(nn.Linear(8, 4), shared, expert,
                                    nn.Linear(4, 8))
    def __init__(self, multiout=False, checkpoint_config=0):
        super().__init__()
        torch.manual_seed(0)  # make sure weights are deterministic.

        self.multiout = multiout

        self.conv1 = nn.Sequential(nn.Conv2d(1, 5, 3), nn.ReLU(), nn.Conv2d(5, 5, 3))
        self.conv2 = nn.Sequential(nn.Conv2d(3, 5, 3), nn.ReLU(), nn.Conv2d(5, 5, 3))

        assert 0 <= checkpoint_config <= 3
        if checkpoint_config & 1:
            self.conv1 = checkpoint_wrapper(self.conv1)
        if checkpoint_config & (1 << 1):
            self.conv2 = checkpoint_wrapper(self.conv2)
예제 #3
0
    def __init__(self,
                 group,
                 wrapper_config,
                 checkpoint_act=False,
                 delay_before_free_ms=0):
        super().__init__(group, wrapper_config)
        self.group = group
        self.delay_before_free_ms = delay_before_free_ms

        # "expert" params are different on each rank
        torch.manual_seed(42 + group.rank())
        d_expert = 23
        d_shared = 12
        d_input = 8
        expert = nn.Linear(d_expert, d_shared)

        self.num_expert_params = sum([p.numel() for p in expert.parameters()])
        for p in expert.parameters():
            p.expert = True

        # everything else is shared
        torch.manual_seed(0)

        shared = nn.Linear(d_shared, d_expert)

        if checkpoint_act:
            expert = checkpoint_wrapper(expert)
            shared = checkpoint_wrapper(shared)

        if wrapper_config is not None:
            # we create a process group of size 1 for the expert params
            expert_group = torch.distributed.new_group(
                [group.rank()])  # world size 1 means no shard
            expert = FullyShardedDataParallel(expert, expert_group,
                                              **wrapper_config)

            shared = FullyShardedDataParallel(shared, group, **wrapper_config)

        self.module = nn.Sequential(nn.Linear(d_input, d_shared), shared,
                                    expert, nn.Linear(d_shared, d_input))
 def __init__(self, use_pytorch_checkpoint=False, use_fairscale_checkpoint=False, **kwargs):
     super().__init__()
     torch.manual_seed(0)  # make sure weights are deterministic.
     assert not (
         use_pytorch_checkpoint and use_fairscale_checkpoint
     ), "Cannot use both pytorch and fairscale checkpointing mechanisms."
     self.use_pytorch_checkpoint = use_pytorch_checkpoint
     self.ffn = nn.Sequential(
         nn.Linear(32, 128),
         # add a Dropout layer to test RNG save/restore
         nn.Dropout(p=0.5),
         nn.Linear(128, 32),
     )
     if use_fairscale_checkpoint:
         self.ffn = checkpoint_wrapper(self.ffn, **kwargs)
     self.out = nn.Linear(32, 1)
예제 #5
0
 def __init__(self,
              use_pytorch_checkpoint=False,
              use_fairseq_checkpoint=False,
              **kwargs):
     super().__init__()
     torch.manual_seed(0)
     self.use_pytorch_checkpoint = use_pytorch_checkpoint
     self.ffn = nn.Sequential(
         nn.Linear(32, 128),
         # add a Dropout layer to test RNG save/restore
         nn.Dropout(p=0.5),
         nn.Linear(128, 32),
     )
     if use_fairseq_checkpoint:
         self.ffn = checkpoint_wrapper(self.ffn, **kwargs)
     self.out = nn.Linear(32, 1)
    def __init__(self, enable_checkpoint=False, cpu_offload=False):
        super().__init__()

        torch.manual_seed(0)  # make sure weights are deterministic.

        # These numbers are picked to show cpu_offload memory saving.
        # Inner (recomputed) activation sizes need to be just right
        # to show the benefit.
        self.layers = nn.Sequential(
            nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 4), nn.Linear(4, 8)),
            nn.Sequential(nn.Linear(8, 4), nn.Linear(4, 4), nn.Linear(4, 4)),
            nn.Sequential(nn.Linear(4, 6), nn.Linear(6, 8), nn.Linear(8, 2)),
        )

        if enable_checkpoint:
            for i, layer in enumerate(self.layers):
                # Only middle layer needs to have offloading
                self.layers[i] = checkpoint_wrapper(layer, cpu_offload if i == 1 else False)
def get_model(norm_type, checkpointed, mixed_precision):
    assert norm_type in NORM_TYPES, norm_type
    assert checkpointed in [True, False], checkpointed
    assert mixed_precision in MP_TYPES

    model = Sequential(Linear(3, 2), norm_type(2))

    if mixed_precision == "fp16":
        # Set param.data and buffers as fp16
        for p in model.parameters():
            p.data = p.data.half()
        for m in model:
            for n, b in m.named_buffers():
                setattr(m, n, b.half())
    elif mixed_precision == "call_half":
        model.half()

    if checkpointed:
        model = checkpoint_wrapper(model)

    return model