Example #1
0
    def test_ema(self):
        model = DummyModule()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        state = deepcopy(model.state_dict())
        config = EMAConfig()
        ema = EMA(model, config)

        # set decay
        ema._set_decay(config.ema_decay)
        self.assertEqual(ema.get_decay(), config.ema_decay)

        # get model
        self.assertEqual(ema.get_model(), ema.model)

        # Since fp32 params is not used, it should be of size 0
        self.assertEqual(len(ema.fp32_params), 0)

        # EMA step
        x = torch.randn(32)
        y = model(x)
        loss = y.sum()
        loss.backward()
        optimizer.step()

        ema.step(model)

        ema_state_dict = ema.get_model().state_dict()

        for key, param in model.state_dict().items():
            prev_param = state[key]
            ema_param = ema_state_dict[key]

            if "version" in key:
                # Do not decay a model.version pytorch param
                continue
            self.assertTorchAllClose(
                ema_param,
                config.ema_decay * prev_param + (1 - config.ema_decay) * param,
            )

        # Since fp32 params is not used, it should be of size 0
        self.assertEqual(len(ema.fp32_params), 0)

        # Load EMA into model
        model2 = DummyModule()
        ema.reverse(model2)

        for key, param in model2.state_dict().items():
            ema_param = ema_state_dict[key]
            self.assertTrue(torch.allclose(ema_param, param))

        # Check that step_internal is called once
        with patch.object(
            ema, "_step_internal", return_value=None
        ) as mock_method:
            ema.step(model)
            mock_method.assert_called_once_with(model, None)
Example #2
0
    def test_ema_fp16(self):
        model = DummyModule().half()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        state = deepcopy(model.state_dict())
        config = EMAConfig(ema_fp32=False)
        ema = EMA(model, config)

        # Since fp32 params is not used, it should be of size 0
        self.assertEqual(len(ema.fp32_params), 0)

        x = torch.randn(32)
        y = model(x.half())
        loss = y.sum()
        loss.backward()
        optimizer.step()

        ema.step(model)

        for key, param in model.state_dict().items():
            prev_param = state[key]
            ema_param = ema.get_model().state_dict()[key]

            if "version" in key:
                # Do not decay a model.version pytorch param
                continue

            # EMA update is done in fp16, and hence the EMA param must be
            # closer to the EMA update done in fp16 than in fp32.
            self.assertLessEqual(
                torch.norm(
                    ema_param.float()
                    - (
                        config.ema_decay * prev_param + (1 - config.ema_decay) * param
                    ).float()
                ),
                torch.norm(
                    ema_param.float()
                    - (
                        config.ema_decay * prev_param.float()
                        + (1 - config.ema_decay) * param.float()
                    )
                    .half()
                    .float()
                ),
            )
            self.assertTorchAllClose(
                ema_param,
                config.ema_decay * prev_param + (1 - config.ema_decay) * param,
            )

        # Since fp32 params is not used, it should be of size 0
        self.assertEqual(len(ema.fp32_params), 0)
Example #3
0
    def _test_ema_start_update(self, updates):
        model = DummyModule()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        state = deepcopy(model.state_dict())
        config = EMAConfig(ema_start_update=1)
        ema = EMA(model, config)

        # EMA step
        x = torch.randn(32)
        y = model(x)
        loss = y.sum()
        loss.backward()
        optimizer.step()

        ema.step(model, updates=updates)
        ema_state_dict = ema.get_model().state_dict()

        self.assertEqual(ema.get_decay(),
                         0 if updates == 0 else config.ema_decay)

        for key, param in model.state_dict().items():
            ema_param = ema_state_dict[key]
            prev_param = state[key]

            if "version" in key:
                # Do not decay a model.version pytorch param
                continue
            if updates == 0:
                self.assertTorchAllClose(
                    ema_param,
                    param,
                )
            else:
                self.assertTorchAllClose(
                    ema_param,
                    config.ema_decay * prev_param +
                    (1 - config.ema_decay) * param,
                )

        # Check that step_internal is called once
        with patch.object(ema, "_step_internal",
                          return_value=None) as mock_method:
            ema.step(model, updates=updates)
            mock_method.assert_called_once_with(model, updates)
Example #4
0
    def test_ema_fp32(self):
        # CPU no longer supports Linear in half precision
        dtype = torch.half if torch.cuda.is_available() else torch.float

        model = DummyModule().to(dtype)
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        state = deepcopy(model.state_dict())
        config = EMAConfig(ema_fp32=True)
        ema = EMA(model, config)

        x = torch.randn(32)
        y = model(x.to(dtype))
        loss = y.sum()
        loss.backward()
        optimizer.step()

        ema.step(model)

        for key, param in model.state_dict().items():
            prev_param = state[key]
            ema_param = ema.get_model().state_dict()[key]

            if "version" in key:
                # Do not decay a model.version pytorch param
                continue
            self.assertIn(key, ema.fp32_params)

            # EMA update is done in fp32, and hence the EMA param must be
            # closer to the EMA update done in fp32 than in fp16.
            self.assertLessEqual(
                torch.norm(ema_param.float() - (
                    config.ema_decay * prev_param.float() +
                    (1 - config.ema_decay) * param.float()).to(dtype).float()),
                torch.norm(ema_param.float() -
                           (config.ema_decay * prev_param +
                            (1 - config.ema_decay) * param).float()),
            )
            self.assertTorchAllClose(
                ema_param,
                (config.ema_decay * prev_param.float() +
                 (1 - config.ema_decay) * param.float()).to(dtype),
            )