示例#1
0
def _test_func(rank, world_size, tempfile_name, unused):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    # Keep initialization deterministic.
    torch.manual_seed(0)

    model = FullyShardedDataParallel(SimpleModuleWithCheckpointing().cuda())
    optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # Collect parameter sizes to ensure these stay consistent through the steps below.
    expected_param_shapes = {
        name: tuple(param.shape)
        for name, param in model.named_parameters()
    }

    # For clarity, this is what `expected_param_shapes` should look like depending on world size:
    assert expected_param_shapes == {
        "_fsdp_wrapped_module.flat_param_0": (12, ),
        "_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param_0":
        (6, ),
    }, expected_param_shapes

    torch.manual_seed(1 + rank)

    # Train for a step.
    _train_step(model, optim, expected_param_shapes)

    # Now do an eval step.
    _eval_step(model, optim, expected_param_shapes)

    # And finally do another train step.
    _train_step(model, optim, expected_param_shapes)

    teardown()
示例#2
0
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused,
               test_case):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    if test_case["assert_ref_out"]:
        with torch.no_grad():
            weight = model.weight.T.clone().cuda()
            v = torch.Tensor(test_case["inputs"][0][rank]).cuda()
            ref_out = torch.matmul(v, weight)
    model.to("cuda")
    assert isinstance(fsdp_config, dict), str(fsdp_config)
    model = FSDP(model, **fsdp_config)
    optim = SGD(model.parameters(), lr=0.1)
    inputs = test_case["inputs"]
    assert len(inputs) == 1 or not test_case["assert_ref_out"]
    assert len(inputs[0]) >= world_size
    for in_data in inputs:
        in_data = Tensor(in_data[rank]).cuda()
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    if test_case["assert_ref_out"]:
        torch.testing.assert_allclose(ref_out, out)

    model.assert_state(TrainingState.IDLE)
    teardown()
    def __init__(self, with_fsdp=False, wrap_middle="none"):
        super().__init__()
        self.l0 = nn.Embedding(VOCAB, D_MODEL).cuda().half()
        nn.init.uniform_(self.l0.weight, -1.0e-1, 1.0e-1)
        self.l1 = MEVO(self.l0.weight, tile_factor=TILE, reduction="sum")
        self.middle = nn.Linear(D_MODEL, D_MODEL).cuda().half()
        # LNs are not strictly needed for this test, but they help reduce the loss quickly
        # and improves the numerical stability.
        self.ln1 = nn.LayerNorm(D_MODEL).cuda().half()
        self.ln2 = nn.LayerNorm(D_MODEL).cuda().half()

        if with_fsdp:
            # Shared layers must be un-flatten.
            self.l0 = FSDP(self.l0,
                           flatten_parameters=False,
                           mixed_precision=False,
                           compute_dtype=torch.float16)
            self.l1 = FSDP(self.l1,
                           flatten_parameters=False,
                           mixed_precision=False,
                           compute_dtype=torch.float16)
            self.l1.append_shared_param(self.l0.module.weight)
            # These are for debugging.
            # print(id(self.l0), "is emb")
            # print(id(self.l1), "is out")
            assert wrap_middle in ["none", "flat", "nonflat"]
            if wrap_middle != "none":
                self.middle = FSDP(
                    self.middle,
                    flatten_parameters=wrap_middle == "flat",
                    mixed_precision=False,
                    compute_dtype=torch.float16,
                )
    def __init__(self, with_fsdp=False, inner_flat=False, sharing=None):
        super().__init__()
        self.l0 = Linear(4, 4, bias=True).cuda()
        self.l1 = Linear(4, 4, bias=True).cuda()
        self.l2 = Linear(4, 4, bias=True).cuda()
        self.l3 = Linear(4, 4, bias=True).cuda()

        # share the weights. the layer must have at least 1 param is that's not
        # shared. Therefore, we have bias=True and testing either sharing the
        # weight or the bias.
        if sharing == "share_only_weights":
            self.l1.weight = self.l3.weight
        elif sharing == "share_only_bias":
            self.l1.bias = self.l3.bias
        else:
            assert sharing is None or sharing == "share_none"

        if with_fsdp:
            # Shared layers much be un-flatten.
            self.l1 = FSDP(self.l1, flatten_parameters=False)
            self.l2 = FSDP(self.l2, flatten_parameters=inner_flat)
            self.l3 = FSDP(self.l3, flatten_parameters=False)

            if sharing in ["share_only_weights"]:
                self.l3.append_shared_param(self.l1.module.weight)
            if sharing in ["share_only_bias"]:
                self.l3.append_shared_param(self.l1.module.bias)
示例#5
0
    def _test_dtypes(cfg: Dict,
                     autocast,
                     in_dtype,
                     p_dtype,
                     loss_dtype,
                     reduce_dtype,
                     rank,
                     group,
                     expected_buffer_type=None):
        # Patch torch.distributed.reduce_scatter to check the dtype of the reduction
        orig_reduce_scatter = torch.distributed.reduce_scatter

        model: nn.Module = DeviceAndTypeCheckModule(
            expected_input_dtype=in_dtype,
            expected_param_dtype=p_dtype,
            expected_loss_dtype=loss_dtype,
            expected_buffer_dtype=expected_buffer_type,
        )

        def _reduce_scatter(output, input_list, **kwargs):
            for tensor in input_list:
                model._check("reduce_scatter.dtype",
                             tensor.dtype,
                             expected=reduce_dtype)
            return orig_reduce_scatter(output, input_list, **kwargs)

        with mock.patch("torch.distributed.reduce_scatter",
                        new=_reduce_scatter):
            model = FullyShardedDataParallel(model, group, **cfg).cuda()
            device = next(model.parameters()).device
            x = torch.rand(2, 5).to(device)
            with torch.cuda.amp.autocast(enabled=autocast):
                loss = model(x)
            loss.backward()
示例#6
0
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    assert isinstance(fsdp_config, dict), str(fsdp_config)

    class Model(Module):
        def __init__(self):
            super().__init__()
            self.inner = FSDP(Linear(4, 4), **fsdp_config)
            self.outer = Linear(4, 5)

        def forward(self, x):
            # Forward twice.
            i = self.inner(x)
            j = self.inner(x)
            return self.outer(i + j)

    model = FSDP(Model(), **fsdp_config).cuda()
    optim = SGD(model.parameters(), lr=0.1)

    for _ in range(3):
        in_data = torch.rand(64, 4).cuda()
        in_data.requires_grad = True
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    model.assert_state(TrainingState.IDLE)
    teardown()
示例#7
0
    def _test_identical_outputs_eval(
        cls,
        model_init_fn,
        config,
        rank,
        group,
        num_steps=2,
        use_cuda=True,
        lr=0.01,
        ref_ddp_fn=None,
    ):
        if config.get("mixed_precision", False):
            autocast = True
            # Force the compute dtype to be torch.float32 so that we get
            # identical results as PyTorch DDP when using autocast. Note that
            # this will cause the all-gather to happen in FP32, which is slower
            # than necessary in most cases.
            config["compute_dtype"] = torch.float32
        else:
            autocast = False

        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrapper_config=None).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(model,
                                                        device_ids=[rank],
                                                        output_device=rank,
                                                        process_group=group)
        else:
            model = ref_ddp_fn(model, group)
        ref_loss = cls._eval_with_config(model, autocast)
        ref_state_dict = model.module.state_dict()
        if config.get("cpu_offload", False):
            for k in ref_state_dict.keys():
                ref_state_dict[k] = ref_state_dict[k].cpu()

        # Confirm we get the same behavior using FullyShardedDataParallel.
        if config.get("ssd_offload", False):
            config["offload_config"] = OffloadConfig(
                offload_type="ssd_offload")

        del config["ssd_offload"]
        model = FullyShardedDataParallel(
            model_init_fn(group=group, wrapper_config=config), group, **config)
        if not model.ssd_offload and not model.move_params_to_cpu:
            if use_cuda:
                model = model.cuda()
            else:
                assert next(model.parameters()).device == torch.device("cpu")
        shard_loss = cls._eval_with_config(model, autocast)

        try:
            torch.testing.assert_allclose(ref_loss, shard_loss)
        except (AssertionError, RuntimeError) as e:
            raise Exception(
                f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}"
            )
        if config.get("flatten_parameters", True):
            metadata = model.local_metadata_dict()
            assert isinstance(metadata, dict)
示例#8
0
    def __init__(self, group, wrapper_config, checkpoint_act=False):
        super().__init__(group, wrapper_config)
        self.group = group

        # "expert" params are different on each rank
        torch.manual_seed(42 + group.rank())
        expert = nn.Linear(16, 4)
        for p in expert.parameters():
            p.expert = True

        # everything else is shared
        torch.manual_seed(0)
        shared = nn.Linear(4, 16)

        if checkpoint_act:
            expert = checkpoint_wrapper(expert)
            shared = checkpoint_wrapper(shared)

        if wrapper_config is not None:
            # we create a process group of size 1 for the expert params
            expert_group = torch.distributed.new_group([group.rank()])
            expert = FullyShardedDataParallel(expert, expert_group,
                                              **wrapper_config)

            shared = FullyShardedDataParallel(shared, group, **wrapper_config)

        self.module = nn.Sequential(nn.Linear(8, 4), shared, expert,
                                    nn.Linear(4, 8))
示例#9
0
    def _test_ssd_offload_eval(self, rank, group, config):
        model = TransformerWithSharedParams(group)
        state_dict = model.state_dict()

        nested_wrapping = config["nested_wrapping"]
        del config["nested_wrapping"]

        with tempfile.TemporaryDirectory() as current_tempdir:
            config["offload_config"] = OffloadConfig(
                offload_type="ssd_offload", ssd_filepath_dir=current_tempdir)
            if nested_wrapping:
                model = FullyShardedDataParallel(
                    NestedWrappedModule(group,
                                        wrap_everything=True,
                                        wrapper_config=config))
            else:
                model = FullyShardedDataParallel(model, **config)

            self._eval_with_config(model, autocast=config["mixed_precision"])

            # With SSD offload only local_state_dict will work. We can support global
            # state dict if we think it is necessary.
            state_dict = model.local_state_dict()
            model.load_local_state_dict(state_dict)

            self._eval_with_config(model, config["mixed_precision"])
    def _test_state_dict_device(self, config, rank, group, pure_fp16=False, **model_kwargs):
        model = TransformerWithSharedParams(group, **model_kwargs)
        if pure_fp16:
            assert not config["mixed_precision"]
            model = model.half()
        fsdp_model = FSDP(model, group, **config)
        if not config["cpu_offload"]:
            fsdp_model = fsdp_model.cuda()
        autocast = fsdp_model.mixed_precision or pure_fp16
        self._train_for_several_steps(fsdp_model, 1, autocast)

        sd = fsdp_model.state_dict()

        sd_device = config.get("state_dict_device")
        for k, v in sd.items():
            if config["cpu_offload"] or (sd_device is not None and sd_device.type == "cpu"):
                assert v.device.type == "cpu", v.device.type
            else:
                assert v.device.type == "cuda", v.device.type

        expected_dtype = torch.float16 if pure_fp16 else torch.float32
        for k, v in sd.items():
            if not torch.is_floating_point(v):
                continue
            assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
示例#11
0
    def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_free_ms=0):
        super().__init__(group, wrapper_config)
        self.group = group
        self.delay_before_free_ms = delay_before_free_ms

        # "expert" params are different on each rank
        torch.manual_seed(42 + group.rank())
        d_expert = 23
        d_shared = 12
        d_input = 8
        expert = nn.Linear(d_expert, d_shared)

        self.num_expert_params = sum([p.numel() for p in expert.parameters()])
        for p in expert.parameters():
            p.expert = True

        # everything else is shared
        torch.manual_seed(0)

        shared = nn.Linear(d_shared, d_expert)

        if checkpoint_act:
            expert = checkpoint_wrapper(expert)
            shared = checkpoint_wrapper(shared)

        if wrapper_config is not None:
            # we create a process group of size 1 for the expert params
            expert_group = torch.distributed.new_group([group.rank()])  # world size 1 means no shard
            expert = FullyShardedDataParallel(expert, expert_group, **wrapper_config)

            shared = FullyShardedDataParallel(shared, group, **wrapper_config)

        self.module = nn.Sequential(nn.Linear(d_input, d_shared), shared, expert, nn.Linear(d_shared, d_input))
def _distributed_worker(gpu_id, world_size, with_fsdp, freezing_method,
                        tempfile_name, unused, rank_0_output, expected_state):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    model = _create_model(with_fsdp)
    model = model.cuda()

    # freezing the trunk using requires_grad.
    assert freezing_method in ["requires_grad", "grad_to_none"]
    if freezing_method == "requires_grad":
        for param in model.trunk.parameters():
            param.requires_grad = False

    if with_fsdp:
        model = FSDP(model)
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id])

    if gpu_id == 0:
        print(model)

    target = torch.LongTensor([0, 1]).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
        if freezing_method == "grad_to_none":
            for param in model.trunk.parameters():
                param.grad = None
        optimizer.step()

    if with_fsdp:
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        assert objects_are_equal(expected_state,
                                 fsdp_state,
                                 raise_exception=True)
    elif rank == 0:
        state_after = model.module.cpu().state_dict()
        torch.save(state_after, rank_0_output)

    teardown()
示例#13
0
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused,
               test_case):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    my_lr = 0.1

    device = torch.device("cuda")
    if fsdp_config.get("mixed_precision", False):
        dtype = torch.float16
        fsdp_config["fp32_reduce_scatter"] = True
    else:
        dtype = torch.float32

    if test_case["assert_ref_out"]:
        with torch.no_grad():
            # Compute one iteration local output.
            fp32_weight = model.weight.T.clone().to(device)
            weight = fp32_weight.to(dtype)
            v = torch.Tensor(test_case["inputs"][0][rank]).to(device, dtype)
            ref_forward_output_my_rank = torch.matmul(v, weight)
            # Compute one iteration global weight update.
            v = torch.Tensor(test_case["inputs"][0][:world_size]).to(
                device, dtype)
            grad = v.float().sum(0).repeat(weight.shape[0], 1).div(world_size)
            ref_weight_out = fp32_weight - grad.T * my_lr
            assert ref_weight_out.dtype == torch.float32
    model.to(
        device)  # not dtype, since FSDP will manage mixed precision internally
    assert isinstance(fsdp_config, dict), str(fsdp_config)
    model = FSDP(model, **fsdp_config)
    optim = SGD(model.parameters(), lr=my_lr)
    inputs = test_case["inputs"]
    assert len(inputs) == 1 or not test_case["assert_ref_out"]
    assert len(inputs[0]) >= world_size
    for in_data in inputs:
        in_data = Tensor(in_data[rank]).to(device, dtype)
        out = model(in_data)
        out.float().sum().backward()
        optim.step()
        optim.zero_grad()
        if test_case["assert_ref_out"]:
            with model.summon_full_params():
                weight_out = model.module.weight.data.T.clone()
            # make sure we can do more fwd/bwd
            loss = model(in_data)
            loss.sum().backward()

    if test_case["assert_ref_out"]:
        torch.testing.assert_allclose(ref_forward_output_my_rank, out)
        torch.testing.assert_allclose(ref_weight_out, weight_out)

    model.assert_state(TrainingState.IDLE)
    teardown()
def _test_func(rank, world_size, tempfile_name, unused, flatten,
               mixed_precision, amp_context, half_input, fsdp_wrap_ckpt):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    # Keep initialization deterministic.
    torch.manual_seed(0)

    model = FSDP(
        SimpleModuleWithCheckpointing(flatten, mixed_precision,
                                      fsdp_wrap_ckpt).cuda(),
        flatten_parameters=flatten,
        mixed_precision=mixed_precision,
    )
    optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # Collect parameter sizes to ensure these stay consistent through the steps below.
    expected_param_shapes = {
        name: tuple(param.shape)
        for name, param in model.named_parameters()
    }

    # For clarity, this is what `expected_param_shapes` should look like depending on world size:
    if not flatten:
        assert expected_param_shapes == {
            "ffn.0.weight": (5, ),
            "ffn.0.bias": (2, ),
            "ffn.1.weight": (5, ),
            "ffn.1.bias": (2, ),
            "ffn.2.weight": (5, ),
            "ffn.2.bias": (2, ),
        }
    else:
        assert expected_param_shapes == {
            "_fsdp_wrapped_module.flat_param_0": (12, ),
            "_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param_0":
            (6, ),
        }, expected_param_shapes

    torch.manual_seed(1 + rank)

    # Train for a step.
    _train_step(model, optim, expected_param_shapes, amp_context,
                mixed_precision, half_input)

    # Now do an eval step.
    _eval_step(model, optim, expected_param_shapes, amp_context,
               mixed_precision, half_input)

    # And finally do another train step.
    _train_step(model, optim, expected_param_shapes, amp_context,
                mixed_precision, half_input)

    teardown()
示例#15
0
def test_input_type(temp_files, fsdp_config, input_cls):
    """Test FSDP with input being a list or a dict, only single GPU."""

    if torch_version() < (1, 7, 0):
        # This test runs multiple test cases in a single process. On 1.6.0 it
        # throw an error like this:
        #     RuntimeError: Container is already initialized! Cannot initialize it twice!
        pytest.skip(
            "older pytorch doesn't work well with single process dist_init multiple times"
        )

    result = dist_init(rank=0,
                       world_size=1,
                       filename=temp_files[0],
                       filename_rpc=temp_files[1])
    assert result, "Dist init failed"

    assert isinstance(fsdp_config, dict), str(fsdp_config)

    class Model(Module):
        def __init__(self):
            super().__init__()
            self.layer = Linear(4, 4)

        def forward(self, input):
            if isinstance(input, list):
                input = input[0]
            else:
                assert isinstance(input, dict), input
                input = input["in"]
            return self.layer(input)

    model = FSDP(Model(), **fsdp_config).cuda()
    optim = SGD(model.parameters(), lr=0.1)

    for _ in range(5):
        in_data = torch.rand(64, 4).cuda()
        in_data.requires_grad = True
        if input_cls is list:
            in_data = [in_data]
        else:
            assert input_cls is dict
            in_data = {"in": in_data}

        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    model.assert_state(TrainingState.IDLE)

    teardown()
示例#16
0
 def get_wrapped_model(group,
                       cuda_first=False,
                       config={},
                       **model_kwargs) -> FullyShardedDataParallel:
     if cuda_first:
         model = FullyShardedDataParallel(
             TransformerWithSharedParams(group, **model_kwargs).cuda(),
             group, **config)
     else:
         model = FullyShardedDataParallel(
             TransformerWithSharedParams(group, **model_kwargs), group,
             **config).cuda()
     return model
def test_it(fsdp_config, input_cls):
    """Test FSDP with input being a list or a dict, only single GPU."""
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    # Random port in case the next test run quickly, same port would cause conflict.
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
    torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

    try:
        assert isinstance(fsdp_config, dict), str(fsdp_config)

        class Model(Module):
            def __init__(self):
                super().__init__()
                self.layer = Linear(4, 4)

            def forward(self, input):
                if isinstance(input, list):
                    input = input[0]
                else:
                    assert isinstance(input, dict), input
                    input = input["in"]
                return self.layer(input)

        model = FSDP(Model(), **fsdp_config).cuda()
        optim = SGD(model.parameters(), lr=0.1)

        for _ in range(5):
            in_data = torch.rand(64, 4).cuda()
            in_data.requires_grad = True
            if input_cls is list:
                in_data = [in_data]
            else:
                assert input_cls is dict
                in_data = {"in": in_data}

            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        model.assert_state(TrainingState.IDLE)

    finally:
        # Clean-up is important or the next test in this file may fail to init the PG.
        torch.distributed.destroy_process_group()
        del os.environ["MASTER_ADDR"]
        del os.environ["MASTER_PORT"]
 def _test_module_state_dict(cls, config, rank, group):
     ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
     autocast = ddp_model.mixed_precision
     cls._train_for_several_steps(ddp_model, 2, autocast)
     state_1 = ddp_model.state_dict()
     # You must make a new FSDP instance to use module.load_state_dict
     unwrapped_model = TransformerWithSharedParams(group)
     unwrapped_model.load_state_dict(state_1)
     new_ddp_model = FSDP(unwrapped_model, group, **config).cuda()
     cls._train_for_several_steps(new_ddp_model, 2, autocast)
     try:
         ddp_model.load_state_dict(new_ddp_model.state_dict())
         assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded"
     except Exception:
         pass
示例#19
0
    def init_components(
        self,
        model_fn=None,
        criterion_fn=None,
        optimizer_fn=None,
        scheduler_fn=None,
    ):
        """Inits the runs components."""
        model = model_fn()
        model = self.sync_device(model)
        if self._sync_bn:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

        model = FullyShardedDataParallel(model, **self.ddp_kwargs)

        criterion = criterion_fn()
        criterion = self.sync_device(criterion)

        optimizer = optimizer_fn(model)
        optimizer = self.sync_device(optimizer)

        scheduler = scheduler_fn(optimizer)
        scheduler = self.sync_device(scheduler)

        return model, criterion, optimizer, scheduler
示例#20
0
    def _test_nested_wrapped_model_local_state_dict(cls,
                                                    rank,
                                                    group,
                                                    config=None,
                                                    local=None):
        # Create a nested FSDP-wrapped instance.
        model = NestedWrappedModule(group, config)
        model = FullyShardedDataParallel(model, group, **config).cuda()
        cls._train_for_several_steps(model,
                                     2,
                                     autocast=config["mixed_precision"])

        # Round trip state dict save/load/save.
        ref_state_dict = {
            k: v.clone()
            for k, v in model.local_state_dict().items()
        }
        model.load_local_state_dict(ref_state_dict)
        state_dict = model.local_state_dict()

        assert ref_state_dict.keys() == state_dict.keys(
        ), f"{ref_state_dict.keys()} != {state_dict.keys()}"
        for key in ref_state_dict.keys():
            assert objects_are_equal(
                ref_state_dict[key], state_dict[key], raise_exception=False
            ), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
示例#21
0
    def _test_nested_wrapped_model(cls, rank, group, config=None):
        # Get reference state dict without any nested FSDP instances.
        model = NestedWrappedModule(group, None).cuda()
        model = nn.parallel.DistributedDataParallel(model,
                                                    device_ids=[rank],
                                                    output_device=rank,
                                                    process_group=group)
        cls._train_for_several_steps(model,
                                     2,
                                     autocast=config["mixed_precision"])
        ref_state_dict = {
            k: v.clone()
            for k, v in model.module.state_dict().items()
        }

        # Create a nested FSDP-wrapped instance.
        if config["mixed_precision"]:
            config["compute_dtype"] = torch.float32
        model = NestedWrappedModule(group, config)
        model = FullyShardedDataParallel(model, group, **config).cuda()
        cls._train_for_several_steps(model,
                                     2,
                                     autocast=config["mixed_precision"])

        # Round-trip state dict save/load/save.
        state_dict = {k: v.clone() for k, v in model.state_dict().items()}
        model.load_state_dict(state_dict)
        state_dict = model.state_dict()

        assert ref_state_dict.keys() == state_dict.keys(
        ), f"{ref_state_dict.keys()} != {state_dict.keys()}"
        for key in ref_state_dict.keys():
            assert objects_are_equal(
                ref_state_dict[key], state_dict[key], raise_exception=False
            ), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
示例#22
0
    def _test_memory_benchmark(self, rank, group, config):
        time_keeper = TimeKeeper()

        SIZE = 8 * 8
        time_keeper.print_time("START", 1.0)
        a = torch.empty(1)
        b = a.cuda()
        # wait for cuda to fully load
        time.sleep(1)
        time_keeper.print_time("INIT_CUDA", 1.0)
        model = SimpleLinear(group,
                             input_size=SIZE,
                             output_size=SIZE,
                             layers=4)
        time_keeper.print_time("CPU_MODEL", 1.0)

        with tempfile.TemporaryDirectory() as current_tempdir:
            config["offload_config"] = OffloadConfig(
                offload_type="ssd_offload", ssd_filepath_dir=current_tempdir)

            model = FullyShardedDataParallel(model, **config)
            time_keeper.print_time("FSDP_MODEL", 1.0)

            self._eval_for_several_steps(model, 1, autocast=False)
            time_keeper.print_time("EVAL")
示例#23
0
    def init_distributed_data_parallel_model(self):
        """
        Initialize FSDP if needed.

        This method overloads the ClassificationTask class's method from ClassyVision.
        """
        if not is_distributed_training_run():
            return

        # Make sure default cuda device is set. TODO (Min): we should enable FSDP can
        # be enabled for 1-GPU as well, but the use case there is likely different.
        # I.e. perhaps we use it for cpu_offloading.
        assert get_cuda_device_index(
        ) > -1, "Distributed training not setup correctly"

        # The model might be already wrapped by FSDP internally. Check regnet_fsdp.py.
        # Here, we wrap it at the outer most level.
        fsdp_config = self.config["MODEL"]["FSDP_CONFIG"]
        if is_primary():
            logging.info(f"Using FSDP, config: {fsdp_config}")

        # First, wrap the head's prototype_i layers if it is SWAV.
        # TODO (Min): make this more general for different models, which may have multiple
        #             heads.
        head0 = self.base_model.heads[0]
        if isinstance(head0, SwAVPrototypesHead):
            for j in range(head0.nmb_heads):
                module = getattr(head0, "prototypes" + str(j))
                module = FSDP(module=module, **fsdp_config)
                setattr(head0, "prototypes" + str(j), module)

        # TODO (Min): We can load checkpoint, but it ends up setting the trunk's _is_root
        # flag to true. We need to set it back to None here.
        # Also, right now, the head's weight is only partially loaded from the checkpoint
        # because we dump the checkpoint after the head if wrapped, but loading it before
        # it is wrapped.
        # For very big models, we need re-work the checkpoint logic because we don't have
        # enough memory to load the entire model on one node. We need to use local_state_dict()
        # API to load checkpoint shards.
        for module in self.base_model.trunk.modules():
            if isinstance(module, FSDP):
                module._is_root = None

        # Then, wrap the whole model. We replace the base_model since it is used
        # when checkpoint is taken.
        self.base_model = FSDP(module=self.base_model, **fsdp_config)
        self.distributed_model = self.base_model
示例#24
0
 def __init__(self):
     super().__init__()
     self.ffn = nn.Sequential(
         nn.Linear(3, 3),
         FullyShardedDataParallel(
             checkpoint_wrapper(nn.Linear(3, 3),
                                maintain_forward_counter=True)),
         nn.Linear(3, 3),
     )
示例#25
0
    def _distributed_worker(
        gpu_id: int, with_fsdp: bool, sync_file: str, result_file: str
    ):
        torch.cuda.set_device(gpu_id)
        dist.init_process_group(
            backend="nccl", init_method="file://" + sync_file, world_size=2, rank=gpu_id
        )

        # Create the inputs
        torch.manual_seed(0)
        torch.backends.cudnn.deterministic = True
        batch = torch.randn(size=(8, 3, 224, 224)).cuda()

        # Create a fake model based on SWAV blocks
        config = TestRegnetFSDP._create_config(with_fsdp)
        model = build_model(config["MODEL"], config["OPTIMIZER"])
        model = model.cuda()
        if with_fsdp:
            model = FSDP(model)
        else:
            model = DistributedDataParallel(model, device_ids=[gpu_id])
        criterion = SwAVLoss(loss_config=config["LOSS"]["swav_loss"])
        optimizer = optim.SGD(model.parameters(), lr=1e-2)

        # Run a few iterations and collect the losses
        losses = []
        for iteration in range(5):
            out = model(batch)
            loss = criterion(out[0], torch.tensor(0.0).cuda())
            if gpu_id == 0:
                losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            if iteration <= 2:
                for name, param in model.named_parameters():
                    if "prototypes" in name:
                        param.grad = None
            optimizer.step()

        # Store the losses in a file to compare several methods
        if gpu_id == 0:
            with open(result_file, "wb") as f:
                pickle.dump(losses, f)
def _dist_worker(rank, world_size, files, outer_flat, inner_flat, sharing):

    # Get data from files.
    file1, file2, sd_before, sd_after, in_data = files
    sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank))
    sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank))
    in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank))

    result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2)
    assert result, "Dist init failed"

    fsdp_model = FSDP(Model(with_fsdp=True, inner_flat=inner_flat, sharing=sharing), flatten_parameters=outer_flat)
    fsdp_model.load_state_dict(sd_before)

    _train(fsdp_model, in_data)

    objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True)

    teardown()
示例#27
0
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused,
               test_case):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    my_lr = 0.1

    if test_case["assert_ref_out"]:
        with torch.no_grad():
            # Compute one iteration local output.
            weight = model.weight.T.clone().cuda()
            v = torch.Tensor(test_case["inputs"][0][rank]).cuda()
            ref_forward_output_my_rank = torch.matmul(v, weight)
            # Compute one iteration global weight update.
            v = torch.Tensor(test_case["inputs"][0][:world_size]).cuda()
            grad = v.sum(0).repeat(weight.shape[0], 1).div(world_size)
            ref_weight_out = weight - grad.T * my_lr
    model.to("cuda")
    assert isinstance(fsdp_config, dict), str(fsdp_config)
    model = FSDP(model, **fsdp_config)
    optim = SGD(model.parameters(), lr=my_lr)
    inputs = test_case["inputs"]
    assert len(inputs) == 1 or not test_case["assert_ref_out"]
    assert len(inputs[0]) >= world_size
    for in_data in inputs:
        in_data = Tensor(in_data[rank]).cuda()
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()
        if test_case["assert_ref_out"]:
            with model.summon_full_params():
                weight_out = model.module.weight.data.T.clone()
            # make sure we can do more fwd/bwd
            loss = model(in_data)
            loss.sum().backward()

    if test_case["assert_ref_out"]:
        torch.testing.assert_allclose(ref_forward_output_my_rank, out)
        torch.testing.assert_allclose(ref_weight_out, weight_out)

    model.assert_state(TrainingState.IDLE)
    teardown()
示例#28
0
def is_valid_fsdp_model(model: FSDP) -> bool:
    """
    Checks if a FSDP model is valid by looking at the sub-FSDP modules
    and ensuring that they do not think they are the root FSDP model
    """
    for n, m in model.named_modules():
        if isinstance(m, FSDP):
            if n != "" and m._is_root is not None:
                return False
    return True
示例#29
0
def fsdp_recursive_reset_lazy_init(fsdp_module: FSDP):
    """
    Before the first forward pass, an FSDP module might have been initialized
    for instance by calling load_state_dict or load_local_state_dict to
    reload a previous training checkpoint.

    This function will recursively walk though the sub-FSDP modules and
    call _reset_lazy_init on each of them.
    """
    for module in fsdp_module.modules():
        if isinstance(module, FSDP) and module._is_root is not None:
            module._reset_lazy_init()
示例#30
0
    def _test_identical_outputs(
        cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
    ):
        if config["mixed_precision"]:
            autocast = True
            # Force the compute dtype to be torch.float32 so that we get
            # identical results as PyTorch DDP when using autocast. Note that
            # this will cause the all-gather to happen in FP32, which is slower
            # than necessary in most cases.
            config["compute_dtype"] = torch.float32
        else:
            autocast = False

        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrapper_config=None).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[rank], output_device=rank, process_group=group
            )
        else:
            model = ref_ddp_fn(model, group)
        ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
        ref_state_dict = model.module.state_dict()

        # Confirm we get the same behavior using FullyShardedDataParallel.
        model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
        if use_cuda:
            model = model.cuda()
        else:
            assert next(model.parameters()).device == torch.device("cpu")
        shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
        shard_state_dict = model.state_dict()

        try:
            torch.testing.assert_allclose(ref_loss, shard_loss)
            assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
        except (AssertionError, RuntimeError) as e:
            raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")