check_offload = m != fsdp_only_seq and i == 0 and offload_activations if m == fsdp_call_checkpoint: # _save_on_cpu should not be called yet self.assertFalse(_save_on_cpu_called) offload_ctx = ( get_patched_save_on_cpu()(pin_memory=True) if offload_activations else contextlib.suppress() ) with offload_ctx: out = checkpoint(m, inp) else: # _save_on_cpu should not be called yet self.assertFalse(_save_on_cpu_called) out = m(inp) if check_offload: self.assertTrue(_save_on_cpu_called) loss = out.sum() loss.backward() losses.append(loss) outputs.append(out) _save_on_cpu_called = False self._verify_parity(losses, outputs, models) instantiate_parametrized_tests(TestFSDPCheckpoint) if __name__ == "__main__": run_tests()
model(*input) self.assertTrue(model._register_post_backward_hooks.called) self.assertTrue(model._register_pre_backward_hooks.called) class TestNoGrad(FSDPTest): @skip_if_lt_x_gpu(2) def test_transformer_no_grad(self): group = dist.distributed_c10d._get_default_group() model = self._get_wrapped_model(group, cuda_first=False) # Train model for a step self._train_for_several_steps(model, num_steps=1, autocast=False) model.eval() # no dropout for this test # Eval in standard mode (i.e., without no_grad) input = model.module.get_input(torch.device("cuda")) ref_output = model(*input) # Eval with no_grad and compare with torch.no_grad(): no_grad_output = model(*input) self.assertEqual(ref_output, no_grad_output) instantiate_parametrized_tests(TestHooks) if __name__ == "__main__": run_tests()
os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '6789' dist.init_process_group("dummy", rank=self.rank, world_size=self.world_size) # test send input_tensor = torch.zeros(2, 2) dist.send(input_tensor, (self.rank + 1) % self.world_size) self.assertEqual(input_tensor, torch.zeros(2, 2) + 1) # test recv input_tensor = torch.zeros(2, 2) dist.recv(input_tensor, (self.rank + 1) % self.world_size) self.assertEqual(input_tensor, torch.zeros(2, 2) + 2) dist.barrier() # intentionally not calling into `destroy_process_group` as not all # user applications would explicitly that. instantiate_parametrized_tests(CommonDistributedDataParallelTest) if __name__ == "__main__": assert ( not torch.cuda._initialized ), "test_distributed must not have initialized CUDA context on main process" run_tests()
for (t_args, mt_args) in self._yield_sample_args(fn_name, data0, data1, mask): mt_result = fn(*mt_args, **kwargs) t_result = fn(*t_args, **kwargs) _compare_mt_t(mt_result, t_result) @parametrize("fn_name", ["add", "add_"]) def test_masks_match(self, fn_name): torch.random.manual_seed(0) fn = getattr(torch.ops.aten, fn_name) data0, data1, mask = self._get_test_data(fn_name) mask0 = mask mask1 = torch.rand(mask.size()) > 0.5 mt0 = masked_tensor(data0, mask0) mt1 = masked_tensor(data1, mask1) try: fn(mt0, mt1) raise AssertionError() except ValueError as e: assert ( "Input masks must match. If you need support for this, please open an issue on Github." == str(e)) instantiate_parametrized_tests(TestUnary) instantiate_parametrized_tests(TestBinary) if __name__ == '__main__': run_tests()
num_reduce_scatter = mock_reduce_scatter.call_count # previous non-sync iteration does not free full parameters for # the root instance. if use_no_sync and i == 0: expected_num_all_gather_sync_updated = expected_num_all_gather_sync - 1 # previous non-sync iteration does not free full parameters if sharding_strategy == ShardingStrategy.SHARD_GRAD_OP: expected_num_all_gather_sync_updated = 0 else: expected_num_all_gather_sync_updated = expected_num_all_gather_sync # no need to all_gather shards in the backward pass when in # SHARD_GRAD_OP mode if sharding_strategy == ShardingStrategy.SHARD_GRAD_OP: expected_num_all_gather_sync_updated = num_fsdp self.assertEqual( num_all_gather, expected_num_all_gather_sync_updated, f"Expected {expected_num_all_gather_sync_updated} all-gathers " f"but saw {num_all_gather} all-gathers when not using " "`no_sync()`") self.assertEqual( num_reduce_scatter, expected_num_reduce_scatter_sync, f"Expected {expected_num_reduce_scatter_sync} reduce-" f"scatters but saw {num_reduce_scatter} reduce-scatters " "when not using `no_sync()`") instantiate_parametrized_tests(TestCommunication) if __name__ == "__main__": run_tests()
) model.register_buffer("buffer", torch.ones(1)) # `named_parameters()` and `named_buffers` will contain FSDP prefixes # if called on a non-FSDP root module fsdp_model = FSDP( NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ), self.process_group, ) fsdp_model.register_buffer("buffer", torch.ones(1)) with FSDP.summon_full_params(fsdp_model): for call in ["named_parameters", "named_buffers"]: for (n1, p1), (n2, p2) in itertools.zip_longest( getattr(fsdp_model, call)(prefix=prefix, recurse=recurse), getattr(model, call)(prefix=prefix, recurse=recurse), ): self.assertEqual(n1, n2) self.assertEqual(p1, p2) instantiate_parametrized_tests(TestSummonFullParams) instantiate_parametrized_tests(TestSummonFullParamsNoShard) if __name__ == "__main__": run_tests()
def __init__(self, t) -> None: self.tensor: torch.Tensor = t __torch_function__ = torch._C._disabled_torch_function_impl @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(e) -> torch.Tensor: if isinstance(e, NonRewrappingTensor): t = e.tensor return t else: return e r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) # Return an unwrapped tensor no longer of original subclass type. return r with self.assertRaisesRegex( RuntimeError, r"requires that detach\(\) returns an instance of the same type" ): param = nn.Parameter(NonRewrappingTensor(torch.randn(3))) instantiate_parametrized_tests(TestSubclass) if __name__ == '__main__': run_tests()
) @parametrize( "cpu_offload", [CPUOffload(offload_params=False), CPUOffload(offload_params=True)], ) @parametrize( "backward_prefetch", [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]) def test_no_sync( self, num_iters_to_acc: int, cpu_offload: CPUOffload, backward_prefetch: Optional[BackwardPrefetch], ): """Tests the ``no_sync()`` context manager.""" assert num_iters_to_acc >= 2, \ "Accumulate for at least 2 iterations to be nontrivial" self._test_no_sync( batch_dim=1, num_iters_to_acc=num_iters_to_acc, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, ) instantiate_parametrized_tests(TestNoSync) if __name__ == "__main__": run_tests()
out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if wrap_fsdp: get_full_params(model) return list(model.parameters()) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) def test_pure_fp16(self, cpu_offload): # DDP ddp_state = self._dist_train(wrap_fsdp=False) # FSDP fsdp_state = self._dist_train(wrap_fsdp=True, cpu_offload=cpu_offload) self.assertEqual(ddp_state, fsdp_state) instantiate_parametrized_tests(TestPureFP16) if __name__ == "__main__": run_tests()
self.assertTrue(new_tensor.reloaded) def test_tensor_subclass_deepcopy(self): wrapped_tensor = torch.rand(2) my_tensor = TestWrapperSubclass(wrapped_tensor) foo_val = "bar" my_tensor.foo = foo_val self.assertEqual(my_tensor.foo, foo_val) new_tensor = deepcopy(my_tensor) self.assertIsInstance(new_tensor, TestWrapperSubclass) self.assertEqual(new_tensor.elem, my_tensor.elem) self.assertEqual(new_tensor.foo, foo_val) @parametrize('requires_grad', (True, False)) def test_cloned_deepcopy(self, requires_grad): my_tensor = torch.rand(2, requires_grad=requires_grad, device='meta') new_tensor = deepcopy(my_tensor) self.assertEqual(new_tensor.requires_grad, my_tensor.requires_grad) instantiate_device_type_tests(TestBothSerialization, globals()) instantiate_parametrized_tests(TestSubclassSerialization) if __name__ == '__main__': run_tests()
Args: pass_ignored_modules_to_root (bool): If ``False``, does not pass any ignored modules (including those already ignored in child FSDP instances) to the root FSDP instance; if ``True``, passes all ignored modules (representing a superset of the children's ignored modules) to the root FSDP instance. """ # To exercise different `FlatParameter` enumerations across ranks, # we wrap `layer3` with FSDP, where `layer3` is registered as a module # after `layer1`, which has the variable number of ignored modules model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda() layer1_ignored_modules = [ m for m in model.layer1.modules() if isinstance(m, IgnoredModule) ] model.layer1 = FSDP(model.layer1, ignored_modules=layer1_ignored_modules) model.layer3 = FSDP(model.layer3) model_ignored_modules = [ m for m in model.modules() if isinstance(m, IgnoredModule) ] if pass_ignored_modules_to_root else [] wrapped_model = FSDP(model, ignored_modules=model_ignored_modules) optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._train_model(wrapped_model, optim, 3) instantiate_parametrized_tests(TestFSDPIgnoredModules) if __name__ == "__main__": run_tests()
def forward(self, input): if isinstance(input, list): input = input[0] else: assert isinstance(input, dict), input input = input["in"] return self.layer(input) model = FSDP(Model()).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(5): in_data = torch.rand(64, 4).cuda() in_data.requires_grad = True if input_cls is list: in_data = [in_data] else: self.assertTrue(input_cls is dict) in_data = {"in": in_data} out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() instantiate_parametrized_tests(TestInput) if __name__ == "__main__": run_tests()
cpu_offload: CPUOffload, backward_prefetch: Optional[BackwardPrefetch], ): """ Tests gradient accumulation. This exercises gradient accumulation inside and outside the ``no_sync()`` context manager, in particular by interleaving the two. It tests both interleaving starting with (and ending with, resp.) inside versus outside ``no_sync()`` to ensure that initial conditions (and final conditions, resp.) do not affect the correctness. This test also checks for compatibility with the CPU offload and backward prefetch options. NOTE: Gradient accumulation without using the ``no_sync()`` context manager is not currently compatible with CPU offloading, so those tests are vacuous. """ self._test_grad_acc( batch_dim=1, configs=configs.configs, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, ) instantiate_parametrized_tests(TestGradAcc) if __name__ == "__main__": run_tests()
path = paths[0] writer = FileSystemWriter(path) reader = FileSystemReader(path) with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( new_model, state_dict_type): state_dict = model.state_dict() save_state_dict(state_dict, writer) with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( new_model, state_dict_type): state_dict = new_model.state_dict() load_state_dict(state_dict, reader) new_model.load_state_dict(state_dict) with FullyShardedDataParallel.summon_full_params( model), FullyShardedDataParallel.summon_full_params(new_model): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertEqual(params, new_params) # TODO: add resharding test case. instantiate_parametrized_tests(TestDistributedCheckpoint) if __name__ == "__main__": run_tests()
expected[f"iter {iteration}: after fwd"] = ( 340 + sharded_model_size_mb ) expected[f"iter {iteration}: after loss"] = ( 340 + sharded_model_size_mb ) expected[f"iter {iteration}: after bwd"] = 3 * sharded_model_size_mb + 1 # sharded model size + sharded grad size + optimizer states + 1M temp memory expected[f"iter {iteration}: after step"] = 3 * sharded_model_size_mb + 1 # grad memory is claimed after setting grad = None # sharded model size + optimizer states + 1M temp memory expected[f"iter {iteration}: done"] = 2 * sharded_model_size_mb + 1 # Get the fsdp and checkpoint flags. with_ckpt = ckpt == "ckpt" self._dist_train( with_ckpt, expected, model_hidden_dim, iterations, ) instantiate_parametrized_tests(TestFSDPMemory) if __name__ == "__main__": run_tests()
_replace_by_prefix(state_dict, "layer.", "module.layer.") assert state_dict == { "module.layer.a": torch.tensor(1), "abc.layer.def": torch.tensor(2), "module.layer.b": torch.tensor(3), } _replace_by_prefix(state_dict, "module.layer.", "layer.") assert state_dict == original_state_dict def test_packed_sequence(self): """Test to ensure RNN packed sequences are modified correctly.""" rnn = nn.RNN(5, 5) x = torch.rand((5, 1, 5), dtype=torch.float) seq_length = torch.tensor([4], dtype=torch.int) def fill_fn(x): x.fill_(0) x = nn.utils.rnn.pack_padded_sequence(x, seq_length) x, h = rnn(x) x = _apply_to_tensors(fill_fn, x) x, _ = nn.utils.rnn.pad_packed_sequence(x) self.assertEqual(torch.sum(x), 0) instantiate_parametrized_tests(TestUtils) if __name__ == "__main__": run_tests()
# Shard the non-wrapped model's re-keyed optimizer state dict, which # maps back to (flattened) parameter IDs sharded_osd = FSDP.shard_full_optim_state_dict( rekeyed_osd, model1, optim_input1, ) # Check that this sharded optimizer state dict matches the wrapped # model's per-rank optimizer state dict osd1 = optim1.state_dict() check_same_param_keys = True self._check_same_param_groups( sharded_osd, osd1, check_same_param_keys=check_same_param_keys, ) self._check_same_state( sharded_osd, osd1, check_same_param_keys=check_same_param_keys, ) # As a sanity check, check that we can load and run a few iterations optim1.load_state_dict(sharded_osd) self._step_model(model1, optim1, num_iters=NUM_ITERS) instantiate_parametrized_tests(TestFSDPOptimState) if __name__ == "__main__": run_tests()
self.assertEqual( execution_info.module_to_execution_infos[model.layer1], [(model.layer1, list(model.layer1.named_parameters()))], ) self.assertEqual( execution_info.module_to_execution_infos[model.layer2], [ (model.layer2[0], list(model.layer2[0].named_parameters())), (model.layer2[2], list(model.layer2[2].named_parameters())), ], ) self.assertEqual(execution_info.module_to_execution_infos[model.relu], []) # test tracer.param_exec_order correct_param_order = [ model.layer0.weight, model.layer0.bias, model.layer2[0].weight, model.layer2[2].weight, model.weight1, model.layer1.weight, model.weight2, ] self.assertEqual(execution_info.param_exec_order, correct_param_order) instantiate_parametrized_tests(TestSymbolicTracing) if __name__ == "__main__": run_tests()
src_lengths=None, with_triangle_mask=False, incremental_state=incr_state, ) for i in range(1, seqlen + 1) ] ref_output = torch.stack(ref_outputs) incr_key_lst = [] incr_value_lst = [] results = [] for i in range(1, seqlen + 1): res, incr_key_lst, incr_value_lst = better_decoder( tokens[:, :i], src_mask=None, include_padding_mask=False, incr_key_lst=incr_key_lst, incr_value_lst=incr_value_lst, is_incremental_decoding=True, ) results.append(res) result = torch.stack(results) self.assertEqual(result.shape, ref_output.shape) torch.testing.assert_close(result, ref_output, atol=1e-3, rtol=1e-2) instantiate_parametrized_tests(TestTransformers) if __name__ == '__main__': run_tests()
if attr.name in ("then_branch", "else_branch"): self.assertEqual(expected_output_type, attr.g.output[0].type) def test_uninitialized_optional(self): class Module(torch.nn.Module): def forward(self, y: Optional[Tensor]) -> Optional[Tensor]: if y is not None: if y.shape[1] < 5: if y.size(0) == 1: y = y + 4 else: return y return y y = torch.ones((3, 4), dtype=torch.int) torch.onnx.export( torch.jit.script(Module()), y, io.BytesIO(), opset_version=15, dynamic_axes={"y": {0: "y0", 1: "y1"}}, input_names=["y"], ) instantiate_parametrized_tests(TestOptionalOutput) if __name__ == "__main__": unittest.main()
@parametrize("freezing_method", [FreezingMethod.RequiresGrad, FreezingMethod.GradToNone]) @parametrize("freeze_after_wrap_fsdp", [True, False]) def test_freezing_weights(self, with_nested_trunk, freezing_method, freeze_after_wrap_fsdp): # DDP ddp_state = self._dist_train(with_nested_trunk, freezing_method, freeze_after_wrap_fsdp, with_fsdp=False) # FSDP fsdp_state = self._dist_train(with_nested_trunk, freezing_method, freeze_after_wrap_fsdp, with_fsdp=True) self.assertEqual( ddp_state, fsdp_state, exact_device=True, msg= "FullyShardedDataParallel states didn't match PyTorch DDP states", ) instantiate_parametrized_tests(TestFreezingWeights) if __name__ == "__main__": run_tests()
fsdp_kwargs["mixed_precision"] = None fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, fsdp_kwargs, ) self._train_for_several_steps( fsdp_model, num_steps=1, autocast=False, mixed_precision=fsdp_kwargs["mixed_precision"] ) input = fsdp_model.module.get_input(torch.device("cuda")) # Run a forward in eval mode fsdp_model.eval() ref_output = fsdp_model(*input) # Run a forward in `no_grad()` and compare with torch.no_grad(): no_grad_output = fsdp_model(*input) self.assertEqual(ref_output, no_grad_output) instantiate_parametrized_tests(TestHooks) instantiate_parametrized_tests(TestParityWithDDP) instantiate_parametrized_tests(TestNoGrad) instantiate_parametrized_tests(TestParamInit) if __name__ == "__main__": run_tests()
# buffer still have the same data pointer with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd2 = fsdp_model.state_dict(prefix=prefix_str) for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): prefixed_tensor_name = f"{prefix_str}{tensor_name}" self.assertTrue(prefixed_tensor_name in sd2) self.assertEqual(tensor.data_ptr(), sd2[prefixed_tensor_name].data_ptr()) self.assertEqual(sd1[prefixed_tensor_name].data_ptr(), sd2[prefixed_tensor_name].data_ptr()) @skip_if_lt_x_gpu(2) def test_state_dict_type(self): module = SkipModel(double_nest=True) with enable_wrap(wrapper_cls=FSDP): fsdp = wrap(module) with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT): pass for module in FSDP.fsdp_modules(fsdp): self.assertEqual(module._state_dict_type, StateDictType.FULL_STATE_DICT) instantiate_parametrized_tests(TestFSDPStateDict) if __name__ == "__main__": run_tests()
inp = fsdp_model.module.get_input(self.device) output = fsdp_model(*inp) loss = fsdp_model.module.get_loss(inp, output).to(self.device) fsdp_model.module.run_backward(loss) # Match the warning message with the following prefix regex = "^(All-gather order differs from that of the first iteration " \ f"on rank {self.rank} -- collectives are unchecked and may give " \ "incorrect results or hang)" context = self.assertWarnsRegex( expected_warning=UserWarning, expected_regex=regex, ) if self.rank != 0 else suppress() if self.rank != 0: fsdp_model.flip_path() inp = fsdp_model.module.get_input(self.device) # Expect a warning for the forward pass all-gather with context: # warning for forward pass all-gather output = fsdp_model(*inp) loss = fsdp_model.module.get_loss(inp, output).to(self.device) fsdp_model.module.run_backward(loss) # Run an additional iteration to check that there are no more warnings inp = fsdp_model.module.get_input(self.device) output = fsdp_model(*inp) loss = fsdp_model.module.get_loss(inp, output).to(self.device) fsdp_model.module.run_backward(loss) instantiate_parametrized_tests(TestFSDPExecOrder) if __name__ == "__main__": run_tests()
m = MyModel(self.rank).cuda() _validate(m, process_group=self.process_group, assert_fn=self.assertNotEqual) # Passing sync_module_states into FSDP makes model the same during init. fsdp = FSDP(m, sync_module_states=True) with fsdp.summon_full_params(fsdp): _validate(fsdp, process_group=self.process_group, assert_fn=self.assertEqual) # sync_module_states also works with CPU module with device_id passed in m = MyModel(self.rank) _validate(m, process_group=self.process_group, assert_fn=self.assertNotEqual) # Passing sync_module_states into FSDP makes model the same during init. fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True) with fsdp.summon_full_params(fsdp): _validate(fsdp, process_group=self.process_group, assert_fn=self.assertEqual) instantiate_parametrized_tests(TestFSDPMisc) if __name__ == "__main__": run_tests()
@skip_if_lt_x_gpu(1) def test_mixed_precision_no_reshard_after_forward(self): # Note that we don't exercise all possible different configs so as to # not increase test TTS too much. mp = default_mp if not nccl_supports_bf16 else mp_diff_buffer_and_reduce self._run_test_mixed_precision_e2e( mp_config=mp, cpu_offload=CPUOffload(offload_params=True), backward_prefetch=None, full_precision_param_dtype=torch.float64, sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, ) @skip_if_lt_x_gpu(1) def test_mixed_precision_e2e_full_shard(self): mp = default_mp if not nccl_supports_bf16 else mp_diff_buffer_and_reduce self._run_test_mixed_precision_e2e( mp_config=mp, cpu_offload=CPUOffload(offload_params=True), backward_prefetch=None, full_precision_param_dtype=torch.float64, sharding_strategy=ShardingStrategy.FULL_SHARD, ) instantiate_parametrized_tests(TestFSDPMixedPrecisionSharded) if __name__ == "__main__": run_tests()
self.assertTrue(model._register_post_backward_hooks.called) self.assertTrue(model._register_pre_backward_hooks.called) class TestNoGrad(FSDPTest): @skip_if_lt_x_gpu(2) def test_transformer_no_grad(self): group = dist.distributed_c10d._get_default_group() model = self._get_wrapped_model(group, cuda_first=False) # Train model for a step self._train_for_several_steps(model, num_steps=1, autocast=False) model.eval() # no dropout for this test # Eval in standard mode (i.e., without no_grad) input = model.module.get_input(torch.device("cuda")) ref_output = model(*input) # Eval with no_grad and compare with torch.no_grad(): no_grad_output = model(*input) self.assertEqual(ref_output, no_grad_output) instantiate_parametrized_tests(TestHooks) instantiate_parametrized_tests(TestParityWithDDP) if __name__ == "__main__": run_tests()
# Same (nested) structures ((1, [2, 3]), (0, [0, 0]), [1, 2, 3]), ((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]), # Mismatched (nested) structures ((1, [2, 3]), (0, (0, 0)), None), ((1, [2, 3]), (0, [0, 0, 0]), None), # Broadcasting single value (1, (0, 0, 0), [1, 1, 1]), (1, [0, 0, 0], [1, 1, 1]), (1, {'a': 0, 'b': 0}, [1, 1]), (1, (0, [0, [0]], 0), [1, 1, 1, 1]), (1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]), # Broadcast multiple things ((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]), ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]), (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]), ] for pytree, to_pytree, expected in cases: _, to_spec = tree_flatten(to_pytree) result = _broadcast_to_and_flatten(pytree, to_spec) self.assertEqual(result, expected, msg=str([pytree, to_spec, expected])) instantiate_parametrized_tests(TestPytree) if __name__ == '__main__': run_tests()
os.environ["MASTER_PORT"] = str(find_free_port()) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) # NOTE: We move model to CUDA after init with FSDP to simulate real use # cases where full model cannot be loaded onto GPU, but their shards can. cuda_after_init = fsdp_init_mode == FSDPInitMode.CUDA_AFTER try: sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=(not cuda_after_init)) my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40 ) model = FSDP(sequential, cpu_offload=cpu_offload, fsdp_auto_wrap_policy=my_auto_wrap_policy) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) if cuda_after_init: model = model.cuda() input = torch.rand((1, 5), dtype=torch.float).to(device) output = model(input) loss = F.mse_loss(input, output) loss.backward() finally: torch.distributed.destroy_process_group() del os.environ["MASTER_ADDR"] del os.environ["MASTER_PORT"] instantiate_parametrized_tests(TestFSDPWrap) instantiate_parametrized_tests(TestAutoWrap) if __name__ == "__main__": run_tests()
auto_wrap=auto_wrap, meta_module_fn=meta_module_fn, init_fn=_init_with_torchdistX, ) def _test_bad_arg(self, meta_module_fn): mod = meta_module_fn() with self.assertRaisesRegex(ValueError, "to be callable"): FSDP(mod, param_init_fn=42) @skip_if_lt_x_gpu(2) @sandcastle_skip_if( not _TORCHDISTX_AVAIL, "Test requires torchdistX: https://github.com/pytorch/torchdistX" ) def test_bad_arg_torchdistx(self): def meta_module_fn(): return deferred_init.deferred_init(NestedModel, "cuda") self._test_bad_arg(meta_module_fn) @skip_if_lt_x_gpu(2) def test_bad_arg_meta(self): def meta_module_fn(): return NestedModel(device="meta") self._test_bad_arg(meta_module_fn) instantiate_parametrized_tests(TestFSDPWithMetaDevice) if __name__ == "__main__": run_tests()