예제 #1
0
    def setUp(self):

        set_random_seed(20)
        self.model = MNISTSparseCNN()
        self.model.eval()

        # Make all params twice as large to differentiate it from an init-ed model.
        for name, param in self.model.named_parameters():
            if ("cnn" in name or "linear" in name) and ("weight" in name):
                param[:] = param.data * 2

        # self.model.eval()
        self.in_1 = torch.rand(2, 1, 28, 28)
        self.in_2 = torch.rand(2, 1024)
        self.out_full = full_forward(self.model, self.in_1)
        self.out_lower = lower_forward(self.model, self.in_1)
        self.out_upper = upper_forward(self.model, self.in_2)

        # Create temporary results directory.
        self.tempdir = tempfile.TemporaryDirectory()
        self.results_dir = Path(self.tempdir.name) / Path("results")
        self.results_dir.mkdir()

        # Save model state.
        state = {}
        with io.BytesIO() as buffer:
            serialize_state_dict(buffer,
                                 self.model.state_dict(),
                                 compresslevel=-1)
            state["model"] = buffer.getvalue()

        self.checkpoint_path = self.results_dir / Path("mymodel")
        with open(self.checkpoint_path, "wb") as f:
            pickle.dump(state, f)
    def test_load_full(self):

        # Initialize model with new random seed.
        set_random_seed(33)
        model = MNISTSparseCNN()
        model.eval()

        # Check output through the full network.
        for param1, param2 in zip(model.parameters(), self.model.parameters()):
            tot_eq = (param1 == param2).sum().item()
            self.assertNotEqual(tot_eq, np.prod(param1.shape))

        # Check output through the full network.
        out = full_forward(model, self.in_1)
        num_matches = out.isclose(self.out_full, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 1)  # some correct

        # Check output through the lower network.
        out = lower_forward(model, self.in_1)
        num_matches = out.isclose(self.out_lower, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 1337)  # some correct

        # Check output through the lower network.
        out = upper_forward(model, self.in_2)
        num_matches = out.isclose(self.out_upper, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 1)  # some correct

        # Restore full model.
        model = load_multi_state(model,
                                 restore_full_model=self.checkpoint_path)
        model.eval()

        # Check output through the full network.
        for param1, param2 in zip(model.parameters(), self.model.parameters()):
            tot_eq = (param1 == param2).sum().item()
            self.assertEqual(tot_eq, np.prod(param1.shape))

        for buffer1, buffer2 in zip(model.buffers(), self.model.buffers()):
            tot_eq = (buffer1 == buffer2).sum().item()
            self.assertEqual(tot_eq, np.prod(buffer1.shape))

        out = full_forward(model, self.in_1)
        num_matches = out.isclose(self.out_full, atol=1e-2,
                                  rtol=0).sum().item()
        self.assertEqual(num_matches, 20)  # all correct

        # Check output through the lower network.
        out = lower_forward(model, self.in_1)
        num_matches = out.isclose(self.out_lower, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 2048)  # all correct

        # Check output through the lower network.
        out = upper_forward(model, self.in_2)
        num_matches = out.isclose(self.out_upper, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 20)  # all correct
예제 #3
0
    def test_load_nonlinear(self):

        # Initialize model with new random seed.
        set_random_seed(33)
        model = MNISTSparseCNN()
        model.eval()

        # Check output through the full network.
        for param1, param2 in zip(model.parameters(), self.model.parameters()):
            tot_eq = (param1 == param2).sum().item()
            self.assertNotEqual(tot_eq, np.prod(param1.shape))

        # Restore full model.
        model = load_multi_state(model, restore_nonlinear=self.checkpoint_path)
        model.eval()

        # Check output through the lower network.
        out = lower_forward(model, self.in_1)
        num_matches = out.isclose(self.out_lower, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 2048)  # all correct
예제 #4
0
class RestoreUtilsTest1(unittest.TestCase):
    def setUp(self):

        set_random_seed(20)
        self.model = MNISTSparseCNN()
        self.model.eval()

        # Make all params twice as large to differentiate it from an init-ed model.
        for name, param in self.model.named_parameters():
            if ("cnn" in name or "linear" in name) and ("weight" in name):
                param[:] = param.data * 2

        # self.model.eval()
        self.in_1 = torch.rand(2, 1, 28, 28)
        self.in_2 = torch.rand(2, 1024)
        self.out_full = full_forward(self.model, self.in_1)
        self.out_lower = lower_forward(self.model, self.in_1)
        self.out_upper = upper_forward(self.model, self.in_2)

        # Create temporary results directory.
        self.tempdir = tempfile.TemporaryDirectory()
        self.results_dir = Path(self.tempdir.name) / Path("results")
        self.results_dir.mkdir()

        # Save model state.
        state = {}
        with io.BytesIO() as buffer:
            serialize_state_dict(buffer,
                                 self.model.state_dict(),
                                 compresslevel=-1)
            state["model"] = buffer.getvalue()

        self.checkpoint_path = self.results_dir / Path("mymodel")
        with open(self.checkpoint_path, "wb") as f:
            pickle.dump(state, f)

    def tearDown(self):
        self.tempdir.cleanup()

    def test_get_param_names(self):

        linear_params = get_linear_param_names(self.model)
        expected_linear_params = [
            "output.weight", "linear.module.bias", "output.bias",
            "linear.zero_mask", "linear.module.weight"
        ]
        self.assertTrue(set(linear_params) == set(expected_linear_params))

        nonlinear_params = get_nonlinear_param_names(self.model)
        expected_nonlinear_params = []
        for param, _ in itertools.chain(self.model.named_parameters(),
                                        self.model.named_buffers()):
            if param not in expected_linear_params:
                expected_nonlinear_params.append(param)

        self.assertTrue(
            set(nonlinear_params) == set(expected_nonlinear_params))

    def test_load_full(self):

        # Initialize model with new random seed.
        set_random_seed(33)
        model = MNISTSparseCNN()
        model.eval()

        # Check output through the full network.
        for param1, param2 in zip(model.parameters(), self.model.parameters()):
            tot_eq = (param1 == param2).sum().item()
            self.assertNotEqual(tot_eq, np.prod(param1.shape))

        # Restore full model.
        model = load_multi_state(model,
                                 restore_full_model=self.checkpoint_path)
        model.eval()

        # Check output through the full network.
        for param1, param2 in zip(model.parameters(), self.model.parameters()):
            tot_eq = (param1 == param2).sum().item()
            self.assertEqual(tot_eq, np.prod(param1.shape))

        for buffer1, buffer2 in zip(model.buffers(), self.model.buffers()):
            if buffer1.dtype == torch.float16:
                buffer1 = buffer1.float()
                buffer2 = buffer2.float()

            tot_eq = (buffer1 == buffer2).sum().item()
            self.assertEqual(tot_eq, np.prod(buffer1.shape))

        out = full_forward(model, self.in_1)
        num_matches = out.isclose(self.out_full, atol=1e-2,
                                  rtol=0).sum().item()
        self.assertEqual(num_matches, 20)  # all correct

        # Check output through the lower network.
        out = lower_forward(model, self.in_1)
        num_matches = out.isclose(self.out_lower, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 2048)  # all correct

        # Check output through the lower network.
        out = upper_forward(model, self.in_2)
        num_matches = out.isclose(self.out_upper, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 20)  # all correct

    def test_load_nonlinear(self):

        # Initialize model with new random seed.
        set_random_seed(33)
        model = MNISTSparseCNN()
        model.eval()

        # Check output through the full network.
        for param1, param2 in zip(model.parameters(), self.model.parameters()):
            tot_eq = (param1 == param2).sum().item()
            self.assertNotEqual(tot_eq, np.prod(param1.shape))

        # Restore full model.
        model = load_multi_state(model, restore_nonlinear=self.checkpoint_path)
        model.eval()

        # Check output through the lower network.
        out = lower_forward(model, self.in_1)
        num_matches = out.isclose(self.out_lower, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 2048)  # all correct

    def test_load_linear(self):

        # Initialize model with new random seed.
        set_random_seed(33)
        model = MNISTSparseCNN()
        model.eval()

        # Check output through the full network.
        for param1, param2 in zip(model.parameters(), self.model.parameters()):
            tot_eq = (param1 == param2).sum().item()
            self.assertNotEqual(tot_eq, np.prod(param1.shape))

        # Restore full model.
        model = load_multi_state(model, restore_linear=self.checkpoint_path)
        model.eval()

        # Check output through the lower network.
        out = upper_forward(model, self.in_2)
        num_matches = out.isclose(self.out_upper, atol=1e-2).sum().item()
        self.assertEqual(num_matches, 20)  # all correct