コード例 #1
0
    def test_3_model_checkpoint_loading_best_epoch(self):
        set_logger(join(self.cfg.log_dir, 'train.log'))

        tester_cfg = deepcopy(self.cfg)

        # do not have to load existing checkpoints
        load_cfg = {
            'version': 'default',
            'epoch': -1,
            'load_best': True,
            'resume_optimizer': False,
            'resume_epoch': False
        }
        tester_cfg.model['load'] = load_cfg

        classifier = BinaryClassificationModel(tester_cfg)

        # checking if the loaded params are indeed the same as saved
        network_state = classifier.network.get_state_dict()
        load_path = classifier.checkpoint.get_saved_checkpoint_path(
            classifier.checkpoint_dir, load_cfg['load_best'],
            load_cfg['epoch'])
        self.assertIn('best_ckpt', load_path)
        saved_state = torch.load(load_path)['network']

        for key in tqdm(network_state.keys(), desc='Testing params'):
            if key.endswith('weight'):
                network_params = network_state[key]
                saved_params = saved_state[key]
                self.assertTrue(
                    bool(torch.all(torch.eq(saved_params, network_params))))
コード例 #2
0
    def test_1_model_checkpoint_saving(self):
        """Tests model.save()"""
        set_logger(join(self.cfg.log_dir, 'train.log'))

        tester_cfg = deepcopy(self.cfg)

        # remove existing checkpoints for sake of testing
        os.system('rm -rf {}/*.pth.tar'.format(tester_cfg.checkpoint_dir))

        # set epochs to be 5 in order to test saving best/regular models
        tester_cfg.model['epochs'] = 4

        # do not have to load existing checkpoints
        load_cfg = {
            'version': None,
            'epoch': -1,
            'load_best': False,
            'resume_optimizer': False
        }
        tester_cfg.model['load'] = load_cfg

        # saving after every two epochs and the best model
        save_cfg = {'period': 2, 'monitor': 'precision', 'monitor_mode': 'max'}
        tester_cfg.model['save'] = save_cfg

        classifier = BinaryClassificationModel(tester_cfg)
        classifier.fit(debug=True, use_wandb=False)

        # checking both best as well as regular checkpoints
        saved_models = [
            'best_ckpt.pth.tar', '1_ckpt.pth.tar', '3_ckpt.pth.tar'
        ]
        for saved_model in saved_models:
            model_path = join(tester_cfg.checkpoint_dir, saved_model)
            self.assertTrue(exists(model_path))
コード例 #3
0
    def test_1_model_fitting(self):
        """Test model.fit()"""
        set_logger(join(self.cfg.log_dir, 'train.log'))

        tester_cfg = deepcopy(self.cfg)
        tester_cfg.model['epochs'] = 1
        classifier = BinaryClassificationModel(tester_cfg)
        classifier.fit(debug=True, use_wandb=False)
コード例 #4
0
    def test_optimizer(self):
        """Test model.fit()"""
        set_logger(join(self.cfg.log_dir, 'train.log'))

        tester_cfg = deepcopy(self.cfg)
        tester_cfg.model['epochs'] = 1
        classifier = BinaryClassificationModel(tester_cfg)
        self.assertIsInstance(classifier.optimizer, optim.SGD)
        self.assertIsInstance(
            classifier.scheduler, optim.lr_scheduler.ReduceLROnPlateau)
コード例 #5
0
    def test_2_evaluate(self):
        """Test model.evaluate()"""
        set_logger(join(self.cfg.log_dir, 'train.log'))

        tester_cfg = deepcopy(self.cfg)
        tester_cfg.model['load']['version'] = 'default'
        tester_cfg.model['load']['load_best'] = True
        model = BinaryClassificationModel(tester_cfg)
        dataloader, _ = get_dataloader(
            tester_cfg.data, 'val',
            tester_cfg.model['batch_size'],
            num_workers=4,
            shuffle=False,
            drop_last=False)
        model.evaluate(dataloader, 'val', False)
コード例 #6
0
def main(args):
    seed_everything(args.seed)
    config = Config(args.version)

    set_logger(join(config.log_dir, 'train.log'))
    logging.info(args)

    if args.wandb:
        os.environ['WANDB_ENTITY'] = config.entity
        os.environ['WANDB_PROJECT'] = config.project
        os.environ['WANDB_DIR'] = dirname(config.checkpoint_dir)

        run_name = args.version.replace('/', '_')
        wandb.init(name=run_name,
                   dir=dirname(config.checkpoint_dir),
                   notes=config.description,
                   resume=args.resume,
                   id=args.id)
        wandb.config.update(config.__dict__,
                            allow_val_change=config.allow_val_change)

    config.num_workers = args.num_workers
    train(config, args.debug, args.overfit_batch, args.wandb)
コード例 #7
0
def main(args):
    version = args.version
    config = Config(version)
    version = splitext(version)[0]

    set_logger(join(config.log_dir, 'eval.log'))
    logging.info(args)

    if args.bs is not None:
        config.model['batch_size'] = args.bs

    # add checkpoint loading values
    load_epoch = args.epoch
    load_best = args.best
    config.model['load']['version'] = version
    config.model['load']['epoch'] = load_epoch
    config.model['load']['load_best'] = load_best

    # ensures that the epoch_counter attribute is set to the
    # epoch number being loaded
    config.model['load']['resume_epoch'] = True

    if args.wandb:
        # set up wandb
        os.environ['WANDB_ENTITY'] = config.entity
        os.environ['WANDB_PROJECT'] = config.project
        os.environ['WANDB_DIR'] = dirname(config.checkpoint_dir)

        run_name = '_'.join(['evaluation', version.replace('/', '_')])
        wandb.init(name=run_name,
                   dir=dirname(config.checkpoint_dir),
                   notes=config.description)
        wandb.config.update(config.__dict__)

    config.num_workers = args.num_workers
    evaluate(config, args.mode, args.wandb, args.ignore_cache, args.n_tta)
コード例 #8
0
ファイル: SWA.py プロジェクト: Ares2013/coreml
val_preds = val['target'].values
val_labels = val['label'].values
roc = roc_auc_score(val_labels, val_preds)
print(roc)


# In[9]:


config = Config(join('/workspace/coreml', config_name + '.yml'))


# In[19]:


set_logger(join(config.log_dir, 'debug.log'))


# In[10]:


val_dataloader, _ = get_dataloader(
        config.data, 'val',
        config.model['batch_size'],
        num_workers=10,
        shuffle=False,
        drop_last=False)


# In[16]: