def _test_state_dict_device(self, config, rank, group, pure_fp16=False, **model_kwargs):
        model = TransformerWithSharedParams(group, **model_kwargs)
        if pure_fp16:
            assert not config["mixed_precision"]
            model = model.half()
        fsdp_model = FSDP(model, group, **config)
        if not config["cpu_offload"]:
            fsdp_model = fsdp_model.cuda()
        autocast = fsdp_model.mixed_precision or pure_fp16
        self._train_for_several_steps(fsdp_model, 1, autocast)

        sd = fsdp_model.state_dict()

        sd_device = config.get("state_dict_device")
        for k, v in sd.items():
            if config["cpu_offload"] or (sd_device is not None and sd_device.type == "cpu"):
                assert v.device.type == "cpu", v.device.type
            else:
                assert v.device.type == "cuda", v.device.type

        expected_dtype = torch.float16 if pure_fp16 else torch.float32
        for k, v in sd.items():
            if not torch.is_floating_point(v):
                continue
            assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
示例#2
0
    def _test_identical_outputs_eval(
        cls,
        model_init_fn,
        config,
        rank,
        group,
        num_steps=2,
        use_cuda=True,
        lr=0.01,
        ref_ddp_fn=None,
    ):
        if config.get("mixed_precision", False):
            autocast = True
            # Force the compute dtype to be torch.float32 so that we get
            # identical results as PyTorch DDP when using autocast. Note that
            # this will cause the all-gather to happen in FP32, which is slower
            # than necessary in most cases.
            config["compute_dtype"] = torch.float32
        else:
            autocast = False

        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrapper_config=None).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(model,
                                                        device_ids=[rank],
                                                        output_device=rank,
                                                        process_group=group)
        else:
            model = ref_ddp_fn(model, group)
        ref_loss = cls._eval_with_config(model, autocast)
        ref_state_dict = model.module.state_dict()
        if config.get("cpu_offload", False):
            for k in ref_state_dict.keys():
                ref_state_dict[k] = ref_state_dict[k].cpu()

        # Confirm we get the same behavior using FullyShardedDataParallel.
        if config.get("ssd_offload", False):
            config["offload_config"] = OffloadConfig(
                offload_type="ssd_offload")

        del config["ssd_offload"]
        model = FullyShardedDataParallel(
            model_init_fn(group=group, wrapper_config=config), group, **config)
        if not model.ssd_offload and not model.move_params_to_cpu:
            if use_cuda:
                model = model.cuda()
            else:
                assert next(model.parameters()).device == torch.device("cpu")
        shard_loss = cls._eval_with_config(model, autocast)

        try:
            torch.testing.assert_allclose(ref_loss, shard_loss)
        except (AssertionError, RuntimeError) as e:
            raise Exception(
                f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}"
            )
        if config.get("flatten_parameters", True):
            metadata = model.local_metadata_dict()
            assert isinstance(metadata, dict)
示例#3
0
    def _test_identical_outputs(
        cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
    ):
        if config["mixed_precision"]:
            autocast = True
            # Force the compute dtype to be torch.float32 so that we get
            # identical results as PyTorch DDP when using autocast. Note that
            # this will cause the all-gather to happen in FP32, which is slower
            # than necessary in most cases.
            config["compute_dtype"] = torch.float32
        else:
            autocast = False

        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrapper_config=None).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[rank], output_device=rank, process_group=group
            )
        else:
            model = ref_ddp_fn(model, group)
        ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
        ref_state_dict = model.module.state_dict()

        # Confirm we get the same behavior using FullyShardedDataParallel.
        model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
        if use_cuda:
            model = model.cuda()
        else:
            assert next(model.parameters()).device == torch.device("cpu")
        shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
        shard_state_dict = model.state_dict()

        try:
            torch.testing.assert_allclose(ref_loss, shard_loss)
            assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
        except (AssertionError, RuntimeError) as e:
            raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")
示例#4
0
    def _test_state_dict_device(self,
                                config,
                                rank,
                                group,
                                pure_fp16=False,
                                **model_kwargs):
        model = TransformerWithSharedParams(group, **model_kwargs)
        if pure_fp16:
            assert not config["mixed_precision"]
            model = model.half()
        fsdp_model = FullyShardedDataParallel(model, group, **config)
        if not config["cpu_offload"]:
            fsdp_model = fsdp_model.cuda()
        autocast = fsdp_model.mixed_precision or pure_fp16
        self._train_for_several_steps(fsdp_model, 1, autocast)

        sd = fsdp_model.state_dict()

        sd_device = config.get("state_dict_device")
        for k, v in sd.items():
            if config["cpu_offload"] or (sd_device is not None
                                         and sd_device.type == "cpu"):
                assert v.device.type == "cpu", v.device.type
            else:
                assert v.device.type == "cuda", v.device.type

        expected_dtype = torch.float16 if pure_fp16 else torch.float32
        buffers = {
            k.replace("_fsdp_wrapped_module.", "").replace("_fpw_module.", "")
            for k, _ in fsdp_model.named_buffers()
        }
        for k, v in sd.items():
            if not torch.is_floating_point(v):
                continue
            if k in buffers:
                assert v.dtype == fsdp_model.buffer_dtype, f"{v.dtype} != {fsdp_model.buffer_dtype}"
            else:
                assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
示例#5
0
    def _test_named_params(self, rank, group, config):
        # Get the named parameters before wrapping.
        before_wrap_model = TransformerWithSharedParams(group)
        before_wrap_params = before_wrap_model.named_parameters()

        with tempfile.TemporaryDirectory() as current_tempdir:
            if config["ssd_offload"]:
                config["offload_config"] = OffloadConfig(
                    offload_type="ssd_offload",
                    ssd_filepath_dir=current_tempdir)
            del config["ssd_offload"]

            model = FullyShardedDataParallel(before_wrap_model, **config)
            print(f"model.ssd_offload {model.ssd_offload}")
            if not model.ssd_offload and not model.move_params_to_cpu:
                model = model.cuda()

            self._eval_with_config(model, autocast=config["mixed_precision"])

            # Get the named parameters after wrapping to compare.
            after_wrap_params = model.named_parameters()

            if not config.get("flatten_parameters", False):
                for before_nm, after_nm in zip(before_wrap_params,
                                               after_wrap_params):
                    assert before_nm[0] == after_nm[0]
            else:
                named_params_flat = [p for p in after_wrap_params][0][0]
                assert "flat_param_0" in named_params_flat

            after_wrap_params = model.named_parameters()

            for before_nm, after_nm_original in zip(before_wrap_params,
                                                    after_wrap_params):
                assert before_nm[0] == after_nm_original[0]
                torch.testing.assert_allclose(before_nm[1].shape,
                                              after_nm_original[1].shape)