Esempio n. 1
0
                    check_offload = m != fsdp_only_seq and i == 0 and offload_activations
                    if m == fsdp_call_checkpoint:
                        # _save_on_cpu should not be called yet
                        self.assertFalse(_save_on_cpu_called)
                        offload_ctx = (
                            get_patched_save_on_cpu()(pin_memory=True)
                            if offload_activations
                            else contextlib.suppress()
                        )
                        with offload_ctx:
                            out = checkpoint(m, inp)
                    else:
                        # _save_on_cpu should not be called yet
                        self.assertFalse(_save_on_cpu_called)
                        out = m(inp)

                    if check_offload:
                        self.assertTrue(_save_on_cpu_called)
                    loss = out.sum()
                    loss.backward()
                    losses.append(loss)
                    outputs.append(out)
                    _save_on_cpu_called = False

                self._verify_parity(losses, outputs, models)

instantiate_parametrized_tests(TestFSDPCheckpoint)

if __name__ == "__main__":
    run_tests()
Esempio n. 2
0
        model(*input)
        self.assertTrue(model._register_post_backward_hooks.called)
        self.assertTrue(model._register_pre_backward_hooks.called)


class TestNoGrad(FSDPTest):
    @skip_if_lt_x_gpu(2)
    def test_transformer_no_grad(self):
        group = dist.distributed_c10d._get_default_group()
        model = self._get_wrapped_model(group, cuda_first=False)
        # Train model for a step
        self._train_for_several_steps(model, num_steps=1, autocast=False)

        model.eval()  # no dropout for this test

        # Eval in standard mode (i.e., without no_grad)
        input = model.module.get_input(torch.device("cuda"))
        ref_output = model(*input)

        # Eval with no_grad and compare
        with torch.no_grad():
            no_grad_output = model(*input)

        self.assertEqual(ref_output, no_grad_output)


instantiate_parametrized_tests(TestHooks)

if __name__ == "__main__":
    run_tests()
Esempio n. 3
0
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '6789'
        dist.init_process_group("dummy",
                                rank=self.rank,
                                world_size=self.world_size)

        # test send
        input_tensor = torch.zeros(2, 2)
        dist.send(input_tensor, (self.rank + 1) % self.world_size)
        self.assertEqual(input_tensor, torch.zeros(2, 2) + 1)

        # test recv
        input_tensor = torch.zeros(2, 2)
        dist.recv(input_tensor, (self.rank + 1) % self.world_size)
        self.assertEqual(input_tensor, torch.zeros(2, 2) + 2)

        dist.barrier()
        # intentionally not calling into `destroy_process_group` as not all
        # user applications would explicitly that.


instantiate_parametrized_tests(CommonDistributedDataParallelTest)

if __name__ == "__main__":
    assert (
        not torch.cuda._initialized
    ), "test_distributed must not have initialized CUDA context on main process"

    run_tests()
Esempio n. 4
0
        for (t_args,
             mt_args) in self._yield_sample_args(fn_name, data0, data1, mask):
            mt_result = fn(*mt_args, **kwargs)
            t_result = fn(*t_args, **kwargs)
            _compare_mt_t(mt_result, t_result)

    @parametrize("fn_name", ["add", "add_"])
    def test_masks_match(self, fn_name):
        torch.random.manual_seed(0)
        fn = getattr(torch.ops.aten, fn_name)
        data0, data1, mask = self._get_test_data(fn_name)
        mask0 = mask
        mask1 = torch.rand(mask.size()) > 0.5
        mt0 = masked_tensor(data0, mask0)
        mt1 = masked_tensor(data1, mask1)
        try:
            fn(mt0, mt1)
            raise AssertionError()
        except ValueError as e:
            assert (
                "Input masks must match. If you need support for this, please open an issue on Github."
                == str(e))


instantiate_parametrized_tests(TestUnary)
instantiate_parametrized_tests(TestBinary)

if __name__ == '__main__':
    run_tests()
Esempio n. 5
0
                num_reduce_scatter = mock_reduce_scatter.call_count
                # previous non-sync iteration does not free full parameters for
                # the root instance.
                if use_no_sync and i == 0:
                    expected_num_all_gather_sync_updated = expected_num_all_gather_sync - 1
                    # previous non-sync iteration does not free full parameters
                    if sharding_strategy == ShardingStrategy.SHARD_GRAD_OP:
                        expected_num_all_gather_sync_updated = 0
                else:
                    expected_num_all_gather_sync_updated = expected_num_all_gather_sync
                    # no need to all_gather shards in the backward pass when in
                    # SHARD_GRAD_OP mode
                    if sharding_strategy == ShardingStrategy.SHARD_GRAD_OP:
                        expected_num_all_gather_sync_updated = num_fsdp
                self.assertEqual(
                    num_all_gather, expected_num_all_gather_sync_updated,
                    f"Expected {expected_num_all_gather_sync_updated} all-gathers "
                    f"but saw {num_all_gather} all-gathers when not using "
                    "`no_sync()`")
                self.assertEqual(
                    num_reduce_scatter, expected_num_reduce_scatter_sync,
                    f"Expected {expected_num_reduce_scatter_sync} reduce-"
                    f"scatters but saw {num_reduce_scatter} reduce-scatters "
                    "when not using `no_sync()`")


instantiate_parametrized_tests(TestCommunication)

if __name__ == "__main__":
    run_tests()
        )
        model.register_buffer("buffer", torch.ones(1))
        # `named_parameters()` and `named_buffers` will contain FSDP prefixes
        # if called on a non-FSDP root module
        fsdp_model = FSDP(
            NestedWrappedModule.init(
                self.process_group,
                FSDPInitMode.NO_FSDP,
                CUDAInitMode.CUDA_BEFORE,
                deterministic=True,
            ),
            self.process_group,
        )
        fsdp_model.register_buffer("buffer", torch.ones(1))
        with FSDP.summon_full_params(fsdp_model):
            for call in ["named_parameters", "named_buffers"]:
                for (n1, p1), (n2, p2) in itertools.zip_longest(
                        getattr(fsdp_model, call)(prefix=prefix,
                                                  recurse=recurse),
                        getattr(model, call)(prefix=prefix, recurse=recurse),
                ):
                    self.assertEqual(n1, n2)
                    self.assertEqual(p1, p2)


instantiate_parametrized_tests(TestSummonFullParams)
instantiate_parametrized_tests(TestSummonFullParamsNoShard)

if __name__ == "__main__":
    run_tests()
Esempio n. 7
0
            def __init__(self, t) -> None:
                self.tensor: torch.Tensor = t

            __torch_function__ = torch._C._disabled_torch_function_impl

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                def unwrap(e) -> torch.Tensor:
                    if isinstance(e, NonRewrappingTensor):
                        t = e.tensor
                        return t
                    else:
                        return e

                r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
                # Return an unwrapped tensor no longer of original subclass type.
                return r

        with self.assertRaisesRegex(
                RuntimeError,
                r"requires that detach\(\) returns an instance of the same type"
        ):
            param = nn.Parameter(NonRewrappingTensor(torch.randn(3)))


instantiate_parametrized_tests(TestSubclass)

if __name__ == '__main__':
    run_tests()
Esempio n. 8
0
    )
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=False),
         CPUOffload(offload_params=True)],
    )
    @parametrize(
        "backward_prefetch",
        [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None])
    def test_no_sync(
        self,
        num_iters_to_acc: int,
        cpu_offload: CPUOffload,
        backward_prefetch: Optional[BackwardPrefetch],
    ):
        """Tests the ``no_sync()`` context manager."""
        assert num_iters_to_acc >= 2, \
            "Accumulate for at least 2 iterations to be nontrivial"
        self._test_no_sync(
            batch_dim=1,
            num_iters_to_acc=num_iters_to_acc,
            cpu_offload=cpu_offload,
            backward_prefetch=backward_prefetch,
        )


instantiate_parametrized_tests(TestNoSync)

if __name__ == "__main__":
    run_tests()
Esempio n. 9
0
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        if wrap_fsdp:
            get_full_params(model)

        return list(model.parameters())

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)],
    )
    def test_pure_fp16(self, cpu_offload):
        # DDP
        ddp_state = self._dist_train(wrap_fsdp=False)

        # FSDP
        fsdp_state = self._dist_train(wrap_fsdp=True, cpu_offload=cpu_offload)

        self.assertEqual(ddp_state, fsdp_state)


instantiate_parametrized_tests(TestPureFP16)

if __name__ == "__main__":
    run_tests()
Esempio n. 10
0
        self.assertTrue(new_tensor.reloaded)

    def test_tensor_subclass_deepcopy(self):
        wrapped_tensor = torch.rand(2)
        my_tensor = TestWrapperSubclass(wrapped_tensor)

        foo_val = "bar"
        my_tensor.foo = foo_val
        self.assertEqual(my_tensor.foo, foo_val)

        new_tensor = deepcopy(my_tensor)

        self.assertIsInstance(new_tensor, TestWrapperSubclass)
        self.assertEqual(new_tensor.elem, my_tensor.elem)
        self.assertEqual(new_tensor.foo, foo_val)

    @parametrize('requires_grad', (True, False))
    def test_cloned_deepcopy(self, requires_grad):
        my_tensor = torch.rand(2, requires_grad=requires_grad, device='meta')

        new_tensor = deepcopy(my_tensor)

        self.assertEqual(new_tensor.requires_grad, my_tensor.requires_grad)


instantiate_device_type_tests(TestBothSerialization, globals())
instantiate_parametrized_tests(TestSubclassSerialization)

if __name__ == '__main__':
    run_tests()
Esempio n. 11
0
        Args:
            pass_ignored_modules_to_root (bool): If ``False``, does not pass
                any ignored modules (including those already ignored in child
                FSDP instances) to the root FSDP instance; if ``True``, passes
                all ignored modules (representing a superset of the children's
                ignored modules) to the root FSDP instance.
        """
        # To exercise different `FlatParameter` enumerations across ranks,
        # we wrap `layer3` with FSDP, where `layer3` is registered as a module
        # after `layer1`, which has the variable number of ignored modules
        model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda()
        layer1_ignored_modules = [
            m for m in model.layer1.modules() if isinstance(m, IgnoredModule)
        ]
        model.layer1 = FSDP(model.layer1,
                            ignored_modules=layer1_ignored_modules)
        model.layer3 = FSDP(model.layer3)
        model_ignored_modules = [
            m for m in model.modules() if isinstance(m, IgnoredModule)
        ] if pass_ignored_modules_to_root else []
        wrapped_model = FSDP(model, ignored_modules=model_ignored_modules)
        optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
        self._train_model(wrapped_model, optim, 3)


instantiate_parametrized_tests(TestFSDPIgnoredModules)

if __name__ == "__main__":
    run_tests()
Esempio n. 12
0
            def forward(self, input):
                if isinstance(input, list):
                    input = input[0]
                else:
                    assert isinstance(input, dict), input
                    input = input["in"]
                return self.layer(input)

        model = FSDP(Model()).cuda()
        optim = SGD(model.parameters(), lr=0.1)

        for _ in range(5):
            in_data = torch.rand(64, 4).cuda()
            in_data.requires_grad = True
            if input_cls is list:
                in_data = [in_data]
            else:
                self.assertTrue(input_cls is dict)
                in_data = {"in": in_data}

            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()


instantiate_parametrized_tests(TestInput)

if __name__ == "__main__":
    run_tests()
Esempio n. 13
0
        cpu_offload: CPUOffload,
        backward_prefetch: Optional[BackwardPrefetch],
    ):
        """
        Tests gradient accumulation.

        This exercises gradient accumulation inside and outside the
        ``no_sync()`` context manager, in particular by interleaving the two.
        It tests both interleaving starting with (and ending with, resp.)
        inside versus outside ``no_sync()`` to ensure that initial conditions
        (and final conditions, resp.) do not affect the correctness. This test
        also checks for compatibility with the CPU offload and backward
        prefetch options.

        NOTE: Gradient accumulation without using the ``no_sync()`` context
        manager is not currently compatible with CPU offloading, so those tests
        are vacuous.
        """
        self._test_grad_acc(
            batch_dim=1,
            configs=configs.configs,
            cpu_offload=cpu_offload,
            backward_prefetch=backward_prefetch,
        )


instantiate_parametrized_tests(TestGradAcc)

if __name__ == "__main__":
    run_tests()
Esempio n. 14
0
            path = paths[0]
            writer = FileSystemWriter(path)
            reader = FileSystemReader(path)
            with FSDP.state_dict_type(model,
                                      state_dict_type), FSDP.state_dict_type(
                                          new_model, state_dict_type):
                state_dict = model.state_dict()

            save_state_dict(state_dict, writer)

            with FSDP.state_dict_type(model,
                                      state_dict_type), FSDP.state_dict_type(
                                          new_model, state_dict_type):
                state_dict = new_model.state_dict()
                load_state_dict(state_dict, reader)
                new_model.load_state_dict(state_dict)

        with FullyShardedDataParallel.summon_full_params(
                model), FullyShardedDataParallel.summon_full_params(new_model):
            params = list(model.parameters())
            new_params = list(new_model.parameters())
            self.assertEqual(params, new_params)

        # TODO: add resharding test case.


instantiate_parametrized_tests(TestDistributedCheckpoint)

if __name__ == "__main__":
    run_tests()
Esempio n. 15
0
                    expected[f"iter {iteration}: after fwd"] = (
                        340 + sharded_model_size_mb
                    )
                    expected[f"iter {iteration}: after loss"] = (
                        340 + sharded_model_size_mb
                    )
                expected[f"iter {iteration}: after bwd"] = 3 * sharded_model_size_mb + 1

            # sharded model size + sharded grad size + optimizer states + 1M temp memory
            expected[f"iter {iteration}: after step"] = 3 * sharded_model_size_mb + 1
            # grad memory is claimed after setting grad = None
            # sharded model size + optimizer states + 1M temp memory
            expected[f"iter {iteration}: done"] = 2 * sharded_model_size_mb + 1

        # Get the fsdp and checkpoint flags.
        with_ckpt = ckpt == "ckpt"

        self._dist_train(
            with_ckpt,
            expected,
            model_hidden_dim,
            iterations,
        )


instantiate_parametrized_tests(TestFSDPMemory)


if __name__ == "__main__":
    run_tests()
Esempio n. 16
0
        _replace_by_prefix(state_dict, "layer.", "module.layer.")
        assert state_dict == {
            "module.layer.a": torch.tensor(1),
            "abc.layer.def": torch.tensor(2),
            "module.layer.b": torch.tensor(3),
        }
        _replace_by_prefix(state_dict, "module.layer.", "layer.")
        assert state_dict == original_state_dict

    def test_packed_sequence(self):
        """Test to ensure RNN packed sequences are modified correctly."""
        rnn = nn.RNN(5, 5)

        x = torch.rand((5, 1, 5), dtype=torch.float)
        seq_length = torch.tensor([4], dtype=torch.int)

        def fill_fn(x):
            x.fill_(0)

        x = nn.utils.rnn.pack_padded_sequence(x, seq_length)
        x, h = rnn(x)
        x = _apply_to_tensors(fill_fn, x)
        x, _ = nn.utils.rnn.pad_packed_sequence(x)
        self.assertEqual(torch.sum(x), 0)


instantiate_parametrized_tests(TestUtils)

if __name__ == "__main__":
    run_tests()
Esempio n. 17
0
        # Shard the non-wrapped model's re-keyed optimizer state dict, which
        # maps back to (flattened) parameter IDs
        sharded_osd = FSDP.shard_full_optim_state_dict(
            rekeyed_osd,
            model1,
            optim_input1,
        )
        # Check that this sharded optimizer state dict matches the wrapped
        # model's per-rank optimizer state dict
        osd1 = optim1.state_dict()
        check_same_param_keys = True
        self._check_same_param_groups(
            sharded_osd,
            osd1,
            check_same_param_keys=check_same_param_keys,
        )
        self._check_same_state(
            sharded_osd,
            osd1,
            check_same_param_keys=check_same_param_keys,
        )
        # As a sanity check, check that we can load and run a few iterations
        optim1.load_state_dict(sharded_osd)
        self._step_model(model1, optim1, num_iters=NUM_ITERS)


instantiate_parametrized_tests(TestFSDPOptimState)

if __name__ == "__main__":
    run_tests()
Esempio n. 18
0
        self.assertEqual(
            execution_info.module_to_execution_infos[model.layer1],
            [(model.layer1, list(model.layer1.named_parameters()))],
        )
        self.assertEqual(
            execution_info.module_to_execution_infos[model.layer2],
            [
                (model.layer2[0], list(model.layer2[0].named_parameters())),
                (model.layer2[2], list(model.layer2[2].named_parameters())),
            ],
        )
        self.assertEqual(execution_info.module_to_execution_infos[model.relu],
                         [])
        # test tracer.param_exec_order
        correct_param_order = [
            model.layer0.weight,
            model.layer0.bias,
            model.layer2[0].weight,
            model.layer2[2].weight,
            model.weight1,
            model.layer1.weight,
            model.weight2,
        ]
        self.assertEqual(execution_info.param_exec_order, correct_param_order)


instantiate_parametrized_tests(TestSymbolicTracing)

if __name__ == "__main__":
    run_tests()
Esempio n. 19
0
                src_lengths=None,
                with_triangle_mask=False,
                incremental_state=incr_state,
            ) for i in range(1, seqlen + 1)
        ]
        ref_output = torch.stack(ref_outputs)

        incr_key_lst = []
        incr_value_lst = []
        results = []
        for i in range(1, seqlen + 1):
            res, incr_key_lst, incr_value_lst = better_decoder(
                tokens[:, :i],
                src_mask=None,
                include_padding_mask=False,
                incr_key_lst=incr_key_lst,
                incr_value_lst=incr_value_lst,
                is_incremental_decoding=True,
            )
            results.append(res)
        result = torch.stack(results)

        self.assertEqual(result.shape, ref_output.shape)
        torch.testing.assert_close(result, ref_output, atol=1e-3, rtol=1e-2)


instantiate_parametrized_tests(TestTransformers)

if __name__ == '__main__':
    run_tests()
Esempio n. 20
0
                    if attr.name in ("then_branch", "else_branch"):
                        self.assertEqual(expected_output_type, attr.g.output[0].type)

    def test_uninitialized_optional(self):
        class Module(torch.nn.Module):
            def forward(self, y: Optional[Tensor]) -> Optional[Tensor]:
                if y is not None:
                    if y.shape[1] < 5:
                        if y.size(0) == 1:
                            y = y + 4
                        else:
                            return y
                return y

        y = torch.ones((3, 4), dtype=torch.int)
        torch.onnx.export(
            torch.jit.script(Module()),
            y,
            io.BytesIO(),
            opset_version=15,
            dynamic_axes={"y": {0: "y0", 1: "y1"}},
            input_names=["y"],
        )


instantiate_parametrized_tests(TestOptionalOutput)


if __name__ == "__main__":
    unittest.main()
    @parametrize("freezing_method",
                 [FreezingMethod.RequiresGrad, FreezingMethod.GradToNone])
    @parametrize("freeze_after_wrap_fsdp", [True, False])
    def test_freezing_weights(self, with_nested_trunk, freezing_method,
                              freeze_after_wrap_fsdp):
        # DDP
        ddp_state = self._dist_train(with_nested_trunk,
                                     freezing_method,
                                     freeze_after_wrap_fsdp,
                                     with_fsdp=False)

        # FSDP
        fsdp_state = self._dist_train(with_nested_trunk,
                                      freezing_method,
                                      freeze_after_wrap_fsdp,
                                      with_fsdp=True)

        self.assertEqual(
            ddp_state,
            fsdp_state,
            exact_device=True,
            msg=
            "FullyShardedDataParallel states didn't match PyTorch DDP states",
        )


instantiate_parametrized_tests(TestFreezingWeights)

if __name__ == "__main__":
    run_tests()
Esempio n. 22
0
            fsdp_kwargs["mixed_precision"] = None
        fsdp_model = TransformerWithSharedParams.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_AFTER,
            fsdp_kwargs,
        )
        self._train_for_several_steps(
            fsdp_model,
            num_steps=1,
            autocast=False,
            mixed_precision=fsdp_kwargs["mixed_precision"]
        )
        input = fsdp_model.module.get_input(torch.device("cuda"))
        # Run a forward in eval mode
        fsdp_model.eval()
        ref_output = fsdp_model(*input)
        # Run a forward in `no_grad()` and compare
        with torch.no_grad():
            no_grad_output = fsdp_model(*input)
        self.assertEqual(ref_output, no_grad_output)


instantiate_parametrized_tests(TestHooks)
instantiate_parametrized_tests(TestParityWithDDP)
instantiate_parametrized_tests(TestNoGrad)
instantiate_parametrized_tests(TestParamInit)

if __name__ == "__main__":
    run_tests()
Esempio n. 23
0
        # buffer still have the same data pointer
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd2 = fsdp_model.state_dict(prefix=prefix_str)
        for tensor, tensor_name in {
                **ignored_tensor_to_tensor_name,
                **buffer_to_buffer_name,
        }.items():
            prefixed_tensor_name = f"{prefix_str}{tensor_name}"
            self.assertTrue(prefixed_tensor_name in sd2)
            self.assertEqual(tensor.data_ptr(),
                             sd2[prefixed_tensor_name].data_ptr())
            self.assertEqual(sd1[prefixed_tensor_name].data_ptr(),
                             sd2[prefixed_tensor_name].data_ptr())

    @skip_if_lt_x_gpu(2)
    def test_state_dict_type(self):
        module = SkipModel(double_nest=True)
        with enable_wrap(wrapper_cls=FSDP):
            fsdp = wrap(module)
        with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT):
            pass
        for module in FSDP.fsdp_modules(fsdp):
            self.assertEqual(module._state_dict_type,
                             StateDictType.FULL_STATE_DICT)


instantiate_parametrized_tests(TestFSDPStateDict)

if __name__ == "__main__":
    run_tests()
Esempio n. 24
0
            inp = fsdp_model.module.get_input(self.device)
            output = fsdp_model(*inp)
            loss = fsdp_model.module.get_loss(inp, output).to(self.device)
            fsdp_model.module.run_backward(loss)
        # Match the warning message with the following prefix
        regex = "^(All-gather order differs from that of the first iteration " \
            f"on rank {self.rank} -- collectives are unchecked and may give " \
            "incorrect results or hang)"
        context = self.assertWarnsRegex(
            expected_warning=UserWarning, expected_regex=regex,
        ) if self.rank != 0 else suppress()
        if self.rank != 0:
            fsdp_model.flip_path()
        inp = fsdp_model.module.get_input(self.device)
        # Expect a warning for the forward pass all-gather
        with context:  # warning for forward pass all-gather
            output = fsdp_model(*inp)
        loss = fsdp_model.module.get_loss(inp, output).to(self.device)
        fsdp_model.module.run_backward(loss)
        # Run an additional iteration to check that there are no more warnings
        inp = fsdp_model.module.get_input(self.device)
        output = fsdp_model(*inp)
        loss = fsdp_model.module.get_loss(inp, output).to(self.device)
        fsdp_model.module.run_backward(loss)


instantiate_parametrized_tests(TestFSDPExecOrder)

if __name__ == "__main__":
    run_tests()
Esempio n. 25
0
        m = MyModel(self.rank).cuda()
        _validate(m,
                  process_group=self.process_group,
                  assert_fn=self.assertNotEqual)
        # Passing sync_module_states into FSDP makes model the same during init.
        fsdp = FSDP(m, sync_module_states=True)
        with fsdp.summon_full_params(fsdp):
            _validate(fsdp,
                      process_group=self.process_group,
                      assert_fn=self.assertEqual)

        # sync_module_states also works with CPU module with device_id passed in
        m = MyModel(self.rank)
        _validate(m,
                  process_group=self.process_group,
                  assert_fn=self.assertNotEqual)
        # Passing sync_module_states into FSDP makes model the same during init.
        fsdp = FSDP(m,
                    device_id=torch.cuda.current_device(),
                    sync_module_states=True)
        with fsdp.summon_full_params(fsdp):
            _validate(fsdp,
                      process_group=self.process_group,
                      assert_fn=self.assertEqual)


instantiate_parametrized_tests(TestFSDPMisc)

if __name__ == "__main__":
    run_tests()
    @skip_if_lt_x_gpu(1)
    def test_mixed_precision_no_reshard_after_forward(self):
        # Note that we don't exercise all possible different configs so as to
        # not increase test TTS too much.
        mp = default_mp if not nccl_supports_bf16 else mp_diff_buffer_and_reduce
        self._run_test_mixed_precision_e2e(
            mp_config=mp,
            cpu_offload=CPUOffload(offload_params=True),
            backward_prefetch=None,
            full_precision_param_dtype=torch.float64,
            sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
        )

    @skip_if_lt_x_gpu(1)
    def test_mixed_precision_e2e_full_shard(self):
        mp = default_mp if not nccl_supports_bf16 else mp_diff_buffer_and_reduce
        self._run_test_mixed_precision_e2e(
            mp_config=mp,
            cpu_offload=CPUOffload(offload_params=True),
            backward_prefetch=None,
            full_precision_param_dtype=torch.float64,
            sharding_strategy=ShardingStrategy.FULL_SHARD,
        )


instantiate_parametrized_tests(TestFSDPMixedPrecisionSharded)

if __name__ == "__main__":
    run_tests()
Esempio n. 27
0
        self.assertTrue(model._register_post_backward_hooks.called)
        self.assertTrue(model._register_pre_backward_hooks.called)


class TestNoGrad(FSDPTest):
    @skip_if_lt_x_gpu(2)
    def test_transformer_no_grad(self):
        group = dist.distributed_c10d._get_default_group()
        model = self._get_wrapped_model(group, cuda_first=False)
        # Train model for a step
        self._train_for_several_steps(model, num_steps=1, autocast=False)

        model.eval()  # no dropout for this test

        # Eval in standard mode (i.e., without no_grad)
        input = model.module.get_input(torch.device("cuda"))
        ref_output = model(*input)

        # Eval with no_grad and compare
        with torch.no_grad():
            no_grad_output = model(*input)

        self.assertEqual(ref_output, no_grad_output)


instantiate_parametrized_tests(TestHooks)
instantiate_parametrized_tests(TestParityWithDDP)

if __name__ == "__main__":
    run_tests()
Esempio n. 28
0
            # Same (nested) structures
            ((1, [2, 3]), (0, [0, 0]), [1, 2, 3]),
            ((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]),

            # Mismatched (nested) structures
            ((1, [2, 3]), (0, (0, 0)), None),
            ((1, [2, 3]), (0, [0, 0, 0]), None),

            # Broadcasting single value
            (1, (0, 0, 0), [1, 1, 1]),
            (1, [0, 0, 0], [1, 1, 1]),
            (1, {'a': 0, 'b': 0}, [1, 1]),
            (1, (0, [0, [0]], 0), [1, 1, 1, 1]),
            (1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]),

            # Broadcast multiple things
            ((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]),
            ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]),
            (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]),
        ]
        for pytree, to_pytree, expected in cases:
            _, to_spec = tree_flatten(to_pytree)
            result = _broadcast_to_and_flatten(pytree, to_spec)
            self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))


instantiate_parametrized_tests(TestPytree)

if __name__ == '__main__':
    run_tests()
Esempio n. 29
0
        os.environ["MASTER_PORT"] = str(find_free_port())
        torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

        # NOTE: We move model to CUDA after init with FSDP to simulate real use
        # cases where full model cannot be loaded onto GPU, but their shards can.
        cuda_after_init = fsdp_init_mode == FSDPInitMode.CUDA_AFTER
        try:
            sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=(not cuda_after_init))
            my_auto_wrap_policy = functools.partial(
                default_auto_wrap_policy, min_num_params=40
            )
            model = FSDP(sequential, cpu_offload=cpu_offload, fsdp_auto_wrap_policy=my_auto_wrap_policy)
            TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
            if cuda_after_init:
                model = model.cuda()
            input = torch.rand((1, 5), dtype=torch.float).to(device)
            output = model(input)
            loss = F.mse_loss(input, output)
            loss.backward()
        finally:
            torch.distributed.destroy_process_group()
            del os.environ["MASTER_ADDR"]
            del os.environ["MASTER_PORT"]


instantiate_parametrized_tests(TestFSDPWrap)
instantiate_parametrized_tests(TestAutoWrap)

if __name__ == "__main__":
    run_tests()
Esempio n. 30
0
            auto_wrap=auto_wrap, meta_module_fn=meta_module_fn, init_fn=_init_with_torchdistX,
        )

    def _test_bad_arg(self, meta_module_fn):
        mod = meta_module_fn()
        with self.assertRaisesRegex(ValueError, "to be callable"):
            FSDP(mod, param_init_fn=42)

    @skip_if_lt_x_gpu(2)
    @sandcastle_skip_if(
        not _TORCHDISTX_AVAIL, "Test requires torchdistX: https://github.com/pytorch/torchdistX"
    )
    def test_bad_arg_torchdistx(self):
        def meta_module_fn():
            return deferred_init.deferred_init(NestedModel, "cuda")

        self._test_bad_arg(meta_module_fn)

    @skip_if_lt_x_gpu(2)
    def test_bad_arg_meta(self):
        def meta_module_fn():
            return NestedModel(device="meta")

        self._test_bad_arg(meta_module_fn)


instantiate_parametrized_tests(TestFSDPWithMetaDevice)

if __name__ == "__main__":
    run_tests()