Ejemplo n.º 1
0
    def test_permuted_model_loading(self):

        model = torch.nn.Sequential(
            KWinners(8, percent_on=0.1),
            torch.nn.Linear(8, 8),
        )

        param_map = {
            "0.weight": "1.weight",
            "0.bias": "1.bias",
            "1.boost_strength": "0.boost_strength",
            "1.duty_cycle": "0.duty_cycle",
        }

        model = load_multi_state(
            model,
            restore_linear=self.checkpoint_path,
            param_map=param_map,
        )

        model = load_multi_state(
            model,
            restore_full_model=self.checkpoint_path,
            param_map=param_map,
        )
    def test_load_linear(self):

        # Initialize model with new random seed.
        set_random_seed(33)
        model = MNISTSparseCNN()
        model.eval()

        # Check output through the full network.
        for param1, param2 in zip(model.parameters(), self.model.parameters()):
            tot_eq = (param1 == param2).sum().item()
            self.assertNotEqual(tot_eq, np.prod(param1.shape))

        # Check output through the lower network.
        out = lower_forward(model, self.in_1)
        num_matches = out.isclose(self.out_lower, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 1337)  # some correct

        # Check output through the lower network.
        out = upper_forward(model, self.in_2)
        num_matches = out.isclose(self.out_upper, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 1)  # some correct

        # Restore full model.
        model = load_multi_state(model, restore_linear=self.checkpoint_path)
        model.eval()

        # Check output through the lower network.
        out = lower_forward(model, self.in_1)
        num_matches = out.isclose(self.out_lower, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 1337)  # some correct

        # Check output through the lower network.
        out = upper_forward(model, self.in_2)
        num_matches = out.isclose(self.out_upper, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 20)  # all correct
Ejemplo n.º 3
0
    def test_load_full(self):

        # Initialize model with new random seed.
        set_random_seed(33)
        model = MNISTSparseCNN()
        model.eval()

        # Check output through the full network.
        for param1, param2 in zip(model.parameters(), self.model.parameters()):
            tot_eq = (param1 == param2).sum().item()
            self.assertNotEqual(tot_eq, np.prod(param1.shape))

        # Restore full model.
        model = load_multi_state(model,
                                 restore_full_model=self.checkpoint_path)
        model.eval()

        # Check output through the full network.
        for param1, param2 in zip(model.parameters(), self.model.parameters()):
            tot_eq = (param1 == param2).sum().item()
            self.assertEqual(tot_eq, np.prod(param1.shape))

        for buffer1, buffer2 in zip(model.buffers(), self.model.buffers()):
            if buffer1.dtype == torch.float16:
                buffer1 = buffer1.float()
                buffer2 = buffer2.float()

            tot_eq = (buffer1 == buffer2).sum().item()
            self.assertEqual(tot_eq, np.prod(buffer1.shape))

        out = full_forward(model, self.in_1)
        num_matches = out.isclose(self.out_full, atol=1e-2,
                                  rtol=0).sum().item()
        self.assertEqual(num_matches, 20)  # all correct

        # Check output through the lower network.
        out = lower_forward(model, self.in_1)
        num_matches = out.isclose(self.out_lower, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 2048)  # all correct

        # Check output through the lower network.
        out = upper_forward(model, self.in_2)
        num_matches = out.isclose(self.out_upper, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 20)  # all correct