def setUp(self): set_random_seed(20) self.model = MNISTSparseCNN() self.model.eval() # Make all params twice as large to differentiate it from an init-ed model. for name, param in self.model.named_parameters(): if ("cnn" in name or "linear" in name) and ("weight" in name): param[:] = param.data * 2 # self.model.eval() self.in_1 = torch.rand(2, 1, 28, 28) self.in_2 = torch.rand(2, 1024) self.out_full = full_forward(self.model, self.in_1) self.out_lower = lower_forward(self.model, self.in_1) self.out_upper = upper_forward(self.model, self.in_2) # Create temporary results directory. self.tempdir = tempfile.TemporaryDirectory() self.results_dir = Path(self.tempdir.name) / Path("results") self.results_dir.mkdir() # Save model state. state = {} with io.BytesIO() as buffer: serialize_state_dict(buffer, self.model.state_dict(), compresslevel=-1) state["model"] = buffer.getvalue() self.checkpoint_path = self.results_dir / Path("mymodel") with open(self.checkpoint_path, "wb") as f: pickle.dump(state, f)
def test_load_full(self): # Initialize model with new random seed. set_random_seed(33) model = MNISTSparseCNN() model.eval() # Check output through the full network. for param1, param2 in zip(model.parameters(), self.model.parameters()): tot_eq = (param1 == param2).sum().item() self.assertNotEqual(tot_eq, np.prod(param1.shape)) # Check output through the full network. out = full_forward(model, self.in_1) num_matches = out.isclose(self.out_full, atol=1e-2).sum().item() self.assertEqual(num_matches, 1) # some correct # Check output through the lower network. out = lower_forward(model, self.in_1) num_matches = out.isclose(self.out_lower, atol=1e-2).sum().item() self.assertEqual(num_matches, 1337) # some correct # Check output through the lower network. out = upper_forward(model, self.in_2) num_matches = out.isclose(self.out_upper, atol=1e-2).sum().item() self.assertEqual(num_matches, 1) # some correct # Restore full model. model = load_multi_state(model, restore_full_model=self.checkpoint_path) model.eval() # Check output through the full network. for param1, param2 in zip(model.parameters(), self.model.parameters()): tot_eq = (param1 == param2).sum().item() self.assertEqual(tot_eq, np.prod(param1.shape)) for buffer1, buffer2 in zip(model.buffers(), self.model.buffers()): tot_eq = (buffer1 == buffer2).sum().item() self.assertEqual(tot_eq, np.prod(buffer1.shape)) out = full_forward(model, self.in_1) num_matches = out.isclose(self.out_full, atol=1e-2, rtol=0).sum().item() self.assertEqual(num_matches, 20) # all correct # Check output through the lower network. out = lower_forward(model, self.in_1) num_matches = out.isclose(self.out_lower, atol=1e-2).sum().item() self.assertEqual(num_matches, 2048) # all correct # Check output through the lower network. out = upper_forward(model, self.in_2) num_matches = out.isclose(self.out_upper, atol=1e-2).sum().item() self.assertEqual(num_matches, 20) # all correct
def test_load_nonlinear(self): # Initialize model with new random seed. set_random_seed(33) model = MNISTSparseCNN() model.eval() # Check output through the full network. for param1, param2 in zip(model.parameters(), self.model.parameters()): tot_eq = (param1 == param2).sum().item() self.assertNotEqual(tot_eq, np.prod(param1.shape)) # Restore full model. model = load_multi_state(model, restore_nonlinear=self.checkpoint_path) model.eval() # Check output through the lower network. out = lower_forward(model, self.in_1) num_matches = out.isclose(self.out_lower, atol=1e-2).sum().item() self.assertEqual(num_matches, 2048) # all correct
class RestoreUtilsTest1(unittest.TestCase): def setUp(self): set_random_seed(20) self.model = MNISTSparseCNN() self.model.eval() # Make all params twice as large to differentiate it from an init-ed model. for name, param in self.model.named_parameters(): if ("cnn" in name or "linear" in name) and ("weight" in name): param[:] = param.data * 2 # self.model.eval() self.in_1 = torch.rand(2, 1, 28, 28) self.in_2 = torch.rand(2, 1024) self.out_full = full_forward(self.model, self.in_1) self.out_lower = lower_forward(self.model, self.in_1) self.out_upper = upper_forward(self.model, self.in_2) # Create temporary results directory. self.tempdir = tempfile.TemporaryDirectory() self.results_dir = Path(self.tempdir.name) / Path("results") self.results_dir.mkdir() # Save model state. state = {} with io.BytesIO() as buffer: serialize_state_dict(buffer, self.model.state_dict(), compresslevel=-1) state["model"] = buffer.getvalue() self.checkpoint_path = self.results_dir / Path("mymodel") with open(self.checkpoint_path, "wb") as f: pickle.dump(state, f) def tearDown(self): self.tempdir.cleanup() def test_get_param_names(self): linear_params = get_linear_param_names(self.model) expected_linear_params = [ "output.weight", "linear.module.bias", "output.bias", "linear.zero_mask", "linear.module.weight" ] self.assertTrue(set(linear_params) == set(expected_linear_params)) nonlinear_params = get_nonlinear_param_names(self.model) expected_nonlinear_params = [] for param, _ in itertools.chain(self.model.named_parameters(), self.model.named_buffers()): if param not in expected_linear_params: expected_nonlinear_params.append(param) self.assertTrue( set(nonlinear_params) == set(expected_nonlinear_params)) def test_load_full(self): # Initialize model with new random seed. set_random_seed(33) model = MNISTSparseCNN() model.eval() # Check output through the full network. for param1, param2 in zip(model.parameters(), self.model.parameters()): tot_eq = (param1 == param2).sum().item() self.assertNotEqual(tot_eq, np.prod(param1.shape)) # Restore full model. model = load_multi_state(model, restore_full_model=self.checkpoint_path) model.eval() # Check output through the full network. for param1, param2 in zip(model.parameters(), self.model.parameters()): tot_eq = (param1 == param2).sum().item() self.assertEqual(tot_eq, np.prod(param1.shape)) for buffer1, buffer2 in zip(model.buffers(), self.model.buffers()): if buffer1.dtype == torch.float16: buffer1 = buffer1.float() buffer2 = buffer2.float() tot_eq = (buffer1 == buffer2).sum().item() self.assertEqual(tot_eq, np.prod(buffer1.shape)) out = full_forward(model, self.in_1) num_matches = out.isclose(self.out_full, atol=1e-2, rtol=0).sum().item() self.assertEqual(num_matches, 20) # all correct # Check output through the lower network. out = lower_forward(model, self.in_1) num_matches = out.isclose(self.out_lower, atol=1e-2).sum().item() self.assertEqual(num_matches, 2048) # all correct # Check output through the lower network. out = upper_forward(model, self.in_2) num_matches = out.isclose(self.out_upper, atol=1e-2).sum().item() self.assertEqual(num_matches, 20) # all correct def test_load_nonlinear(self): # Initialize model with new random seed. set_random_seed(33) model = MNISTSparseCNN() model.eval() # Check output through the full network. for param1, param2 in zip(model.parameters(), self.model.parameters()): tot_eq = (param1 == param2).sum().item() self.assertNotEqual(tot_eq, np.prod(param1.shape)) # Restore full model. model = load_multi_state(model, restore_nonlinear=self.checkpoint_path) model.eval() # Check output through the lower network. out = lower_forward(model, self.in_1) num_matches = out.isclose(self.out_lower, atol=1e-2).sum().item() self.assertEqual(num_matches, 2048) # all correct def test_load_linear(self): # Initialize model with new random seed. set_random_seed(33) model = MNISTSparseCNN() model.eval() # Check output through the full network. for param1, param2 in zip(model.parameters(), self.model.parameters()): tot_eq = (param1 == param2).sum().item() self.assertNotEqual(tot_eq, np.prod(param1.shape)) # Restore full model. model = load_multi_state(model, restore_linear=self.checkpoint_path) model.eval() # Check output through the lower network. out = upper_forward(model, self.in_2) num_matches = out.isclose(self.out_upper, atol=1e-2).sum().item() self.assertEqual(num_matches, 20) # all correct