def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_free_ms=0): super().__init__(group, wrapper_config) self.group = group self.delay_before_free_ms = delay_before_free_ms # "expert" params are different on each rank torch.manual_seed(42 + group.rank()) d_expert = 23 d_shared = 12 d_input = 8 expert = nn.Linear(d_expert, d_shared) self.num_expert_params = sum([p.numel() for p in expert.parameters()]) for p in expert.parameters(): p.expert = True # everything else is shared torch.manual_seed(0) shared = nn.Linear(d_shared, d_expert) if checkpoint_act: expert = checkpoint_wrapper(expert) shared = checkpoint_wrapper(shared) if wrapper_config is not None: # we create a process group of size 1 for the expert params expert_group = torch.distributed.new_group([group.rank()]) # world size 1 means no shard expert = FullyShardedDataParallel(expert, expert_group, **wrapper_config) shared = FullyShardedDataParallel(shared, group, **wrapper_config) self.module = nn.Sequential(nn.Linear(d_input, d_shared), shared, expert, nn.Linear(d_shared, d_input))
def test_checkpoint_disabling(): """Test to check new disable_checkpoint() API added to checkpoint_wrapper.""" class TestModel(nn.Module): def __init__(self): super().__init__() self.cnt = 0 self.linear = nn.Linear(2, 2) def forward(self, x): self.cnt += 1 y = [] for i in x: y.append(self.linear(i)) return y x = torch.rand(4, 2) model1 = checkpoint_wrapper(TestModel()) model2 = checkpoint_wrapper(TestModel()) # Forward. cnt += 1 y = model1(x) y = sum(i.sum() for i in y) # Backward. cnt += 1 y.backward() assert model1.cnt == 2 with disable_checkpointing(): # Forward. cnt += 1 y = model2(x) y = sum(i.sum() for i in y) # Backward. cnt remains same as checkpointing is disabled y.backward() assert model2.cnt == 1
def __init__(self, multiout=False, checkpoint_config=0): super().__init__() torch.manual_seed(0) # make sure weights are deterministic. self.multiout = multiout self.conv1 = nn.Sequential(nn.Conv2d(1, 5, 3), nn.ReLU(), nn.Conv2d(5, 5, 3)) self.conv2 = nn.Sequential(nn.Conv2d(3, 5, 3), nn.ReLU(), nn.Conv2d(5, 5, 3)) assert 0 <= checkpoint_config <= 3 if checkpoint_config & 1: self.conv1 = checkpoint_wrapper(self.conv1) if checkpoint_config & (1 << 1): self.conv2 = checkpoint_wrapper(self.conv2)
def __init__(self, flatten, mixed_precision, fsdp_wrap_ckpt): super().__init__() if fsdp_wrap_ckpt: middle_module = FSDP(checkpoint_wrapper(nn.Linear(3, 3)), flatten_parameters=flatten, mixed_precision=mixed_precision) else: middle_module = checkpoint_wrapper( FSDP(nn.Linear(3, 3), flatten_parameters=flatten, mixed_precision=mixed_precision)) self.ffn = nn.Sequential(nn.Linear(3, 3), middle_module, nn.Linear(3, 3))
def test_list_input(): """ Test checkpointing with input argument type being a list. Note: Testing shows that PyTorch's torch.utils.checkpoint function does not pass this test. """ count = 0 class Model(nn.Module): def __init__(self): super().__init__() self.conv = nn.Linear(2, 2) def forward(self, x): nonlocal count count += 1 y = [] for i in x: y.append(self.conv(i)) return y model = nn.Sequential(checkpoint_wrapper(Model()), Model()).cuda() in_data1 = torch.rand(4, 2).cuda() in_data2 = torch.rand(4, 2).cuda() # Forward. Count should be 2 for 2 modules. out = model([in_data1, in_data2]) loss = sum(x.sum() for x in out) assert count == 2, f"Incorrect count {count}" # Backward. Adds 1 more forward call due to checkpoint. loss.backward() assert count == 3, f"Incorrect count {count}"
def __init__(self, group, wrapper_config, wrap_everything=False, checkpoint=False): super().__init__() self.rank = group.rank() self.world_size = group.size() self.wrapper_config = wrapper_config def _maybe_wrap(layer): if wrapper_config is not None: return FullyShardedDataParallel(layer, group, **wrapper_config) return layer torch.manual_seed(0) # keep everything deterministic self.module = nn.Sequential( nn.Linear(8, 4), _maybe_wrap( nn.Sequential( _maybe_wrap(nn.Linear(4, 16)), nn.Linear(16, 16), )), _maybe_wrap(nn.Linear(16, 4)), nn.Linear(4, 8), ) # Wrap all modules triggers a corner case where root FSDP doesn't have any params. # Test it with checkpoint_wrapper as well to validate final backward callback # is queued correctly when root FSDP does not have any params and every layer is # wrapped as FSDP(checkpoint(module)). if wrap_everything: if checkpoint: self.module = nn.Sequential( _maybe_wrap(checkpoint_wrapper(nn.Linear(8, 4))), _maybe_wrap(checkpoint_wrapper(nn.Linear(4, 16))), _maybe_wrap(checkpoint_wrapper(nn.Linear(16, 4))), _maybe_wrap(checkpoint_wrapper(nn.Linear(4, 8))), ) else: self.module = nn.Sequential( _maybe_wrap(nn.Linear(8, 4)), _maybe_wrap(nn.Linear(4, 16)), _maybe_wrap(nn.Linear(16, 4)), _maybe_wrap(nn.Linear(4, 8)), )
def __init__(self): super().__init__() self.ffn = nn.Sequential( nn.Linear(3, 3), FullyShardedDataParallel( checkpoint_wrapper(nn.Linear(3, 3), maintain_forward_counter=True)), nn.Linear(3, 3), )
def test_deprecated_path(): # Check if import works as before. # from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper from fairscale.nn import checkpoint_wrapper ffn = nn.Sequential(nn.Linear(32, 128), nn.Dropout(p=0.5), nn.Linear(128, 32),) ffn = checkpoint_wrapper(ffn, {}) # Check if direct import works as before. ffn = nn.Sequential(nn.Linear(32, 128), nn.Dropout(p=0.5), nn.Linear(128, 32),) ffn = deprecated_checkpoint_wrapper(ffn, {})
def __init__(self, use_pytorch_checkpoint=False, use_fairscale_checkpoint=False, **kwargs): super().__init__() torch.manual_seed(0) # make sure weights are deterministic. assert not ( use_pytorch_checkpoint and use_fairscale_checkpoint ), "Cannot use both pytorch and fairscale checkpointing mechanisms." self.use_pytorch_checkpoint = use_pytorch_checkpoint self.ffn = nn.Sequential( nn.Linear(32, 128), # add a Dropout layer to test RNG save/restore nn.Dropout(p=0.5), nn.Linear(128, 32), ) if use_fairscale_checkpoint: self.ffn = checkpoint_wrapper(self.ffn, **kwargs) self.out = nn.Linear(32, 1)
def __init__(self, enable_checkpoint=False, cpu_offload=False): super().__init__() torch.manual_seed(0) # make sure weights are deterministic. # These numbers are picked to show cpu_offload memory saving. # Inner (recomputed) activation sizes need to be just right # to show the benefit. self.layers = nn.Sequential( nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 4), nn.Linear(4, 8)), nn.Sequential(nn.Linear(8, 4), nn.Linear(4, 4), nn.Linear(4, 4)), nn.Sequential(nn.Linear(4, 6), nn.Linear(6, 8), nn.Linear(8, 2)), ) if enable_checkpoint: for i, layer in enumerate(self.layers): # Only middle layer needs to have offloading self.layers[i] = checkpoint_wrapper(layer, cpu_offload if i == 1 else False)
def get_model(norm_type, checkpointed, mixed_precision): assert norm_type in NORM_TYPES, norm_type assert checkpointed in [True, False], checkpointed assert mixed_precision in MP_TYPES model = Sequential(Linear(3, 2), norm_type(2)) if mixed_precision == "fp16": # Set param.data and buffers as fp16 for p in model.parameters(): p.data = p.data.half() for m in model: for n, b in m.named_buffers(): setattr(m, n, b.half()) elif mixed_precision == "call_half": model.half() if checkpointed: model = checkpoint_wrapper(model) return model
def test_checkpoint_requires_grad(): """Test to check checkpointing when outputs do not require gradient.""" class TestModel(nn.Module): def __init__(self): super().__init__() self.cnt = 0 self.linear = nn.Linear(2, 2) def forward(self, x): self.cnt += 1 return self.linear(x) x = torch.rand(4, 2) model = nn.Sequential( checkpoint_wrapper(TestModel()), checkpoint_wrapper(TestModel()), checkpoint_wrapper(TestModel()), checkpoint_wrapper(TestModel()), ) model[0].requires_grad_(False) model[1].requires_grad_(False) model[2].requires_grad_(False) y = model(x) y = y.sum() y.backward() # Since only last model needs grad, we only run forward twice for it assert model[0].cnt == 1 assert model[1].cnt == 1 assert model[2].cnt == 1 assert model[3].cnt == 2 # Now test with first model needing grad model = nn.Sequential( checkpoint_wrapper(TestModel()), checkpoint_wrapper(TestModel()), checkpoint_wrapper(TestModel()), checkpoint_wrapper(TestModel()), ) model[0].requires_grad_(True) model[1].requires_grad_(False) model[2].requires_grad_(False) y = model(x) y = y.sum() y.backward() # Since first model needs grad, all models need grad, so we run forward twice for all assert model[0].cnt == 2 assert model[1].cnt == 2 assert model[2].cnt == 2 assert model[3].cnt == 2 # Stress test with multiple inputs/outputs, of which some are not Tensor class TestModel2(nn.Module): def __init__(self): super().__init__() self.cnt = 0 self.linear = nn.Linear(2, 2) def forward(self, x, y, z): self.cnt += 1 z = z + [self.cnt] return self.linear(x + y), z, ["hi"] model1 = checkpoint_wrapper(TestModel()) model2 = checkpoint_wrapper(TestModel()) model3 = checkpoint_wrapper(TestModel2()) model4 = checkpoint_wrapper(TestModel()) model1.requires_grad_(False) model2.requires_grad_(False) y = model4(model3(model1(x), model2(x), ["bye"])[0]) y = y.sum() y.backward() assert model1.cnt == 1 assert model2.cnt == 1 assert model3.cnt == 2 assert model4.cnt == 2 model1 = checkpoint_wrapper(TestModel()) model2 = checkpoint_wrapper(TestModel()) model3 = checkpoint_wrapper(TestModel2()) model4 = checkpoint_wrapper(TestModel()) model2.requires_grad_(False) y = model4(model3(model1(x), model2(x), ["bye"])[0]) y = y.sum() y.backward() assert model1.cnt == 2 assert model2.cnt == 1 assert model3.cnt == 2 assert model4.cnt == 2 # Test flattened pararameters model = nn.Sequential( FlattenParamsWrapper(checkpoint_wrapper(TestModel())), FlattenParamsWrapper(checkpoint_wrapper(TestModel())), FlattenParamsWrapper(checkpoint_wrapper(TestModel())), FlattenParamsWrapper(checkpoint_wrapper(TestModel())), ) model[0].requires_grad_(False) model[1].requires_grad_(False) y = model(x) y = y.sum() y.backward() assert model[0].cnt == 1 assert model[1].cnt == 1 assert model[2].cnt == 2 assert model[3].cnt == 2