def log_metrics(boardlogger, train_loss, val_loss, val_pred, val_gt): kvs = GlobalKVS() res = {'epoch': kvs['cur_epoch'], 'val_loss': val_loss} print( colored('==> ', 'green') + f'Train loss: {train_loss:.4f} / Val loss: {val_loss:.4f}') res.update(compute_metrics(val_pred, val_gt, no_kl=kvs['args'].no_kl)) boardlogger.add_scalars('Losses', { 'train': train_loss, 'val': val_loss }, kvs['cur_epoch']) boardlogger.add_scalars( 'Metrics', {metric: res[metric] for metric in res if metric.startswith('kappa')}, kvs['cur_epoch']) kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', { 'epoch': kvs['cur_epoch'], 'train_loss': train_loss, 'val_loss': val_loss }) kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', res) kvs.save_pkl( os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def save_checkpoint(model, optimizer): kvs = GlobalKVS() fold_id = kvs['cur_fold'] epoch = kvs['cur_epoch'] val_metric = kvs[f'val_metrics_fold_[{fold_id}]'][-1][0][kvs['args'].snapshot_on] comparator = getattr(operator, kvs['args'].snapshot_comparator) cur_snapshot_name = os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], f'fold_{fold_id}_epoch_{epoch+1}.pth') if kvs['prev_model'] is None: print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name) torch.save({'epoch': epoch, 'net': net_core(model).state_dict(), 'optim': optimizer.state_dict()}, cur_snapshot_name) kvs.update('prev_model', cur_snapshot_name) kvs.update('best_val_metric', val_metric) else: if comparator(val_metric, kvs['best_val_metric']): print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name) if not kvs['args'].keep_snapshots: os.remove(kvs['prev_model']) torch.save({'epoch': epoch, 'net': net_core(model).state_dict(), 'optim': optimizer.state_dict()}, cur_snapshot_name) kvs.update('prev_model', cur_snapshot_name) kvs.update('best_val_metric', val_metric) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def init_metadata(): kvs = GlobalKVS() if not os.path.isfile(os.path.join(kvs['args'].snapshots, 'oai_meta.pkl')): print('==> Cached metadata is not found. Generating...') oai_meta, most_meta = build_dataset_meta(kvs['args']) oai_meta.to_pickle(os.path.join(kvs['args'].snapshots, 'oai_meta.pkl'), compression='infer', protocol=4) most_meta.to_pickle(os.path.join(kvs['args'].snapshots, 'most_meta.pkl'), compression='infer', protocol=4) else: print('==> Loading cached metadata...') oai_meta = pd.read_pickle(os.path.join(kvs['args'].snapshots, 'oai_meta.pkl')) most_meta = pd.read_pickle(os.path.join(kvs['args'].snapshots, 'most_meta.pkl')) most_meta = most_meta[(most_meta.XRKL >= 0) & (most_meta.XRKL <= 4)] oai_meta = oai_meta[(oai_meta.XRKL >= 0) & (oai_meta.XRKL <= 4)] print(colored('==> ', 'green') + 'Images in OAI:', oai_meta.shape[0]) print(colored('==> ', 'green') + 'Images in MOST:', most_meta.shape[0]) kvs.update('most_meta', most_meta) kvs.update('oai_meta', oai_meta) gkf = GroupKFold(kvs['args'].n_folds) cv_split = [x for x in gkf.split(kvs[kvs["args"].train_set + '_meta'], groups=kvs[kvs["args"].train_set + '_meta']['ID'].values)] kvs.update('cv_split_all_folds', cv_split) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def init_data_processing(): kvs = GlobalKVS() train_trf, val_trf = init_transforms(None, None) dataset = OARSIGradingDataset(kvs[f'{kvs["args"].train_set}_meta'], train_trf) mean_vector, std_vector = init_mean_std(snapshots_dir=kvs['args'].snapshots, dataset=dataset, batch_size=kvs['args'].bs, n_threads=kvs['args'].n_threads) print(colored('====> ', 'red') + 'Mean:', mean_vector) print(colored('====> ', 'red') + 'Std:', std_vector) kvs.update('mean_vector', mean_vector) kvs.update('std_vector', std_vector) train_trf, val_trf = init_transforms(mean_vector, std_vector) kvs.update('train_trf', train_trf) kvs.update('val_trf', val_trf) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def init_folds(): kvs = GlobalKVS() writers = {} cv_split_train = {} for fold_id, split in enumerate(kvs['cv_split_all_folds']): if kvs['args'].fold != -1 and fold_id != kvs['args'].fold: continue kvs.update(f'losses_fold_[{fold_id}]', None, list) kvs.update(f'val_metrics_fold_[{fold_id}]', None, list) cv_split_train[fold_id] = split writers[fold_id] = SummaryWriter(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'logs', 'fold_{}'.format(fold_id), kvs['snapshot_name'])) kvs.update('cv_split_train', cv_split_train) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl')) return writers
def init_session(): if not torch.cuda.is_available(): raise EnvironmentError('The code must be run on GPU.') kvs = GlobalKVS() # Getting the arguments args = parse_args() # Initializing the seeds torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) # Creating the snapshot snapshot_name = time.strftime('%Y_%m_%d_%H_%M') os.makedirs(os.path.join(args.snapshots, snapshot_name), exist_ok=True) res = git_info() if res is not None: kvs.update('git branch name', res[0]) kvs.update('git commit id', res[1]) else: kvs.update('git branch name', None) kvs.update('git commit id', None) kvs.update('pytorch_version', torch.__version__) if torch.cuda.is_available(): kvs.update('cuda', torch.version.cuda) kvs.update('gpus', torch.cuda.device_count()) else: kvs.update('cuda', None) kvs.update('gpus', None) kvs.update('snapshot_name', snapshot_name) kvs.update('args', args) kvs.save_pkl(os.path.join(args.snapshots, snapshot_name, 'session.pkl')) return args, snapshot_name
cv2.setNumThreads(0) DEBUG = sys.gettrace() is not None if __name__ == "__main__": kvs = GlobalKVS() session.init_session() session.init_metadata() writers = session.init_folds() session.init_data_processing() for fold_id in kvs['cv_split_train']: if kvs['args'].fold != -1 and fold_id != kvs['args'].fold: continue kvs.update('cur_fold', fold_id) kvs.update('prev_model', None) print(colored('====> ', 'blue') + f'Training fold {fold_id}....') train_index, val_index = kvs['cv_split_train'][fold_id] train_loader, val_loader = session.init_loaders(kvs[f'{kvs["args"].train_set}_meta'].iloc[train_index], kvs[f'{kvs["args"].train_set}_meta'].iloc[val_index]) net, criterion = utils.init_model() if kvs['args'].pretrained: net.train(False) utils.net_core(net).classifier.train(True) optimizer = utils.init_optimizer(utils.layer_params(net, 'classifier')) else: print(colored('====> ', 'red') + 'The model will be trained from scratch!')