Ejemplo n.º 1
0
def test_scaler_cpu_offload_breaks():
    device = torch.device("cuda")
    torch.cuda.set_device(0)

    # Random port in case the next test run quickly, same port would cause conflict.
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
    torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

    try:
        scaler = ShardedGradScaler()
        model = FullyShardedDataParallel(nn.Linear(5, 5),
                                         cpu_offload=True,
                                         mixed_precision=True)
        optim = torch.optim.SGD(model.parameters(), lr=1e-3)

        input = torch.rand((1, 5), dtype=torch.float).to(device)
        optim.zero_grad()
        with autocast():
            output = model(input)
            loss = F.mse_loss(input, output)

        scaler.scale(loss).backward()
        # TODO (Min): Need to fix. Details in issue #421.
        with pytest.raises(RuntimeError):
            scaler.step(optim)
            scaler.update()

    finally:
        # Clean-up is important or the next test in this file may fail to init the PG.
        torch.distributed.destroy_process_group()
        del os.environ["MASTER_ADDR"]
        del os.environ["MASTER_PORT"]
Ejemplo n.º 2
0
    def _test_consolidate_weights(self,
                                  config,
                                  rank,
                                  group,
                                  paths=None,
                                  transformer=False):
        """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
        # Establish reference behavior.

        if transformer:
            fsdp = self.get_wrapped_model(group, config=config).cuda()
        else:
            fsdp = FullyShardedDataParallel(
                MixtureOfExperts(group, wrapper_config=config)).cuda()

        optim = Adam(
            fsdp.parameters(),
            lr=0.01,
        )
        optim.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):
            x = fsdp.module.get_input(torch.device("cuda"))
            output = fsdp(*x)
            loss = fsdp.module.get_loss(x, output).to("cuda")
            fsdp.module.run_backward(loss)
            optim.step()

        # each worker saves a checkpoint with local_state_dict
        cp_data = {
            "weights":
            {k: v.cpu()
             for k, v in fsdp.local_state_dict().items()},
            "meta": fsdp.local_metadata_dict(),
        }
        torch.save(cp_data, paths[fsdp.rank])
        full_model_state_dict = fsdp.state_dict()
        torch.distributed.barrier()
        if fsdp.rank > 0:
            return
        all_checkpoints = [torch.load(p) for p in paths]
        consolidated_checkpoint = FullyShardedDataParallel.consolidate_shard_weights(
            shard_weights=[c["weights"] for c in all_checkpoints],
            shard_metadata=[c["meta"] for c in all_checkpoints],
        )
        full_model_extra = set(full_model_state_dict).difference(
            set(consolidated_checkpoint))
        consolidated_extra = set(consolidated_checkpoint).difference(
            set(full_model_state_dict))
        msg = f"full model extra keys: {full_model_extra}, consolidated extra {consolidated_extra}"
        for k in full_model_state_dict.keys():
            assert consolidated_checkpoint[k].shape == full_model_state_dict[
                k].shape
        assert set(full_model_state_dict.keys()) == set(
            consolidated_checkpoint.keys()), msg
    def _test_consolidated_optimizer(self,
                                     config,
                                     rank,
                                     group,
                                     optim_fn=torch.optim.SGD,
                                     transformer=False):
        """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
        # Establish reference behavior.

        if transformer:
            fsdp = self.get_wrapped_model(group, config=config).cuda()
            unwrapped_model = TransformerWithSharedParams(group).cuda()
        else:
            fsdp = FullyShardedDataParallel(
                NestedWrappedModule(group, wrapper_config=config), group,
                **config).cuda()
            unwrapped_model = NestedWrappedModule(group,
                                                  wrapper_config=None).cuda()

        try:
            fsdp_optim = optim_fn(
                fsdp.parameters(),
                lr=0.01,
            )
            optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01)
        except TypeError:  # Adadelta
            fsdp_optim = optim_fn(fsdp.parameters())
            optim_unwrapped = optim_fn(unwrapped_model.parameters())

        fsdp_optim.zero_grad()
        optim_unwrapped.zero_grad()

        x = fsdp.module.get_input(torch.device("cuda"))
        output = fsdp(*x)
        loss = fsdp.module.get_loss(x, output).to("cuda")
        fsdp.module.run_backward(loss)
        fsdp_optim.step()

        output = unwrapped_model(*x)
        loss = unwrapped_model.get_loss(x, output)
        unwrapped_model.run_backward(loss)
        optim_unwrapped.step()
        unwrapped_sd = optim_unwrapped.state_dict()

        tstart = time()
        sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
        duration = time() - tstart
        # Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise
        assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"

        if fsdp.rank > 0:
            return

        assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
        assert_equal(len(sd["param_groups"][0]["params"]),
                     len(unwrapped_sd["param_groups"][0]["params"]))
        assert_equal(
            sum([first_tensor_numel(v) for k, v in sd["state"].items()]),
            sum([
                first_tensor_numel(v)
                for k, v in unwrapped_sd["state"].items()
            ]),
        )

        shard_sd = fsdp.get_shard_from_optim_state_dict(sd)

        original_shard_sd = fsdp_optim.state_dict()
        assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"]))
        assert_equal(shard_sd.keys(), original_shard_sd.keys())
        original_shard_sd = recursive_copy_to_device(original_shard_sd,
                                                     non_blocking=False,
                                                     device="cpu")

        assert_equal(
            sum([first_tensor_numel(v) for k, v in shard_sd["state"].items()]),
            sum([
                first_tensor_numel(v)
                for k, v in original_shard_sd["state"].items()
            ]),
        )
        assert objects_are_equal(shard_sd, original_shard_sd)
Ejemplo n.º 4
0
    def _test_consolidated_optimizer(self,
                                     config,
                                     rank,
                                     group,
                                     optim_fn=torch.optim.SGD,
                                     transformer=False):
        """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
        # Establish reference behavior.

        if transformer:
            unwrapped_model = TransformerWithSharedParams(
                group, wrapper_config=config).cuda()
            fsdp = self.get_wrapped_model(group, config=config).cuda()
        else:
            unwrapped_model = MixtureOfExperts(group,
                                               wrapper_config=None).cuda()
            fsdp = FullyShardedDataParallel(
                MixtureOfExperts(group, wrapper_config=config)).cuda()

        try:
            fsdp_optim = optim_fn(
                fsdp.parameters(),
                lr=0.01,
            )
            optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01)
        except TypeError:  # Adadelta
            fsdp_optim = optim_fn(fsdp.parameters())
            optim_unwrapped = optim_fn(unwrapped_model.parameters())

        fsdp_optim.zero_grad()
        optim_unwrapped.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):
            x = fsdp.module.get_input(torch.device("cuda"))
            output = fsdp(*x)
            loss = fsdp.module.get_loss(x, output).to("cuda")
            fsdp.module.run_backward(loss)
            fsdp_optim.step()

            output = unwrapped_model(*x)
            loss = unwrapped_model.get_loss(x, output)
            unwrapped_model.run_backward(loss)
            optim_unwrapped.step()
        unwrapped_sd = optim_unwrapped.state_dict()

        if not transformer:
            no_broadcast_children = [
                x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state
            ]
            assert len(no_broadcast_children) == 1
            assert fsdp._fsdp_instances[-1].no_broadcast_optim_state
        torch.cuda.empty_cache()
        cuda_gb_before = torch.cuda.memory_stats(
            fsdp.rank)["allocated_bytes.all.current"] / 1024**3
        tstart = time()
        sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
        duration = time() - tstart
        assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"

        cuda_gb_after = torch.cuda.memory_stats(
            fsdp.rank)["allocated_bytes.all.current"] / 1024**3
        mem_usg_gb = cuda_gb_after - cuda_gb_before
        assert mem_usg_gb == 0, f"gather_full_optim_state_dict used {mem_usg_gb:.2f} CUDA GB, max allowed is 0"
        assert cuda_gb_after > 0, "got 0 memory usage, logging is broken"

        if fsdp.rank > 0:
            assert sd is None
            return

        # assert whole state dict on CPU
        for k, v in sd["state"].items():
            for buffer_name, t in v.items():
                if torch.is_tensor(t):
                    msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU"
                    assert t.device == torch.device("cpu"), msg

        unflat_state = sd["state"]
        assert "uncollected_local_ids" in sd
        shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
        shard_sd = recursive_copy_to_device(shard_sd,
                                            non_blocking=False,
                                            device="cpu")
        state_after_get_shard = sd["state"]
        assert objects_are_equal(unflat_state,
                                 state_after_get_shard)  # no side effects.

        assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
        assert_equal(len(sd["param_groups"][0]["params"]),
                     len(unwrapped_sd["param_groups"][0]["params"]))
        assert_equal(
            sum([first_tensor_numel(v) for k, v in sd["state"].items()]),
            sum([
                first_tensor_numel(v)
                for k, v in unwrapped_sd["state"].items()
            ]),
        )

        original_shard_sd = fsdp_optim.state_dict()
        assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"]))
        assert_equal(shard_sd.keys(), original_shard_sd.keys())
        original_shard_sd = recursive_copy_to_device(original_shard_sd,
                                                     non_blocking=False,
                                                     device="cpu")
        # Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks.
        assert_equal(
            [first_tensor_numel(v) for k, v in shard_sd["state"].items()],
            [
                first_tensor_numel(v)
                for k, v in original_shard_sd["state"].items()
            ],
        )
        assert_equal(
            [v for k, v in shard_sd["param_groups"][0].items()],
            [v for k, v in original_shard_sd["param_groups"][0].items()],
        )
        assert objects_are_equal(shard_sd["state"], original_shard_sd["state"])
        assert objects_are_equal({k: shard_sd[k]
                                  for k in original_shard_sd},
                                 original_shard_sd)