def _test_func(rank, world_size, tempfile_name, unused): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" # Keep initialization deterministic. torch.manual_seed(0) model = FullyShardedDataParallel(SimpleModuleWithCheckpointing().cuda()) optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Collect parameter sizes to ensure these stay consistent through the steps below. expected_param_shapes = { name: tuple(param.shape) for name, param in model.named_parameters() } # For clarity, this is what `expected_param_shapes` should look like depending on world size: assert expected_param_shapes == { "_fsdp_wrapped_module.flat_param_0": (12, ), "_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param_0": (6, ), }, expected_param_shapes torch.manual_seed(1 + rank) # Train for a step. _train_step(model, optim, expected_param_shapes) # Now do an eval step. _eval_step(model, optim, expected_param_shapes) # And finally do another train step. _train_step(model, optim, expected_param_shapes) teardown()
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" if test_case["assert_ref_out"]: with torch.no_grad(): weight = model.weight.T.clone().cuda() v = torch.Tensor(test_case["inputs"][0][rank]).cuda() ref_out = torch.matmul(v, weight) model.to("cuda") assert isinstance(fsdp_config, dict), str(fsdp_config) model = FSDP(model, **fsdp_config) optim = SGD(model.parameters(), lr=0.1) inputs = test_case["inputs"] assert len(inputs) == 1 or not test_case["assert_ref_out"] assert len(inputs[0]) >= world_size for in_data in inputs: in_data = Tensor(in_data[rank]).cuda() out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if test_case["assert_ref_out"]: torch.testing.assert_allclose(ref_out, out) model.assert_state(TrainingState.IDLE) teardown()
def __init__(self, with_fsdp=False, wrap_middle="none"): super().__init__() self.l0 = nn.Embedding(VOCAB, D_MODEL).cuda().half() nn.init.uniform_(self.l0.weight, -1.0e-1, 1.0e-1) self.l1 = MEVO(self.l0.weight, tile_factor=TILE, reduction="sum") self.middle = nn.Linear(D_MODEL, D_MODEL).cuda().half() # LNs are not strictly needed for this test, but they help reduce the loss quickly # and improves the numerical stability. self.ln1 = nn.LayerNorm(D_MODEL).cuda().half() self.ln2 = nn.LayerNorm(D_MODEL).cuda().half() if with_fsdp: # Shared layers must be un-flatten. self.l0 = FSDP(self.l0, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16) self.l1 = FSDP(self.l1, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16) self.l1.append_shared_param(self.l0.module.weight) # These are for debugging. # print(id(self.l0), "is emb") # print(id(self.l1), "is out") assert wrap_middle in ["none", "flat", "nonflat"] if wrap_middle != "none": self.middle = FSDP( self.middle, flatten_parameters=wrap_middle == "flat", mixed_precision=False, compute_dtype=torch.float16, )
def __init__(self, with_fsdp=False, inner_flat=False, sharing=None): super().__init__() self.l0 = Linear(4, 4, bias=True).cuda() self.l1 = Linear(4, 4, bias=True).cuda() self.l2 = Linear(4, 4, bias=True).cuda() self.l3 = Linear(4, 4, bias=True).cuda() # share the weights. the layer must have at least 1 param is that's not # shared. Therefore, we have bias=True and testing either sharing the # weight or the bias. if sharing == "share_only_weights": self.l1.weight = self.l3.weight elif sharing == "share_only_bias": self.l1.bias = self.l3.bias else: assert sharing is None or sharing == "share_none" if with_fsdp: # Shared layers much be un-flatten. self.l1 = FSDP(self.l1, flatten_parameters=False) self.l2 = FSDP(self.l2, flatten_parameters=inner_flat) self.l3 = FSDP(self.l3, flatten_parameters=False) if sharing in ["share_only_weights"]: self.l3.append_shared_param(self.l1.module.weight) if sharing in ["share_only_bias"]: self.l3.append_shared_param(self.l1.module.bias)
def _test_dtypes(cfg: Dict, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtype, rank, group, expected_buffer_type=None): # Patch torch.distributed.reduce_scatter to check the dtype of the reduction orig_reduce_scatter = torch.distributed.reduce_scatter model: nn.Module = DeviceAndTypeCheckModule( expected_input_dtype=in_dtype, expected_param_dtype=p_dtype, expected_loss_dtype=loss_dtype, expected_buffer_dtype=expected_buffer_type, ) def _reduce_scatter(output, input_list, **kwargs): for tensor in input_list: model._check("reduce_scatter.dtype", tensor.dtype, expected=reduce_dtype) return orig_reduce_scatter(output, input_list, **kwargs) with mock.patch("torch.distributed.reduce_scatter", new=_reduce_scatter): model = FullyShardedDataParallel(model, group, **cfg).cuda() device = next(model.parameters()).device x = torch.rand(2, 5).to(device) with torch.cuda.amp.autocast(enabled=autocast): loss = model(x) loss.backward()
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" assert isinstance(fsdp_config, dict), str(fsdp_config) class Model(Module): def __init__(self): super().__init__() self.inner = FSDP(Linear(4, 4), **fsdp_config) self.outer = Linear(4, 5) def forward(self, x): # Forward twice. i = self.inner(x) j = self.inner(x) return self.outer(i + j) model = FSDP(Model(), **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(3): in_data = torch.rand(64, 4).cuda() in_data.requires_grad = True out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() model.assert_state(TrainingState.IDLE) teardown()
def _test_identical_outputs_eval( cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, ): if config.get("mixed_precision", False): autocast = True # Force the compute dtype to be torch.float32 so that we get # identical results as PyTorch DDP when using autocast. Note that # this will cause the all-gather to happen in FP32, which is slower # than necessary in most cases. config["compute_dtype"] = torch.float32 else: autocast = False # Establish reference behavior with PyTorch DDP (+ optionally autocast). model = model_init_fn(group=group, wrapper_config=None).cuda() if ref_ddp_fn is None: model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group) else: model = ref_ddp_fn(model, group) ref_loss = cls._eval_with_config(model, autocast) ref_state_dict = model.module.state_dict() if config.get("cpu_offload", False): for k in ref_state_dict.keys(): ref_state_dict[k] = ref_state_dict[k].cpu() # Confirm we get the same behavior using FullyShardedDataParallel. if config.get("ssd_offload", False): config["offload_config"] = OffloadConfig( offload_type="ssd_offload") del config["ssd_offload"] model = FullyShardedDataParallel( model_init_fn(group=group, wrapper_config=config), group, **config) if not model.ssd_offload and not model.move_params_to_cpu: if use_cuda: model = model.cuda() else: assert next(model.parameters()).device == torch.device("cpu") shard_loss = cls._eval_with_config(model, autocast) try: torch.testing.assert_allclose(ref_loss, shard_loss) except (AssertionError, RuntimeError) as e: raise Exception( f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}" ) if config.get("flatten_parameters", True): metadata = model.local_metadata_dict() assert isinstance(metadata, dict)
def __init__(self, group, wrapper_config, checkpoint_act=False): super().__init__(group, wrapper_config) self.group = group # "expert" params are different on each rank torch.manual_seed(42 + group.rank()) expert = nn.Linear(16, 4) for p in expert.parameters(): p.expert = True # everything else is shared torch.manual_seed(0) shared = nn.Linear(4, 16) if checkpoint_act: expert = checkpoint_wrapper(expert) shared = checkpoint_wrapper(shared) if wrapper_config is not None: # we create a process group of size 1 for the expert params expert_group = torch.distributed.new_group([group.rank()]) expert = FullyShardedDataParallel(expert, expert_group, **wrapper_config) shared = FullyShardedDataParallel(shared, group, **wrapper_config) self.module = nn.Sequential(nn.Linear(8, 4), shared, expert, nn.Linear(4, 8))
def _test_ssd_offload_eval(self, rank, group, config): model = TransformerWithSharedParams(group) state_dict = model.state_dict() nested_wrapping = config["nested_wrapping"] del config["nested_wrapping"] with tempfile.TemporaryDirectory() as current_tempdir: config["offload_config"] = OffloadConfig( offload_type="ssd_offload", ssd_filepath_dir=current_tempdir) if nested_wrapping: model = FullyShardedDataParallel( NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)) else: model = FullyShardedDataParallel(model, **config) self._eval_with_config(model, autocast=config["mixed_precision"]) # With SSD offload only local_state_dict will work. We can support global # state dict if we think it is necessary. state_dict = model.local_state_dict() model.load_local_state_dict(state_dict) self._eval_with_config(model, config["mixed_precision"])
def _test_state_dict_device(self, config, rank, group, pure_fp16=False, **model_kwargs): model = TransformerWithSharedParams(group, **model_kwargs) if pure_fp16: assert not config["mixed_precision"] model = model.half() fsdp_model = FSDP(model, group, **config) if not config["cpu_offload"]: fsdp_model = fsdp_model.cuda() autocast = fsdp_model.mixed_precision or pure_fp16 self._train_for_several_steps(fsdp_model, 1, autocast) sd = fsdp_model.state_dict() sd_device = config.get("state_dict_device") for k, v in sd.items(): if config["cpu_offload"] or (sd_device is not None and sd_device.type == "cpu"): assert v.device.type == "cpu", v.device.type else: assert v.device.type == "cuda", v.device.type expected_dtype = torch.float16 if pure_fp16 else torch.float32 for k, v in sd.items(): if not torch.is_floating_point(v): continue assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_free_ms=0): super().__init__(group, wrapper_config) self.group = group self.delay_before_free_ms = delay_before_free_ms # "expert" params are different on each rank torch.manual_seed(42 + group.rank()) d_expert = 23 d_shared = 12 d_input = 8 expert = nn.Linear(d_expert, d_shared) self.num_expert_params = sum([p.numel() for p in expert.parameters()]) for p in expert.parameters(): p.expert = True # everything else is shared torch.manual_seed(0) shared = nn.Linear(d_shared, d_expert) if checkpoint_act: expert = checkpoint_wrapper(expert) shared = checkpoint_wrapper(shared) if wrapper_config is not None: # we create a process group of size 1 for the expert params expert_group = torch.distributed.new_group([group.rank()]) # world size 1 means no shard expert = FullyShardedDataParallel(expert, expert_group, **wrapper_config) shared = FullyShardedDataParallel(shared, group, **wrapper_config) self.module = nn.Sequential(nn.Linear(d_input, d_shared), shared, expert, nn.Linear(d_shared, d_input))
def _distributed_worker(gpu_id, world_size, with_fsdp, freezing_method, tempfile_name, unused, rank_0_output, expected_state): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = _create_model(with_fsdp) model = model.cuda() # freezing the trunk using requires_grad. assert freezing_method in ["requires_grad", "grad_to_none"] if freezing_method == "requires_grad": for param in model.trunk.parameters(): param.requires_grad = False if with_fsdp: model = FSDP(model) else: model = DistributedDataParallel(model, device_ids=[gpu_id]) if gpu_id == 0: print(model) target = torch.LongTensor([0, 1]).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) for iteration in range(3): out = model(batch) fake_loss = criterion(out, target) print("Loss", iteration, ":", fake_loss.item()) optimizer.zero_grad() fake_loss.backward() if freezing_method == "grad_to_none": for param in model.trunk.parameters(): param.grad = None optimizer.step() if with_fsdp: fsdp_state = model.state_dict() # Move tensors to CPU to compare numerics. for k, v in fsdp_state.items(): fsdp_state[k] = v.cpu() assert objects_are_equal(expected_state, fsdp_state, raise_exception=True) elif rank == 0: state_after = model.module.cpu().state_dict() torch.save(state_after, rank_0_output) teardown()
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" my_lr = 0.1 device = torch.device("cuda") if fsdp_config.get("mixed_precision", False): dtype = torch.float16 fsdp_config["fp32_reduce_scatter"] = True else: dtype = torch.float32 if test_case["assert_ref_out"]: with torch.no_grad(): # Compute one iteration local output. fp32_weight = model.weight.T.clone().to(device) weight = fp32_weight.to(dtype) v = torch.Tensor(test_case["inputs"][0][rank]).to(device, dtype) ref_forward_output_my_rank = torch.matmul(v, weight) # Compute one iteration global weight update. v = torch.Tensor(test_case["inputs"][0][:world_size]).to( device, dtype) grad = v.float().sum(0).repeat(weight.shape[0], 1).div(world_size) ref_weight_out = fp32_weight - grad.T * my_lr assert ref_weight_out.dtype == torch.float32 model.to( device) # not dtype, since FSDP will manage mixed precision internally assert isinstance(fsdp_config, dict), str(fsdp_config) model = FSDP(model, **fsdp_config) optim = SGD(model.parameters(), lr=my_lr) inputs = test_case["inputs"] assert len(inputs) == 1 or not test_case["assert_ref_out"] assert len(inputs[0]) >= world_size for in_data in inputs: in_data = Tensor(in_data[rank]).to(device, dtype) out = model(in_data) out.float().sum().backward() optim.step() optim.zero_grad() if test_case["assert_ref_out"]: with model.summon_full_params(): weight_out = model.module.weight.data.T.clone() # make sure we can do more fwd/bwd loss = model(in_data) loss.sum().backward() if test_case["assert_ref_out"]: torch.testing.assert_allclose(ref_forward_output_my_rank, out) torch.testing.assert_allclose(ref_weight_out, weight_out) model.assert_state(TrainingState.IDLE) teardown()
def _test_func(rank, world_size, tempfile_name, unused, flatten, mixed_precision, amp_context, half_input, fsdp_wrap_ckpt): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" # Keep initialization deterministic. torch.manual_seed(0) model = FSDP( SimpleModuleWithCheckpointing(flatten, mixed_precision, fsdp_wrap_ckpt).cuda(), flatten_parameters=flatten, mixed_precision=mixed_precision, ) optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Collect parameter sizes to ensure these stay consistent through the steps below. expected_param_shapes = { name: tuple(param.shape) for name, param in model.named_parameters() } # For clarity, this is what `expected_param_shapes` should look like depending on world size: if not flatten: assert expected_param_shapes == { "ffn.0.weight": (5, ), "ffn.0.bias": (2, ), "ffn.1.weight": (5, ), "ffn.1.bias": (2, ), "ffn.2.weight": (5, ), "ffn.2.bias": (2, ), } else: assert expected_param_shapes == { "_fsdp_wrapped_module.flat_param_0": (12, ), "_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param_0": (6, ), }, expected_param_shapes torch.manual_seed(1 + rank) # Train for a step. _train_step(model, optim, expected_param_shapes, amp_context, mixed_precision, half_input) # Now do an eval step. _eval_step(model, optim, expected_param_shapes, amp_context, mixed_precision, half_input) # And finally do another train step. _train_step(model, optim, expected_param_shapes, amp_context, mixed_precision, half_input) teardown()
def test_input_type(temp_files, fsdp_config, input_cls): """Test FSDP with input being a list or a dict, only single GPU.""" if torch_version() < (1, 7, 0): # This test runs multiple test cases in a single process. On 1.6.0 it # throw an error like this: # RuntimeError: Container is already initialized! Cannot initialize it twice! pytest.skip( "older pytorch doesn't work well with single process dist_init multiple times" ) result = dist_init(rank=0, world_size=1, filename=temp_files[0], filename_rpc=temp_files[1]) assert result, "Dist init failed" assert isinstance(fsdp_config, dict), str(fsdp_config) class Model(Module): def __init__(self): super().__init__() self.layer = Linear(4, 4) def forward(self, input): if isinstance(input, list): input = input[0] else: assert isinstance(input, dict), input input = input["in"] return self.layer(input) model = FSDP(Model(), **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(5): in_data = torch.rand(64, 4).cuda() in_data.requires_grad = True if input_cls is list: in_data = [in_data] else: assert input_cls is dict in_data = {"in": in_data} out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() model.assert_state(TrainingState.IDLE) teardown()
def get_wrapped_model(group, cuda_first=False, config={}, **model_kwargs) -> FullyShardedDataParallel: if cuda_first: model = FullyShardedDataParallel( TransformerWithSharedParams(group, **model_kwargs).cuda(), group, **config) else: model = FullyShardedDataParallel( TransformerWithSharedParams(group, **model_kwargs), group, **config).cuda() return model
def test_it(fsdp_config, input_cls): """Test FSDP with input being a list or a dict, only single GPU.""" if torch_version() < (1, 6, 0): pytest.skip("older pytorch doesn't support reduce_scatter") # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(random.randint(2000, 3000)) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) try: assert isinstance(fsdp_config, dict), str(fsdp_config) class Model(Module): def __init__(self): super().__init__() self.layer = Linear(4, 4) def forward(self, input): if isinstance(input, list): input = input[0] else: assert isinstance(input, dict), input input = input["in"] return self.layer(input) model = FSDP(Model(), **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(5): in_data = torch.rand(64, 4).cuda() in_data.requires_grad = True if input_cls is list: in_data = [in_data] else: assert input_cls is dict in_data = {"in": in_data} out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() model.assert_state(TrainingState.IDLE) finally: # Clean-up is important or the next test in this file may fail to init the PG. torch.distributed.destroy_process_group() del os.environ["MASTER_ADDR"] del os.environ["MASTER_PORT"]
def _test_module_state_dict(cls, config, rank, group): ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config) autocast = ddp_model.mixed_precision cls._train_for_several_steps(ddp_model, 2, autocast) state_1 = ddp_model.state_dict() # You must make a new FSDP instance to use module.load_state_dict unwrapped_model = TransformerWithSharedParams(group) unwrapped_model.load_state_dict(state_1) new_ddp_model = FSDP(unwrapped_model, group, **config).cuda() cls._train_for_several_steps(new_ddp_model, 2, autocast) try: ddp_model.load_state_dict(new_ddp_model.state_dict()) assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded" except Exception: pass
def init_components( self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None, ): """Inits the runs components.""" model = model_fn() model = self.sync_device(model) if self._sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model = FullyShardedDataParallel(model, **self.ddp_kwargs) criterion = criterion_fn() criterion = self.sync_device(criterion) optimizer = optimizer_fn(model) optimizer = self.sync_device(optimizer) scheduler = scheduler_fn(optimizer) scheduler = self.sync_device(scheduler) return model, criterion, optimizer, scheduler
def _test_nested_wrapped_model_local_state_dict(cls, rank, group, config=None, local=None): # Create a nested FSDP-wrapped instance. model = NestedWrappedModule(group, config) model = FullyShardedDataParallel(model, group, **config).cuda() cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"]) # Round trip state dict save/load/save. ref_state_dict = { k: v.clone() for k, v in model.local_state_dict().items() } model.load_local_state_dict(ref_state_dict) state_dict = model.local_state_dict() assert ref_state_dict.keys() == state_dict.keys( ), f"{ref_state_dict.keys()} != {state_dict.keys()}" for key in ref_state_dict.keys(): assert objects_are_equal( ref_state_dict[key], state_dict[key], raise_exception=False ), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
def _test_nested_wrapped_model(cls, rank, group, config=None): # Get reference state dict without any nested FSDP instances. model = NestedWrappedModule(group, None).cuda() model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group) cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"]) ref_state_dict = { k: v.clone() for k, v in model.module.state_dict().items() } # Create a nested FSDP-wrapped instance. if config["mixed_precision"]: config["compute_dtype"] = torch.float32 model = NestedWrappedModule(group, config) model = FullyShardedDataParallel(model, group, **config).cuda() cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"]) # Round-trip state dict save/load/save. state_dict = {k: v.clone() for k, v in model.state_dict().items()} model.load_state_dict(state_dict) state_dict = model.state_dict() assert ref_state_dict.keys() == state_dict.keys( ), f"{ref_state_dict.keys()} != {state_dict.keys()}" for key in ref_state_dict.keys(): assert objects_are_equal( ref_state_dict[key], state_dict[key], raise_exception=False ), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
def _test_memory_benchmark(self, rank, group, config): time_keeper = TimeKeeper() SIZE = 8 * 8 time_keeper.print_time("START", 1.0) a = torch.empty(1) b = a.cuda() # wait for cuda to fully load time.sleep(1) time_keeper.print_time("INIT_CUDA", 1.0) model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4) time_keeper.print_time("CPU_MODEL", 1.0) with tempfile.TemporaryDirectory() as current_tempdir: config["offload_config"] = OffloadConfig( offload_type="ssd_offload", ssd_filepath_dir=current_tempdir) model = FullyShardedDataParallel(model, **config) time_keeper.print_time("FSDP_MODEL", 1.0) self._eval_for_several_steps(model, 1, autocast=False) time_keeper.print_time("EVAL")
def init_distributed_data_parallel_model(self): """ Initialize FSDP if needed. This method overloads the ClassificationTask class's method from ClassyVision. """ if not is_distributed_training_run(): return # Make sure default cuda device is set. TODO (Min): we should enable FSDP can # be enabled for 1-GPU as well, but the use case there is likely different. # I.e. perhaps we use it for cpu_offloading. assert get_cuda_device_index( ) > -1, "Distributed training not setup correctly" # The model might be already wrapped by FSDP internally. Check regnet_fsdp.py. # Here, we wrap it at the outer most level. fsdp_config = self.config["MODEL"]["FSDP_CONFIG"] if is_primary(): logging.info(f"Using FSDP, config: {fsdp_config}") # First, wrap the head's prototype_i layers if it is SWAV. # TODO (Min): make this more general for different models, which may have multiple # heads. head0 = self.base_model.heads[0] if isinstance(head0, SwAVPrototypesHead): for j in range(head0.nmb_heads): module = getattr(head0, "prototypes" + str(j)) module = FSDP(module=module, **fsdp_config) setattr(head0, "prototypes" + str(j), module) # TODO (Min): We can load checkpoint, but it ends up setting the trunk's _is_root # flag to true. We need to set it back to None here. # Also, right now, the head's weight is only partially loaded from the checkpoint # because we dump the checkpoint after the head if wrapped, but loading it before # it is wrapped. # For very big models, we need re-work the checkpoint logic because we don't have # enough memory to load the entire model on one node. We need to use local_state_dict() # API to load checkpoint shards. for module in self.base_model.trunk.modules(): if isinstance(module, FSDP): module._is_root = None # Then, wrap the whole model. We replace the base_model since it is used # when checkpoint is taken. self.base_model = FSDP(module=self.base_model, **fsdp_config) self.distributed_model = self.base_model
def __init__(self): super().__init__() self.ffn = nn.Sequential( nn.Linear(3, 3), FullyShardedDataParallel( checkpoint_wrapper(nn.Linear(3, 3), maintain_forward_counter=True)), nn.Linear(3, 3), )
def _distributed_worker( gpu_id: int, with_fsdp: bool, sync_file: str, result_file: str ): torch.cuda.set_device(gpu_id) dist.init_process_group( backend="nccl", init_method="file://" + sync_file, world_size=2, rank=gpu_id ) # Create the inputs torch.manual_seed(0) torch.backends.cudnn.deterministic = True batch = torch.randn(size=(8, 3, 224, 224)).cuda() # Create a fake model based on SWAV blocks config = TestRegnetFSDP._create_config(with_fsdp) model = build_model(config["MODEL"], config["OPTIMIZER"]) model = model.cuda() if with_fsdp: model = FSDP(model) else: model = DistributedDataParallel(model, device_ids=[gpu_id]) criterion = SwAVLoss(loss_config=config["LOSS"]["swav_loss"]) optimizer = optim.SGD(model.parameters(), lr=1e-2) # Run a few iterations and collect the losses losses = [] for iteration in range(5): out = model(batch) loss = criterion(out[0], torch.tensor(0.0).cuda()) if gpu_id == 0: losses.append(loss.item()) optimizer.zero_grad() loss.backward() if iteration <= 2: for name, param in model.named_parameters(): if "prototypes" in name: param.grad = None optimizer.step() # Store the losses in a file to compare several methods if gpu_id == 0: with open(result_file, "wb") as f: pickle.dump(losses, f)
def _dist_worker(rank, world_size, files, outer_flat, inner_flat, sharing): # Get data from files. file1, file2, sd_before, sd_after, in_data = files sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank)) sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank)) in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank)) result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2) assert result, "Dist init failed" fsdp_model = FSDP(Model(with_fsdp=True, inner_flat=inner_flat, sharing=sharing), flatten_parameters=outer_flat) fsdp_model.load_state_dict(sd_before) _train(fsdp_model, in_data) objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True) teardown()
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" my_lr = 0.1 if test_case["assert_ref_out"]: with torch.no_grad(): # Compute one iteration local output. weight = model.weight.T.clone().cuda() v = torch.Tensor(test_case["inputs"][0][rank]).cuda() ref_forward_output_my_rank = torch.matmul(v, weight) # Compute one iteration global weight update. v = torch.Tensor(test_case["inputs"][0][:world_size]).cuda() grad = v.sum(0).repeat(weight.shape[0], 1).div(world_size) ref_weight_out = weight - grad.T * my_lr model.to("cuda") assert isinstance(fsdp_config, dict), str(fsdp_config) model = FSDP(model, **fsdp_config) optim = SGD(model.parameters(), lr=my_lr) inputs = test_case["inputs"] assert len(inputs) == 1 or not test_case["assert_ref_out"] assert len(inputs[0]) >= world_size for in_data in inputs: in_data = Tensor(in_data[rank]).cuda() out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if test_case["assert_ref_out"]: with model.summon_full_params(): weight_out = model.module.weight.data.T.clone() # make sure we can do more fwd/bwd loss = model(in_data) loss.sum().backward() if test_case["assert_ref_out"]: torch.testing.assert_allclose(ref_forward_output_my_rank, out) torch.testing.assert_allclose(ref_weight_out, weight_out) model.assert_state(TrainingState.IDLE) teardown()
def is_valid_fsdp_model(model: FSDP) -> bool: """ Checks if a FSDP model is valid by looking at the sub-FSDP modules and ensuring that they do not think they are the root FSDP model """ for n, m in model.named_modules(): if isinstance(m, FSDP): if n != "" and m._is_root is not None: return False return True
def fsdp_recursive_reset_lazy_init(fsdp_module: FSDP): """ Before the first forward pass, an FSDP module might have been initialized for instance by calling load_state_dict or load_local_state_dict to reload a previous training checkpoint. This function will recursively walk though the sub-FSDP modules and call _reset_lazy_init on each of them. """ for module in fsdp_module.modules(): if isinstance(module, FSDP) and module._is_root is not None: module._reset_lazy_init()
def _test_identical_outputs( cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2, ): if config["mixed_precision"]: autocast = True # Force the compute dtype to be torch.float32 so that we get # identical results as PyTorch DDP when using autocast. Note that # this will cause the all-gather to happen in FP32, which is slower # than necessary in most cases. config["compute_dtype"] = torch.float32 else: autocast = False # Establish reference behavior with PyTorch DDP (+ optionally autocast). model = model_init_fn(group=group, wrapper_config=None).cuda() if ref_ddp_fn is None: model = nn.parallel.DistributedDataParallel( model, device_ids=[rank], output_device=rank, process_group=group ) else: model = ref_ddp_fn(model, group) ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) ref_state_dict = model.module.state_dict() # Confirm we get the same behavior using FullyShardedDataParallel. model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config) if use_cuda: model = model.cuda() else: assert next(model.parameters()).device == torch.device("cpu") shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) shard_state_dict = model.state_dict() try: torch.testing.assert_allclose(ref_loss, shard_loss) assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True) except (AssertionError, RuntimeError) as e: raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")