Esempio n. 1
0
    def _dist_train(self,
                    wrap_fsdp,
                    cpu_offload=CPUOffload(offload_params=False)):
        # keep everything deterministic for input data
        torch.manual_seed(0)

        model = Model(wrap_fsdp, cpu_offload)
        if wrap_fsdp:
            model = FSDP(model, cpu_offload=cpu_offload)
        else:
            model = DistributedDataParallel(model, device_ids=[self.rank])
        model.half()
        optim = SGD(model.parameters(), lr=0.1)

        in_data = torch.rand(16, 2).cuda().half()
        in_data.requires_grad = True
        for _ in range(1):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        if wrap_fsdp:
            get_full_params(model)

        return list(model.parameters())
Esempio n. 2
0
    def _dist_train(self, with_nested_trunk, freezing_method,
                    freeze_after_wrap_fsdp, with_fsdp):
        torch.manual_seed(0)
        batch = torch.randn(size=(2, 3, 224, 224)).cuda()

        model = self._create_model(with_fsdp, with_nested_trunk,
                                   freeze_after_wrap_fsdp)
        model = model.cuda()

        # freezing the trunk using requires_grad.
        if freezing_method == FreezingMethod.RequiresGrad:
            for param in model.trunk.parameters():
                param.requires_grad = False

        if with_fsdp:
            if not freeze_after_wrap_fsdp:
                model.fsdp_wrap()
            model = FSDP(model)
        else:
            model = DistributedDataParallel(model, device_ids=[self.rank])

        target = torch.tensor([0, 1], dtype=torch.long).cuda()
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

        for iteration in range(3):
            out = model(batch)
            fake_loss = criterion(out, target)
            optimizer.zero_grad()
            fake_loss.backward()
            if freezing_method == FreezingMethod.GradToNone:
                if with_fsdp:
                    for param in model.module.module.trunk.parameters():
                        param.grad = None
                else:
                    for param in model.module.trunk.parameters():
                        param.grad = None
            optimizer.step()

        if with_fsdp:
            get_full_params(model)

        return list(model.parameters())
Esempio n. 3
0
    def test_one_iteration(self):
        """Test FSDP with uneven divide of parameter shards."""
        model = Linear(3, 3, bias=False)
        input = torch.rand(8, 3)
        my_lr = 0.1

        ref_forward_output_my_rank, ref_weight_out = self._get_ref_results(
            model, input, my_lr)

        model.to(self.rank)
        model = FSDP(model)
        optim = SGD(model.parameters(), lr=my_lr)
        self.assertTrue(len(input) >= self.world_size)
        in_data = torch.Tensor(input[self.rank]).to(self.rank)
        out = model(in_data)
        out.float().sum().backward()
        optim.step()
        optim.zero_grad()
        get_full_params(model)
        weight_out = model.module.weight.T.clone()

        self.assertEqual(ref_forward_output_my_rank, out)
        self.assertEqual(ref_weight_out, weight_out)
    def _dist_train(self, wrap_fsdp: bool, state_dict_type: str = ""):
        # TODO: Move this test to common_fsdp.
        model = self._initialize_model(wrap_fsdp)
        optim = SGD(model.parameters(), lr=0.1)

        in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda"))
        for _ in range(3):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        if wrap_fsdp:
            blank_model = FSDP(Model(True).cuda())
            _zero_model(blank_model)
            state_dict = self._state_dict(model, state_dict_type)
            self._load_state_dict(blank_model, state_dict_type, state_dict)
            return get_full_params(blank_model)
        else:
            return list(model.parameters())
Esempio n. 5
0
    def _dist_train(self,
                    wrap_fsdp: bool,
                    state_dict_type: str = "",
                    with_context: bool = False):
        # TODO: Move this test to common_fsdp.
        model = self._initialize_model(wrap_fsdp)
        optim = SGD(model.parameters(), lr=0.1)

        in_data = torch.rand(64,
                             4,
                             requires_grad=True,
                             device=torch.device("cuda"))
        for _ in range(3):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        if wrap_fsdp:
            blank_model = FSDP(Model(True).cuda())
            _zero_model(blank_model)
            if with_context:
                state_dict_type = {
                    "state_dict": StateDictType.FULL_STATE_DICT,
                    "local_state_dict": StateDictType.LOCAL_STATE_DICT,
                    "sharded_state_dict": StateDictType.SHARDED_STATE_DICT,
                }[state_dict_type]
                with model.state_dict_type(state_dict_type):
                    state_dict = model.state_dict()
                with blank_model.state_dict_type(state_dict_type):
                    blank_model.load_state_dict(state_dict)
            else:
                state_dict = self._state_dict(model, state_dict_type)
                self._load_state_dict(blank_model, state_dict_type, state_dict)
            return get_full_params(blank_model)
        else:
            return list(model.parameters())
Esempio n. 6
0
    def test_save_and_load_after_forward_state_dict(
            self, state_dict_type, mixed_precision,
            state_dict_rank0_and_offload):
        """
        Test that saving after some training results in params being updated as
        expected.
        """
        if state_dict_rank0_and_offload and state_dict_type != "state_dict":
            return
        torch.cuda.set_device(self.rank)
        mixed_precision = (MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
        ) if mixed_precision else None)
        model = self._get_simple_nested_model(mixed_precision=mixed_precision)
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
        initial_params = get_full_params(model)
        for _ in range(6):
            inp = torch.randn(1, 10, device=torch.cuda.current_device())
            output = model(*inp)
            loss = output.sum()
            expected_dtype = torch.float32 if mixed_precision is None else torch.float16
            self.assertEqual(expected_dtype, loss.dtype)
            loss.backward()
            optim.step()

        trained_params = get_full_params(model)
        # Ensure some training occured
        self.assertNotEqual(initial_params, trained_params)
        # Save a copy of the state_dict
        fsd_mgr = self._get_state_dict_mgr(model, state_dict_type,
                                           state_dict_rank0_and_offload)
        with fsd_mgr:
            state_dict = model.state_dict()
            if state_dict_type == "state_dict":
                state_dict = {k: v.clone() for k, v in state_dict.items()}
            else:
                for sharded_tensor in state_dict.values():
                    shard = sharded_tensor._local_shards[0]
                    shard.tensor = shard.tensor.clone().detach_()
        self._validate_state_dict_contents(model, state_dict,
                                           state_dict_rank0_and_offload)
        _zero_model(model)

        # Ensure checkpointed params have the full param dtype
        for tensor in state_dict.values():
            self.assertEqual(tensor.dtype, torch.float32)

        # Load state_dict into zeroed model
        if state_dict_rank0_and_offload:
            # Broadcast the state dict and move it back to GPU in
            # preparation for loading.
            state_dict = self._broadcast_state_dict(state_dict)
            for key in state_dict.keys():
                state_dict[key] = state_dict[key].cuda()

        with FSDP.state_dict_type(model, STATE_DICT_MAPPING[state_dict_type]):
            model.load_state_dict(state_dict, strict=True)
        loaded_params = get_full_params(model)
        self.assertEqual(loaded_params, trained_params)