class ModelWrapperTest(unittest.TestCase): def setUp(self): # self.model = nn.Sequential( # nn.Linear(10, 8), nn.ReLU(), nn.Dropout(), nn.Linear(8, 1), nn.Sigmoid() # ) self.model = DummyModel() self.criterion = nn.BCEWithLogitsLoss() self.wrapper = ModelWrapper(self.model, self.criterion) self.optim = torch.optim.SGD(self.wrapper.get_params(), 0.01) self.dataset = DummyDataset() def test_train_on_batch(self): self.wrapper.train() old_param = list(map(lambda x: x.clone(), self.model.parameters())) input, target = torch.randn([1, 3, 10, 10]), torch.randn(1, 1) self.wrapper.train_on_batch(input, target, self.optim) new_param = list(map(lambda x: x.clone(), self.model.parameters())) assert any( [not torch.allclose(i, j) for i, j in zip(old_param, new_param)]) # test reset weights properties linear_weights = list( self.wrapper.model.named_children())[3][1].weight.clone() conv_weights = list( self.wrapper.model.named_children())[0][1].weight.clone() self.wrapper.reset_fcs() linear_new_weights = list( self.wrapper.model.named_children())[3][1].weight.clone() conv_new_weights = list( self.wrapper.model.named_children())[0][1].weight.clone() assert all([ not torch.allclose(i, j) for i, j in zip(linear_new_weights, linear_weights) ]) assert all([ torch.allclose(i, j) for i, j in zip(conv_new_weights, conv_weights) ]) self.wrapper.reset_all() conv_next_new_weights = list( self.wrapper.model.named_children())[0][1].weight.clone() assert all([ not torch.allclose(i, j) for i, j in zip(conv_new_weights, conv_next_new_weights) ]) def test_test_on_batch(self): self.wrapper.eval() input, target = torch.randn([1, 3, 10, 10]), torch.randn(1, 1) preds = torch.stack([ self.wrapper.test_on_batch(input, target, cuda=False) for _ in range(10) ]).view(10, -1) # Same loss assert torch.allclose(torch.mean(preds, 0), preds[0]) preds = torch.stack([ self.wrapper.test_on_batch(input, target, cuda=False, average_predictions=10) for _ in range(10) ]).view(10, -1) assert torch.allclose(torch.mean(preds, 0), preds[0]) def test_predict_on_batch(self): self.wrapper.eval() input = torch.randn([2, 3, 10, 10]) # iteration == 1 pred = self.wrapper.predict_on_batch(input, 1, False) assert pred.size() == (2, 1, 1) # iterations > 1 pred = self.wrapper.predict_on_batch(input, 10, False) assert pred.size() == (2, 1, 10) # iteration == 1 self.wrapper = ModelWrapper(self.model, self.criterion, replicate_in_memory=False) pred = self.wrapper.predict_on_batch(input, 1, False) assert pred.size() == (2, 1, 1) # iterations > 1 pred = self.wrapper.predict_on_batch(input, 10, False) assert pred.size() == (2, 1, 10) def test_train(self): history = self.wrapper.train_on_dataset(self.dataset, self.optim, 10, 2, use_cuda=False, workers=0) assert len(history) == 2 def test_test(self): l = self.wrapper.test_on_dataset(self.dataset, 10, use_cuda=False, workers=0) assert np.isfinite(l) l = self.wrapper.test_on_dataset(self.dataset, 10, use_cuda=False, workers=0, average_predictions=10) assert np.isfinite(l) def test_predict(self): l = self.wrapper.predict_on_dataset(self.dataset, 10, 20, use_cuda=False, workers=0) self.wrapper.eval() assert np.allclose( self.wrapper.predict_on_batch(self.dataset[0][0].unsqueeze(0), 20)[0].detach().numpy(), l[0]) assert np.allclose( self.wrapper.predict_on_batch(self.dataset[19][0].unsqueeze(0), 20)[0].detach().numpy(), l[19]) assert l.shape == (len(self.dataset), 1, 20) # Test generators l_gen = self.wrapper.predict_on_dataset_generator(self.dataset, 10, 20, use_cuda=False, workers=0) assert np.allclose(next(l_gen)[0], l[0]) for last in l_gen: pass # Get last item assert np.allclose(last[-1], l[-1]) # Test Half l_gen = self.wrapper.predict_on_dataset_generator(self.dataset, 10, 20, use_cuda=False, workers=0, half=True) l = self.wrapper.predict_on_dataset(self.dataset, 10, 20, use_cuda=False, workers=0, half=True) assert next(l_gen).dtype == np.float16 assert l.dtype == np.float16 def test_states(self): input = torch.randn([1, 3, 10, 10]) def pred_with_dropout(replicate_in_memory): self.wrapper = ModelWrapper( self.model, self.criterion, replicate_in_memory=replicate_in_memory) self.wrapper.train() # Dropout make the pred changes preds = torch.stack([ self.wrapper.predict_on_batch(input, iterations=1, cuda=False) for _ in range(10) ]).view(10, -1) assert not torch.allclose(torch.mean(preds, 0), preds[0]) pred_with_dropout(replicate_in_memory=True) pred_with_dropout(replicate_in_memory=False) def pred_without_dropout(replicate_in_memory): self.wrapper = ModelWrapper( self.model, self.criterion, replicate_in_memory=replicate_in_memory) # Dropout is not active in eval self.wrapper.eval() preds = torch.stack([ self.wrapper.predict_on_batch(input, iterations=1, cuda=False) for _ in range(10) ]).view(10, -1) assert torch.allclose(torch.mean(preds, 0), preds[0]) pred_without_dropout(replicate_in_memory=True) pred_without_dropout(replicate_in_memory=False) def test_add_metric(self): self.wrapper.add_metric('cls_report', lambda: ClassificationReport(2)) assert 'test_cls_report' in self.wrapper.metrics assert 'train_cls_report' in self.wrapper.metrics self.wrapper.train_on_dataset(self.dataset, self.optim, 32, 2, False) self.wrapper.test_on_dataset(self.dataset, 32, False) assert (self.wrapper.metrics['train_cls_report'].value['accuracy'] != 0).any() assert (self.wrapper.metrics['test_cls_report'].value['accuracy'] != 0).any() def test_train_and_test(self): res = self.wrapper.train_and_test_on_datasets( self.dataset, self.dataset, self.optim, 32, 5, False, return_best_weights=False) assert len(res) == 5 res = self.wrapper.train_and_test_on_datasets(self.dataset, self.dataset, self.optim, 32, 5, False, return_best_weights=True) assert len(res) == 2 assert len(res[0]) == 5 assert isinstance(res[1], dict) mock = Mock() mock.side_effect = (((np.linspace(0, 50) - 10) / 10)**2).tolist() self.wrapper.test_on_dataset = mock res = self.wrapper.train_and_test_on_datasets(self.dataset, self.dataset, self.optim, 32, 50, False, return_best_weights=True, patience=1) assert len(res) == 2 assert len(res[0]) < 50 mock = Mock() mock.side_effect = (((np.linspace(0, 50) - 10) / 10)**2).tolist() self.wrapper.test_on_dataset = mock res = self.wrapper.train_and_test_on_datasets(self.dataset, self.dataset, self.optim, 32, 50, False, return_best_weights=True, patience=1, min_epoch_for_es=20) assert len(res) == 2 assert len(res[0]) < 50 and len(res[0]) > 20
def main(): args = parse_args() use_cuda = torch.cuda.is_available() torch.backends.cudnn.benchmark = True random.seed(1337) torch.manual_seed(1337) if not use_cuda: print("warning, the experiments would take ages to run on cpu") hyperparams = vars(args) active_set, test_set = get_datasets(hyperparams['initial_pool']) heuristic = get_heuristic(hyperparams['heuristic'], hyperparams['shuffle_prop']) criterion = CrossEntropyLoss() model = vgg16(pretrained=False, num_classes=10) weights = load_state_dict_from_url( 'https://download.pytorch.org/models/vgg16-397923af.pth') weights = {k: v for k, v in weights.items() if 'classifier.6' not in k} model.load_state_dict(weights, strict=False) # change dropout layer to MCDropout model = patch_module(model) if use_cuda: model.cuda() optimizer = optim.SGD(model.parameters(), lr=hyperparams["lr"], momentum=0.9) # Wraps the model into a usable API. model = ModelWrapper(model, criterion) logs = {} logs['epoch'] = 0 # for prediction we use a smaller batchsize # since it is slower active_loop = ActiveLearningLoop(active_set, model.predict_on_dataset, heuristic, hyperparams.get('n_data_to_label', 1), batch_size=10, iterations=hyperparams['iterations'], use_cuda=use_cuda) for epoch in tqdm(range(args.epoch)): model.train_on_dataset(active_set, optimizer, hyperparams["batch_size"], 1, use_cuda) # Validation! model.test_on_dataset(test_set, hyperparams["batch_size"], use_cuda) metrics = model.metrics if epoch % hyperparams['learning_epoch'] == 0: should_continue = active_loop.step() model.reset_fcs() if not should_continue: break val_loss = metrics['test_loss'].value logs = { "val": val_loss, "epoch": epoch, "train": metrics['train_loss'].value, "labeled_data": active_set._labelled, "Next Training set size": len(active_set) } print(logs)