def init_metadata(): kvs = GlobalKVS() if not os.path.isfile(os.path.join(kvs['args'].snapshots, 'oai_meta.pkl')): print('==> Cached metadata is not found. Generating...') oai_meta, most_meta = build_dataset_meta(kvs['args']) oai_meta.to_pickle(os.path.join(kvs['args'].snapshots, 'oai_meta.pkl'), compression='infer', protocol=4) most_meta.to_pickle(os.path.join(kvs['args'].snapshots, 'most_meta.pkl'), compression='infer', protocol=4) else: print('==> Loading cached metadata...') oai_meta = pd.read_pickle(os.path.join(kvs['args'].snapshots, 'oai_meta.pkl')) most_meta = pd.read_pickle(os.path.join(kvs['args'].snapshots, 'most_meta.pkl')) most_meta = most_meta[(most_meta.XRKL >= 0) & (most_meta.XRKL <= 4)] oai_meta = oai_meta[(oai_meta.XRKL >= 0) & (oai_meta.XRKL <= 4)] print(colored('==> ', 'green') + 'Images in OAI:', oai_meta.shape[0]) print(colored('==> ', 'green') + 'Images in MOST:', most_meta.shape[0]) kvs.update('most_meta', most_meta) kvs.update('oai_meta', oai_meta) gkf = GroupKFold(kvs['args'].n_folds) cv_split = [x for x in gkf.split(kvs[kvs["args"].train_set + '_meta'], groups=kvs[kvs["args"].train_set + '_meta']['ID'].values)] kvs.update('cv_split_all_folds', cv_split) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def log_metrics(boardlogger, train_loss, val_loss, val_pred, val_gt): kvs = GlobalKVS() res = {'epoch': kvs['cur_epoch'], 'val_loss': val_loss} print( colored('==> ', 'green') + f'Train loss: {train_loss:.4f} / Val loss: {val_loss:.4f}') res.update(compute_metrics(val_pred, val_gt, no_kl=kvs['args'].no_kl)) boardlogger.add_scalars('Losses', { 'train': train_loss, 'val': val_loss }, kvs['cur_epoch']) boardlogger.add_scalars( 'Metrics', {metric: res[metric] for metric in res if metric.startswith('kappa')}, kvs['cur_epoch']) kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', { 'epoch': kvs['cur_epoch'], 'train_loss': train_loss, 'val_loss': val_loss }) kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', res) kvs.save_pkl( os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def init_loaders(x_train, x_val): kvs = GlobalKVS() train_dataset, val_dataset = init_datasets(x_train, x_val) if kvs['args'].weighted_sampling: if not kvs['args'].mtw: print(colored('====> ', 'red') + 'Using weighted sampling (KL)') _, weights = make_weights_for_multiclass(x_train.XRKL.values.astype(int)) else: print(colored('====> ', 'red') + 'Using weighted sampling (MTW)') cols = ['XROSTL', 'XROSFL', 'XRJSL', 'XROSTM', 'XROSFM', 'XRJSM'] weights = torch.stack([make_weights_for_multiclass(x_train[col].values.astype(int))[1] for col in cols], 1).max(1)[0] sampler = WeightedRandomSampler(weights, x_train.shape[0], True) train_loader = DataLoader(train_dataset, batch_size=kvs['args'].bs, num_workers=kvs['args'].n_threads, drop_last=True, sampler=sampler) else: train_loader = DataLoader(train_dataset, batch_size=kvs['args'].bs, num_workers=kvs['args'].n_threads, drop_last=True, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=kvs['args'].val_bs, num_workers=kvs['args'].n_threads) return train_loader, val_loader
def init_mean_std(snapshots_dir, dataset, batch_size, n_threads): kvs = GlobalKVS() if os.path.isfile(os.path.join(snapshots_dir, f'mean_std_{kvs["args"].train_set}.npy')): tmp = np.load(os.path.join(snapshots_dir, f'mean_std_{kvs["args"].train_set}.npy')) mean_vector, std_vector = tmp else: tmp_loader = DataLoader(dataset, batch_size=batch_size, num_workers=n_threads) mean_vector = None std_vector = None print(colored('==> ', 'green') + 'Calculating mean and std') for batch in tqdm(tmp_loader, total=len(tmp_loader)): if kvs['args'].siamese: imgs = torch.cat((batch['img_med'], batch['img_lat'])) else: imgs = batch['img'] if mean_vector is None: mean_vector = np.zeros(imgs.size(1)) std_vector = np.zeros(imgs.size(1)) for j in range(mean_vector.shape[0]): mean_vector[j] += imgs[:, j, :, :].mean() std_vector[j] += imgs[:, j, :, :].std() mean_vector /= len(tmp_loader) std_vector /= len(tmp_loader) np.save(os.path.join(snapshots_dir, f'mean_std_{kvs["args"].train_set}.npy'), [mean_vector.astype(np.float32), std_vector.astype(np.float32)]) return mean_vector, std_vector
def init_folds(): kvs = GlobalKVS() writers = {} cv_split_train = {} for fold_id, split in enumerate(kvs['cv_split_all_folds']): if kvs['args'].fold != -1 and fold_id != kvs['args'].fold: continue kvs.update(f'losses_fold_[{fold_id}]', None, list) kvs.update(f'val_metrics_fold_[{fold_id}]', None, list) cv_split_train[fold_id] = split writers[fold_id] = SummaryWriter(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'logs', 'fold_{}'.format(fold_id), kvs['snapshot_name'])) kvs.update('cv_split_train', cv_split_train) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl')) return writers
def init_scheduler(optimizer: Optimizer, epoch_start: int) -> MultiStepLR: kvs = GlobalKVS() scheduler = MultiStepLR(optimizer, milestones=list( map(lambda x: x - epoch_start, kvs['args'].lr_drop)), gamma=kvs['args'].lr_drop_gamma) return scheduler
def init_optimizer(params) -> Optimizer: kvs = GlobalKVS() if kvs['args'].optimizer == 'adam': return optim.Adam(params, lr=kvs['args'].lr, weight_decay=kvs['args'].wd) elif kvs['args'].optimizer == 'sgd': return optim.SGD(params, lr=kvs['args'].lr, weight_decay=kvs['args'].wd, momentum=kvs['args'].momentum, nesterov=kvs['args'].nesterov) else: raise NotImplementedError
def init_model() -> Tuple[nn.Module, nn.Module]: kvs = GlobalKVS() if kvs['args'].siamese: net = OARSIGradingNetSiamese(backbone=kvs['args'].siamese_bb, dropout=kvs['args'].dropout_rate) else: net = OARSIGradingNet(bb_depth=kvs['args'].backbone_depth, dropout=kvs['args'].dropout_rate, cls_bnorm=kvs['args'].use_bnorm, se=kvs['args'].se, dw=kvs['args'].dw, use_gwap=kvs['args'].use_gwap, use_gwap_hidden=kvs['args'].use_gwap_hidden, pretrained=kvs['args'].pretrained, no_kl=kvs['args'].no_kl) if kvs['gpus'] > 1: net = nn.DataParallel(net).to('cuda') return net.to('cuda'), init_loss().to('cuda')
def init_data_processing(): kvs = GlobalKVS() train_trf, val_trf = init_transforms(None, None) dataset = OARSIGradingDataset(kvs[f'{kvs["args"].train_set}_meta'], train_trf) mean_vector, std_vector = init_mean_std(snapshots_dir=kvs['args'].snapshots, dataset=dataset, batch_size=kvs['args'].bs, n_threads=kvs['args'].n_threads) print(colored('====> ', 'red') + 'Mean:', mean_vector) print(colored('====> ', 'red') + 'Std:', std_vector) kvs.update('mean_vector', mean_vector) kvs.update('std_vector', std_vector) train_trf, val_trf = init_transforms(mean_vector, std_vector) kvs.update('train_trf', train_trf) kvs.update('val_trf', val_trf) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def init_session(): if not torch.cuda.is_available(): raise EnvironmentError('The code must be run on GPU.') kvs = GlobalKVS() # Getting the arguments args = parse_args() # Initializing the seeds torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) # Creating the snapshot snapshot_name = time.strftime('%Y_%m_%d_%H_%M') os.makedirs(os.path.join(args.snapshots, snapshot_name), exist_ok=True) res = git_info() if res is not None: kvs.update('git branch name', res[0]) kvs.update('git commit id', res[1]) else: kvs.update('git branch name', None) kvs.update('git commit id', None) kvs.update('pytorch_version', torch.__version__) if torch.cuda.is_available(): kvs.update('cuda', torch.version.cuda) kvs.update('gpus', torch.cuda.device_count()) else: kvs.update('cuda', None) kvs.update('gpus', None) kvs.update('snapshot_name', snapshot_name) kvs.update('args', args) kvs.save_pkl(os.path.join(args.snapshots, snapshot_name, 'session.pkl')) return args, snapshot_name
def save_checkpoint(model, optimizer): kvs = GlobalKVS() fold_id = kvs['cur_fold'] epoch = kvs['cur_epoch'] val_metric = kvs[f'val_metrics_fold_[{fold_id}]'][-1][0][kvs['args'].snapshot_on] comparator = getattr(operator, kvs['args'].snapshot_comparator) cur_snapshot_name = os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], f'fold_{fold_id}_epoch_{epoch+1}.pth') if kvs['prev_model'] is None: print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name) torch.save({'epoch': epoch, 'net': net_core(model).state_dict(), 'optim': optimizer.state_dict()}, cur_snapshot_name) kvs.update('prev_model', cur_snapshot_name) kvs.update('best_val_metric', val_metric) else: if comparator(val_metric, kvs['best_val_metric']): print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name) if not kvs['args'].keep_snapshots: os.remove(kvs['prev_model']) torch.save({'epoch': epoch, 'net': net_core(model).state_dict(), 'optim': optimizer.state_dict()}, cur_snapshot_name) kvs.update('prev_model', cur_snapshot_name) kvs.update('best_val_metric', val_metric) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def init_datasets(x_train, x_val): kvs = GlobalKVS() train_dataset = OARSIGradingDataset(x_train, kvs['train_trf']) val_dataset = OARSIGradingDataset(x_val, kvs['val_trf']) return train_dataset, val_dataset
from oarsigrading.kvs import GlobalKVS from oarsigrading.training.dataset import OARSIGradingDataset from oarsigrading.evaluation import metrics from oarsigrading.training.model_zoo import backbone_name from oarsigrading.training.model import OARSIGradingNet, OARSIGradingNetSiamese import oarsigrading.evaluation.tta as tta from oarsigrading.training.transforms import apply_by_index import pandas as pd cv2.ocl.setUseOpenCL(False) cv2.setNumThreads(0) DEBUG = sys.gettrace() is not None if __name__ == "__main__": kvs = GlobalKVS() parser = argparse.ArgumentParser() parser.add_argument('--dataset_root', default='') parser.add_argument('--meta_root', default='') parser.add_argument('--tta', type=bool, default=False) parser.add_argument('--bs', type=int, default=32) parser.add_argument('--n_threads', type=int, default=12) parser.add_argument('--snapshots_root', default='') parser.add_argument('--snapshot', default='') parser.add_argument('--save_dir', default='') args = parser.parse_args() with open(os.path.join(args.snapshots_root, args.snapshot, 'session.pkl'), 'rb') as f: session_backup = pickle.load(f)
def init_transforms(mean_vector, std_vector): kvs = GlobalKVS() if mean_vector is not None: mean_vector = torch.from_numpy(mean_vector).float() std_vector = torch.from_numpy(std_vector).float() norm_trf = partial(normalize_channel_wise, mean=mean_vector, std=std_vector) norm_trf = partial(apply_by_index, transform=norm_trf, idx=[0, 1, 2]) else: norm_trf = None if kvs['args'].siamese: resize_train = slc.Stream() crop_train = slt.CropTransform(crop_size=(kvs['args'].imsize, kvs['args'].imsize), crop_mode='c') else: resize_train = slt.ResizeTransform( (kvs['args'].inp_size, kvs['args'].inp_size)) crop_train = slt.CropTransform(crop_size=(kvs['args'].crop_size, kvs['args'].crop_size), crop_mode='r') train_trf = [ wrap2solt, slc.Stream([ slt.PadTransform(pad_to=(kvs['args'].imsize, kvs['args'].imsize)), slt.CropTransform(crop_size=(kvs['args'].imsize, kvs['args'].imsize), crop_mode='c'), resize_train, slt.ImageAdditiveGaussianNoise(p=0.5, gain_range=0.3), slt.RandomRotate(p=1, rotation_range=(-10, 10)), crop_train, slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)), ]), unpack_solt_data, partial(pack_tensors, no_kl=kvs['args'].no_kl), ] if not kvs['args'].siamese: resize_val = slc.Stream([ slt.ResizeTransform((kvs['args'].inp_size, kvs['args'].inp_size)), slt.CropTransform(crop_size=(kvs['args'].crop_size, kvs['args'].crop_size), crop_mode='c'), ]) else: resize_val = slc.Stream() val_trf = [ wrap2solt, slc.Stream([ slt.PadTransform(pad_to=(kvs['args'].imsize, kvs['args'].imsize)), slt.CropTransform(crop_size=(kvs['args'].imsize, kvs['args'].imsize), crop_mode='c'), resize_val, ]), unpack_solt_data, partial(pack_tensors, no_kl=kvs['args'].no_kl), ] if norm_trf is not None: train_trf.append(norm_trf) val_trf.append(norm_trf) train_trf = transforms.Compose(train_trf) val_trf = transforms.Compose(val_trf) return train_trf, val_trf
import cv2 import sys from termcolor import colored from oarsigrading.kvs import GlobalKVS from oarsigrading.training import session from oarsigrading.training import utils from oarsigrading.evaluation import metrics cv2.ocl.setUseOpenCL(False) cv2.setNumThreads(0) DEBUG = sys.gettrace() is not None if __name__ == "__main__": kvs = GlobalKVS() session.init_session() session.init_metadata() writers = session.init_folds() session.init_data_processing() for fold_id in kvs['cv_split_train']: if kvs['args'].fold != -1 and fold_id != kvs['args'].fold: continue kvs.update('cur_fold', fold_id) kvs.update('prev_model', None) print(colored('====> ', 'blue') + f'Training fold {fold_id}....') train_index, val_index = kvs['cv_split_train'][fold_id] train_loader, val_loader = session.init_loaders(kvs[f'{kvs["args"].train_set}_meta'].iloc[train_index],
def epoch_pass( net: nn.Module, loader: DataLoader, criterion: nn.Module, optimizer: Optimizer or None, writer: SummaryWriter or None = None ) -> float or Tuple[float, List[str], np.ndarray, np.ndarray]: kvs = GlobalKVS() if optimizer is not None: net.train(True) else: net.train(False) running_loss = 0.0 n_batches = len(loader) pbar = tqdm(total=len(loader)) epoch = kvs['cur_epoch'] max_epoch = kvs['args'].n_epochs fold_id = kvs['cur_fold'] device = next(net.parameters()).device predicts = [] fnames = [] gt = [] with torch.set_grad_enabled(optimizer is not None): for i, batch in enumerate(loader): if optimizer is not None: optimizer.zero_grad() # forward + backward + optimize labels = batch['target'].squeeze().to(device) if kvs['args'].siamese: inp_med = batch['img_med'].squeeze().to(device) inp_lat = batch['img_lat'].squeeze().to(device) outputs = net(inp_med, inp_lat) else: inputs = batch['img'].squeeze().to(device) outputs = net(inputs) loss = criterion(outputs, labels) if optimizer is not None: loss.backward() optimizer.step() pbar.set_description( f'[{fold_id}] Train:: [{epoch} / {max_epoch}]:: ' f'{running_loss / (i + 1):.3f} | {loss.item():.3f}') else: tmp_preds = np.zeros(batch['target'].squeeze().size(), dtype=np.int64) for task_id, o in enumerate(outputs): tmp_preds[:, task_id] = outputs[task_id].to( 'cpu').squeeze().argmax(1) predicts.append(tmp_preds) gt.append(batch['target'].to('cpu').numpy().squeeze()) fnames.extend(batch['ID']) pbar.set_description( f'[{fold_id}] Validating [{epoch} / {max_epoch}]:') if writer is not None and optimizer is not None: writer.add_scalar('train_logs/loss', loss.item(), kvs['cur_epoch'] * len(loader) + i) running_loss += loss.item() pbar.update() gc.collect() gc.collect() pbar.close() if optimizer is not None: return running_loss / n_batches return running_loss / n_batches, fnames, np.vstack( gt).squeeze(), np.vstack(predicts)