def testModelOptimizer(self): original_model = models.LeNet(10, [1, 32, 32]) original_optimizer = torch.optim.SGD(original_model.parameters(), lr=0.01, momentum=0.9) state = common.state.State(original_model, original_optimizer) state.save(self.filepath) state = common.state.State.load(self.filepath) loaded_model = state.model loaded_optimizer = torch.optim.SGD(original_model.parameters(), lr=0.99, momentum=0.1) loaded_optimizer.load_state_dict(state.optimizer) for param_group in loaded_optimizer.param_groups: self.assertEqual(param_group['lr'], 0.01) self.assertEqual(param_group['momentum'], 0.9)
def testModelOptimizerScheduler(self): original_model = models.LeNet(10, [1, 32, 32]) original_optimizer = torch.optim.SGD(original_model.parameters(), lr=0.01, momentum=0.9) original_scheduler = torch.optim.lr_scheduler.StepLR(original_optimizer, step_size=10, gamma=0.9) state = common.state.State(original_model, original_optimizer, original_scheduler) state.save(self.filepath) state = common.state.State.load(self.filepath) loaded_model = state.model loaded_optimizer = torch.optim.SGD(original_model.parameters(), lr=0.99, momentum=0.1) loaded_optimizer.load_state_dict(state.optimizer) loaded_scheduler = torch.optim.lr_scheduler.StepLR(original_optimizer, step_size=10, gamma=0.9) loaded_scheduler.load_state_dict(state.scheduler) self.assertEqual(original_scheduler.step_size, loaded_scheduler.step_size) self.assertEqual(original_scheduler.gamma, loaded_scheduler.gamma)
def testModelOnly(self): original_model = models.LeNet(10, [1, 32, 32]) for parameters in original_model.parameters(): parameters.data.zero_() state = common.state.State(original_model) state.save(self.filepath) state = common.state.State.load(self.filepath) loaded_model = state.model self.assertEqual(loaded_model.__class__.__name__, original_model.__class__.__name__) self.assertListEqual(loaded_model.resolution, original_model.resolution) for parameters in loaded_model.parameters(): self.assertEqual(torch.sum(parameters).item(), 0)
def testModels(self): model_classes = [ 'LeNet', 'MLP', 'ResNet' ] for model_class in model_classes: model_class = common.utils.get_class('models', model_class) original_model = model_class(10, [1, 32, 32]) for parameters in original_model.parameters(): parameters.data.zero_() state = common.state.State(original_model) state.save(self.filepath) state = common.state.State.load(self.filepath) loaded_model = state.model self.assertEqual(loaded_model.__class__.__name__, original_model.__class__.__name__) self.assertListEqual(loaded_model.resolution, original_model.resolution) for parameters in loaded_model.parameters(): self.assertEqual(torch.sum(parameters).item(), 0)