Ejemplo n.º 1
0
    def _test_state_dict_device(self, config, rank, group, pure_fp16=False, **model_kwargs):
        model = TransformerWithSharedParams(group, **model_kwargs)
        if pure_fp16:
            assert not config["mixed_precision"]
            model = model.half()
        fsdp_model = FSDP(model, group, **config)
        if not config["cpu_offload"]:
            fsdp_model = fsdp_model.cuda()
        autocast = fsdp_model.mixed_precision or pure_fp16
        self._train_for_several_steps(fsdp_model, 1, autocast)

        sd = fsdp_model.state_dict()

        sd_device = config.get("state_dict_device")
        for k, v in sd.items():
            if config["cpu_offload"] or (sd_device is not None and sd_device.type == "cpu"):
                assert v.device.type == "cpu", v.device.type
            else:
                assert v.device.type == "cuda", v.device.type

        expected_dtype = torch.float16 if pure_fp16 else torch.float32
        for k, v in sd.items():
            if not torch.is_floating_point(v):
                continue
            assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
def _distributed_worker(gpu_id, world_size, with_fsdp, freezing_method,
                        tempfile_name, unused, rank_0_output, expected_state):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    model = _create_model(with_fsdp)
    model = model.cuda()

    # freezing the trunk using requires_grad.
    assert freezing_method in ["requires_grad", "grad_to_none"]
    if freezing_method == "requires_grad":
        for param in model.trunk.parameters():
            param.requires_grad = False

    if with_fsdp:
        model = FSDP(model)
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id])

    if gpu_id == 0:
        print(model)

    target = torch.LongTensor([0, 1]).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
        if freezing_method == "grad_to_none":
            for param in model.trunk.parameters():
                param.grad = None
        optimizer.step()

    if with_fsdp:
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        assert objects_are_equal(expected_state,
                                 fsdp_state,
                                 raise_exception=True)
    elif rank == 0:
        state_after = model.module.cpu().state_dict()
        torch.save(state_after, rank_0_output)

    teardown()
def _dist_worker(rank, world_size, files, wrap_middle, test_fn):

    # Get data from files.
    file1, file2, sd_before, sd_after, in_data = files
    sd_before = torch.load(
        sd_before, map_location=lambda storage, loc: storage.cuda(rank))
    if test_fn == "train":
        sd_after = torch.load(
            sd_after, map_location=lambda storage, loc: storage.cuda(rank))
    in_data = torch.load(in_data,
                         map_location=lambda storage, loc: storage.cuda(rank))

    result = dist_init(rank=rank,
                       world_size=world_size,
                       filename=file1,
                       filename_rpc=file2)
    assert result, "Dist init failed"

    fsdp_model = FSDP(
        # To debug: first make with_fsdp=False (no inner wrapping) work, then enable inner wrapping
        # and make that work.
        Model(with_fsdp=True, wrap_middle=wrap_middle),
        flatten_parameters=test_fn == "optim_state",
        mixed_precision=False,
        compute_dtype=torch.float16,
    )
    fsdp_model.load_state_dict(sd_before)

    if test_fn == "train":
        _train(fsdp_model, in_data)
        objects_are_equal(sd_after,
                          fsdp_model.state_dict(),
                          raise_exception=True)
    elif test_fn == "eval":
        _eval(fsdp_model, in_data)
    elif test_fn == "optim_state":
        optim = SGD(fsdp_model.parameters(), lr=0.1)
        for _ in range(3):
            out = fsdp_model(in_data)
            out.backward()
            optim.step()
        sd = fsdp_model.gather_full_optim_state_dict(optim)
        if rank == 0:
            # There should 8 momentum buffers in the state.
            assert len(sd["state"].keys()) == 8
        else:
            assert sd is None, "only rank 0 should have the optim state"
    else:
        assert 0, f"invalid test_fn {test_fn}"

    teardown()
Ejemplo n.º 4
0
 def _test_module_state_dict(cls, config, rank, group):
     ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
     autocast = ddp_model.mixed_precision
     cls._train_for_several_steps(ddp_model, 2, autocast)
     state_1 = ddp_model.state_dict()
     # You must make a new FSDP instance to use module.load_state_dict
     unwrapped_model = TransformerWithSharedParams(group)
     unwrapped_model.load_state_dict(state_1)
     new_ddp_model = FSDP(unwrapped_model, group, **config).cuda()
     cls._train_for_several_steps(new_ddp_model, 2, autocast)
     try:
         ddp_model.load_state_dict(new_ddp_model.state_dict())
         assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded"
     except Exception:
         pass
Ejemplo n.º 5
0
    def _test_identical_outputs(
        cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
    ):
        if config.get("mixed_precision", False):
            autocast = True
            # Force the compute dtype to be torch.float32 so that we get
            # identical results as PyTorch DDP when using autocast. Note that
            # this will cause the all-gather to happen in FP32, which is slower
            # than necessary in most cases.
            config["compute_dtype"] = torch.float32
        else:
            autocast = False

        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrapper_config=None).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[rank], output_device=rank, process_group=group
            )
        else:
            model = ref_ddp_fn(model, group)
        ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
        ref_state_dict = model.module.state_dict()
        if config.get("cpu_offload", False):
            for k in ref_state_dict.keys():
                ref_state_dict[k] = ref_state_dict[k].cpu()

        # Confirm we get the same behavior using FullyShardedDataParallel.
        model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
        if use_cuda:
            model = model.cuda()
        else:
            assert next(model.parameters()).device == torch.device("cpu")
        shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
        shard_state_dict = model.state_dict()

        try:
            torch.testing.assert_allclose(ref_loss, shard_loss)
            assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
        except (AssertionError, RuntimeError) as e:
            raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")
        if config.get("flatten_parameters", True):
            metadata = model.local_metadata_dict()
            assert isinstance(metadata, dict)
Ejemplo n.º 6
0
def _dist_worker(rank, world_size, files, outer_flat, inner_flat, sharing):

    # Get data from files.
    file1, file2, sd_before, sd_after, in_data = files
    sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank))
    sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank))
    in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank))

    result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2)
    assert result, "Dist init failed"

    fsdp_model = FSDP(Model(with_fsdp=True, inner_flat=inner_flat, sharing=sharing), flatten_parameters=outer_flat)
    fsdp_model.load_state_dict(sd_before)

    _train(fsdp_model, in_data)

    objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True)

    teardown()
Ejemplo n.º 7
0
    def _test_state_dict_device(self,
                                config,
                                rank,
                                group,
                                pure_fp16=False,
                                **model_kwargs):
        model = TransformerWithSharedParams(group, **model_kwargs)
        if pure_fp16:
            assert not config["mixed_precision"]
            model = model.half()
        fsdp_model = FullyShardedDataParallel(model, group, **config)
        if not config["cpu_offload"]:
            fsdp_model = fsdp_model.cuda()
        autocast = fsdp_model.mixed_precision or pure_fp16
        self._train_for_several_steps(fsdp_model, 1, autocast)

        sd = fsdp_model.state_dict()

        sd_device = config.get("state_dict_device")
        for k, v in sd.items():
            if config["cpu_offload"] or (sd_device is not None
                                         and sd_device.type == "cpu"):
                assert v.device.type == "cpu", v.device.type
            else:
                assert v.device.type == "cuda", v.device.type

        expected_dtype = torch.float16 if pure_fp16 else torch.float32
        buffers = {
            k.replace("_fsdp_wrapped_module.", "").replace("_fpw_module.", "")
            for k, _ in fsdp_model.named_buffers()
        }
        for k, v in sd.items():
            if not torch.is_floating_point(v):
                continue
            if k in buffers:
                assert v.dtype == fsdp_model.buffer_dtype, f"{v.dtype} != {fsdp_model.buffer_dtype}"
            else:
                assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
Ejemplo n.º 8
0
def main(local_rank, *args):
    torch.backends.cudnn.benchmark = True
    init_method = "tcp://%s:%s" % ("0.0.0.0", "9999")
    torch.distributed.init_process_group(backend="nccl",
                                         rank=local_rank,
                                         world_size=8,
                                         init_method=init_method)
    print("[Train]: Time = %s, Initialized Dist Process for Rank = %s" %
          (get_time_string(), local_rank))
    device = torch.device(
        f'cuda:{local_rank}')  # Unique only on individual node.
    torch.cuda.set_device(device)
    torch.cuda.set_device(device)
    fsdp_params = dict(mixed_precision=True,
                       flatten_parameters=True,
                       bucket_cap_mb=25,
                       reshard_after_forward=False,
                       fp32_reduce_scatter=False,
                       cpu_offload=False,
                       move_grads_to_cpu=False,
                       process_group=torch.distributed.group.WORLD)
    with enable_wrap(wrapper_cls=FullyShardedDDP, **fsdp_params):
        nn_model = nn.Sequential(
            nn.Linear(200, 200),
            wrap(
                checkpoint_wrapper(nn.Sequential(
                    nn.Linear(200, 200), nn.Linear(200, 200),
                    wrap(
                        checkpoint_wrapper(nn.Linear(200, 200),
                                           offload_to_cpu=True)),
                    checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
                    nn.Linear(200, 200)),
                                   offload_to_cpu=True)),
            checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
            nn.LayerNorm(200, eps=1e-7), nn.Linear(200, 64)).cuda()

        model = FullyShardedDDP(nn_model, **fsdp_params)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=1e-4,
                                  eps=1e-7,
                                  weight_decay=1e-2,
                                  betas=(0.9, 0.99))
    optimizer.zero_grad(set_to_none=True)

    for i in range(1000):
        optimizer.zero_grad(set_to_none=True)
        fake_inputs = torch.randn(32, 200, device=device)
        fake_labels = torch.randn(32, 64, device=device)
        outputs = model(fake_inputs)
        loss = ((outputs - fake_labels)**2).mean()
        loss.backward()
        model.clip_grad_norm_(1.0)
        optimizer.step()
        if i % 100 == 0:
            print("Loss = %s, rank = %s" % (loss.item(), local_rank))

    state_dict = model.state_dict()
    nn_model = nn.Sequential(
        nn.Linear(200, 200),
        nn.Sequential(nn.Linear(200, 200), nn.Linear(200, 200),
                      nn.Linear(200, 200),
                      checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
                      nn.Linear(200, 200)),
        checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
        nn.LayerNorm(200, eps=1e-7), nn.Linear(200, 64)).cuda()
    nn_model.load_state_dict(state_dict)
    print("[Train]: Time = %s, Trainable Params = %s" %
          (get_time_string(), numel(nn_model) / 1_000_000))