Example #1
0
class ModelWrapperTest(unittest.TestCase):
    def setUp(self):
        # self.model = nn.Sequential(
        #     nn.Linear(10, 8), nn.ReLU(), nn.Dropout(), nn.Linear(8, 1), nn.Sigmoid()
        # )
        self.model = DummyModel()
        self.criterion = nn.BCEWithLogitsLoss()
        self.wrapper = ModelWrapper(self.model, self.criterion)
        self.optim = torch.optim.SGD(self.wrapper.get_params(), 0.01)
        self.dataset = DummyDataset()

    def test_train_on_batch(self):
        self.wrapper.train()
        old_param = list(map(lambda x: x.clone(), self.model.parameters()))
        input, target = torch.randn([1, 3, 10, 10]), torch.randn(1, 1)
        self.wrapper.train_on_batch(input, target, self.optim)
        new_param = list(map(lambda x: x.clone(), self.model.parameters()))

        assert any(
            [not torch.allclose(i, j) for i, j in zip(old_param, new_param)])

        # test reset weights properties
        linear_weights = list(
            self.wrapper.model.named_children())[3][1].weight.clone()
        conv_weights = list(
            self.wrapper.model.named_children())[0][1].weight.clone()
        self.wrapper.reset_fcs()
        linear_new_weights = list(
            self.wrapper.model.named_children())[3][1].weight.clone()
        conv_new_weights = list(
            self.wrapper.model.named_children())[0][1].weight.clone()
        assert all([
            not torch.allclose(i, j)
            for i, j in zip(linear_new_weights, linear_weights)
        ])
        assert all([
            torch.allclose(i, j)
            for i, j in zip(conv_new_weights, conv_weights)
        ])

        self.wrapper.reset_all()
        conv_next_new_weights = list(
            self.wrapper.model.named_children())[0][1].weight.clone()
        assert all([
            not torch.allclose(i, j)
            for i, j in zip(conv_new_weights, conv_next_new_weights)
        ])

    def test_test_on_batch(self):
        self.wrapper.eval()
        input, target = torch.randn([1, 3, 10, 10]), torch.randn(1, 1)
        preds = torch.stack([
            self.wrapper.test_on_batch(input, target, cuda=False)
            for _ in range(10)
        ]).view(10, -1)

        # Same loss
        assert torch.allclose(torch.mean(preds, 0), preds[0])

        preds = torch.stack([
            self.wrapper.test_on_batch(input,
                                       target,
                                       cuda=False,
                                       average_predictions=10)
            for _ in range(10)
        ]).view(10, -1)
        assert torch.allclose(torch.mean(preds, 0), preds[0])

    def test_predict_on_batch(self):
        self.wrapper.eval()
        input = torch.randn([2, 3, 10, 10])

        # iteration == 1
        pred = self.wrapper.predict_on_batch(input, 1, False)
        assert pred.size() == (2, 1, 1)

        # iterations > 1
        pred = self.wrapper.predict_on_batch(input, 10, False)
        assert pred.size() == (2, 1, 10)

        # iteration == 1
        self.wrapper = ModelWrapper(self.model,
                                    self.criterion,
                                    replicate_in_memory=False)
        pred = self.wrapper.predict_on_batch(input, 1, False)
        assert pred.size() == (2, 1, 1)

        # iterations > 1
        pred = self.wrapper.predict_on_batch(input, 10, False)
        assert pred.size() == (2, 1, 10)

    def test_train(self):
        history = self.wrapper.train_on_dataset(self.dataset,
                                                self.optim,
                                                10,
                                                2,
                                                use_cuda=False,
                                                workers=0)
        assert len(history) == 2

    def test_test(self):
        l = self.wrapper.test_on_dataset(self.dataset,
                                         10,
                                         use_cuda=False,
                                         workers=0)
        assert np.isfinite(l)
        l = self.wrapper.test_on_dataset(self.dataset,
                                         10,
                                         use_cuda=False,
                                         workers=0,
                                         average_predictions=10)
        assert np.isfinite(l)

    def test_predict(self):
        l = self.wrapper.predict_on_dataset(self.dataset,
                                            10,
                                            20,
                                            use_cuda=False,
                                            workers=0)
        self.wrapper.eval()
        assert np.allclose(
            self.wrapper.predict_on_batch(self.dataset[0][0].unsqueeze(0),
                                          20)[0].detach().numpy(), l[0])
        assert np.allclose(
            self.wrapper.predict_on_batch(self.dataset[19][0].unsqueeze(0),
                                          20)[0].detach().numpy(), l[19])
        assert l.shape == (len(self.dataset), 1, 20)

        # Test generators
        l_gen = self.wrapper.predict_on_dataset_generator(self.dataset,
                                                          10,
                                                          20,
                                                          use_cuda=False,
                                                          workers=0)
        assert np.allclose(next(l_gen)[0], l[0])
        for last in l_gen:
            pass  # Get last item
        assert np.allclose(last[-1], l[-1])

        # Test Half
        l_gen = self.wrapper.predict_on_dataset_generator(self.dataset,
                                                          10,
                                                          20,
                                                          use_cuda=False,
                                                          workers=0,
                                                          half=True)
        l = self.wrapper.predict_on_dataset(self.dataset,
                                            10,
                                            20,
                                            use_cuda=False,
                                            workers=0,
                                            half=True)
        assert next(l_gen).dtype == np.float16
        assert l.dtype == np.float16

    def test_states(self):
        input = torch.randn([1, 3, 10, 10])

        def pred_with_dropout(replicate_in_memory):
            self.wrapper = ModelWrapper(
                self.model,
                self.criterion,
                replicate_in_memory=replicate_in_memory)
            self.wrapper.train()
            # Dropout make the pred changes
            preds = torch.stack([
                self.wrapper.predict_on_batch(input, iterations=1, cuda=False)
                for _ in range(10)
            ]).view(10, -1)
            assert not torch.allclose(torch.mean(preds, 0), preds[0])

        pred_with_dropout(replicate_in_memory=True)
        pred_with_dropout(replicate_in_memory=False)

        def pred_without_dropout(replicate_in_memory):
            self.wrapper = ModelWrapper(
                self.model,
                self.criterion,
                replicate_in_memory=replicate_in_memory)
            # Dropout is not active in eval
            self.wrapper.eval()
            preds = torch.stack([
                self.wrapper.predict_on_batch(input, iterations=1, cuda=False)
                for _ in range(10)
            ]).view(10, -1)
            assert torch.allclose(torch.mean(preds, 0), preds[0])

        pred_without_dropout(replicate_in_memory=True)
        pred_without_dropout(replicate_in_memory=False)

    def test_add_metric(self):
        self.wrapper.add_metric('cls_report', lambda: ClassificationReport(2))
        assert 'test_cls_report' in self.wrapper.metrics
        assert 'train_cls_report' in self.wrapper.metrics
        self.wrapper.train_on_dataset(self.dataset, self.optim, 32, 2, False)
        self.wrapper.test_on_dataset(self.dataset, 32, False)
        assert (self.wrapper.metrics['train_cls_report'].value['accuracy'] !=
                0).any()
        assert (self.wrapper.metrics['test_cls_report'].value['accuracy'] !=
                0).any()

    def test_train_and_test(self):
        res = self.wrapper.train_and_test_on_datasets(
            self.dataset,
            self.dataset,
            self.optim,
            32,
            5,
            False,
            return_best_weights=False)
        assert len(res) == 5
        res = self.wrapper.train_and_test_on_datasets(self.dataset,
                                                      self.dataset,
                                                      self.optim,
                                                      32,
                                                      5,
                                                      False,
                                                      return_best_weights=True)
        assert len(res) == 2
        assert len(res[0]) == 5
        assert isinstance(res[1], dict)
        mock = Mock()
        mock.side_effect = (((np.linspace(0, 50) - 10) / 10)**2).tolist()
        self.wrapper.test_on_dataset = mock
        res = self.wrapper.train_and_test_on_datasets(self.dataset,
                                                      self.dataset,
                                                      self.optim,
                                                      32,
                                                      50,
                                                      False,
                                                      return_best_weights=True,
                                                      patience=1)

        assert len(res) == 2
        assert len(res[0]) < 50

        mock = Mock()
        mock.side_effect = (((np.linspace(0, 50) - 10) / 10)**2).tolist()
        self.wrapper.test_on_dataset = mock
        res = self.wrapper.train_and_test_on_datasets(self.dataset,
                                                      self.dataset,
                                                      self.optim,
                                                      32,
                                                      50,
                                                      False,
                                                      return_best_weights=True,
                                                      patience=1,
                                                      min_epoch_for_es=20)
        assert len(res) == 2
        assert len(res[0]) < 50 and len(res[0]) > 20
Example #2
0
def main():
    args = parse_args()
    use_cuda = torch.cuda.is_available()
    torch.backends.cudnn.benchmark = True
    random.seed(1337)
    torch.manual_seed(1337)
    if not use_cuda:
        print("warning, the experiments would take ages to run on cpu")

    hyperparams = vars(args)

    active_set, test_set = get_datasets(hyperparams['initial_pool'])

    heuristic = get_heuristic(hyperparams['heuristic'],
                              hyperparams['shuffle_prop'])
    criterion = CrossEntropyLoss()
    model = vgg16(pretrained=False, num_classes=10)
    weights = load_state_dict_from_url(
        'https://download.pytorch.org/models/vgg16-397923af.pth')
    weights = {k: v for k, v in weights.items() if 'classifier.6' not in k}
    model.load_state_dict(weights, strict=False)

    # change dropout layer to MCDropout
    model = patch_module(model)

    if use_cuda:
        model.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=hyperparams["lr"],
                          momentum=0.9)

    # Wraps the model into a usable API.
    model = ModelWrapper(model, criterion)

    logs = {}
    logs['epoch'] = 0

    # for prediction we use a smaller batchsize
    # since it is slower
    active_loop = ActiveLearningLoop(active_set,
                                     model.predict_on_dataset,
                                     heuristic,
                                     hyperparams.get('n_data_to_label', 1),
                                     batch_size=10,
                                     iterations=hyperparams['iterations'],
                                     use_cuda=use_cuda)

    for epoch in tqdm(range(args.epoch)):
        model.train_on_dataset(active_set, optimizer,
                               hyperparams["batch_size"], 1, use_cuda)

        # Validation!
        model.test_on_dataset(test_set, hyperparams["batch_size"], use_cuda)
        metrics = model.metrics

        if epoch % hyperparams['learning_epoch'] == 0:
            should_continue = active_loop.step()
            model.reset_fcs()
            if not should_continue:
                break
        val_loss = metrics['test_loss'].value
        logs = {
            "val": val_loss,
            "epoch": epoch,
            "train": metrics['train_loss'].value,
            "labeled_data": active_set._labelled,
            "Next Training set size": len(active_set)
        }
        print(logs)