Exemplo n.º 1
0
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    assert isinstance(fsdp_config, dict), str(fsdp_config)

    class InnerModel(Module):
        def __init__(self):
            super().__init__()
            self.layers = Sequential(FSDP(Linear(5, 5), **fsdp_config), )

        def forward(self, x):
            return self.layers(x)

    inner_model = InnerModel()
    model = FSDP(inner_model, **fsdp_config).cuda()
    optim = SGD(model.parameters(), lr=0.1)

    for i in range(3):
        input = torch.rand((1, 5), dtype=torch.float).cuda()
        input.requires_grad = True
        output = model(input)
        output.sum().backward()
        optim.step()
        optim.zero_grad()
    input = torch.rand((1, 5), dtype=torch.float).cuda()
    output = model(input)

    model.assert_state(TrainingState.IDLE)

    # second time to rewrap the inner model
    rewrapped_model = FSDP(inner_model, **fsdp_config).cuda()
    rewrapped_output = rewrapped_model(input)

    assert torch.allclose(output, rewrapped_output)
    teardown()
Exemplo n.º 2
0
 def _test_module_state_dict(cls, config, rank, group):
     ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
     autocast = ddp_model.mixed_precision
     cls._train_for_several_steps(ddp_model, 2, autocast)
     state_1 = ddp_model.state_dict()
     # You must make a new FSDP instance to use module.load_state_dict
     unwrapped_model = TransformerWithSharedParams(group)
     unwrapped_model.load_state_dict(state_1)
     new_ddp_model = FSDP(unwrapped_model, group, **config).cuda()
     cls._train_for_several_steps(new_ddp_model, 2, autocast)
     try:
         ddp_model.load_state_dict(new_ddp_model.state_dict())
         assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded"
     except Exception:
         pass
Exemplo n.º 3
0
    def _test_nested_wrapped_model_local_state_dict(cls, rank, group, config=None, local=None):
        # Create a nested FSDP-wrapped instance.
        model = NestedWrappedModule(group, config)
        model = FSDP(model, group, **config).cuda()
        cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])

        # Round trip state dict save/load/save.
        ref_state_dict = {k: v.clone() for k, v in model.local_state_dict().items()}
        model.load_local_state_dict(ref_state_dict)
        state_dict = model.local_state_dict()

        assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}"
        for key in ref_state_dict.keys():
            assert objects_are_equal(
                ref_state_dict[key], state_dict[key], raise_exception=False
            ), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
Exemplo n.º 4
0
def fsdp_wrapper(module, **kwargs):
    """
    Customer FSDP wrapper, adding the missing options
    """
    from vissl.utils.layer_memory_tracking import ProcessGroupTracker

    # Add global process group to the list of keys
    fsdp_config = dict(**kwargs)
    fsdp_config["process_group"] = get_global_group()
    if fsdp_config.get("_TRACK_COMMUNICATIONS", False):
        fsdp_config["process_group"] = ProcessGroupTracker(fsdp_config["process_group"])

    # Remove keys that are not supported in FSDP
    for key in {"_TRACK_COMMUNICATIONS", "AUTO_WRAP_THRESHOLD"}:
        fsdp_config.pop(key, None)

    return FSDP(module, **fsdp_config)
Exemplo n.º 5
0
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    assert isinstance(fsdp_config, dict), str(fsdp_config)

    class Model(Module):
        def __init__(self):
            super().__init__()
            # TODO (Min): for now, we just test pytorch sync_bn here.
            #             this will grow into regnet; testing apex sync_bn, etc.
            self.conv = Conv2d(2, 2, (1, 1))
            self.bn = BatchNorm2d(2)

        def forward(self, x):
            x = self.conv(x)
            x = self.bn(x)
            return x

    # TODO (Min): check DDP equivalency.

    model = Model()
    # Note, different rank may wrap in different order due to different random
    # seeds. But results should be the same.
    if random.randint(0, 1) == 0:
        print("auto_wrap_bn, then convert_sync_batchnorm")
        model = auto_wrap_bn(model)
        model = SyncBatchNorm.convert_sync_batchnorm(model)
    else:
        print("convert_sync_batchnorm, then auto_wrap_bn")
        model = SyncBatchNorm.convert_sync_batchnorm(model)
        model = auto_wrap_bn(model)
    model = FSDP(model, **fsdp_config).cuda()
    optim = SGD(model.parameters(), lr=0.1)

    for _ in range(3):
        in_data = torch.rand(2, 2, 2, 2).cuda()
        in_data.requires_grad = True
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    model.assert_state(TrainingState.IDLE)
    teardown()
def _freeze_distributed_worker(
    gpu_id,
    world_size,
    tempfile_name,
    unused,
):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    # The use case for this test is where the weights in the submodule
    # are not frozen but the leftover weights or those contained by the
    # root module are frozen. Refer to issue #758 for a real world example.
    model = FreezeModel()
    model = model.cuda()

    for param in model.head.parameters():
        param.requires_grad = False

    model = FSDP(model)

    if gpu_id == 0:
        print(model)

    target = torch.tensor([0, 1], dtype=torch.long).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
        optimizer.step()

    teardown()
Exemplo n.º 7
0
def _dist_worker(rank, world_size, files, outer_flat, inner_flat, sharing):

    # Get data from files.
    file1, file2, sd_before, sd_after, in_data = files
    sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank))
    sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank))
    in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank))

    result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2)
    assert result, "Dist init failed"

    fsdp_model = FSDP(Model(with_fsdp=True, inner_flat=inner_flat, sharing=sharing), flatten_parameters=outer_flat)
    fsdp_model.load_state_dict(sd_before)

    _train(fsdp_model, in_data)

    objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True)

    teardown()
Exemplo n.º 8
0
    def _distributed_worker(
        gpu_id: int, with_fsdp: bool, sync_file: str, result_file: str
    ):
        torch.cuda.set_device(gpu_id)
        dist.init_process_group(
            backend="nccl", init_method="file://" + sync_file, world_size=2, rank=gpu_id
        )

        # Create the inputs
        torch.manual_seed(0)
        torch.backends.cudnn.deterministic = True
        batch = torch.randn(size=(8, 3, 224, 224)).cuda()

        # Create a fake model based on SWAV blocks
        config = TestRegnetFSDP._create_config(with_fsdp)
        model = build_model(config["MODEL"], config["OPTIMIZER"])
        model = model.cuda()
        if with_fsdp:
            model = FSDP(model)
        else:
            model = DistributedDataParallel(model, device_ids=[gpu_id])
        criterion = SwAVLoss(loss_config=config["LOSS"]["swav_loss"])
        optimizer = optim.SGD(model.parameters(), lr=1e-2)

        # Run a few iterations and collect the losses
        losses = []
        for iteration in range(5):
            out = model(batch)
            loss = criterion(out[0], torch.tensor(0.0).cuda())
            if gpu_id == 0:
                losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            if iteration <= 2:
                for name, param in model.named_parameters():
                    if "prototypes" in name:
                        param.grad = None
            optimizer.step()

        # Store the losses in a file to compare several methods
        if gpu_id == 0:
            with open(result_file, "wb") as f:
                pickle.dump(losses, f)
Exemplo n.º 9
0
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused,
               test_case):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    my_lr = 0.1

    if test_case["assert_ref_out"]:
        with torch.no_grad():
            # Compute one iteration local output.
            weight = model.weight.T.clone().cuda()
            v = torch.Tensor(test_case["inputs"][0][rank]).cuda()
            ref_forward_output_my_rank = torch.matmul(v, weight)
            # Compute one iteration global weight update.
            v = torch.Tensor(test_case["inputs"][0][:world_size]).cuda()
            grad = v.sum(0).repeat(weight.shape[0], 1).div(world_size)
            ref_weight_out = weight - grad.T * my_lr
    model.to("cuda")
    assert isinstance(fsdp_config, dict), str(fsdp_config)
    model = FSDP(model, **fsdp_config)
    optim = SGD(model.parameters(), lr=my_lr)
    inputs = test_case["inputs"]
    assert len(inputs) == 1 or not test_case["assert_ref_out"]
    assert len(inputs[0]) >= world_size
    for in_data in inputs:
        in_data = Tensor(in_data[rank]).cuda()
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()
        if test_case["assert_ref_out"]:
            with model.summon_full_params():
                weight_out = model.module.weight.data.T.clone()
            # make sure we can do more fwd/bwd
            loss = model(in_data)
            loss.sum().backward()

    if test_case["assert_ref_out"]:
        torch.testing.assert_allclose(ref_forward_output_my_rank, out)
        torch.testing.assert_allclose(ref_weight_out, weight_out)

    model.assert_state(TrainingState.IDLE)
    teardown()
Exemplo n.º 10
0
def test_pre_backward_hook(temp_files):
    """Test FSDP with a model that triggers a pre_backward hook bug."""

    result = dist_init(rank=0,
                       world_size=1,
                       filename=temp_files[0],
                       filename_rpc=temp_files[1])
    assert result, "Dist init failed"

    class Model(Module):
        def __init__(self):
            super().__init__()
            self.l1 = Linear(4, 4).cuda()
            self.l2 = FSDP(Linear(4, 4).cuda())
            self.l3 = Linear(4, 4).cuda()

        def forward(self, x):
            x = self.l1(x)
            x = self.l2(x)
            inner_result = x
            x = self.l3(x)
            return x, inner_result

        def assert_and_clear_grad(self):
            for p in self.parameters():
                assert p.shape in [(4, 4), (4, ), (4 * 4 + 4, )], p.shape
                assert p.grad is not None
                p.grad = None

    model = FSDP(Model(), flatten_parameters=False).cuda()
    in_data = torch.rand(1, 4).cuda()
    for _ in range(3):
        out, _ = model(in_data)
        out.sum().backward()
        model.assert_and_clear_grad()

    teardown()
Exemplo n.º 11
0
    def _test_nested_wrapped_model(cls, rank, group, config=None):
        # Get reference state dict without any nested FSDP instances.
        model = NestedWrappedModule(group, None).cuda()
        model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group)
        cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
        ref_state_dict = {k: v.clone() for k, v in model.module.state_dict().items()}

        # Create a nested FSDP-wrapped instance.
        if config["mixed_precision"]:
            config["compute_dtype"] = torch.float32
        model = NestedWrappedModule(group, config)
        model = FSDP(model, group, **config).cuda()
        cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])

        # Round-trip state dict save/load/save.
        state_dict = {k: v.clone() for k, v in model.state_dict().items()}
        model.load_state_dict(state_dict)
        state_dict = model.state_dict()

        assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}"
        for key in ref_state_dict.keys():
            assert objects_are_equal(
                ref_state_dict[key], state_dict[key], raise_exception=False
            ), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
Exemplo n.º 12
0
 def __init__(self):
     super().__init__()
     self.layers = Sequential(FSDP(Linear(5, 5), **fsdp_config), )
Exemplo n.º 13
0
 def __init__(self):
     super().__init__()
     self.l1 = Linear(4, 4).cuda()
     self.l2 = FSDP(Linear(4, 4).cuda())
     self.l3 = Linear(4, 4).cuda()
Exemplo n.º 14
0
    def init_distributed_data_parallel_model(self):
        """
        Initialize FSDP if needed.

        This method overloads the ClassificationTask class's method from ClassyVision.
        """
        if not is_distributed_training_run():
            return

        # Make sure default cuda device is set. TODO (Min): we should ensure FSDP can
        # be enabled for 1-GPU as well, but the use case there is likely different.
        # I.e. perhaps we use it for cpu_offloading.
        assert get_cuda_device_index(
        ) > -1, "Distributed training not setup correctly"

        # The model might be already wrapped by FSDP internally. Check regnet_fsdp.py.
        # Here, we wrap it at the outer most level.
        fsdp_config = self.config["MODEL"]["FSDP_CONFIG"]
        if is_primary():
            logging.info(f"Using FSDP, config: {fsdp_config}")

        # First, wrap the head's prototype_i layers if it is SWAV.
        # TODO (Min): make this more general for different models, which may have multiple
        #             heads.
        if len(self.base_model.heads) != 1:
            raise ValueError(
                f"FSDP only support 1 head, not {len(self.base_model.heads)} heads"
            )
        head0 = self.base_model.heads[0]
        if isinstance(head0, SwAVPrototypesHead):
            # This is important for convergence!
            #
            # Since we "normalize" this layer in the update hook, we need to keep its
            # weights in full precision. It is output is going into the loss and used
            # for clustering, so we need to have that in full precision as well.
            fp_fsdp_config = fsdp_config.copy()
            fp_fsdp_config["flatten_parameters"] = False
            fp_fsdp_config["mixed_precision"] = False
            fp_fsdp_config["fp32_reduce_scatter"] = False
            for j in range(head0.nmb_heads):
                module = getattr(head0, "prototypes" + str(j))
                module = FSDP(module=module, **fp_fsdp_config)
                setattr(head0, "prototypes" + str(j), module)
        head0 = FSDP(module=head0, **fsdp_config)
        self.base_model.heads[0] = head0

        # Init the head properly since the weights are potentially initialized on different
        # ranks with different seeds. We first summon the full params from all workers.
        # Then, within that context, we set a fixed random seed so that all workers init the
        # weights the same way. Finally, we reset the layer's weights using reset_parameters().
        #
        # TODO (Min): This will go away once we have a way to sync from rank 0.
        with head0.summon_full_params():
            with set_torch_seed(self.config["SEED_VALUE"]):
                for m in head0.modules():
                    if isinstance(m, Linear):
                        m.reset_parameters()
        head0._reset_lazy_init()
        head0.prototypes0._reset_lazy_init()

        # TODO (Min): We can load checkpoint, but it ends up setting the trunk's _is_root
        # flag to true. We need to set it back to None here.
        # Also, right now, the head's weight is only partially loaded from the checkpoint
        # because we dump the checkpoint after the head if wrapped, but loading it before
        # it is wrapped.
        # For very big models, we need re-work the checkpoint logic because we don't have
        # enough memory to load the entire model on one node. We need to use local_state_dict()
        # API to load checkpoint shards.
        for module in self.base_model.trunk.modules():
            if isinstance(module, FSDP):
                module._is_root = None

        # Then, wrap the whole model. We replace the base_model since it is used
        # when checkpoint is taken.
        self.base_model = FSDP(module=self.base_model, **fsdp_config)
        self.distributed_model = self.base_model
Exemplo n.º 15
0
 def fsdp_wrap(self):
     self.trunk = FSDP(self.trunk)
     self.head = FSDP(self.head)
Exemplo n.º 16
0
def to_fsdp(module, fsdp_config):
    return FSDP(module,
                process_group=get_process_group_cached(),
                **fsdp_config)
Exemplo n.º 17
0
 def fsdp_wrap(self):
     for name, child in self.trunk.named_children():
         wrapped_child = FSDP(child)
         setattr(self.trunk, name, wrapped_child)
     self.trunk = FSDP(self.trunk)
     self.head = FSDP(self.head)
def _create_model(with_fsdp):
    model = Model()
    if with_fsdp:
        model.trunk = FSDP(model.trunk)
        model.head = FSDP(model.head)
    return model
Exemplo n.º 19
0
def to_fsdp(module):
    return FSDP(module, process_group=get_global_group())
Exemplo n.º 20
0
def fsdp_wrapper(module, **kwargs):
    """Customer wrapper that does FSDP + checkpoint at the same time."""
    # TODO (Min): enable checkpoint_wrapper
    return FSDP(module, **kwargs)
Exemplo n.º 21
0
def _distributed_worker(
    rank,
    world_size,
    fsdp_config,
    fsdp_wrap_bn,
    ddp_mixed_precision,
    tempfile_name,
    unused,
    state_before,
    inputs,
    rank_0_output,
    state_after,
    sync_bn,
    conv_bias,
    linear_bias,
):
    torch.backends.cudnn.deterministic = True

    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    ddp = True
    if fsdp_config:
        ddp = False
        assert isinstance(fsdp_config, dict), str(fsdp_config)
        if fsdp_config["mixed_precision"]:
            # To match DDP in AMP -O1, we need fp32 reduce scatter.
            fsdp_config["fp32_reduce_scatter"] = True

    model = Model(conv_bias, linear_bias)
    model.load_state_dict(state_before)
    model = model.cuda()

    class DummyScaler:
        def scale(self, loss):
            return loss

        def step(self, optim):
            optim.step()

        def update(self):
            pass

    scaler = DummyScaler()
    if ddp:
        if sync_bn == "pytorch":
            model = pytorch_bn_converter(model)
        model = DDP(model, device_ids=[rank], broadcast_buffers=True)
        if ddp_mixed_precision:
            scaler = GradScaler()
    else:
        # Note, different rank may wrap in different order due to different random
        # seeds. But results should be the same.
        if random.randint(0, 1) == 0:
            print(f"auto_wrap_bn {fsdp_wrap_bn}, then sync_bn {sync_bn}")
            if fsdp_wrap_bn:
                model = auto_wrap_bn(model, _single_rank_pg)
            if sync_bn == "pytorch":
                model = pytorch_bn_converter(model)
        else:
            print(f"sync_bn {sync_bn}, then auto_wrap_bn {fsdp_wrap_bn}")
            if sync_bn == "pytorch":
                model = pytorch_bn_converter(model)
            if fsdp_wrap_bn:
                model = auto_wrap_bn(model, _single_rank_pg)
        model = FSDP(model, **fsdp_config).cuda()
        if fsdp_config["mixed_precision"]:
            scaler = ShardedGradScaler()
        # Print the model for verification.
        if rank == 0:
            print(model)
    optim = SGD(model.parameters(), lr=0.1)
    loss_func = CrossEntropyLoss()

    for in_data in inputs[rank]:
        in_data = in_data.cuda()
        context = contextlib.suppress()
        if ddp and ddp_mixed_precision:
            in_data = in_data.half()
            context = torch.cuda.amp.autocast(enabled=True)
        if not ddp and fsdp_config["mixed_precision"]:
            context = torch.cuda.amp.autocast(enabled=True)
        with context:
            out = model(in_data)
            fake_label = torch.zeros(1, dtype=torch.long).cuda()
            loss = loss_func(out.unsqueeze(0), fake_label)
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()
        optim.zero_grad()

    if ddp:
        # Save the rank 0 state_dict to the output file.
        if rank == 0:
            state_after = model.module.cpu().state_dict()
            torch.save(state_after, rank_0_output)
    else:
        model.assert_state(TrainingState.IDLE)
        # Ensure final state equals to the state_after.
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        # Change False to True to enable this when you want to debug the mismatch.
        if False and rank == 0:

            def dump(d):
                for k, v in d.items():
                    print(k, v)

            dump(state_after)
            dump(fsdp_state)
        # If sync_bn is used, all ranks should have the same state, so we can compare with
        # rank 0 state on every rank. Otherwise, only compare rank 0 with rank 0.
        if sync_bn != "none" or rank == 0:
            assert objects_are_equal(state_after,
                                     fsdp_state,
                                     raise_exception=True)

    teardown()
Exemplo n.º 22
0
 def __init__(self):
     super().__init__()
     self.inner = FSDP(Linear(4, 4), **fsdp_config)
     self.outer = Linear(4, 5)
def _distributed_worker(
    gpu_id,
    world_size,
    with_fsdp,
    with_nested_trunk,
    freezing_method,
    tempfile_name,
    unused,
    rank_0_output,
    expected_state,
):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    model = _create_model(with_fsdp, with_nested_trunk)
    model = model.cuda()

    # freezing the trunk using requires_grad.
    if freezing_method == FreezingMethod.RequiresGrad:
        for param in model.trunk.parameters():
            param.requires_grad = False

    if with_fsdp:
        model = FSDP(model)
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id])

    if gpu_id == 0:
        print(model)

    target = torch.tensor([0, 1], dtype=torch.long).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
        if freezing_method == FreezingMethod.GradToNone:
            for param in model.trunk.parameters():
                param.grad = None
        optimizer.step()

    if with_fsdp:
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        assert objects_are_equal(expected_state,
                                 fsdp_state,
                                 raise_exception=True)
    elif rank == 0:
        state_after = model.module.cpu().state_dict()
        torch.save(state_after, rank_0_output)

    teardown()
def _distributed_worker(
    gpu_id,
    world_size,
    with_model2,
    with_sync_bn,
    with_fsdp,
    with_checkpoint,
    files,
    mixed_precision,
    flatten,
    wrap_bn,
    fp32_reduce_scatter,
    bucket_cap_mb,
):
    filename, filename_rpc = files[:2]
    filename_loss = files[2:]

    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, filename, filename_rpc)
    assert result, "Dist init failed"

    # use False below to debug since error msg is not as good with cudnn.
    torch.backends.cudnn.enabled = True

    # these make things deterministic.
    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Ensure we have multiple forward passes.
    batch = [
        torch.randn(size=(2, 3, 16, 16)).cuda(),
        torch.randn(size=(2, 3, 9, 9)).cuda(),
        torch.randn(size=(2, 3, 9, 9)).cuda(),
    ]

    if mixed_precision and not with_fsdp:
        batch = [x.half() for x in batch]

    model = _create_model(
        with_model2,
        with_sync_bn,
        with_fsdp,
        with_checkpoint,
        mixed_precision,
        flatten,
        wrap_bn,
        fp32_reduce_scatter,
        bucket_cap_mb,
    )
    model = model.cuda()

    if with_fsdp:
        model = FSDP(
            model,
            flatten_parameters=flatten,
            mixed_precision=mixed_precision,
            compute_dtype=torch.float32,
            fp32_reduce_scatter=fp32_reduce_scatter,
            bucket_cap_mb=bucket_cap_mb,
        )
        model.set_gradient_divide_factors(1.0, 2.0, True)
        no_sync_context = contextlib.suppress()
    else:
        # With DDP, we need no_sync and manual gradient reduction below because
        # it can't handle multiple forward pass + checkpointing otherwise.
        model = DistributedDataParallel(model, device_ids=[gpu_id])
        no_sync_context = model.no_sync()

    mp_context = contextlib.suppress()
    if mixed_precision:
        mp_context = torch.cuda.amp.autocast(enabled=True)

    if gpu_id == 0:
        print(model)

    target = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.long).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    losses = {}
    i = 0
    with no_sync_context:
        for iteration in range(3):
            with mp_context:
                out = model(batch)
                loss = criterion(out, target)
                print("Loss", iteration, ":", loss.item())
                losses[f"iter_{i}"] = loss
                i += 1
                optimizer.zero_grad()
                loss.backward()
            # Manual grad reduction, no autocast.
            if not with_fsdp:
                for p in model.parameters():
                    dist.all_reduce(p.grad.data)
                    p.grad.data.div_(2.0)
            # Stepping, no autocast
            optimizer.step()

    # Due to dist.all_reduce code block above with ddp.no_sync, we seem to hit a bug
    # in DDP where tensor.cpu() and torch.save() calls both hang. FSDP is not affected.
    # Therefore, we have to compare losses here instead of states.
    with open(filename_loss[rank], "wb") as f:
        pickle.dump(losses, f)

    teardown()
Exemplo n.º 25
0
def to_fsdp(module, fsdp_config):
    return FSDP(module, process_group=get_global_group(), **fsdp_config)