def test_3_model_checkpoint_loading_best_epoch(self): set_logger(join(self.cfg.log_dir, 'train.log')) tester_cfg = deepcopy(self.cfg) # do not have to load existing checkpoints load_cfg = { 'version': 'default', 'epoch': -1, 'load_best': True, 'resume_optimizer': False, 'resume_epoch': False } tester_cfg.model['load'] = load_cfg classifier = BinaryClassificationModel(tester_cfg) # checking if the loaded params are indeed the same as saved network_state = classifier.network.get_state_dict() load_path = classifier.checkpoint.get_saved_checkpoint_path( classifier.checkpoint_dir, load_cfg['load_best'], load_cfg['epoch']) self.assertIn('best_ckpt', load_path) saved_state = torch.load(load_path)['network'] for key in tqdm(network_state.keys(), desc='Testing params'): if key.endswith('weight'): network_params = network_state[key] saved_params = saved_state[key] self.assertTrue( bool(torch.all(torch.eq(saved_params, network_params))))
def test_1_model_checkpoint_saving(self): """Tests model.save()""" set_logger(join(self.cfg.log_dir, 'train.log')) tester_cfg = deepcopy(self.cfg) # remove existing checkpoints for sake of testing os.system('rm -rf {}/*.pth.tar'.format(tester_cfg.checkpoint_dir)) # set epochs to be 5 in order to test saving best/regular models tester_cfg.model['epochs'] = 4 # do not have to load existing checkpoints load_cfg = { 'version': None, 'epoch': -1, 'load_best': False, 'resume_optimizer': False } tester_cfg.model['load'] = load_cfg # saving after every two epochs and the best model save_cfg = {'period': 2, 'monitor': 'precision', 'monitor_mode': 'max'} tester_cfg.model['save'] = save_cfg classifier = BinaryClassificationModel(tester_cfg) classifier.fit(debug=True, use_wandb=False) # checking both best as well as regular checkpoints saved_models = [ 'best_ckpt.pth.tar', '1_ckpt.pth.tar', '3_ckpt.pth.tar' ] for saved_model in saved_models: model_path = join(tester_cfg.checkpoint_dir, saved_model) self.assertTrue(exists(model_path))
def test_1_model_fitting(self): """Test model.fit()""" set_logger(join(self.cfg.log_dir, 'train.log')) tester_cfg = deepcopy(self.cfg) tester_cfg.model['epochs'] = 1 classifier = BinaryClassificationModel(tester_cfg) classifier.fit(debug=True, use_wandb=False)
def test_optimizer(self): """Test model.fit()""" set_logger(join(self.cfg.log_dir, 'train.log')) tester_cfg = deepcopy(self.cfg) tester_cfg.model['epochs'] = 1 classifier = BinaryClassificationModel(tester_cfg) self.assertIsInstance(classifier.optimizer, optim.SGD) self.assertIsInstance( classifier.scheduler, optim.lr_scheduler.ReduceLROnPlateau)
def test_2_evaluate(self): """Test model.evaluate()""" set_logger(join(self.cfg.log_dir, 'train.log')) tester_cfg = deepcopy(self.cfg) tester_cfg.model['load']['version'] = 'default' tester_cfg.model['load']['load_best'] = True model = BinaryClassificationModel(tester_cfg) dataloader, _ = get_dataloader( tester_cfg.data, 'val', tester_cfg.model['batch_size'], num_workers=4, shuffle=False, drop_last=False) model.evaluate(dataloader, 'val', False)
def main(args): seed_everything(args.seed) config = Config(args.version) set_logger(join(config.log_dir, 'train.log')) logging.info(args) if args.wandb: os.environ['WANDB_ENTITY'] = config.entity os.environ['WANDB_PROJECT'] = config.project os.environ['WANDB_DIR'] = dirname(config.checkpoint_dir) run_name = args.version.replace('/', '_') wandb.init(name=run_name, dir=dirname(config.checkpoint_dir), notes=config.description, resume=args.resume, id=args.id) wandb.config.update(config.__dict__, allow_val_change=config.allow_val_change) config.num_workers = args.num_workers train(config, args.debug, args.overfit_batch, args.wandb)
def main(args): version = args.version config = Config(version) version = splitext(version)[0] set_logger(join(config.log_dir, 'eval.log')) logging.info(args) if args.bs is not None: config.model['batch_size'] = args.bs # add checkpoint loading values load_epoch = args.epoch load_best = args.best config.model['load']['version'] = version config.model['load']['epoch'] = load_epoch config.model['load']['load_best'] = load_best # ensures that the epoch_counter attribute is set to the # epoch number being loaded config.model['load']['resume_epoch'] = True if args.wandb: # set up wandb os.environ['WANDB_ENTITY'] = config.entity os.environ['WANDB_PROJECT'] = config.project os.environ['WANDB_DIR'] = dirname(config.checkpoint_dir) run_name = '_'.join(['evaluation', version.replace('/', '_')]) wandb.init(name=run_name, dir=dirname(config.checkpoint_dir), notes=config.description) wandb.config.update(config.__dict__) config.num_workers = args.num_workers evaluate(config, args.mode, args.wandb, args.ignore_cache, args.n_tta)
val_preds = val['target'].values val_labels = val['label'].values roc = roc_auc_score(val_labels, val_preds) print(roc) # In[9]: config = Config(join('/workspace/coreml', config_name + '.yml')) # In[19]: set_logger(join(config.log_dir, 'debug.log')) # In[10]: val_dataloader, _ = get_dataloader( config.data, 'val', config.model['batch_size'], num_workers=10, shuffle=False, drop_last=False) # In[16]: