def _dist_train(self, wrap_fsdp, cpu_offload=CPUOffload(offload_params=False)): # keep everything deterministic for input data torch.manual_seed(0) model = Model(wrap_fsdp, cpu_offload) if wrap_fsdp: model = FSDP(model, cpu_offload=cpu_offload) else: model = DistributedDataParallel(model, device_ids=[self.rank]) model.half() optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(16, 2).cuda().half() in_data.requires_grad = True for _ in range(1): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if wrap_fsdp: get_full_params(model) return list(model.parameters())
def _dist_train(self, with_nested_trunk, freezing_method, freeze_after_wrap_fsdp, with_fsdp): torch.manual_seed(0) batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = self._create_model(with_fsdp, with_nested_trunk, freeze_after_wrap_fsdp) model = model.cuda() # freezing the trunk using requires_grad. if freezing_method == FreezingMethod.RequiresGrad: for param in model.trunk.parameters(): param.requires_grad = False if with_fsdp: if not freeze_after_wrap_fsdp: model.fsdp_wrap() model = FSDP(model) else: model = DistributedDataParallel(model, device_ids=[self.rank]) target = torch.tensor([0, 1], dtype=torch.long).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) for iteration in range(3): out = model(batch) fake_loss = criterion(out, target) optimizer.zero_grad() fake_loss.backward() if freezing_method == FreezingMethod.GradToNone: if with_fsdp: for param in model.module.module.trunk.parameters(): param.grad = None else: for param in model.module.trunk.parameters(): param.grad = None optimizer.step() if with_fsdp: get_full_params(model) return list(model.parameters())
def test_one_iteration(self): """Test FSDP with uneven divide of parameter shards.""" model = Linear(3, 3, bias=False) input = torch.rand(8, 3) my_lr = 0.1 ref_forward_output_my_rank, ref_weight_out = self._get_ref_results( model, input, my_lr) model.to(self.rank) model = FSDP(model) optim = SGD(model.parameters(), lr=my_lr) self.assertTrue(len(input) >= self.world_size) in_data = torch.Tensor(input[self.rank]).to(self.rank) out = model(in_data) out.float().sum().backward() optim.step() optim.zero_grad() get_full_params(model) weight_out = model.module.weight.T.clone() self.assertEqual(ref_forward_output_my_rank, out) self.assertEqual(ref_weight_out, weight_out)
def _dist_train(self, wrap_fsdp: bool, state_dict_type: str = ""): # TODO: Move this test to common_fsdp. model = self._initialize_model(wrap_fsdp) optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if wrap_fsdp: blank_model = FSDP(Model(True).cuda()) _zero_model(blank_model) state_dict = self._state_dict(model, state_dict_type) self._load_state_dict(blank_model, state_dict_type, state_dict) return get_full_params(blank_model) else: return list(model.parameters())
def _dist_train(self, wrap_fsdp: bool, state_dict_type: str = "", with_context: bool = False): # TODO: Move this test to common_fsdp. model = self._initialize_model(wrap_fsdp) optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if wrap_fsdp: blank_model = FSDP(Model(True).cuda()) _zero_model(blank_model) if with_context: state_dict_type = { "state_dict": StateDictType.FULL_STATE_DICT, "local_state_dict": StateDictType.LOCAL_STATE_DICT, "sharded_state_dict": StateDictType.SHARDED_STATE_DICT, }[state_dict_type] with model.state_dict_type(state_dict_type): state_dict = model.state_dict() with blank_model.state_dict_type(state_dict_type): blank_model.load_state_dict(state_dict) else: state_dict = self._state_dict(model, state_dict_type) self._load_state_dict(blank_model, state_dict_type, state_dict) return get_full_params(blank_model) else: return list(model.parameters())
def test_save_and_load_after_forward_state_dict( self, state_dict_type, mixed_precision, state_dict_rank0_and_offload): """ Test that saving after some training results in params being updated as expected. """ if state_dict_rank0_and_offload and state_dict_type != "state_dict": return torch.cuda.set_device(self.rank) mixed_precision = (MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) if mixed_precision else None) model = self._get_simple_nested_model(mixed_precision=mixed_precision) optim = torch.optim.SGD(model.parameters(), lr=0.1) initial_params = get_full_params(model) for _ in range(6): inp = torch.randn(1, 10, device=torch.cuda.current_device()) output = model(*inp) loss = output.sum() expected_dtype = torch.float32 if mixed_precision is None else torch.float16 self.assertEqual(expected_dtype, loss.dtype) loss.backward() optim.step() trained_params = get_full_params(model) # Ensure some training occured self.assertNotEqual(initial_params, trained_params) # Save a copy of the state_dict fsd_mgr = self._get_state_dict_mgr(model, state_dict_type, state_dict_rank0_and_offload) with fsd_mgr: state_dict = model.state_dict() if state_dict_type == "state_dict": state_dict = {k: v.clone() for k, v in state_dict.items()} else: for sharded_tensor in state_dict.values(): shard = sharded_tensor._local_shards[0] shard.tensor = shard.tensor.clone().detach_() self._validate_state_dict_contents(model, state_dict, state_dict_rank0_and_offload) _zero_model(model) # Ensure checkpointed params have the full param dtype for tensor in state_dict.values(): self.assertEqual(tensor.dtype, torch.float32) # Load state_dict into zeroed model if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. state_dict = self._broadcast_state_dict(state_dict) for key in state_dict.keys(): state_dict[key] = state_dict[key].cuda() with FSDP.state_dict_type(model, STATE_DICT_MAPPING[state_dict_type]): model.load_state_dict(state_dict, strict=True) loaded_params = get_full_params(model) self.assertEqual(loaded_params, trained_params)