def test_norm(device, norm_type, mixed_precision):
    """Test checkpoint_wrapper with different norm layers."""
    if device == "cuda" and not torch.cuda.is_available():
        pytest.skip("Skip due to lack of GPU")

    # Get input, ref, checkpoint models and make them equal.
    in_data = torch.rand(2, 2, 3, 3).to(device)
    m_ref = get_model(norm_type, False, mixed_precision).to(device)
    m_cpt = get_model(norm_type, True, mixed_precision).to(device)
    m_cpt.load_state_dict(m_ref.state_dict())

    if torch_version() >= (1, 6, 0):
        # This assert fails on 1.5.1.
        assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict())

    if mixed_precision != "fp32":
        in_data = in_data.half()

    # Needed due to checkpointing.
    in_data.requires_grad = True
    for model in (m_ref, m_cpt):
        optim = SGD(model.parameters(), lr=0.1)
        if device == "cpu" and mixed_precision != "fp32":
            # Got: RuntimeError: "batch_norm"/"layer_norm" not implemented for 'Half'.
            with pytest.raises(RuntimeError):
                out = model(in_data)
            return
        else:
            # Everything else work.
            out = model(in_data)
        out.sum().backward()
        optim.step()

    if torch_version() >= (1, 6, 0):
        assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict())
def _dist_worker(rank, world_size, files, wrap_middle, test_fn):

    # Get data from files.
    file1, file2, sd_before, sd_after, in_data = files
    sd_before = torch.load(
        sd_before, map_location=lambda storage, loc: storage.cuda(rank))
    if test_fn == "train":
        sd_after = torch.load(
            sd_after, map_location=lambda storage, loc: storage.cuda(rank))
    in_data = torch.load(in_data,
                         map_location=lambda storage, loc: storage.cuda(rank))

    result = dist_init(rank=rank,
                       world_size=world_size,
                       filename=file1,
                       filename_rpc=file2)
    assert result, "Dist init failed"

    fsdp_model = FSDP(
        # To debug: first make with_fsdp=False (no inner wrapping) work, then enable inner wrapping
        # and make that work.
        Model(with_fsdp=True, wrap_middle=wrap_middle),
        flatten_parameters=test_fn == "optim_state",
        mixed_precision=False,
        compute_dtype=torch.float16,
    )
    fsdp_model.load_state_dict(sd_before)

    if test_fn == "train":
        _train(fsdp_model, in_data)
        objects_are_equal(sd_after,
                          fsdp_model.state_dict(),
                          raise_exception=True)
    elif test_fn == "eval":
        _eval(fsdp_model, in_data)
    elif test_fn == "optim_state":
        optim = SGD(fsdp_model.parameters(), lr=0.1)
        for _ in range(3):
            out = fsdp_model(in_data)
            out.backward()
            optim.step()
        sd = fsdp_model.gather_full_optim_state_dict(optim)
        if rank == 0:
            # There should 8 momentum buffers in the state.
            assert len(sd["state"].keys()) == 8
        else:
            assert sd is None, "only rank 0 should have the optim state"
    else:
        assert 0, f"invalid test_fn {test_fn}"

    teardown()
 def test_named_params_ordering(self):
     """Test assumption of consolidate_optimizer_state_dict"""
     group = DummyProcessGroup(0, 1)
     model = TransformerWithSharedParams(group)
     named_pars = [p for n, p in model.named_parameters()]
     for i, p in enumerate(model.parameters()):
         assert objects_are_equal(p, named_pars[i])
def test_shared_weight_mevo(temp_files, wrap_middle, test_fn):
    """Test FSDP with a model with shared weights."""
    if test_fn == "optim_state":
        if wrap_middle != "flat":
            pytest.skip(
                "only support optim_state when root and middle part is flat")

    world_size = 2

    # Get ref.
    model = Model()
    sd_before = deepcopy(model.state_dict())
    in_data = (torch.rand(BS, SEQ) * (VOCAB - 1)).cuda().long()
    if test_fn == "train":
        _train(model, in_data, world_size)
        sd_after = deepcopy(model.state_dict())
        # Before and after state should not be equal.
        assert not objects_are_equal(sd_before, sd_after)

    # Save data
    torch.save(sd_before, temp_files[2])
    if test_fn == "train":
        torch.save(sd_after, temp_files[3])
    torch.save(in_data, temp_files[4])

    # Run FSDP
    mp.spawn(
        _dist_worker,
        (world_size, temp_files, wrap_middle, test_fn),
        nprocs=world_size,
    )
Exemplo n.º 5
0
    def _test_nested_wrapped_model(cls, rank, group, config=None):
        # Get reference state dict without any nested FSDP instances.
        model = NestedWrappedModule(group, None).cuda()
        model = nn.parallel.DistributedDataParallel(model,
                                                    device_ids=[rank],
                                                    output_device=rank,
                                                    process_group=group)
        cls._train_for_several_steps(model,
                                     2,
                                     autocast=config["mixed_precision"])
        ref_state_dict = {
            k: v.clone()
            for k, v in model.module.state_dict().items()
        }

        # Create a nested FSDP-wrapped instance.
        if config["mixed_precision"]:
            config["compute_dtype"] = torch.float32
        model = NestedWrappedModule(group, config)
        model = FullyShardedDataParallel(model, group, **config).cuda()
        cls._train_for_several_steps(model,
                                     2,
                                     autocast=config["mixed_precision"])

        # Round-trip state dict save/load/save.
        state_dict = {k: v.clone() for k, v in model.state_dict().items()}
        model.load_state_dict(state_dict)
        state_dict = model.state_dict()

        assert ref_state_dict.keys() == state_dict.keys(
        ), f"{ref_state_dict.keys()} != {state_dict.keys()}"
        for key in ref_state_dict.keys():
            assert objects_are_equal(
                ref_state_dict[key], state_dict[key], raise_exception=False
            ), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
Exemplo n.º 6
0
def test_shared_weight(temp_files, outer_flat, inner_flat, sharing):
    """Test FSDP with a model with shared weights."""

    outer_flat = outer_flat == "outer_flat"
    inner_flat = inner_flat == "inner_flat"
    world_size = 2

    # Get reference results.
    model = Model(sharing=sharing)
    sd_before = deepcopy(model.state_dict())
    in_data = torch.rand(1, 4).cuda()
    _train(model, in_data, world_size)
    sd_after = deepcopy(model.state_dict())
    # Before and after state should not be equal.
    assert not objects_are_equal(sd_before, sd_after)

    # Save data
    torch.save(sd_before, temp_files[2])
    torch.save(sd_after, temp_files[3])
    torch.save(in_data, temp_files[4])

    # Run FSDP
    mp.spawn(
        _dist_worker,
        (world_size, temp_files, outer_flat, inner_flat, sharing),
        nprocs=world_size,
    )
Exemplo n.º 7
0
    def test_unflatten_params(self):
        for module_init_fn in self._get_module_init_fns():
            module = FlattenParamsWrapper(module_init_fn())
            buffers = {k.replace("_fpw_module.", "") for k, _ in module.named_buffers()}

            def clone_state_dict():
                return OrderedDict((k, v.clone()) for k, v in module.state_dict().items())

            ref_flat_param = module.flat_param.clone()
            with module.unflatten_params():
                ref_state_dict = clone_state_dict()
            assert not torch.all(ref_flat_param == 0)

            # confirm that unflatten_params reflects values from new_flat_param
            new_flat_param = torch.full_like(module.flat_param, fill_value=42.0)
            with module.unflatten_params(flat_param=new_flat_param):
                new_state_dict = clone_state_dict()
                assert new_state_dict.keys() == ref_state_dict.keys()
                for k, v in new_state_dict.items():
                    if k in buffers:  # buffers are not changed
                        torch.testing.assert_allclose(v, ref_state_dict[k])
                    else:  # params reflect new_flat_param value
                        assert torch.all(v == 42.0)

            # after context manager exits, we go back to previous (reference) state
            torch.testing.assert_allclose(module.flat_param, ref_flat_param)
            with module.unflatten_params():
                ref_state_dict2 = clone_state_dict()
                assert objects_are_equal(ref_state_dict, ref_state_dict2)

            # if we load the new_state_dict, then the flat param should match new_flat_param
            module.load_state_dict(new_state_dict)
            torch.testing.assert_allclose(module.flat_param, new_flat_param)
Exemplo n.º 8
0
    def _test_nested_wrapped_model_local_state_dict(cls,
                                                    rank,
                                                    group,
                                                    config=None,
                                                    local=None):
        # Create a nested FSDP-wrapped instance.
        model = NestedWrappedModule(group, config)
        model = FullyShardedDataParallel(model, group, **config).cuda()
        cls._train_for_several_steps(model,
                                     2,
                                     autocast=config["mixed_precision"])

        # Round trip state dict save/load/save.
        ref_state_dict = {
            k: v.clone()
            for k, v in model.local_state_dict().items()
        }
        model.load_local_state_dict(ref_state_dict)
        state_dict = model.local_state_dict()

        assert ref_state_dict.keys() == state_dict.keys(
        ), f"{ref_state_dict.keys()} != {state_dict.keys()}"
        for key in ref_state_dict.keys():
            assert objects_are_equal(
                ref_state_dict[key], state_dict[key], raise_exception=False
            ), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
    def test_state_dict_equality(self):
        module = self._get_shared_params_transformer()
        ref_state_dict = module.state_dict()

        flat_module = FlattenParamsWrapper(module)
        flat_state_dict = flat_module.state_dict()

        assert objects_are_equal(ref_state_dict, flat_state_dict)
def _distributed_worker(gpu_id, world_size, with_fsdp, freezing_method,
                        tempfile_name, unused, rank_0_output, expected_state):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    model = _create_model(with_fsdp)
    model = model.cuda()

    # freezing the trunk using requires_grad.
    assert freezing_method in ["requires_grad", "grad_to_none"]
    if freezing_method == "requires_grad":
        for param in model.trunk.parameters():
            param.requires_grad = False

    if with_fsdp:
        model = FSDP(model)
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id])

    if gpu_id == 0:
        print(model)

    target = torch.LongTensor([0, 1]).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
        if freezing_method == "grad_to_none":
            for param in model.trunk.parameters():
                param.grad = None
        optimizer.step()

    if with_fsdp:
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        assert objects_are_equal(expected_state,
                                 fsdp_state,
                                 raise_exception=True)
    elif rank == 0:
        state_after = model.module.cpu().state_dict()
        torch.save(state_after, rank_0_output)

    teardown()
Exemplo n.º 11
0
def _dist_worker(rank, world_size, files, outer_flat, inner_flat, sharing):

    # Get data from files.
    file1, file2, sd_before, sd_after, in_data = files
    sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank))
    sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank))
    in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank))

    result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2)
    assert result, "Dist init failed"

    fsdp_model = FSDP(Model(with_fsdp=True, inner_flat=inner_flat, sharing=sharing), flatten_parameters=outer_flat)
    fsdp_model.load_state_dict(sd_before)

    _train(fsdp_model, in_data)

    objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True)

    teardown()
    def test_load_state_dict(self):
        """Test that original (unwrapped) state_dict can be loaded in wrapped module."""
        for module_init_fn in self._get_module_init_fns():
            module = module_init_fn()
            ref_state_dict = module.state_dict()
            ref_output = self._get_output(module)

            module = module_init_fn(seed=1234)
            flat_module = FlattenParamsWrapper(module)

            # This should work without the unflatten_params context manager
            flat_module.load_state_dict(ref_state_dict)
            flat_output = self._get_output(flat_module)
            assert objects_are_equal(ref_output, flat_output)

            # And it should work with the context manager too
            with flat_module.unflatten_params():
                flat_module.load_state_dict(ref_state_dict)
            flat_output = self._get_output(flat_module)
            assert objects_are_equal(ref_output, flat_output)
    def test_load_state_dict(self):
        module = self._get_shared_params_transformer()
        ref_state_dict = module.state_dict()
        ref_output = self._get_output(module)

        module = self._get_shared_params_transformer(seed=1234)
        flat_module = FlattenParamsWrapper(module)
        flat_module.load_state_dict(ref_state_dict)
        flat_output = self._get_output(flat_module)

        assert objects_are_equal(ref_output, flat_output)
Exemplo n.º 14
0
    def test_state_dict_equality(self):
        """Test that unflattened state dict matches original (unwrapped) one."""
        modules_to_test = [init_fn() for init_fn in self._get_module_init_fns()]
        for module in modules_to_test:
            ref_state_dict = module.state_dict()

            flat_module = FlattenParamsWrapper(module)
            flat_state_dict = flat_module.state_dict()

            assert (
                ref_state_dict.keys() == flat_state_dict.keys()
            ), f"{ref_state_dict.keys()} != {flat_state_dict.keys()}"
            assert objects_are_equal(ref_state_dict, flat_state_dict), f"{ref_state_dict} != {flat_state_dict}"
    def test_flat_state_dict(self):
        flat_module = self._get_shared_params_transformer()
        flat_module = FlattenParamsWrapper(flat_module)
        ref_output = self._get_output(flat_module)

        flat_state_dict = flat_module.flat_state_dict()

        new_module = self._get_shared_params_transformer(seed=1234)
        new_module = FlattenParamsWrapper(new_module)
        new_module.load_state_dict(flat_state_dict)
        new_output = self._get_output(new_module)

        assert objects_are_equal(ref_output, new_output)
    def test_flat_state_dict(self):
        """Test that flat state dict can be reloaded and produces the same results."""
        for module_init_fn in self._get_module_init_fns():
            flat_module = FlattenParamsWrapper(module_init_fn())
            ref_output = self._get_output(flat_module)

            flat_state_dict = flat_module.flat_state_dict()

            new_module = FlattenParamsWrapper(module_init_fn(seed=1234))
            new_module.load_state_dict(flat_state_dict)
            new_output = self._get_output(new_module)

            assert objects_are_equal(ref_output, new_output)
Exemplo n.º 17
0
    def _test_param_change_after_init(self, rank, group, config):
        # Establish reference behavior.
        model = self.get_wrapped_model(group, cuda_first=False, config=config)
        model.eval()  # no dropout for this test
        input = model.module.get_input(torch.device("cuda"))
        ref_output = model(*input)

        # Change the weights in place.
        model = self.get_wrapped_model(group, cuda_first=False, config=config)
        model.eval()  # no dropout for this test
        first_param = next(model.parameters())
        nn.init.normal_(first_param.data)
        new_output = model(*input)

        assert not objects_are_equal(ref_output, new_output), "new_output did not reflect change to param after init"
Exemplo n.º 18
0
    def _test_grad_acc(self, model, batch_dim, use_no_sync_context=True):
        # Generate two input batches. We'll test that we get the same grads if
        # we train on them sequentially while accumulating grads (with no_sync
        # or without no_sync) vs. concatenating the batches and training in one go.
        #
        # The difference between with no_sync and without is GPU memory vs. networking
        # bandwidth tradeoff.
        batch1 = model.module.get_input(torch.device("cuda"))
        assert isinstance(batch1, tuple)
        batch2 = tuple(
            # This randomly permutes the values in a multi-dim tensor.
            x.view(-1)[torch.randperm(x.numel())].view_as(x) for x in batch1)
        for x, y in zip(batch1, batch2):
            assert not torch.all(x == y)

        # Concat the batches along batch dimension.
        concat_batch = tuple(
            torch.cat((x, y), dim=batch_dim) for (x, y) in zip(batch1, batch2))

        # Establish reference behavior on the concat batch.
        model.zero_grad()
        output = model(*concat_batch)
        ref_loss = model.module.get_loss(concat_batch, output)
        ref_loss.backward()
        ref_grads = [p.grad.detach().clone() for p in model.parameters()]

        # Test that we get the same results by accumulating grads.
        model.zero_grad()
        context = contextlib.suppress()
        if use_no_sync_context:
            context = model.no_sync()
        with context:  # accumulate gradients from the first batch
            output = model(*batch1)
            loss1 = model.module.get_loss(batch1, output)
            loss1.backward()
        output = model(*batch2)
        loss2 = model.module.get_loss(batch2, output)
        loss2.backward()
        accumulated_loss = loss1 + loss2
        accumulated_grads = [
            p.grad.detach().clone() for p in model.parameters()
        ]

        torch.testing.assert_allclose(ref_loss, accumulated_loss)
        assert objects_are_equal(ref_grads,
                                 accumulated_grads,
                                 raise_exception=True)
Exemplo n.º 19
0
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn):
    mixed_precision = precision == "mixed"
    flatten = flatten == "flatten"
    wrap_bn = wrap_bn == "auto_wrap_bn"
    fp32_reduce_scatter = True if mixed_precision else None

    if torch_version() < (1, 8, 0) and flatten:
        # 1.6 and 1.7 throws this error:
        #   RuntimeError: Trying to backward through the graph a second time, but the saved
        #   intermediate results have already been freed. Specify retain_graph=True when calling
        #   backward the first time.
        pytest.skip("older pytorch throws error when flatten is used")

    world_size = 2
    expected_losses = None
    # Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt.
    for with_fsdp in [False, True]:
        for with_checkpoint in [False, True]:
            # Get 4 files: 2 for dist_init and 2 for each rank to save the losses.
            with temp_files_ctx(num=2 + world_size) as temp_files:
                mp.spawn(
                    _distributed_worker,
                    (
                        world_size,
                        with_fsdp,
                        with_checkpoint,
                        temp_files,
                        mixed_precision,
                        flatten,
                        wrap_bn,
                        fp32_reduce_scatter,
                    ),
                    nprocs=world_size,
                )
                final_losses = {}
                for rank in range(world_size):
                    with open(temp_files[2 + rank], "rb") as f:
                        final_losses[f"rank_{rank}"] = pickle.load(f)
                if expected_losses is None:
                    expected_losses = final_losses
                else:
                    print(f"fsdp: {with_fsdp} ckpt: {with_checkpoint}")
                    assert objects_are_equal(expected_losses,
                                             final_losses,
                                             raise_exception=True)
Exemplo n.º 20
0
    def _test_transformer(self, rank, group, config):
        autocast = config["mixed_precision"]

        # Train model for a step
        model = self.get_wrapped_model(group, cuda_first=False, config=config)
        self._train_for_several_steps(model, 1, autocast)

        model.eval()  # no dropout for this test

        # Eval in standard mode (i.e., without no_grad)
        input = model.module.get_input(torch.device("cuda"))
        ref_output = model(*input)

        # Eval with no_grad and compare
        with torch.no_grad():
            no_grad_output = model(*input)

        assert objects_are_equal(ref_output, no_grad_output, raise_exception=True)
Exemplo n.º 21
0
    def _test_identical_outputs(
        cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
    ):
        if config.get("mixed_precision", False):
            autocast = True
            # Force the compute dtype to be torch.float32 so that we get
            # identical results as PyTorch DDP when using autocast. Note that
            # this will cause the all-gather to happen in FP32, which is slower
            # than necessary in most cases.
            config["compute_dtype"] = torch.float32
        else:
            autocast = False

        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrapper_config=None).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[rank], output_device=rank, process_group=group
            )
        else:
            model = ref_ddp_fn(model, group)
        ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
        ref_state_dict = model.module.state_dict()
        if config.get("cpu_offload", False):
            for k in ref_state_dict.keys():
                ref_state_dict[k] = ref_state_dict[k].cpu()

        # Confirm we get the same behavior using FullyShardedDataParallel.
        model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
        if use_cuda:
            model = model.cuda()
        else:
            assert next(model.parameters()).device == torch.device("cpu")
        shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
        shard_state_dict = model.state_dict()

        try:
            torch.testing.assert_allclose(ref_loss, shard_loss)
            assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
        except (AssertionError, RuntimeError) as e:
            raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")
        if config.get("flatten_parameters", True):
            metadata = model.local_metadata_dict()
            assert isinstance(metadata, dict)
Exemplo n.º 22
0
def _distributed_worker(
    rank,
    world_size,
    fsdp_config,
    fsdp_wrap_bn,
    ddp_mixed_precision,
    tempfile_name,
    unused,
    state_before,
    inputs,
    rank_0_output,
    state_after,
    sync_bn,
    conv_bias,
    linear_bias,
):
    torch.backends.cudnn.deterministic = True

    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    ddp = True
    if fsdp_config:
        ddp = False
        assert isinstance(fsdp_config, dict), str(fsdp_config)
        if fsdp_config["mixed_precision"]:
            # To match DDP in AMP -O1, we need fp32 reduce scatter.
            fsdp_config["fp32_reduce_scatter"] = True

    model = Model(conv_bias, linear_bias)
    model.load_state_dict(state_before)
    model = model.cuda()

    class DummyScaler:
        def scale(self, loss):
            return loss

        def step(self, optim):
            optim.step()

        def update(self):
            pass

    scaler = DummyScaler()
    if ddp:
        if sync_bn == "pytorch":
            model = pytorch_bn_converter(model)
        model = DDP(model, device_ids=[rank], broadcast_buffers=True)
        if ddp_mixed_precision:
            scaler = GradScaler()
    else:
        # Note, different rank may wrap in different order due to different random
        # seeds. But results should be the same.
        if random.randint(0, 1) == 0:
            print(f"auto_wrap_bn {fsdp_wrap_bn}, then sync_bn {sync_bn}")
            if fsdp_wrap_bn:
                model = auto_wrap_bn(model, _single_rank_pg)
            if sync_bn == "pytorch":
                model = pytorch_bn_converter(model)
        else:
            print(f"sync_bn {sync_bn}, then auto_wrap_bn {fsdp_wrap_bn}")
            if sync_bn == "pytorch":
                model = pytorch_bn_converter(model)
            if fsdp_wrap_bn:
                model = auto_wrap_bn(model, _single_rank_pg)
        model = FSDP(model, **fsdp_config).cuda()
        if fsdp_config["mixed_precision"]:
            scaler = ShardedGradScaler()
        # Print the model for verification.
        if rank == 0:
            print(model)
    optim = SGD(model.parameters(), lr=0.1)
    loss_func = CrossEntropyLoss()

    for in_data in inputs[rank]:
        in_data = in_data.cuda()
        context = contextlib.suppress()
        if ddp and ddp_mixed_precision:
            in_data = in_data.half()
            context = torch.cuda.amp.autocast(enabled=True)
        if not ddp and fsdp_config["mixed_precision"]:
            context = torch.cuda.amp.autocast(enabled=True)
        with context:
            out = model(in_data)
            fake_label = torch.zeros(1, dtype=torch.long).cuda()
            loss = loss_func(out.unsqueeze(0), fake_label)
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()
        optim.zero_grad()

    if ddp:
        # Save the rank 0 state_dict to the output file.
        if rank == 0:
            state_after = model.module.cpu().state_dict()
            torch.save(state_after, rank_0_output)
    else:
        model.assert_state(TrainingState.IDLE)
        # Ensure final state equals to the state_after.
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        # Change False to True to enable this when you want to debug the mismatch.
        if False and rank == 0:

            def dump(d):
                for k, v in d.items():
                    print(k, v)

            dump(state_after)
            dump(fsdp_state)
        # If sync_bn is used, all ranks should have the same state, so we can compare with
        # rank 0 state on every rank. Otherwise, only compare rank 0 with rank 0.
        if sync_bn != "none" or rank == 0:
            assert objects_are_equal(state_after,
                                     fsdp_state,
                                     raise_exception=True)

    teardown()
    def _test_consolidated_optimizer(self,
                                     config,
                                     rank,
                                     group,
                                     optim_fn=torch.optim.SGD,
                                     transformer=False):
        """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
        # Establish reference behavior.

        if transformer:
            fsdp = self.get_wrapped_model(group, config=config).cuda()
            unwrapped_model = TransformerWithSharedParams(group).cuda()
        else:
            fsdp = FullyShardedDataParallel(
                NestedWrappedModule(group, wrapper_config=config), group,
                **config).cuda()
            unwrapped_model = NestedWrappedModule(group,
                                                  wrapper_config=None).cuda()

        try:
            fsdp_optim = optim_fn(
                fsdp.parameters(),
                lr=0.01,
            )
            optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01)
        except TypeError:  # Adadelta
            fsdp_optim = optim_fn(fsdp.parameters())
            optim_unwrapped = optim_fn(unwrapped_model.parameters())

        fsdp_optim.zero_grad()
        optim_unwrapped.zero_grad()

        x = fsdp.module.get_input(torch.device("cuda"))
        output = fsdp(*x)
        loss = fsdp.module.get_loss(x, output).to("cuda")
        fsdp.module.run_backward(loss)
        fsdp_optim.step()

        output = unwrapped_model(*x)
        loss = unwrapped_model.get_loss(x, output)
        unwrapped_model.run_backward(loss)
        optim_unwrapped.step()
        unwrapped_sd = optim_unwrapped.state_dict()

        tstart = time()
        sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
        duration = time() - tstart
        # Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise
        assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"

        if fsdp.rank > 0:
            return

        assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
        assert_equal(len(sd["param_groups"][0]["params"]),
                     len(unwrapped_sd["param_groups"][0]["params"]))
        assert_equal(
            sum([first_tensor_numel(v) for k, v in sd["state"].items()]),
            sum([
                first_tensor_numel(v)
                for k, v in unwrapped_sd["state"].items()
            ]),
        )

        shard_sd = fsdp.get_shard_from_optim_state_dict(sd)

        original_shard_sd = fsdp_optim.state_dict()
        assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"]))
        assert_equal(shard_sd.keys(), original_shard_sd.keys())
        original_shard_sd = recursive_copy_to_device(original_shard_sd,
                                                     non_blocking=False,
                                                     device="cpu")

        assert_equal(
            sum([first_tensor_numel(v) for k, v in shard_sd["state"].items()]),
            sum([
                first_tensor_numel(v)
                for k, v in original_shard_sd["state"].items()
            ]),
        )
        assert objects_are_equal(shard_sd, original_shard_sd)
    def _test_output(self, module):
        ref_output = self._get_output(module)

        flat_module = FlattenParamsWrapper(module)
        flat_output = self._get_output(flat_module)
        assert objects_are_equal(ref_output, flat_output)
Exemplo n.º 25
0
    def _test_consolidated_optimizer(self,
                                     config,
                                     rank,
                                     group,
                                     optim_fn=torch.optim.SGD,
                                     transformer=False):
        """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
        # Establish reference behavior.

        if transformer:
            unwrapped_model = TransformerWithSharedParams(
                group, wrapper_config=config).cuda()
            fsdp = self.get_wrapped_model(group, config=config).cuda()
        else:
            unwrapped_model = MixtureOfExperts(group,
                                               wrapper_config=None).cuda()
            fsdp = FullyShardedDataParallel(
                MixtureOfExperts(group, wrapper_config=config)).cuda()

        try:
            fsdp_optim = optim_fn(
                fsdp.parameters(),
                lr=0.01,
            )
            optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01)
        except TypeError:  # Adadelta
            fsdp_optim = optim_fn(fsdp.parameters())
            optim_unwrapped = optim_fn(unwrapped_model.parameters())

        fsdp_optim.zero_grad()
        optim_unwrapped.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):
            x = fsdp.module.get_input(torch.device("cuda"))
            output = fsdp(*x)
            loss = fsdp.module.get_loss(x, output).to("cuda")
            fsdp.module.run_backward(loss)
            fsdp_optim.step()

            output = unwrapped_model(*x)
            loss = unwrapped_model.get_loss(x, output)
            unwrapped_model.run_backward(loss)
            optim_unwrapped.step()
        unwrapped_sd = optim_unwrapped.state_dict()

        if not transformer:
            no_broadcast_children = [
                x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state
            ]
            assert len(no_broadcast_children) == 1
            assert fsdp._fsdp_instances[-1].no_broadcast_optim_state
        torch.cuda.empty_cache()
        cuda_gb_before = torch.cuda.memory_stats(
            fsdp.rank)["allocated_bytes.all.current"] / 1024**3
        tstart = time()
        sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
        duration = time() - tstart
        assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"

        cuda_gb_after = torch.cuda.memory_stats(
            fsdp.rank)["allocated_bytes.all.current"] / 1024**3
        mem_usg_gb = cuda_gb_after - cuda_gb_before
        assert mem_usg_gb == 0, f"gather_full_optim_state_dict used {mem_usg_gb:.2f} CUDA GB, max allowed is 0"
        assert cuda_gb_after > 0, "got 0 memory usage, logging is broken"

        if fsdp.rank > 0:
            assert sd is None
            return

        # assert whole state dict on CPU
        for k, v in sd["state"].items():
            for buffer_name, t in v.items():
                if torch.is_tensor(t):
                    msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU"
                    assert t.device == torch.device("cpu"), msg

        unflat_state = sd["state"]
        assert "uncollected_local_ids" in sd
        shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
        shard_sd = recursive_copy_to_device(shard_sd,
                                            non_blocking=False,
                                            device="cpu")
        state_after_get_shard = sd["state"]
        assert objects_are_equal(unflat_state,
                                 state_after_get_shard)  # no side effects.

        assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
        assert_equal(len(sd["param_groups"][0]["params"]),
                     len(unwrapped_sd["param_groups"][0]["params"]))
        assert_equal(
            sum([first_tensor_numel(v) for k, v in sd["state"].items()]),
            sum([
                first_tensor_numel(v)
                for k, v in unwrapped_sd["state"].items()
            ]),
        )

        original_shard_sd = fsdp_optim.state_dict()
        assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"]))
        assert_equal(shard_sd.keys(), original_shard_sd.keys())
        original_shard_sd = recursive_copy_to_device(original_shard_sd,
                                                     non_blocking=False,
                                                     device="cpu")
        # Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks.
        assert_equal(
            [first_tensor_numel(v) for k, v in shard_sd["state"].items()],
            [
                first_tensor_numel(v)
                for k, v in original_shard_sd["state"].items()
            ],
        )
        assert_equal(
            [v for k, v in shard_sd["param_groups"][0].items()],
            [v for k, v in original_shard_sd["param_groups"][0].items()],
        )
        assert objects_are_equal(shard_sd["state"], original_shard_sd["state"])
        assert objects_are_equal({k: shard_sd[k]
                                  for k in original_shard_sd},
                                 original_shard_sd)