def test_norm(device, norm_type, mixed_precision): """Test checkpoint_wrapper with different norm layers.""" if device == "cuda" and not torch.cuda.is_available(): pytest.skip("Skip due to lack of GPU") # Get input, ref, checkpoint models and make them equal. in_data = torch.rand(2, 2, 3, 3).to(device) m_ref = get_model(norm_type, False, mixed_precision).to(device) m_cpt = get_model(norm_type, True, mixed_precision).to(device) m_cpt.load_state_dict(m_ref.state_dict()) if torch_version() >= (1, 6, 0): # This assert fails on 1.5.1. assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict()) if mixed_precision != "fp32": in_data = in_data.half() # Needed due to checkpointing. in_data.requires_grad = True for model in (m_ref, m_cpt): optim = SGD(model.parameters(), lr=0.1) if device == "cpu" and mixed_precision != "fp32": # Got: RuntimeError: "batch_norm"/"layer_norm" not implemented for 'Half'. with pytest.raises(RuntimeError): out = model(in_data) return else: # Everything else work. out = model(in_data) out.sum().backward() optim.step() if torch_version() >= (1, 6, 0): assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict())
def rpc_worker(rank, world_size, init_file, func, *args): if torch_version() == (1, 8, 0): if torch.cuda.is_available(): # Workaround for https://github.com/pytorch/pytorch/issues/53844 options = rpc.TensorPipeRpcBackendOptions( init_method="file://" + init_file, _transports=["ibv", "uv"]) else: # Workaround for https://github.com/pytorch/pytorch/issues/54266 options = rpc.TensorPipeRpcBackendOptions( init_method="file://" + init_file, _channels=[ "mpt_uv", "basic", "cuda_ipc", "cuda_gdr", "cuda_xth", "cuda_basic" ], ) else: options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file) rpc.init_rpc( "worker" + str(rank), rank=rank, world_size=world_size, backend=rpc.BackendType.TENSORPIPE, rpc_backend_options=options, ) if rank == 0: func(*args) rpc.shutdown()
def test_smaller_than_world_size(world_size, test_case, fsdp_config): """Test FSDP with uneven divide of parameter shards.""" if torch_version() < (1, 6, 0): pytest.skip( "older pytorch doesn't support reduce_scatter in gloo backend") if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs.") temp_file_name = tempfile.mkstemp()[1] unused = tempfile.mkstemp()[1] model = Sequential( Linear(3, 3, bias=False), Linear(3, 4, bias=False), Linear(4, 5, bias=False), Linear(5, 4, bias=False), Linear(4, 3, bias=False), Linear(3, 1, bias=False), Linear(1, 1, bias=False ), # param here is smaller than world_size if unflattened. ) mp.spawn( _test_func, args=(world_size, model, fsdp_config, temp_file_name, unused, test_case), nprocs=world_size, join=True, )
def setUp(self): if torch_version() < (1, 6, 0): raise unittest.SkipTest("Need pytorch version >= 1.6 due to lack of reduce_scatter") if not torch.cuda.is_available(): raise unittest.SkipTest("CUDA not available, skipping test") if sys.platform == "win32": raise unittest.SkipTest("NCCL doesn't support Windows, skipping test") if torch.cuda.device_count() < 2: raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")
def test_regnet(temp_files, ddp_ref, precision, flatten, sync_bn): if torch_version() < (1, 6, 0): pytest.skip("older pytorch doesn't support reduce_scatter") state_before, inputs, conv_bias, linear_bias, state_after = ddp_ref state_after = state_after[(precision, sync_bn)] fsdp_config = {} fsdp_config["mixed_precision"] = precision == "mixed" fsdp_config["flatten_parameters"] = flatten == "flatten" # When linear bias is True, DDP's AMP O1 and FSDP's default AMP O1.5 is different, # we force FSDP to use AMP O1 here by setting compute_dtype to float32. if linear_bias: fsdp_config["compute_dtype"] = torch.float32 if fsdp_config["mixed_precision"] and torch_cuda_version() < (11, 0): pytest.skip("Only CUDA 11 is supported with AMP equivalency") # Wrap BN half of the time. wrap_bn = True if random.randint(0, 1) == 0: wrap_bn = False # Except, always wrap BN in mixed precision + sync_bn mode, due to error of sync_bn wrapping, # regardless of compute_dtype. if fsdp_config["mixed_precision"] and sync_bn != "none": wrap_bn = True # When BN is not wrapped (i.e. not in full precision), FSDP's compute_dtype needs to # be fp32 to match DDP (otherwise, numerical errors happen on BN's running_mean/running_var # buffers). if fsdp_config["mixed_precision"] and not wrap_bn: fsdp_config["compute_dtype"] = torch.float32 world_size = _world_size mp.spawn( _distributed_worker, args=( world_size, fsdp_config, wrap_bn, None, temp_files[0], temp_files[1], state_before, inputs, None, state_after, sync_bn, conv_bias, linear_bias, ), nprocs=world_size, join=True, )
def test_basic(device): if "cuda" in device and not torch.cuda.is_available(): pytest.skip("test requires a GPU") input = torch.rand(2, 16, 32).requires_grad_(True) model = BasicModel().to(device) no_cpt = get_loss_and_gnorm(model, input.to(device)) model = BasicModel(use_pytorch_checkpoint=True).to(device) pyt_cpt = get_loss_and_gnorm(model, input.to(device)) model = BasicModel(use_fairscale_checkpoint=True).to(device) fairscale_cpt = get_loss_and_gnorm(model, input.to(device)) model = BasicModel(use_fairscale_checkpoint=True, offload_to_cpu=True).to(device) fairscale_cpt_offload = get_loss_and_gnorm(model, input.to(device)) # Check for correctness. for key in "loss", "gnorm": if not (no_cpt[key] == pyt_cpt[key] == fairscale_cpt[key] == fairscale_cpt_offload[key]): print(no_cpt, pyt_cpt, fairscale_cpt, fairscale_cpt_offload) assert 0 del no_cpt[key] del pyt_cpt[key] del fairscale_cpt[key] del fairscale_cpt_offload[key] # Check for memory usage for cuda only. if "cpu" in device: return mem_peaks = [98816, 103424, 103424, 107520] if torch_version() < (1, 7, 0): # Older torch behaves slightly differently mem_peaks = [102400, 103424, 103424, 107520] assert no_cpt == {"mem_0": 38912, "mem_peak": mem_peaks[0], "mem_after_fwd": 64000, "mem_after_bwd": 74240}, no_cpt assert pyt_cpt == { "mem_0": 38912, "mem_peak": mem_peaks[1], "mem_after_fwd": 43520, "mem_after_bwd": 74240, }, pyt_cpt assert fairscale_cpt == { "mem_0": 38912, "mem_peak": mem_peaks[2], "mem_after_fwd": 43520, "mem_after_bwd": 74240, }, fairscale_cpt assert fairscale_cpt_offload == { "mem_0": 38912, "mem_peak": mem_peaks[3], "mem_after_fwd": 43520, "mem_after_bwd": 74240, }, fairscale_cpt_offload
def test_it(fsdp_config, input_cls): """Test FSDP with input being a list or a dict, only single GPU.""" if torch_version() < (1, 6, 0): pytest.skip("older pytorch doesn't support reduce_scatter") # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(random.randint(2000, 3000)) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) try: assert isinstance(fsdp_config, dict), str(fsdp_config) class Model(Module): def __init__(self): super().__init__() self.layer = Linear(4, 4) def forward(self, input): if isinstance(input, list): input = input[0] else: assert isinstance(input, dict), input input = input["in"] return self.layer(input) model = FSDP(Model(), **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(5): in_data = torch.rand(64, 4).cuda() in_data.requires_grad = True if input_cls is list: in_data = [in_data] else: assert input_cls is dict in_data = {"in": in_data} out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() model.assert_state(TrainingState.IDLE) finally: # Clean-up is important or the next test in this file may fail to init the PG. torch.distributed.destroy_process_group() del os.environ["MASTER_ADDR"] del os.environ["MASTER_PORT"]
def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename, filename_rpc, expected): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, filename, filename_rpc) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = create_model(with_fsdp, with_checkpoint) model = model.cuda() if with_fsdp: model = to_fsdp(model) else: model = DistributedDataParallel(model, device_ids=[gpu_id], bucket_cap_mb=500) criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=1e-4) results = {} for iteration in range(3): get_cur_mem(gpu_id, results, f"iter {iteration}: start") out = model(batch) get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd") out = sum(o.sum() for o in out[0]) fake_loss = criterion(out, torch.tensor(0.0).cuda()) get_cur_mem(gpu_id, results, f"iter {iteration}: after loss") fake_loss.backward() get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd") optimizer.step() get_cur_mem(gpu_id, results, f"iter {iteration}: after step") # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory. if torch_version() >= (1, 7, 0): model.zero_grad(set_to_none=True) else: for p in model.parameters(): p.grad = None get_cur_mem(gpu_id, results, f"iter {iteration}: done") assert results == expected, f"{results} but expected {expected}" teardown()
def test1(precision, flatten): if torch_version() < (1, 6, 0): pytest.skip("older pytorch doesn't support reduce_scatter") temp_file_name = tempfile.mkstemp()[1] unused = tempfile.mkstemp()[1] fsdp_config = {} fsdp_config["mixed_precision"] = precision == "mixed" fsdp_config["flatten_parameters"] = flatten == "flatten" # Some bugs only show up when we are in world_size > 1 due to sharding changing # the tensor dimensions. world_size = 2 mp.spawn( _test_func, args=(world_size, fsdp_config, temp_file_name, unused), nprocs=world_size, join=True, )
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn): mixed_precision = precision == "mixed" flatten = flatten == "flatten" wrap_bn = wrap_bn == "auto_wrap_bn" fp32_reduce_scatter = True if mixed_precision else None if torch_version() < (1, 8, 0) and flatten: # 1.6 and 1.7 throws this error: # RuntimeError: Trying to backward through the graph a second time, but the saved # intermediate results have already been freed. Specify retain_graph=True when calling # backward the first time. pytest.skip("older pytorch throws error when flatten is used") world_size = 2 expected_losses = None # Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt. for with_fsdp in [False, True]: for with_checkpoint in [False, True]: # Get 4 files: 2 for dist_init and 2 for each rank to save the losses. with temp_files_ctx(num=2 + world_size) as temp_files: mp.spawn( _distributed_worker, ( world_size, with_fsdp, with_checkpoint, temp_files, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, ), nprocs=world_size, ) final_losses = {} for rank in range(world_size): with open(temp_files[2 + rank], "rb") as f: final_losses[f"rank_{rank}"] = pickle.load(f) if expected_losses is None: expected_losses = final_losses else: print(f"fsdp: {with_fsdp} ckpt: {with_checkpoint}") assert objects_are_equal(expected_losses, final_losses, raise_exception=True)
def check(exp, res): assert list(exp.keys()) == list(res.keys( )), f"{list(exp.keys())} vs. {list(res.keys())}" rtol = 1e-4 atol = 1e-5 if with_model2 and mixed_precision and torch_version() >= ( 1, 9, 0): # On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has # larger errors. rtol = 1e-3 atol = 1e-4 for key in exp.keys(): exp_loss = exp[key] res_loss = res[key] torch.testing.assert_allclose(exp_loss, res_loss, rtol=rtol, atol=atol)
def test1(temp_files, ddp_ref, precision, flatten): if torch_version() < (1, 6, 0): pytest.skip("older pytorch doesn't support reduce_scatter") state_before, inputs, state_after_fp, state_after_mp = ddp_ref if precision == "full": state_after = state_after_fp else: state_after = state_after_mp fsdp_config = {} fsdp_config["mixed_precision"] = precision == "mixed" fsdp_config["flatten_parameters"] = flatten == "flatten" if fsdp_config["mixed_precision"] and torch_cuda_version() < (11, 0): pytest.skip("Only CUDA 11 is supported with AMP equivalency") # Wrap BN half of the time in full precision mode. wrap_bn = True if random.randint(0, 1) == 0: wrap_bn = False # Always wrap BN in mixed precision mode. if fsdp_config["mixed_precision"]: wrap_bn = True world_size = _world_size mp.spawn( _test_func, args=( world_size, fsdp_config, wrap_bn, None, temp_files[0], temp_files[1], state_before, inputs, None, state_after, ), nprocs=world_size, join=True, )
def test_one_iteration(world_size, test_case, fsdp_config): """Test FSDP with uneven divide of parameter shards.""" if torch_version() < (1, 6, 0): pytest.skip("older pytorch doesn't support reduce_scatter") if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs.") temp_file_name = tempfile.mkstemp()[1] unused = tempfile.mkstemp()[1] # TODO (Min): we may want to extend this to a simple 2 layer model so that it covers # more cases in FSDP. Also, assert_ref_out can be extended to multiple # iterations. This could be a good bootcamp task. I should file a github # issue once we merge. model = Linear(3, 3, bias=False) mp.spawn( _test_func, args=(world_size, model, fsdp_config, temp_file_name, unused, test_case), nprocs=world_size, join=True, )
def test(world_size, precision, flatten): """ This test simulates wrapping the module after training to run inference. This is required in cases where later in a session, the model is wrapped again in FSDP but contains nested FSDP wrappers within the module. """ if torch_version() < (1, 6, 0): pytest.skip("older pytorch doesn't support reduce_scatter") temp_file_name = tempfile.mkstemp()[1] unused = tempfile.mkstemp()[1] fsdp_config = { "mixed_precision": precision == "mixed", "flatten_parameters": flatten == "flatten", } mp.spawn( _test_func, args=(world_size, fsdp_config, temp_file_name, unused), nprocs=world_size, join=True, )
return True return False model.train() train_output = model(input) assert find_grad_fn(train_output.grad_fn, "CheckpointBackward") assert find_grad_fn(train_output.grad_fn, "RecomputeBackward") model.eval() eval_output = model(input) assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward") assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward") @torch_spawn([2]) @pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) def checkpoint_non_float_input(pipe_class): class ForkNonFloat(nn.Module): def forward(self, input): return (input * 2, torch.tensor([False])) class JoinNonFloat(nn.Module): def forward(self, input): return input[0] * 2 model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) model = pipe_class( model, balance=[1, 1], worker_map=get_worker_map(), chunks=1, checkpoint="always", pipelined_backward=False, )
# limitations under the License. import pytest import torch from torch import nn from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader, Dataset from fairscale.experimental.nn.ampnet_pipe.pipe import AMPnetPipe from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version # Current on CI, there appears to be a bug with torch 1.8 # See: # https://app.circleci.com/pipelines/github/facebookresearch/fairscale/1892/workflows/8f658bf4-8052-4084-bb3e-4cc2c445c8aa/jobs/10080/parallel-runs/0/steps/0-112 # So we skip this file in that case until it is fixed. if torch_version() >= (1, 8, 0): pytestmark = pytest.mark.skip class MySGD(Optimizer): r""" Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float): learning rate (required) """ def __init__(self, params, lr=0.01): defaults = dict(lr=lr) super(MySGD, self).__init__(params, defaults)
import torch.multiprocessing as mp import torch.nn as nn from fairscale.experimental.nn.multiprocess_pipe import DistributedLoss, MultiProcessPipe from fairscale.utils.testing import torch_version BOUNCE_TENSORS = True CPU_DEVICES = ["worker0/cpu", "worker1/cpu"] GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"] if torch.cuda.is_available(): DEVICES = [CPU_DEVICES, GPU_DEVICES] else: DEVICES = [CPU_DEVICES] pytestmark = pytest.mark.skipif(torch_version() < (1, 8, 0), reason="requires torch version >= 1.8.0") def rpc_worker(rank, world_size, init_file, func, *args): if torch_version() == (1, 8, 0): if torch.cuda.is_available(): # Workaround for https://github.com/pytorch/pytorch/issues/53844 options = rpc.TensorPipeRpcBackendOptions( init_method="file://" + init_file, _transports=["ibv", "uv"]) else: # Workaround for https://github.com/pytorch/pytorch/issues/54266 options = rpc.TensorPipeRpcBackendOptions( init_method="file://" + init_file, _channels=[ "mpt_uv", "basic", "cuda_ipc", "cuda_gdr", "cuda_xth",
return True return False model.train() train_output = model(input) assert find_grad_fn(train_output.grad_fn, "CheckpointBackward") assert find_grad_fn(train_output.grad_fn, "RecomputeBackward") model.eval() eval_output = model(input) assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward") assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward") @torch_spawn([2]) @pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True) @pytest.mark.parametrize( "pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) def checkpoint_non_float_input(pipeline_style): class ForkNonFloat(nn.Module): def forward(self, input): return (input * 2, torch.tensor([False])) class JoinNonFloat(nn.Module): def forward(self, input): return input[0] * 2 model = nn.Sequential(ForkNonFloat(), JoinNonFloat())
def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename, filename_rpc, expected, model_hidden_dim, fsdp_config): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, filename, filename_rpc) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True # Note that FSDP auto-cast the input in AMP mode. So we don't need to call half() here. batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = create_model(with_fsdp, with_checkpoint, model_hidden_dim, fsdp_config) model = model.cuda() if with_fsdp: model = to_fsdp(model, fsdp_config) else: model = DistributedDataParallel(model, device_ids=[gpu_id], bucket_cap_mb=500) # We enable momentum so that after the first iteration, the optimizer state is added # to the total memory used. criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) # Set AMP context if needed. context = contextlib.suppress() if "mixed_precision" in fsdp_config and fsdp_config["mixed_precision"]: context = torch.cuda.amp.autocast(enabled=True) # We have observed that sometimes after 3rd iteration, 4th one can fail (not on this # test but on much bigger scale tests). We run 4 iterations here just in case it happens. iterations = 4 results = {} # results of memory stats for iteration in range(iterations): get_cur_mem(gpu_id, results, f"iter {iteration}: start") with context: out = model(batch) get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd") out = sum(o.sum() for o in out[0]) fake_loss = criterion(out, torch.tensor(0.0).cuda()) get_cur_mem(gpu_id, results, f"iter {iteration}: after loss") fake_loss.backward() get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd") optimizer.step() get_cur_mem(gpu_id, results, f"iter {iteration}: after step") # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory. if torch_version() >= (1, 7, 0): model.zero_grad(set_to_none=True) else: for p in model.parameters(): p.grad = None get_cur_mem(gpu_id, results, f"iter {iteration}: done") dump_all_tensors(gpu_id) print(results) def cmp(results, expected): ret = "" assert results.keys() == expected.keys( ), f"{list(results.keys())} vs. {list(expected.keys())}" for k, v in results.items(): exp = expected[k] if abs(exp - v) > 1: # allow 1MB rounding differences ret += f"{k}: got {v}, expected {exp}\n" return ret output = cmp(results, expected) assert not output, output teardown()
import torch from torch import nn from fairscale.nn.model_parallel.initialize import ( destroy_model_parallel, get_pipeline_parallel_group, initialize_model_parallel, ) from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version # Current on CI, there appears to be a bug with torch 1.8 # See: # https://app.circleci.com/pipelines/github/facebookresearch/fairscale/1892/workflows/8f658bf4-8052-4084-bb3e-4cc2c445c8aa/jobs/10080/parallel-runs/0/steps/0-112 # So we skip this file in that case until it is fixed. if torch_version() >= (1, 8, 0): pytestmark = pytest.mark.skip @torch_spawn([2]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) def parameters(pipe_class): model = nn.Sequential(nn.Linear(1, 1)) pipe = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=1) if torch.distributed.get_rank() == 0: assert list(pipe.parameters()) != [] else: assert list(pipe.parameters()) == []
def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer], change_train_graph: bool = False): # Any model works. Add one different buffer per rank trunk = torch.nn.Sequential(torch.nn.Linear(in_channels, hidden), torch.nn.Linear(hidden, hidden), torch.nn.Linear(hidden, hidden)) trunk.register_buffer("test_buffer", torch.ones((1)) * rank) trunk.to(device) head = torch.nn.Linear(hidden, out_channels).to(device) # Define a model to be trained by OSS oss_module = torch.nn.Sequential(trunk, head) # Make sure that the param groups are interleaved, to catch an ordering bug in the state dict oss_trainable_params = [ { "params": list(trunk.parameters())[:-1] + list(head.parameters()), "lr": 1e-5 }, { "params": list(trunk.parameters())[-1], "lr": 1e-4 }, ] optimizer_settings: Dict[Any, Any] = {} if isinstance(optimizer, torch.optim.SGD): optimizer_settings["momentum"] = 0.9 sharded_optimizer = optim.OSS( params=oss_trainable_params, optim=optimizer, group=None, broadcast_buffer_size=2**10, **optimizer_settings, ) oss_ddp_model = DDP(module=oss_module, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) # Define a model to be trained by normal pytorch + DDP ddp_trunk = copy.deepcopy(trunk) ddp_head = copy.deepcopy(head) ddp_module = torch.nn.Sequential(ddp_trunk, ddp_head) ddp_trainable_params = [ { "params": list(ddp_trunk.parameters())[:-1] + list(ddp_head.parameters()), "lr": 1e-5 }, { "params": list(ddp_trunk.parameters())[-1], "lr": 1e-4 }, ] ddp_optimizer = optimizer(ddp_trainable_params, **optimizer_settings) # type: ignore ddp_model = DDP(module=ddp_module, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) def check_step(): input_tensor = torch.rand((batch, in_channels)).to(device) def closure_ddp(input_tensor=input_tensor): ddp_optimizer.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() return ddp_loss def closure_sharded(input_tensor=input_tensor): sharded_optimizer.zero_grad() sharded_loss = oss_ddp_model(input_tensor).abs().sum() sharded_loss.backward() return sharded_loss loss_ddp = cast(torch.Tensor, ddp_optimizer.step(closure=closure_ddp)) loss_sharded_optim = cast( torch.Tensor, sharded_optimizer.step(closure=closure_sharded)) assert torch.allclose( loss_ddp, loss_sharded_optim, rtol=1e-3 ), f"Losses differ in between Pytorch optim and OSS\n {loss_ddp.item()} - {loss_sharded_optim.item()} - world size {world_size}" check_same_model_params(oss_ddp_model, ddp_model) # The model should be synchronized in between the ranks at construction time, check that check_same_model_params(oss_ddp_model, ddp_model) # The models should stay the same in between ddp and sharded optimizer for i in range(5): check_step() # Check that altering the trainable parameters does not cause DDP and OSS to diverge if change_train_graph: # Flip the first parameter from trainable to non-trainable and vice-versa next(ddp_module.parameters()).requires_grad = not next( ddp_module.parameters()).requires_grad next(oss_module.parameters()).requires_grad = not next( oss_module.parameters()).requires_grad # sharded_optimizer.refresh_trainable() # Check that the checkpoints are compatible (post pytorch 1.5) if torch_version()[1] > 5: # - get states ddp_state_dict = ddp_optimizer.state_dict() sharded_optimizer.consolidate_state_dict( recipient_rank=RECIPIENT_RANK) sharded_optim_state_dict = sharded_optimizer.state_dict( ) if rank == RECIPIENT_RANK else {} sharded_optim_state_dict = sync_object_ranks( sharded_optim_state_dict, RECIPIENT_RANK, device) # - cross load the states # run one step and check that the models are still the same ddp_state_dict_ref = copy.deepcopy( ddp_state_dict) # OSS will remove some states ddp_optimizer.load_state_dict( sharded_optim_state_dict) # mixup on purpose ! sharded_optimizer.load_state_dict(ddp_state_dict) check_step() # - self load, rewind, check no problem # run one step and check that the models are still the same ddp_optimizer.load_state_dict(ddp_state_dict_ref) sharded_optimizer.load_state_dict(sharded_optim_state_dict) check_step()
# This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import functools import tempfile import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp from fairscale.nn import MOELayer, Top2Gate from fairscale.utils.testing import torch_version pytestmark = pytest.mark.skipif( not (torch.cuda.is_available() and torch_version() >= (1, 8, 0)), reason="cuda and torch>=1.8.0 required") devices = ["cuda"] def pg_worker(rank, world_size, init_file, func, *args): init_url = "file://" + init_file dist.init_process_group(backend=dist.Backend.NCCL, rank=rank, world_size=world_size, init_method=init_url) torch.cuda.set_device(rank) dist.all_reduce(torch.zeros(1).cuda()) func(*args) dist.destroy_process_group()
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn, model_type, bn_type): mixed_precision = precision == "mixed" flatten = flatten == "flatten" wrap_bn = wrap_bn == "auto_wrap_bn" fp32_reduce_scatter = True if mixed_precision else None with_model2 = model_type == "model2" with_sync_bn = bn_type == "sync_bn" if torch_version() >= (1, 7, 0) and torch_version() < (1, 8, 0) and with_sync_bn: # SyncBN is buggy in 1.7, errors like: # E File "/home/circleci/venv/lib/python3.8/site-packages/torch/nn/modules/_functions.py", line 13, in forward # E dtype=running_mean.dtype, # E AttributeError: 'NoneType' object has no attribute 'dtype' pytest.skip("SyncBatchNorm in 1.7 is buggy") if with_sync_bn and not wrap_bn: pytest.skip("SyncBatchNorm requires auto_wrap_bn") if torch_version() < (1, 8, 0) and flatten: # 1.6 and 1.7 throws this error: # RuntimeError: Trying to backward through the graph a second time, but the saved # intermediate results have already been freed. Specify retain_graph=True when calling # backward the first time. pytest.skip("older pytorch throws error when flatten is used") world_size = 2 expected_losses = None # Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt. for with_fsdp in [False, True]: for with_checkpoint in [False, True]: if not with_fsdp and with_checkpoint: continue final_losses = _get_cached_results( world_size, with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, ) if expected_losses is None: expected_losses = final_losses else: print( f"checking: fsdp {with_fsdp} ckpt {with_checkpoint} with ddp+no_ckpt" ) def check(exp, res): assert list(exp.keys()) == list(res.keys( )), f"{list(exp.keys())} vs. {list(res.keys())}" rtol = 1e-4 atol = 1e-5 if with_model2 and mixed_precision and torch_version() >= ( 1, 9, 0): # On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has # larger errors. rtol = 1e-3 atol = 1e-4 for key in exp.keys(): exp_loss = exp[key] res_loss = res[key] torch.testing.assert_allclose(exp_loss, res_loss, rtol=rtol, atol=atol) check(expected_losses, final_losses)
def run(compute_cycles, all_gather_cycles): has_params = all_gather_cycles > 0 model = _create_model(fsdp_config, compute_cycles, has_params) # Get the input and sets the input's requires_grad to True because # we have a fake compute in the forward pass. batch = torch.rand(1).cuda() batch.requires_grad = True # We run 20 iterations but only collect timing data from the minimal 10 # data points because nondeterministic system events can disturb the timing. cpu_iter = Min10() cpu_wait = Min10() gpu_compute = Min10() gpu_total = Min10() for _ in range(20): # Get two events for measuring the overall time. e1 = Event(enable_timing=True) e2 = Event(enable_timing=True) cpu_start = time.process_time() all_gather_called = False def _delayed_all_gather(*args, **kwargs): nonlocal all_gather_called all_gather_called = True torch.cuda._sleep(all_gather_cycles) return orig_all_gather(*args, **kwargs) # forward pass # # Even though both e1 & e2 are on the compute stream, since # compute depends on all_gather, e2-e1 includes all_gather time. e1.record() with patch("torch.distributed.all_gather", _delayed_all_gather): out = model(batch) if has_params and world_size > 1: assert all_gather_called else: assert not all_gather_called e2.record() # backward pass out.backward() if torch_version() >= (1, 7, 0): model.zero_grad(set_to_none=True) else: for p in model.parameters(): p.grad = None cpu_iter_time = time.process_time() - cpu_start # wait for gpu out.item() cpu_wait_for_gpu_time = time.process_time() - cpu_start - cpu_iter_time # get sum of the compute time times = [] for mod in model.modules(): if not isinstance(mod, Layer): continue times.append(mod.get_time()) # get gpu compute + all_gather time overall_gpu_time = e1.elapsed_time(e2) cpu_iter.add(cpu_iter_time) cpu_wait.add(cpu_wait_for_gpu_time) gpu_compute.add(sum(times)) gpu_total.add(overall_gpu_time) del model return { "cpu_iter": cpu_iter.avg(), "cpu_wait": cpu_wait.avg(), "gpu_compute": gpu_compute.avg(), "gpu_total": gpu_total.avg(), }
import torch.distributed.rpc as rpc import torch.multiprocessing as mp import torch.nn as nn from fairscale.experimental.nn.distributed_pipeline import DistributedLoss, DistributedPipeline, PipelineModulesGraph from fairscale.utils.testing import torch_version CPU_DEVICES = ["worker0/cpu", "worker1/cpu"] GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"] if torch.cuda.is_available(): DEVICES = [CPU_DEVICES, GPU_DEVICES] else: DEVICES = [CPU_DEVICES] pytestmark = pytest.mark.skipif(torch_version() < (1, 9, 0), reason="requires torch version >= 1.9.0") def rpc_worker(rank, world_size, init_file, func, *args): options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file) for i in range(world_size): options.set_device_map("worker" + str(i), {rank: i}) rpc.init_rpc( "worker" + str(rank), rank=rank, world_size=world_size, backend=rpc.BackendType.TENSORPIPE, rpc_backend_options=options, ) if rank == 0: func(*args)
def run_ddp_parity( rank, world_size, backend, temp_file_name, reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp, manual_reduction, ): dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) device = torch.device("cuda") torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) NUMBER_BATCHS = 5 BATCH_SIZE = 8 # Test all combinations: AMP, Accumulate, Change train graph, reduce buckets print( f"{rank}: Checking configuration: accumulate {grad_accumulation}" + f" - change train graph {change_train_graph}" + f" - amp {amp}" + f" - manual reduction {manual_reduction}" + f" - buffers {reduce_buffer_size}", flush=True, ) # The API should be the exact same in between the sharded and non-sharded variants, generic closure def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False): accumulate_steps = 3 if should_accumulate else 1 model.zero_grad() def step(): if scaler is not None: with torch.cuda.amp.autocast(): loss = model(input_tensor).abs().sum() scaler.scale(loss).backward() else: loss = model(input_tensor).abs().sum() loss.backward() with model.no_sync() if should_accumulate else suppress(): for _ in range(accumulate_steps - 1): step() if not _manual_reduction: step() else: with model.no_sync(): step() model.reduce() # Any model works. Add one different buffer per rank model = _get_mlp() model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) # Make sure that the model starts with non-trainable, so that we check for the buckets to be # properly reassigned when/if this changes next(model.parameters()).requires_grad = False sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.99) sharded_ddp_model = ShardedDataParallel( module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True, reduce_buffer_size=reduce_buffer_size, reduce_fp16=fp16_reduction, ) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) if fp16_reduction: from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook ddp_model.register_comm_hook(state=None, hook=fp16_compress_hook) # type: ignore ddp_scaler = TorchGradScaler() if amp else None sharded_scaler = ShardedGradScaler() if amp else None # The model should be synchronized in between the ranks at construction time, check that check_same_model_params(sharded_ddp_model, ddp_model) # Typical training loop, check that we get the exact same results as DDP for i in range(NUMBER_BATCHS): input_tensor = torch.rand((BATCH_SIZE, 2)).to(device) def ddp_closure(input_tensor=input_tensor): return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation) def sharded_closure(input_tensor=input_tensor): return closure( sharded_ddp_model, sharded_scaler, input_tensor, grad_accumulation, _manual_reduction=manual_reduction, ) # Step/scale both for _scaler, _closure, _optimizer in ( (ddp_scaler, ddp_closure, ddp_optimizer), (sharded_scaler, sharded_closure, sharded_optimizer), ): if _scaler is not None: _ = _closure(input_tensor) _scaler.step(_optimizer) _scaler.update() else: _optimizer.step(_closure()) check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke") # Check that the two grad norm are equivalent # NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case # This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also # be valid for ShardedDDP # NOTE: DDP does not handle parameters trainability being changed after the fact, see # https://github.com/pytorch/pytorch/blob/5781aec74ef00284e0262817a649278c2e8072bf/torch/nn/parallel/distributed.py#L471 if clip_grad_norm and not change_train_graph: if torch_version() >= (1, 9, 0): total_norm = torch.nn.utils.clip_grad_norm_( ddp_model.parameters(), 0.3, norm_type=2.0, error_if_nonfinite=False) # type: ignore else: total_norm = torch.nn.utils.clip_grad_norm_( ddp_model.parameters(), 0.3, norm_type=2.0) # type: ignore if not torch.isnan(total_norm): oss_total_norm = sharded_optimizer.clip_grad_norm( 0.3, norm_type=2.0) allclose = torch.allclose(oss_total_norm, total_norm, atol=1e-2 if amp else 1e-8) if not allclose: # Debug helper if this unit test does not pass, compare the gradients in between DDP and ShardedDDP for idx, (p_ddp, p_sdp) in enumerate( zip(ddp_model.parameters(), sharded_ddp_model.parameters())): if p_ddp.grad is not None: if p_sdp.grad is not None: print(rank, idx, torch.norm(p_ddp.grad), torch.norm(p_sdp.grad), flush=True) else: print(rank, idx, torch.norm(p_ddp.grad), "not owned", flush=True) assert ( allclose ), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}" else: print(rank, "NaN grad norm in DDP", flush=True) # Flip the trainability of the first parameter back and forth if i == 0 and change_train_graph: next(sharded_ddp_model.parameters()).requires_grad = not next( sharded_ddp_model.parameters()).requires_grad next(ddp_model.parameters()).requires_grad = not next( ddp_model.parameters()).requires_grad check_same_model_params( sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke") dist.destroy_process_group()