Esempio n. 1
0
    def test_1_model_checkpoint_saving(self):
        """Tests model.save()"""
        set_logger(join(self.cfg.log_dir, 'train.log'))

        tester_cfg = deepcopy(self.cfg)

        # remove existing checkpoints for sake of testing
        os.system('rm -rf {}/*.pth.tar'.format(tester_cfg.checkpoint_dir))

        # set epochs to be 5 in order to test saving best/regular models
        tester_cfg.model['epochs'] = 4

        # do not have to load existing checkpoints
        load_cfg = {
            'version': None,
            'epoch': -1,
            'load_best': False,
            'resume_optimizer': False
        }
        tester_cfg.model['load'] = load_cfg

        # saving after every two epochs and the best model
        save_cfg = {'period': 2, 'monitor': 'precision', 'monitor_mode': 'max'}
        tester_cfg.model['save'] = save_cfg

        classifier = BinaryClassificationModel(tester_cfg)
        classifier.fit(debug=True, use_wandb=False)

        # checking both best as well as regular checkpoints
        saved_models = [
            'best_ckpt.pth.tar', '1_ckpt.pth.tar', '3_ckpt.pth.tar'
        ]
        for saved_model in saved_models:
            model_path = join(tester_cfg.checkpoint_dir, saved_model)
            self.assertTrue(exists(model_path))
Esempio n. 2
0
    def test_1_model_fitting(self):
        """Test model.fit()"""
        set_logger(join(self.cfg.log_dir, 'train.log'))

        tester_cfg = deepcopy(self.cfg)
        tester_cfg.model['epochs'] = 1
        classifier = BinaryClassificationModel(tester_cfg)
        classifier.fit(debug=True, use_wandb=False)
    def test_compute_metrics_recall_none(self):
        """Tests minimum recall not specified"""
        tester_cfg = deepcopy(self.cfg)
        tester_cfg.model['epochs'] = 1
        classifier = BinaryClassificationModel(tester_cfg)

        predictions = torch.Tensor([0.6, 0.4, 0.3, 0.1, 0.8, 0.9])
        targets = torch.Tensor([1, 0, 1, 0, 1, 1])
        metrics = classifier.compute_epoch_metrics(predictions, targets)
        self.assertEqual(metrics['recall'], 1)
    def test_compute_metrics_threshold_none(self):
        """Tests no threshold specified"""
        tester_cfg = deepcopy(self.cfg)
        tester_cfg.model['epochs'] = 1
        classifier = BinaryClassificationModel(tester_cfg)

        predictions = torch.Tensor([1, -1, 0.5])
        targets = torch.Tensor([1, 0, 1])
        metrics = classifier.compute_epoch_metrics(predictions, targets)
        self.assertEqual(metrics['recall'], 1)
        self.assertEqual(metrics['precision'], 1)
Esempio n. 5
0
    def test_2_evaluate(self):
        """Test model.evaluate()"""
        set_logger(join(self.cfg.log_dir, 'train.log'))

        tester_cfg = deepcopy(self.cfg)
        tester_cfg.model['load']['version'] = 'default'
        tester_cfg.model['load']['load_best'] = True
        model = BinaryClassificationModel(tester_cfg)
        dataloader, _ = get_dataloader(
            tester_cfg.data, 'val',
            tester_cfg.model['batch_size'],
            num_workers=4,
            shuffle=False,
            drop_last=False)
        model.evaluate(dataloader, 'val', False)
Esempio n. 6
0
    def test_3_model_checkpoint_loading_best_epoch(self):
        set_logger(join(self.cfg.log_dir, 'train.log'))

        tester_cfg = deepcopy(self.cfg)

        # do not have to load existing checkpoints
        load_cfg = {
            'version': 'default',
            'epoch': -1,
            'load_best': True,
            'resume_optimizer': False,
            'resume_epoch': False
        }
        tester_cfg.model['load'] = load_cfg

        classifier = BinaryClassificationModel(tester_cfg)

        # checking if the loaded params are indeed the same as saved
        network_state = classifier.network.get_state_dict()
        load_path = classifier.checkpoint.get_saved_checkpoint_path(
            classifier.checkpoint_dir, load_cfg['load_best'],
            load_cfg['epoch'])
        self.assertIn('best_ckpt', load_path)
        saved_state = torch.load(load_path)['network']

        for key in tqdm(network_state.keys(), desc='Testing params'):
            if key.endswith('weight'):
                network_params = network_state[key]
                saved_params = saved_state[key]
                self.assertTrue(
                    bool(torch.all(torch.eq(saved_params, network_params))))
Esempio n. 7
0
    def test_optimizer(self):
        """Test model.fit()"""
        set_logger(join(self.cfg.log_dir, 'train.log'))

        tester_cfg = deepcopy(self.cfg)
        tester_cfg.model['epochs'] = 1
        classifier = BinaryClassificationModel(tester_cfg)
        self.assertIsInstance(classifier.optimizer, optim.SGD)
        self.assertIsInstance(
            classifier.scheduler, optim.lr_scheduler.ReduceLROnPlateau)