def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" assert isinstance(fsdp_config, dict), str(fsdp_config) class InnerModel(Module): def __init__(self): super().__init__() self.layers = Sequential(FSDP(Linear(5, 5), **fsdp_config), ) def forward(self, x): return self.layers(x) inner_model = InnerModel() model = FSDP(inner_model, **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for i in range(3): input = torch.rand((1, 5), dtype=torch.float).cuda() input.requires_grad = True output = model(input) output.sum().backward() optim.step() optim.zero_grad() input = torch.rand((1, 5), dtype=torch.float).cuda() output = model(input) model.assert_state(TrainingState.IDLE) # second time to rewrap the inner model rewrapped_model = FSDP(inner_model, **fsdp_config).cuda() rewrapped_output = rewrapped_model(input) assert torch.allclose(output, rewrapped_output) teardown()
def _test_module_state_dict(cls, config, rank, group): ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config) autocast = ddp_model.mixed_precision cls._train_for_several_steps(ddp_model, 2, autocast) state_1 = ddp_model.state_dict() # You must make a new FSDP instance to use module.load_state_dict unwrapped_model = TransformerWithSharedParams(group) unwrapped_model.load_state_dict(state_1) new_ddp_model = FSDP(unwrapped_model, group, **config).cuda() cls._train_for_several_steps(new_ddp_model, 2, autocast) try: ddp_model.load_state_dict(new_ddp_model.state_dict()) assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded" except Exception: pass
def _test_nested_wrapped_model_local_state_dict(cls, rank, group, config=None, local=None): # Create a nested FSDP-wrapped instance. model = NestedWrappedModule(group, config) model = FSDP(model, group, **config).cuda() cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"]) # Round trip state dict save/load/save. ref_state_dict = {k: v.clone() for k, v in model.local_state_dict().items()} model.load_local_state_dict(ref_state_dict) state_dict = model.local_state_dict() assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}" for key in ref_state_dict.keys(): assert objects_are_equal( ref_state_dict[key], state_dict[key], raise_exception=False ), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
def fsdp_wrapper(module, **kwargs): """ Customer FSDP wrapper, adding the missing options """ from vissl.utils.layer_memory_tracking import ProcessGroupTracker # Add global process group to the list of keys fsdp_config = dict(**kwargs) fsdp_config["process_group"] = get_global_group() if fsdp_config.get("_TRACK_COMMUNICATIONS", False): fsdp_config["process_group"] = ProcessGroupTracker(fsdp_config["process_group"]) # Remove keys that are not supported in FSDP for key in {"_TRACK_COMMUNICATIONS", "AUTO_WRAP_THRESHOLD"}: fsdp_config.pop(key, None) return FSDP(module, **fsdp_config)
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" assert isinstance(fsdp_config, dict), str(fsdp_config) class Model(Module): def __init__(self): super().__init__() # TODO (Min): for now, we just test pytorch sync_bn here. # this will grow into regnet; testing apex sync_bn, etc. self.conv = Conv2d(2, 2, (1, 1)) self.bn = BatchNorm2d(2) def forward(self, x): x = self.conv(x) x = self.bn(x) return x # TODO (Min): check DDP equivalency. model = Model() # Note, different rank may wrap in different order due to different random # seeds. But results should be the same. if random.randint(0, 1) == 0: print("auto_wrap_bn, then convert_sync_batchnorm") model = auto_wrap_bn(model) model = SyncBatchNorm.convert_sync_batchnorm(model) else: print("convert_sync_batchnorm, then auto_wrap_bn") model = SyncBatchNorm.convert_sync_batchnorm(model) model = auto_wrap_bn(model) model = FSDP(model, **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(3): in_data = torch.rand(2, 2, 2, 2).cuda() in_data.requires_grad = True out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() model.assert_state(TrainingState.IDLE) teardown()
def _freeze_distributed_worker( gpu_id, world_size, tempfile_name, unused, ): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True batch = torch.randn(size=(2, 3, 224, 224)).cuda() # The use case for this test is where the weights in the submodule # are not frozen but the leftover weights or those contained by the # root module are frozen. Refer to issue #758 for a real world example. model = FreezeModel() model = model.cuda() for param in model.head.parameters(): param.requires_grad = False model = FSDP(model) if gpu_id == 0: print(model) target = torch.tensor([0, 1], dtype=torch.long).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) for iteration in range(3): out = model(batch) fake_loss = criterion(out, target) print("Loss", iteration, ":", fake_loss.item()) optimizer.zero_grad() fake_loss.backward() optimizer.step() teardown()
def _dist_worker(rank, world_size, files, outer_flat, inner_flat, sharing): # Get data from files. file1, file2, sd_before, sd_after, in_data = files sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank)) sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank)) in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank)) result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2) assert result, "Dist init failed" fsdp_model = FSDP(Model(with_fsdp=True, inner_flat=inner_flat, sharing=sharing), flatten_parameters=outer_flat) fsdp_model.load_state_dict(sd_before) _train(fsdp_model, in_data) objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True) teardown()
def _distributed_worker( gpu_id: int, with_fsdp: bool, sync_file: str, result_file: str ): torch.cuda.set_device(gpu_id) dist.init_process_group( backend="nccl", init_method="file://" + sync_file, world_size=2, rank=gpu_id ) # Create the inputs torch.manual_seed(0) torch.backends.cudnn.deterministic = True batch = torch.randn(size=(8, 3, 224, 224)).cuda() # Create a fake model based on SWAV blocks config = TestRegnetFSDP._create_config(with_fsdp) model = build_model(config["MODEL"], config["OPTIMIZER"]) model = model.cuda() if with_fsdp: model = FSDP(model) else: model = DistributedDataParallel(model, device_ids=[gpu_id]) criterion = SwAVLoss(loss_config=config["LOSS"]["swav_loss"]) optimizer = optim.SGD(model.parameters(), lr=1e-2) # Run a few iterations and collect the losses losses = [] for iteration in range(5): out = model(batch) loss = criterion(out[0], torch.tensor(0.0).cuda()) if gpu_id == 0: losses.append(loss.item()) optimizer.zero_grad() loss.backward() if iteration <= 2: for name, param in model.named_parameters(): if "prototypes" in name: param.grad = None optimizer.step() # Store the losses in a file to compare several methods if gpu_id == 0: with open(result_file, "wb") as f: pickle.dump(losses, f)
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" my_lr = 0.1 if test_case["assert_ref_out"]: with torch.no_grad(): # Compute one iteration local output. weight = model.weight.T.clone().cuda() v = torch.Tensor(test_case["inputs"][0][rank]).cuda() ref_forward_output_my_rank = torch.matmul(v, weight) # Compute one iteration global weight update. v = torch.Tensor(test_case["inputs"][0][:world_size]).cuda() grad = v.sum(0).repeat(weight.shape[0], 1).div(world_size) ref_weight_out = weight - grad.T * my_lr model.to("cuda") assert isinstance(fsdp_config, dict), str(fsdp_config) model = FSDP(model, **fsdp_config) optim = SGD(model.parameters(), lr=my_lr) inputs = test_case["inputs"] assert len(inputs) == 1 or not test_case["assert_ref_out"] assert len(inputs[0]) >= world_size for in_data in inputs: in_data = Tensor(in_data[rank]).cuda() out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if test_case["assert_ref_out"]: with model.summon_full_params(): weight_out = model.module.weight.data.T.clone() # make sure we can do more fwd/bwd loss = model(in_data) loss.sum().backward() if test_case["assert_ref_out"]: torch.testing.assert_allclose(ref_forward_output_my_rank, out) torch.testing.assert_allclose(ref_weight_out, weight_out) model.assert_state(TrainingState.IDLE) teardown()
def test_pre_backward_hook(temp_files): """Test FSDP with a model that triggers a pre_backward hook bug.""" result = dist_init(rank=0, world_size=1, filename=temp_files[0], filename_rpc=temp_files[1]) assert result, "Dist init failed" class Model(Module): def __init__(self): super().__init__() self.l1 = Linear(4, 4).cuda() self.l2 = FSDP(Linear(4, 4).cuda()) self.l3 = Linear(4, 4).cuda() def forward(self, x): x = self.l1(x) x = self.l2(x) inner_result = x x = self.l3(x) return x, inner_result def assert_and_clear_grad(self): for p in self.parameters(): assert p.shape in [(4, 4), (4, ), (4 * 4 + 4, )], p.shape assert p.grad is not None p.grad = None model = FSDP(Model(), flatten_parameters=False).cuda() in_data = torch.rand(1, 4).cuda() for _ in range(3): out, _ = model(in_data) out.sum().backward() model.assert_and_clear_grad() teardown()
def _test_nested_wrapped_model(cls, rank, group, config=None): # Get reference state dict without any nested FSDP instances. model = NestedWrappedModule(group, None).cuda() model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group) cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"]) ref_state_dict = {k: v.clone() for k, v in model.module.state_dict().items()} # Create a nested FSDP-wrapped instance. if config["mixed_precision"]: config["compute_dtype"] = torch.float32 model = NestedWrappedModule(group, config) model = FSDP(model, group, **config).cuda() cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"]) # Round-trip state dict save/load/save. state_dict = {k: v.clone() for k, v in model.state_dict().items()} model.load_state_dict(state_dict) state_dict = model.state_dict() assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}" for key in ref_state_dict.keys(): assert objects_are_equal( ref_state_dict[key], state_dict[key], raise_exception=False ), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
def __init__(self): super().__init__() self.layers = Sequential(FSDP(Linear(5, 5), **fsdp_config), )
def __init__(self): super().__init__() self.l1 = Linear(4, 4).cuda() self.l2 = FSDP(Linear(4, 4).cuda()) self.l3 = Linear(4, 4).cuda()
def init_distributed_data_parallel_model(self): """ Initialize FSDP if needed. This method overloads the ClassificationTask class's method from ClassyVision. """ if not is_distributed_training_run(): return # Make sure default cuda device is set. TODO (Min): we should ensure FSDP can # be enabled for 1-GPU as well, but the use case there is likely different. # I.e. perhaps we use it for cpu_offloading. assert get_cuda_device_index( ) > -1, "Distributed training not setup correctly" # The model might be already wrapped by FSDP internally. Check regnet_fsdp.py. # Here, we wrap it at the outer most level. fsdp_config = self.config["MODEL"]["FSDP_CONFIG"] if is_primary(): logging.info(f"Using FSDP, config: {fsdp_config}") # First, wrap the head's prototype_i layers if it is SWAV. # TODO (Min): make this more general for different models, which may have multiple # heads. if len(self.base_model.heads) != 1: raise ValueError( f"FSDP only support 1 head, not {len(self.base_model.heads)} heads" ) head0 = self.base_model.heads[0] if isinstance(head0, SwAVPrototypesHead): # This is important for convergence! # # Since we "normalize" this layer in the update hook, we need to keep its # weights in full precision. It is output is going into the loss and used # for clustering, so we need to have that in full precision as well. fp_fsdp_config = fsdp_config.copy() fp_fsdp_config["flatten_parameters"] = False fp_fsdp_config["mixed_precision"] = False fp_fsdp_config["fp32_reduce_scatter"] = False for j in range(head0.nmb_heads): module = getattr(head0, "prototypes" + str(j)) module = FSDP(module=module, **fp_fsdp_config) setattr(head0, "prototypes" + str(j), module) head0 = FSDP(module=head0, **fsdp_config) self.base_model.heads[0] = head0 # Init the head properly since the weights are potentially initialized on different # ranks with different seeds. We first summon the full params from all workers. # Then, within that context, we set a fixed random seed so that all workers init the # weights the same way. Finally, we reset the layer's weights using reset_parameters(). # # TODO (Min): This will go away once we have a way to sync from rank 0. with head0.summon_full_params(): with set_torch_seed(self.config["SEED_VALUE"]): for m in head0.modules(): if isinstance(m, Linear): m.reset_parameters() head0._reset_lazy_init() head0.prototypes0._reset_lazy_init() # TODO (Min): We can load checkpoint, but it ends up setting the trunk's _is_root # flag to true. We need to set it back to None here. # Also, right now, the head's weight is only partially loaded from the checkpoint # because we dump the checkpoint after the head if wrapped, but loading it before # it is wrapped. # For very big models, we need re-work the checkpoint logic because we don't have # enough memory to load the entire model on one node. We need to use local_state_dict() # API to load checkpoint shards. for module in self.base_model.trunk.modules(): if isinstance(module, FSDP): module._is_root = None # Then, wrap the whole model. We replace the base_model since it is used # when checkpoint is taken. self.base_model = FSDP(module=self.base_model, **fsdp_config) self.distributed_model = self.base_model
def fsdp_wrap(self): self.trunk = FSDP(self.trunk) self.head = FSDP(self.head)
def to_fsdp(module, fsdp_config): return FSDP(module, process_group=get_process_group_cached(), **fsdp_config)
def fsdp_wrap(self): for name, child in self.trunk.named_children(): wrapped_child = FSDP(child) setattr(self.trunk, name, wrapped_child) self.trunk = FSDP(self.trunk) self.head = FSDP(self.head)
def _create_model(with_fsdp): model = Model() if with_fsdp: model.trunk = FSDP(model.trunk) model.head = FSDP(model.head) return model
def to_fsdp(module): return FSDP(module, process_group=get_global_group())
def fsdp_wrapper(module, **kwargs): """Customer wrapper that does FSDP + checkpoint at the same time.""" # TODO (Min): enable checkpoint_wrapper return FSDP(module, **kwargs)
def _distributed_worker( rank, world_size, fsdp_config, fsdp_wrap_bn, ddp_mixed_precision, tempfile_name, unused, state_before, inputs, rank_0_output, state_after, sync_bn, conv_bias, linear_bias, ): torch.backends.cudnn.deterministic = True result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" ddp = True if fsdp_config: ddp = False assert isinstance(fsdp_config, dict), str(fsdp_config) if fsdp_config["mixed_precision"]: # To match DDP in AMP -O1, we need fp32 reduce scatter. fsdp_config["fp32_reduce_scatter"] = True model = Model(conv_bias, linear_bias) model.load_state_dict(state_before) model = model.cuda() class DummyScaler: def scale(self, loss): return loss def step(self, optim): optim.step() def update(self): pass scaler = DummyScaler() if ddp: if sync_bn == "pytorch": model = pytorch_bn_converter(model) model = DDP(model, device_ids=[rank], broadcast_buffers=True) if ddp_mixed_precision: scaler = GradScaler() else: # Note, different rank may wrap in different order due to different random # seeds. But results should be the same. if random.randint(0, 1) == 0: print(f"auto_wrap_bn {fsdp_wrap_bn}, then sync_bn {sync_bn}") if fsdp_wrap_bn: model = auto_wrap_bn(model, _single_rank_pg) if sync_bn == "pytorch": model = pytorch_bn_converter(model) else: print(f"sync_bn {sync_bn}, then auto_wrap_bn {fsdp_wrap_bn}") if sync_bn == "pytorch": model = pytorch_bn_converter(model) if fsdp_wrap_bn: model = auto_wrap_bn(model, _single_rank_pg) model = FSDP(model, **fsdp_config).cuda() if fsdp_config["mixed_precision"]: scaler = ShardedGradScaler() # Print the model for verification. if rank == 0: print(model) optim = SGD(model.parameters(), lr=0.1) loss_func = CrossEntropyLoss() for in_data in inputs[rank]: in_data = in_data.cuda() context = contextlib.suppress() if ddp and ddp_mixed_precision: in_data = in_data.half() context = torch.cuda.amp.autocast(enabled=True) if not ddp and fsdp_config["mixed_precision"]: context = torch.cuda.amp.autocast(enabled=True) with context: out = model(in_data) fake_label = torch.zeros(1, dtype=torch.long).cuda() loss = loss_func(out.unsqueeze(0), fake_label) scaler.scale(loss).backward() scaler.step(optim) scaler.update() optim.zero_grad() if ddp: # Save the rank 0 state_dict to the output file. if rank == 0: state_after = model.module.cpu().state_dict() torch.save(state_after, rank_0_output) else: model.assert_state(TrainingState.IDLE) # Ensure final state equals to the state_after. fsdp_state = model.state_dict() # Move tensors to CPU to compare numerics. for k, v in fsdp_state.items(): fsdp_state[k] = v.cpu() # Change False to True to enable this when you want to debug the mismatch. if False and rank == 0: def dump(d): for k, v in d.items(): print(k, v) dump(state_after) dump(fsdp_state) # If sync_bn is used, all ranks should have the same state, so we can compare with # rank 0 state on every rank. Otherwise, only compare rank 0 with rank 0. if sync_bn != "none" or rank == 0: assert objects_are_equal(state_after, fsdp_state, raise_exception=True) teardown()
def __init__(self): super().__init__() self.inner = FSDP(Linear(4, 4), **fsdp_config) self.outer = Linear(4, 5)
def _distributed_worker( gpu_id, world_size, with_fsdp, with_nested_trunk, freezing_method, tempfile_name, unused, rank_0_output, expected_state, ): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = _create_model(with_fsdp, with_nested_trunk) model = model.cuda() # freezing the trunk using requires_grad. if freezing_method == FreezingMethod.RequiresGrad: for param in model.trunk.parameters(): param.requires_grad = False if with_fsdp: model = FSDP(model) else: model = DistributedDataParallel(model, device_ids=[gpu_id]) if gpu_id == 0: print(model) target = torch.tensor([0, 1], dtype=torch.long).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) for iteration in range(3): out = model(batch) fake_loss = criterion(out, target) print("Loss", iteration, ":", fake_loss.item()) optimizer.zero_grad() fake_loss.backward() if freezing_method == FreezingMethod.GradToNone: for param in model.trunk.parameters(): param.grad = None optimizer.step() if with_fsdp: fsdp_state = model.state_dict() # Move tensors to CPU to compare numerics. for k, v in fsdp_state.items(): fsdp_state[k] = v.cpu() assert objects_are_equal(expected_state, fsdp_state, raise_exception=True) elif rank == 0: state_after = model.module.cpu().state_dict() torch.save(state_after, rank_0_output) teardown()
def _distributed_worker( gpu_id, world_size, with_model2, with_sync_bn, with_fsdp, with_checkpoint, files, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, bucket_cap_mb, ): filename, filename_rpc = files[:2] filename_loss = files[2:] torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, filename, filename_rpc) assert result, "Dist init failed" # use False below to debug since error msg is not as good with cudnn. torch.backends.cudnn.enabled = True # these make things deterministic. torch.manual_seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Ensure we have multiple forward passes. batch = [ torch.randn(size=(2, 3, 16, 16)).cuda(), torch.randn(size=(2, 3, 9, 9)).cuda(), torch.randn(size=(2, 3, 9, 9)).cuda(), ] if mixed_precision and not with_fsdp: batch = [x.half() for x in batch] model = _create_model( with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, bucket_cap_mb, ) model = model.cuda() if with_fsdp: model = FSDP( model, flatten_parameters=flatten, mixed_precision=mixed_precision, compute_dtype=torch.float32, fp32_reduce_scatter=fp32_reduce_scatter, bucket_cap_mb=bucket_cap_mb, ) model.set_gradient_divide_factors(1.0, 2.0, True) no_sync_context = contextlib.suppress() else: # With DDP, we need no_sync and manual gradient reduction below because # it can't handle multiple forward pass + checkpointing otherwise. model = DistributedDataParallel(model, device_ids=[gpu_id]) no_sync_context = model.no_sync() mp_context = contextlib.suppress() if mixed_precision: mp_context = torch.cuda.amp.autocast(enabled=True) if gpu_id == 0: print(model) target = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.long).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) losses = {} i = 0 with no_sync_context: for iteration in range(3): with mp_context: out = model(batch) loss = criterion(out, target) print("Loss", iteration, ":", loss.item()) losses[f"iter_{i}"] = loss i += 1 optimizer.zero_grad() loss.backward() # Manual grad reduction, no autocast. if not with_fsdp: for p in model.parameters(): dist.all_reduce(p.grad.data) p.grad.data.div_(2.0) # Stepping, no autocast optimizer.step() # Due to dist.all_reduce code block above with ddp.no_sync, we seem to hit a bug # in DDP where tensor.cpu() and torch.save() calls both hang. FSDP is not affected. # Therefore, we have to compare losses here instead of states. with open(filename_loss[rank], "wb") as f: pickle.dump(losses, f) teardown()
def to_fsdp(module, fsdp_config): return FSDP(module, process_group=get_global_group(), **fsdp_config)