def test_mixed_precision_resnet(self):
        """
        End to end test to ensure mixed precision + auto_wrap works
        for ResNet model.
        """
        resnet_model = torchvision.models.resnet50().cuda()
        resnet_model = nn.SyncBatchNorm.convert_sync_batchnorm(
            resnet_model,
            process_group=dist.distributed_c10d._get_default_group())
        n_bn = sum(1 if isinstance(x, _BatchNorm) else 0
                   for x in resnet_model.modules())
        inp = torch.ones(1, 3, 1000, 1000, device='cuda')
        mp_config = MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
        )
        fsdp = FSDP(resnet_model,
                    auto_wrap_policy=size_based_auto_wrap_policy,
                    mixed_precision=mp_config)
        # Batchnorm units should be wrapped individually. Validate this by
        # ensuring there are equal no. of FSDP units that are BN as BN units
        # in original resnet model.
        fsdp_bn = 0
        for module in fsdp.fsdp_modules(fsdp):
            wrapped_module = module.module.module
            if isinstance(wrapped_module, _BatchNorm):
                fsdp_bn += 1

        self.assertEqual(fsdp_bn, n_bn)
        # Would throw type mismatch issue without mixed precision autowrapping.
        loss = fsdp(inp).sum()
        loss.backward()
Exemplo n.º 2
0
    def test_save_and_load_after_forward_state_dict(self, mixed_precision):
        """
        Test that saving after some training results in params being updated as
        expected.
        """
        torch.cuda.set_device(self.rank)
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = self._get_simple_nested_model(mixed_precision=mixed_precision)
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
        initial_params = _get_full_detached_param(model)
        for _ in range(6):
            inp = torch.randn(1, 10, device=torch.cuda.current_device())
            output = model(*inp)
            loss = output.sum()
            expected_dtype = torch.float32 if mixed_precision is None else torch.float16
            self.assertEqual(expected_dtype, loss.dtype)
            loss.backward()
            optim.step()

        trained_params = _get_full_detached_param(model)
        # Ensure some training occured
        self.assertNotEqual(initial_params, trained_params)
        # Save a copy of the state_dict
        state_dict = {k: v.clone() for k, v in model.state_dict().items()}
        _zero_model(model)

        # Ensure checkpointed params have the full param dtype
        for tensor in state_dict.values():
            self.assertEqual(tensor.dtype, torch.float32)

        # Load state_dict into zeroed model
        model.load_state_dict(state_dict)
        loaded_params = _get_full_detached_param(model)
        self.assertEqual(loaded_params, trained_params)
Exemplo n.º 3
0
 def test_transformer_no_grad(self, mixed_precision):
     """Tests that for an FSDP-wrapped transformer model with shared
     parameters, after training for one iteration, running a forward pass in
     ``eval()`` mode gives the same output as running a forward pass in
     ``torch.no_grad()``."""
     fsdp_kwargs = {}
     if mixed_precision:
         fsdp_kwargs["mixed_precision"] = MixedPrecision(
             param_dtype=torch.float16,
             reduce_dtype=torch.float16,
             buffer_dtype=torch.float16,
         )
     else:
         fsdp_kwargs["mixed_precision"] = None
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_AFTER,
         fsdp_kwargs,
     )
     self._train_for_several_steps(
         fsdp_model,
         num_steps=1,
         autocast=False,
         mixed_precision=fsdp_kwargs["mixed_precision"]
     )
     input = fsdp_model.module.get_input(torch.device("cuda"))
     # Run a forward in eval mode
     fsdp_model.eval()
     ref_output = fsdp_model(*input)
     # Run a forward in `no_grad()` and compare
     with torch.no_grad():
         no_grad_output = fsdp_model(*input)
     self.assertEqual(ref_output, no_grad_output)
Exemplo n.º 4
0
 def test_param_change_after_init(self, mixed_precision):
     """
     Tests that changing FSDP model parameter values in-place after FSDP
     initialization persist.
     """
     # Establish reference behavior
     fsdp_kwargs = {}
     if mixed_precision:
         fsdp_kwargs["mixed_precision"] = MixedPrecision()
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_AFTER,
         fsdp_kwargs,
         deterministic=True,
     )
     input = fsdp_model.module.get_input(torch.device("cuda"))
     ref_output = fsdp_model(*input)
     # Initialize the same model but change its first parameter value
     # in-place after FSDP initialization
     new_fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_AFTER,
         fsdp_kwargs,
         deterministic=True,
     )
     first_param = next(new_fsdp_model.parameters())
     nn.init.normal_(first_param.data)
     new_output = new_fsdp_model(*input)
     self.assertNotEqual(
         ref_output,
         new_output,
         msg="new_output did not reflect change to param after init",
     )
Exemplo n.º 5
0
    def test_summon_full_param_recursive(self, recurse, summon_outer, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False), mixed_precision=mixed_precision),
                nn.Linear(5, 3, bias=False),
            ),
            mixed_precision=mixed_precision,
        ).cuda(self.rank)

        global_inner_numel = self.get_model_param_count(nn.Linear(5, 5, bias=False))
        global_outer_numel = self.get_model_param_count(nn.Linear(5, 3, bias=False))

        shard_inner_numel = int(math.ceil(global_inner_numel / self.world_size))
        shard_outer_numel = int(math.ceil(global_outer_numel / self.world_size))

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        self.assertEqual(shard_outer_numel, outer_param.numel())
        self.assertEqual(shard_inner_numel, inner_param.numel())

        model_to_summon = model if summon_outer else model[0]
        # outer is summoned if _summon_full_param is called on the outer FSDP module
        expected_outer_numel = global_outer_numel if summon_outer else shard_outer_numel

        # inner is summoned if _summon_full_param is called with recursion or on the inner FSDP module
        expected_inner_numel = (
            global_inner_numel if recurse or not summon_outer else shard_inner_numel
        )

        with model_to_summon.summon_full_params(model_to_summon, recurse=recurse):
            self.assertEqual(expected_outer_numel, outer_param.numel())
            self.assertEqual(expected_inner_numel, inner_param.numel())
 def test_mp_embedding_default(self):
     default_mp_config = MixedPrecision(
         param_dtype=torch.float16,
         buffer_dtype=torch.float16,
         reduce_dtype=torch.float16,
     )
     self._test_mixed_precision_embedding_table(mp_config=default_mp_config)
Exemplo n.º 7
0
    def test_transformer_no_grad(self, mixed_precision):
        group = dist.distributed_c10d._get_default_group()
        mixed_precision = MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
        ) if mixed_precision else None
        config = {"mixed_precision": mixed_precision}
        model = self._get_wrapped_model(group, config=config, cuda_first=False)
        # Train model for a step
        self._train_for_several_steps(
            model,
            num_steps=1,
            autocast=False,
            mixed_precision=config["mixed_precision"]
        )

        model.eval()  # no dropout for this test

        # Eval in standard mode (i.e., without no_grad)
        input = model.module.get_input(torch.device("cuda"))
        ref_output = model(*input)

        # Eval with no_grad and compare
        with torch.no_grad():
            no_grad_output = model(*input)

        self.assertEqual(ref_output, no_grad_output)
    def test_params_count_and_value(self, rank0_only, offload_to_cpu,
                                    mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        fsdp_model = FSDP(
            NestedWrappedModule(
                group=dist.distributed_c10d._get_default_group(),
                wrap_fsdp=True,
                fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
                mixed_precision=mixed_precision,
            ),
            mixed_precision=mixed_precision,
        )
        model = NestedWrappedModule(
            group=dist.distributed_c10d._get_default_group(),
            wrap_fsdp=False,
            fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
        )

        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))

        params_to_compare = ([p.to(dev) for p in model.module.parameters()]
                             if not rank0_only or self.rank == 0 else list(
                                 p.clone() for p in fsdp_model.parameters()))
        with fsdp_model.summon_full_params(fsdp_model,
                                           rank0_only=rank0_only,
                                           writeback=not rank0_only):
            for p1, p2 in itertools.zip_longest(fsdp_model.parameters(),
                                                params_to_compare):
                self.assertEqual(p1, p2)

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device()))
Exemplo n.º 9
0
 def test_mp_embedding_only_params_and_bufs(self):
     self._test_mixed_precision_embedding_table(
         mp_config=MixedPrecision(
             param_dtype=torch.float16,
             buffer_dtype=torch.float16,
         )
     )
Exemplo n.º 10
0
 def test_mp_embedding_params_and_reduce_diff(self):
     params_and_reduce_different = MixedPrecision(
         param_dtype=torch.float16,
         reduce_dtype=torch.float32,
         buffer_dtype=torch.float16)
     self._test_mixed_precision_embedding_table(
         mp_config=params_and_reduce_different)
    def test_summon_full_param_shard_value(self, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        raw_model = nn.Linear(10, 11)
        raw_model_size = self.get_model_param_count(raw_model)
        expected_shard_size = self.get_expected_sharded_size(raw_model_size)

        model = FSDP(raw_model.cuda(self.rank),
                     mixed_precision=mixed_precision)
        self.assertEqual(expected_shard_size,
                         self.get_model_param_count(model))

        # we're assuming a single flattened param
        self.assertEqual(1, len(list(model.parameters())))

        my_shard = torch.clone(next(model.parameters()))

        with model.summon_full_params(model):
            self.assertEqual(raw_model_size, self.get_model_param_count(model))
            parameters = list(model.parameters())
            all_shards = FlatParamHandle.flatten_params(parameters,
                                                        requires_grad=False)
            my_slice = torch.chunk(all_shards, self.world_size)[self.rank]

            # shards are padded but the full_param tensor is not
            a, b = my_shard[0:my_slice.numel()], my_slice
            self.assertTrue(
                torch.equal(my_shard[0:my_slice.numel()].cpu(),
                            my_slice.cpu()))
    def test_summon_full_params_respects_reshard_after_forward(
            self, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False),
                     mixed_precision=mixed_precision),
                nn.Linear(5, 3, bias=False),
            ),
            mixed_precision=mixed_precision,
        ).cuda(self.rank)

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        outer_full_param_size = outer_param.numel() * self.world_size

        # trigger lazy init
        model(torch.zeros(5).cuda(self.rank))
        # the root FSDP module keeps all params around
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        # similarly summon_full_params should have the same behavior
        with model.summon_full_params(model):
            pass
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
    def test_params_count_and_value(
        self,
        rank0_only: bool,
        offload_to_cpu: bool,
        mixed_precision: bool,
    ):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
        )
        fsdp_model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
        )
        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))
        params_to_compare = ([p.to(dev) for p in model.module.parameters()]
                             if not rank0_only or self.rank == 0 else list(
                                 p.clone() for p in fsdp_model.parameters()))
        with FSDP.summon_full_params(fsdp_model,
                                     rank0_only=rank0_only,
                                     writeback=not rank0_only):
            for p1, p2 in itertools.zip_longest(fsdp_model.parameters(),
                                                params_to_compare):
                self.assertEqual(p1, p2)

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device()))
Exemplo n.º 14
0
 def test_nested_wrapped_model_single_iteration_mixed_precision(
     self,
     cpu_offload: CPUOffload,
     sharding_strategy: Optional[ShardingStrategy],
     mixed_precision: bool,
 ):
     init_modes = self._get_init_modes_for_test(cpu_offload)
     mixed_precision = MixedPrecision(
         param_dtype=torch.float16,
         buffer_dtype=torch.float16,
         reduce_dtype=torch.float16,
     ) if mixed_precision else None
     for cuda_init_mode in init_modes:
         with self.subTest(cuda_init_mode=cuda_init_mode):
             self._test_fsdp_parity(
                 NestedWrappedModule,
                 FSDPInitMode.RECURSIVE,
                 # Only run one step for comparison, as usually grad scaler
                 # is needed to avoid NaN after first step.
                 num_iters=1,
                 cuda_init_mode=cuda_init_mode,
                 cpu_offload=cpu_offload,
                 sharding_strategy=sharding_strategy,
                 mixed_precision=mixed_precision,
             )
    def test_reshard_outside_forward_backward_iteration(
            self, rank0_only, offload_to_cpu, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False),
                     mixed_precision=mixed_precision),
                nn.Linear(5, 1, bias=False),
            ),
            mixed_precision=mixed_precision,
        ).cuda(self.rank)

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        outer_full_param_size = outer_param.numel() * self.world_size

        # First lets validate our assumption about resharding

        output = model(torch.zeros(5).cuda(self.rank))
        # the root FSDP module keeps all params around
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        # we reshard everything after backward() finishes
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        # now lets repeat it with summon done in between

        output = model(torch.zeros(5).cuda(self.rank))
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
        with model.summon_full_params(
                model,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            pass
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        with model.summon_full_params(
                model,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            pass
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
 def test_summon_full_param_writeback(self, writeback, cpu_offload,
                                      mixed_precision, modify_outer):
     mixed_precision = MixedPrecision() if mixed_precision else None
     return _run_test_summon_full_param_writeback(
         self,
         writeback,
         modify_outer,
         cpu_offload=cpu_offload,
         mixed_precision=mixed_precision,
     )
 def test_summon_full_param_writeback(self, writeback, modify_outer,
                                      mixed_precision):
     mixed_precision = MixedPrecision() if mixed_precision else None
     return _run_test_summon_full_param_writeback(
         self,
         writeback,
         modify_outer=modify_outer,
         cpu_offload=CPUOffload(offload_params=False),
         mixed_precision=mixed_precision,
     )
    def test_mp_batchnorm(self, convert_sync_bn):
        class BatchNormNet(nn.Module):
            def __init__(self, affine=True):
                super(BatchNormNet, self).__init__()
                self.fc1 = nn.Linear(2, 40, bias=False)
                self.bn = nn.BatchNorm1d(4, affine=affine)
                self.fc2 = nn.Linear(40, 4, bias=False)

            def forward(self, x):
                x = torch.reshape(self.fc1(x), (-1, 4, 10))
                x = self.bn(x)
                x = torch.reshape(x, (-1, 40))
                x = self.fc2(x)
                return F.softmax(x, dim=1)

        def never_wrap_policy(*args, **kwargs):
            return False

        net = BatchNormNet().cuda()
        if convert_sync_bn:
            net = nn.SyncBatchNorm.convert_sync_batchnorm(net)
        # FSDP detects that mixed precision + batchnorm will cause issues
        # and thus wrap batchnorm in a distinct FSDP unit that does not
        # use mixed precision.
        mp_config = MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
        )
        with self.assertWarnsRegex(
                expected_warning=UserWarning,
                expected_regex="BatchNorm units will be wrapped as a separate"
        ):
            model = FSDP(
                net,
                mixed_precision=mp_config,
                auto_wrap_policy=never_wrap_policy,
            )

        bn = model.bn
        self.assertTrue(isinstance(bn, FSDP))
        # policy should not have wrapped any other submodules
        self.assertFalse(isinstance(model.fc1, FSDP))
        self.assertFalse(isinstance(model.fc2, FSDP))
        self.assertEqual(None, bn.mixed_precision)
        self.assertNotEqual(None, model.mixed_precision)

        inp = torch.randn((1, 2), device='cuda')
        # Without FSDP BN mixed precision fix, this would result in
        # RuntimeError: Expected counts to have type Half but got Float
        # for syncBN
        model(inp).sum().backward()
Exemplo n.º 19
0
    def test_save_and_load_after_forward_state_dict(
            self, mixed_precision, state_dict_rank0_and_offload):
        """
        Test that saving after some training results in params being updated as
        expected.
        """
        torch.cuda.set_device(self.rank)
        mixed_precision = MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
        ) if mixed_precision else None
        model = self._get_simple_nested_model(mixed_precision=mixed_precision)
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
        initial_params = _get_full_detached_param(model)
        for _ in range(6):
            inp = torch.randn(1, 10, device=torch.cuda.current_device())
            output = model(*inp)
            loss = output.sum()
            expected_dtype = torch.float32 if mixed_precision is None else torch.float16
            self.assertEqual(expected_dtype, loss.dtype)
            loss.backward()
            optim.step()

        trained_params = _get_full_detached_param(model)
        # Ensure some training occured
        self.assertNotEqual(initial_params, trained_params)
        # Save a copy of the state_dict
        fsd_mgr = self._get_full_state_dict_mgr(model,
                                                state_dict_rank0_and_offload)
        with fsd_mgr:
            state_dict = {k: v.clone() for k, v in model.state_dict().items()}
        self._validate_state_dict_contents(state_dict,
                                           state_dict_rank0_and_offload)
        _zero_model(model)

        # Ensure checkpointed params have the full param dtype
        for tensor in state_dict.values():
            self.assertEqual(tensor.dtype, torch.float32)

        # Load state_dict into zeroed model
        if state_dict_rank0_and_offload:
            # Broadcast the state dict and move it back to GPU in
            # preparation for loading.
            state_dict = self._broadcast_state_dict(state_dict)
            for key in state_dict.keys():
                state_dict[key] = state_dict[key].cuda()

        model.load_state_dict(state_dict)
        loaded_params = _get_full_detached_param(model)
        self.assertEqual(loaded_params, trained_params)
Exemplo n.º 20
0
 def test_nested_wrapped_model_single_iteration_mixed_precision(
         self, cpu_offload, sharding_strategy, mixed_precision):
     init_modes = self._get_init_modes_for_test(cpu_offload)
     mixed_precision = MixedPrecision() if mixed_precision else None
     for fsdp_init_mode in init_modes:
         with self.subTest(fsdp_init_mode=fsdp_init_mode):
             self._test_identical_outputs(
                 NestedWrappedModule,
                 # Only run one step for comparison, as usually grad scaler
                 # is needed to avoid NaN after first step.
                 num_steps=1,
                 fsdp_init_mode=fsdp_init_mode,
                 cpu_offload=cpu_offload,
                 sharding_strategy=sharding_strategy,
                 mixed_precision=mixed_precision,
             )
Exemplo n.º 21
0
 def test_register_functions_called(self, cuda_first, mixed_precision):
     """Tests that _register_{pre|post}_backward_hooks called during forward."""
     group = dist.distributed_c10d._get_default_group()
     mixed_precision = MixedPrecision() if mixed_precision else None
     config = {"mixed_precision": mixed_precision}
     model = self._get_wrapped_model(
         group, mixed_precision=mixed_precision, cuda_first=cuda_first
     )
     input = model.module.get_input(torch.device("cuda"))
     model._register_post_backward_hooks = mock.MagicMock(return_value=None)
     model._register_pre_backward_hooks = mock.MagicMock(return_value=None)
     self.assertFalse(model._register_post_backward_hooks.called)
     self.assertFalse(model._register_pre_backward_hooks.called)
     model(*input)
     self.assertTrue(model._register_post_backward_hooks.called)
     self.assertTrue(model._register_pre_backward_hooks.called)
Exemplo n.º 22
0
 def test_scaler_enabled(self, cpu_offload, sharding_strategy,
                         mixed_precision):
     init_modes = self._get_init_modes_for_test(cpu_offload)
     mp = MixedPrecision(
         param_dtype=torch.float16,
         reduce_dtype=torch.float16,
         buffer_dtype=torch.float16,
     ) if mixed_precision else None
     for fsdp_init_mode in init_modes:
         self._test_identical_outputs(
             NestedWrappedModule,
             fsdp_init_mode=fsdp_init_mode,
             cpu_offload=cpu_offload,
             sharding_strategy=sharding_strategy,
             mixed_precision=mp,
             enable_sharded_grad_scaler=True,
         )
Exemplo n.º 23
0
 def test_mixed_precision_embedding_table(self):
     # Basic test to ensure int inputs are not casted which would break
     # modules such as embedding tables.
     mp_config = MixedPrecision()
     model = self._get_wrapped_model(
         group=torch.distributed.distributed_c10d._get_default_group(),
         config={"mixed_precision": mp_config})
     optim = torch.optim.SGD(model.parameters(), lr=0.1)
     for _ in range(6):
         inp = model.module.get_input(torch.device("cuda"))
         # This would fail if we casted integer module inputs such as for
         # embedding tables.
         output = model(*inp)
         loss = model.module.get_loss(inp, output).cuda()
         self.assertEqual(loss.dtype, mp_config.param_dtype)
         model.module.run_backward(loss)
         optim.step()
Exemplo n.º 24
0
 def test_nested_wrapped_model_single_iteration_mixed_precision(
     self,
     cpu_offload: CPUOffload,
     sharding_strategy: Optional[ShardingStrategy],
 ):
     mixed_precision = MixedPrecision(
         param_dtype=torch.float16,
         buffer_dtype=torch.float16,
         reduce_dtype=torch.float16,
     )
     self.run_subtests(
         self._get_subtest_config(cpu_offload),
         self._test_fsdp_parity,
         NestedWrappedModule,
         FSDPInitMode.RECURSIVE,
         cpu_offload=cpu_offload,
         sharding_strategy=sharding_strategy,
         num_iters=1,
         mixed_precision=mixed_precision,
     )
Exemplo n.º 25
0
 def test_register_functions_called(self, cuda_first: bool, mixed_precision: bool):
     """Tests that ``_register_{pre|post}_backward_hooks()`` are called
     during the FSDP forward."""
     fsdp_kwargs = {}
     if mixed_precision:
         fsdp_kwargs["mixed_precision"] = MixedPrecision()
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE if cuda_first else CUDAInitMode.CUDA_AFTER,
         fsdp_kwargs,
     )
     input = fsdp_model.module.get_input(torch.device("cuda"))
     fsdp_model._register_pre_backward_hooks = mock.MagicMock(return_value=None)
     fsdp_model._register_post_backward_hooks = mock.MagicMock(return_value=None)
     self.assertFalse(fsdp_model._register_post_backward_hooks.called)
     self.assertFalse(fsdp_model._register_pre_backward_hooks.called)
     fsdp_model(*input)
     self.assertTrue(fsdp_model._register_post_backward_hooks.called)
     self.assertTrue(fsdp_model._register_pre_backward_hooks.called)
Exemplo n.º 26
0
    def test_params_are_unflattenned(self, rank0_only, offload_to_cpu, mixed_precision):
        layer_shape = (10, 12)
        model = nn.Linear(*layer_shape, bias=False).cuda(self.rank)
        mixed_precision = MixedPrecision() if mixed_precision else None
        fsdp_model = FSDP(deepcopy(model), mixed_precision=mixed_precision).cuda(
            self.rank
        )

        def _get_flat_param():
            return fsdp_model.get_parameter("_fsdp_wrapped_module.flat_param")

        flattened_param = _get_flat_param()
        self.assertEqual(layer_shape[0] * layer_shape[1] / 2, flattened_param.numel())

        with fsdp_model.summon_full_params(
            fsdp_model,
            rank0_only=rank0_only,
            writeback=not rank0_only,
            offload_to_cpu=offload_to_cpu,
        ):
            if self.rank == 0 or not rank0_only:
                self.assertEqual(fsdp_model.weight.shape, model.weight.shape)
                expected_device = (
                    torch.device("cpu")
                    if offload_to_cpu
                    else torch.device("cuda", torch.cuda.current_device())
                )
                self.assertTrue(expected_device == fsdp_model.weight.device)
            else:
                # Nonzero rank with rank0_only maintains original params.
                flat_within_ctx = _get_flat_param()
                self.assertEqual(flat_within_ctx, flattened_param)
                self.assertEqual(
                    flat_within_ctx.device, torch.device(torch.cuda.current_device())
                )

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device())
        )
 def test_fsdp_ddp_parity_with_grad_scaler(
     self,
     cpu_offload: CPUOffload,
     sharding_strategy: Optional[ShardingStrategy],
     mixed_precision: Optional[str],
 ):
     init_modes = self._get_init_modes_for_test(cpu_offload)
     mp = MixedPrecision(
         param_dtype=torch.float16,
         reduce_dtype=torch.float16,
         buffer_dtype=torch.float16,
     ) if mixed_precision is not None else None
     for cuda_init_mode in init_modes:
         self._test_fsdp_parity(
             NestedWrappedModule,
             FSDPInitMode.RECURSIVE,
             cuda_init_mode=cuda_init_mode,
             cpu_offload=cpu_offload,
             sharding_strategy=sharding_strategy,
             mixed_precision=mp,
             enable_sharded_grad_scaler=True,
         )
Exemplo n.º 28
0
    def _check_low_precision_hook(self, state, hook, sharding_strategy, dtype, has_wrapping):
        # keep everything deterministic for input data
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)

        fsdp_with_hook = self._init_model(
            Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
            sharding_strategy=sharding_strategy
        )
        fsdp_with_hook.register_comm_hook(state, hook)

        mp_only_grad = MixedPrecision(reduce_dtype=dtype)
        fsdp_with_mp = self._init_model(
            Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy, mixed_precision=mp_only_grad),
            sharding_strategy=sharding_strategy,
            mixed_precision=mp_only_grad
        )

        optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1)
        optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1)

        in_data = torch.rand(16, 8).cuda()
        fsdp_with_hook.train()
        fsdp_with_mp.train()
        loss_hook = fsdp_with_hook(in_data).sum()
        loss_mp = fsdp_with_mp(in_data).sum()
        loss_hook.backward()
        # Make sure grads were cast to the parameter's precision
        self.assertEqual(fsdp_with_hook.params[0].dtype, state.parameter_type)
        loss_mp.backward()
        optim_hook.step()
        optim_mp.step()

        dist.barrier()

        for hook_param, mp_param in zip(fsdp_with_hook.parameters(), fsdp_with_mp.parameters()):
            self.assertEqual(hook_param.grad, mp_param.grad)
Exemplo n.º 29
0
    def test_param_change_after_init(self, mixed_precision):
        group = dist.distributed_c10d._get_default_group()
        # Establish reference behavior.
        mixed_precision = MixedPrecision() if mixed_precision else None
        config = {"mixed_precision": mixed_precision}
        model = self._get_wrapped_model(
            group, mixed_precision=mixed_precision, cuda_first=False
        )
        model.eval()  # no dropout for this test
        input = model.module.get_input(torch.device("cuda"))
        ref_output = model(*input)

        # Change the weights in place.
        model = self._get_wrapped_model(group, cuda_first=False)
        model.eval()  # no dropout for this test
        first_param = next(model.parameters())
        nn.init.normal_(first_param.data)
        new_output = model(*input)

        self.assertNotEqual(
            ref_output,
            new_output,
            msg="new_output did not reflect change to param after init",
        )
Exemplo n.º 30
0
 def test_mp_embedding_reduce(self):
     self._test_mixed_precision_embedding_table(mp_config=MixedPrecision(
         reduce_dtype=torch.float16))