def init_metadata(): kvs = GlobalKVS() imgs = glob.glob(os.path.join(kvs['args'].dataset, '*', 'imgs', '*.png')) imgs.sort(key=lambda x: x.split('/')[-1]) masks = glob.glob(os.path.join(kvs['args'].dataset, '*', 'masks', '*.png')) masks.sort(key=lambda x: x.split('/')[-1]) sample_id = list(map(lambda x: x.split('/')[-3], imgs)) subject_id = list(map(lambda x: x.split('/')[-3].split('_')[0], imgs)) metadata = pd.DataFrame( data={ 'img_fname': imgs, 'mask_fname': masks, 'sample_id': sample_id, 'subject_id': subject_id }) metadata['sample_subject_proj'] = metadata.apply( lambda x: gen_image_id(x.img_fname, x.sample_id), 1) grades = pd.read_csv(kvs['args'].grades) metadata = pd.merge(metadata, grades, on='sample_id') kvs.update('metadata', metadata) return metadata
def init_pd_meta(): """ Basic implementation of metadata loading. Loads the pandas data frame and stores it in global KVS under the `metadata` tag. Returns ------- out : None """ kvs = GlobalKVS() metadata = pd.read_csv(os.path.join(kvs['args'].workdir, kvs['args'].metadata)) kvs.update('metadata', metadata)
def init_augs(): kvs = GlobalKVS() args = kvs['args'] cutout = slt.ImageCutOut(cutout_size=(int(args.cutout * args.crop_x), int(args.cutout * args.crop_y)), p=0.5) # plus-minus 1.3 pixels jitter = slt.KeypointsJitter(dx_range=(-0.003, 0.003), dy_range=(-0.003, 0.003)) ppl = tvt.Compose([ jitter if args.use_target_jitter else slc.Stream(), slc.SelectiveStream([ slc.Stream([ slt.RandomFlip(p=0.5, axis=1), slt.RandomProjection(affine_transforms=slc.Stream([ slt.RandomScale(range_x=(0.8, 1.3), p=1), slt.RandomRotate(rotation_range=(-90, 90), p=1), slt.RandomShear( range_x=(-0.1, 0.1), range_y=(-0.1, 0.1), p=0.5), ]), v_range=(1e-5, 2e-3), p=0.5), slt.RandomScale(range_x=(0.5, 2.5), p=0.5), ]), slc.Stream() ], probs=[0.7, 0.3]), slc.Stream([ slt.PadTransform((args.pad_x, args.pad_y), padding='z'), slt.CropTransform((args.crop_x, args.crop_y), crop_mode='r'), ]), slc.SelectiveStream([ slt.ImageSaltAndPepper(p=1, gain_range=0.01), slt.ImageBlur(p=1, blur_type='g', k_size=(3, 5)), slt.ImageBlur(p=1, blur_type='m', k_size=(3, 5)), slt.ImageAdditiveGaussianNoise(p=1, gain_range=0.5), slc.Stream([ slt.ImageSaltAndPepper(p=1, gain_range=0.05), slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)), ]), slc.Stream([ slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)), slt.ImageSaltAndPepper(p=1, gain_range=0.01), ]), slc.Stream() ], n=1), slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)), cutout if args.use_cutout else slc.Stream(), partial(solt2torchhm, downsample=None, sigma=None), ]) kvs.update('train_trf', ppl)
def log_metrics(writer, train_loss, val_loss, val_results, val_results_callback=None): """ Basic function to log the results from the validation stage. takes Tensorboard writer, train loss, validation loss, the artifacts produced during the validation phase, and also additional callback that can process these data, e.g. compute the metrics and visualize them in Tensorboard. By default, train and validation losses are visualized outside of the callback. If any metric is computed in the callback, it is useful to log it into a dictionary `to_log`. Parameters ---------- writer : SummaryWriter Tensorboard summary writer train_loss : float Training loss val_loss : float Validation loss val_results : object Artifacts produced during teh validation val_results_callback : Callable or None A callback function that can process the artifacts and e.g. display those in Tensorboard. Returns ------- out : None """ kvs = GlobalKVS() print(colored('==> ', 'green') + 'Metrics:') print(colored('====> ', 'green') + 'Train loss:', train_loss) print(colored('====> ', 'green') + 'Val loss:', val_loss) to_log = {'train_loss': train_loss, 'val_loss': val_loss} val_metrics = {'epoch': kvs['cur_epoch']} val_metrics.update(to_log) writer.add_scalars(f"Losses_{kvs['args'].experiment_tag}", to_log, kvs['cur_epoch']) if val_results_callback is not None: val_results_callback(writer, val_metrics, to_log, val_results) kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', to_log) kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', val_metrics)
def save_checkpoint(net, loss, optimizer, val_metric_name, comparator='lt'): """ Flexible function that saves the model and the optimizer states using a metric and a comparator. Parameters ---------- net : torch.nn.Module Model optimizer : torch.optim.Optimizer Optimizer val_metric_name : str Name of the metric that needs to be used for snapshot comparison. This name needs match the once that were created in the callback function passed to `log_metrics`. comparator : str How to compare the previous and the current metric values - `lt` is less than, and `gt` is greater than. Returns ------- out : None """ if isinstance(net, torch.nn.DataParallel): net = net.module kvs = GlobalKVS() fold_id = kvs['cur_fold'] epoch = kvs['cur_epoch'] val_metric = kvs[f'val_metrics_fold_[{fold_id}]'][-1][0][val_metric_name] comparator = getattr(operator, comparator) cur_snapshot_name = os.path.join( os.path.join(kvs['args'].workdir, 'snapshots', kvs['snapshot_name'], f'fold_{fold_id}_epoch_{epoch}.pth')) state = { 'model': net.state_dict(), 'optimizer': optimizer.state_dict(), 'loss': loss.state_dict() } if kvs['prev_model'] is None: print( colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name) torch.save(state, cur_snapshot_name) kvs.update('prev_model', cur_snapshot_name) kvs.update('best_val_metric', val_metric) else: if comparator(val_metric, kvs['best_val_metric']): print( colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name) os.remove(kvs['prev_model']) torch.save(state, cur_snapshot_name) kvs.update('prev_model', cur_snapshot_name) kvs.update('best_val_metric', val_metric)
def init_data_processing(): kvs = GlobalKVS() dataset = LandmarkDataset(data_root=kvs['args'].dataset_root, split=kvs['metadata'], hc_spacing=kvs['args'].hc_spacing, lc_spacing=kvs['args'].lc_spacing, transform=kvs['train_trf'], ann_type=kvs['args'].annotations, image_pad=kvs['args'].img_pad) tmp = init_mean_std(snapshots_dir=os.path.join(kvs['args'].workdir, 'snapshots'), dataset=dataset, batch_size=kvs['args'].bs, n_threads=kvs['args'].n_threads, n_classes=-1) if len(tmp) == 3: mean_vector, std_vector, class_weights = tmp elif len(tmp) == 2: mean_vector, std_vector = tmp else: raise ValueError('Incorrect format of mean/std/class-weights') norm_trf = partial(normalize_channel_wise, mean=mean_vector, std=std_vector) train_trf = tvt.Compose( [kvs['train_trf'], partial(apply_by_index, transform=norm_trf, idx=0)]) val_trf = tvt.Compose([ slc.Stream([ slt.PadTransform((kvs['args'].pad_x, kvs['args'].pad_y), padding='z'), slt.CropTransform((kvs['args'].crop_x, kvs['args'].crop_y), crop_mode='c'), ]), partial(solt2torchhm, downsample=None, sigma=None), partial(apply_by_index, transform=norm_trf, idx=0) ]) kvs.update('train_trf', train_trf) kvs.update('val_trf', val_trf)
def train_fold(net, train_loader, optimizer, criterion, val_loader, scheduler): kvs = GlobalKVS() fold_id = kvs['cur_fold'] writer = SummaryWriter( os.path.join(kvs['args'].workdir, 'snapshots', kvs['snapshot_name'], 'logs', 'fold_{}'.format(fold_id), kvs['snapshot_name'])) for epoch in range(kvs['args'].n_epochs): print( colored('==> ', 'green') + f'Training epoch [{epoch}] with LR {scheduler.get_lr()}') kvs.update('cur_epoch', epoch) train_loss, _ = pass_epoch(net, train_loader, optimizer, criterion) val_loss, conf_matrix = pass_epoch(net, val_loader, None, criterion) log_metrics(writer, train_loss, val_loss, conf_matrix) save_checkpoint(net, optimizer, 'val_loss', 'lt') scheduler.step()
def init_binary_segmentation_augs(): kvs = GlobalKVS() ppl = tvt.Compose([ img_binary_mask2solt, slc.Stream([ slt.PadTransform(pad_to=(kvs['args'].pad_x, kvs['args'].pad_y)), slt.RandomFlip(axis=1, p=0.5), slt.CropTransform(crop_size=(kvs['args'].crop_x, kvs['args'].crop_y), crop_mode='r'), slt.ImageGammaCorrection(gamma_range=(kvs['args'].gamma_min, kvs['args'].gamma_max), p=0.5), ]), solt2img_binary_mask, partial(apply_by_index, transform=numpy2tens, idx=[0, 1]), ]) kvs.update('train_trf', ppl) return ppl
def log_metrics(writer, train_loss, val_loss, conf_matrix): kvs = GlobalKVS() dices = { 'dice_{}'.format(cls): dice for cls, dice in enumerate(calculate_dice(conf_matrix)) } ious = { 'iou_{}'.format(cls): iou for cls, iou in enumerate(calculate_iou(conf_matrix)) } print(colored('==> ', 'green') + 'Metrics:') print(colored('====> ', 'green') + 'Train loss:', train_loss) print(colored('====> ', 'green') + 'Val loss:', val_loss) print(colored('====> ', 'green') + f'Val Dice: {dices}') print(colored('====> ', 'green') + f'Val IoU: {ious}') dices_tb = {} for cls in range(1, len(dices)): dices_tb[f"Dice [{cls}]"] = dices[f"dice_{cls}"] ious_tb = {} for cls in range(1, len(ious)): ious_tb[f"IoU [{cls}]"] = ious[f"iou_{cls}"] to_log = {'train_loss': train_loss, 'val_loss': val_loss} # Tensorboard logging writer.add_scalars(f"Losses_{kvs['args'].model}", to_log, kvs['cur_epoch']) writer.add_scalars('Metrics/Dice', dices_tb, kvs['cur_epoch']) writer.add_scalars('Metrics/IoU', ious_tb, kvs['cur_epoch']) # KVS logging to_log.update({'epoch': kvs['cur_epoch']}) val_metrics = {'epoch': kvs['cur_epoch']} val_metrics.update(to_log) val_metrics.update(dices) val_metrics.update({'conf_matrix': conf_matrix}) kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', to_log) kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', val_metrics)
def init_folds(img_group_id_colname=None, img_class_colname=None): """ Initialzies the cross-validation splits. Parameters ---------- img_group_id_colname : str or None Column in `metadata` that is used to create cross-validation splits. If not None, then images that have the same group_id are never in train and validation. img_class_colname : str or None Column in `metadata` that is used to create cross-validation splits. If not none, splits are stratifed to ensure the same distribution of `img_class_colname` in train and validation. Returns ------- """ kvs = GlobalKVS() if img_group_id_colname is not None: gkf = GroupKFold(kvs['args'].n_folds) if img_class_colname is not None: class_col_name = getattr(kvs['metadata'], img_class_colname, None) else: class_col_name = None splitter = gkf.split(X=kvs['metadata'], y=class_col_name, groups=getattr(kvs['metadata'], img_group_id_colname)) else: if img_class_colname is not None: skf = StratifiedKFold(kvs['args'].n_folds) splitter = skf.split(X=kvs['metadata'], y=getattr(kvs['metadata'], img_class_colname, None)) else: kf = KFold(kvs['args'].n_folds) splitter = kf.split(X=kvs['metadata']) cv_split = [] for fold_id, (train_ind, val_ind) in enumerate(splitter): if kvs['args'].fold != -1 and fold_id != kvs['args'].fold: continue np.random.shuffle(train_ind) train_ind = train_ind[::kvs['args'].skip_train] cv_split.append((fold_id, kvs['metadata'].iloc[train_ind], kvs['metadata'].iloc[val_ind])) kvs.update(f'losses_fold_[{fold_id}]', None, list) kvs.update(f'val_metrics_fold_[{fold_id}]', None, list) kvs.update('cv_split', cv_split)
def init_data_processing(img_reader=read_rgb_ocv, mask_reader=read_gs_binary_mask_ocv): kvs = GlobalKVS() dataset = SegmentationDataset(split=kvs['metadata'], trf=kvs['train_trf'], read_img=img_reader, read_mask=mask_reader) tmp = init_mean_std(snapshots_dir=os.path.join(kvs['args'].workdir, 'snapshots'), dataset=dataset, batch_size=kvs['args'].bs, n_threads=kvs['args'].n_threads, n_classes=kvs['args'].n_classes) if len(tmp) == 3: mean_vector, std_vector, class_weights = tmp elif len(tmp) == 2: mean_vector, std_vector = tmp else: raise ValueError('Incorrect format of mean/std/class-weights') norm_trf = partial(normalize_channel_wise, mean=mean_vector, std=std_vector) train_trf = tvt.Compose( [kvs['train_trf'], partial(apply_by_index, transform=norm_trf, idx=0)]) val_trf = tvt.Compose([ partial(apply_by_index, transform=numpy2tens, idx=[0, 1]), partial(apply_by_index, transform=norm_trf, idx=0) ]) kvs.update('class_weights', class_weights) kvs.update('train_trf', train_trf) kvs.update('val_trf', val_trf)
def init_session(args): """ Basic function that initializes each training loop. Sets the seed based on the parsed args, creates the snapshots dir and initializes global KVS. Parameters ---------- args : Namespace Arguments from argparse. Returns ------- out : tuple Args, snapshot name and global KVS. """ # Initializing the seeds torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) if args.experiment_config != '': with open(args.experiment_config, 'r') as f: conf = yaml.safe_load(f) else: conf = None raise Warning('No experiment config has has been provided') # Creating the snapshot snapshot_name = time.strftime(f'{socket.gethostname()}_%Y_%m_%d_%H_%M_%S') os.makedirs(os.path.join(args.workdir, 'snapshots', snapshot_name), exist_ok=True) kvs = GlobalKVS( os.path.join(args.workdir, 'snapshots', snapshot_name, 'session.pkl')) if conf is not None: kvs.update('config', conf) with open( os.path.join(args.workdir, 'snapshots', snapshot_name, 'config.yml'), 'w') as conf_file: yaml.dump(conf, conf_file) res = git_info() if res is not None: kvs.update('git branch name', res[0]) kvs.update('git commit id', res[1]) else: kvs.update('git branch name', None) kvs.update('git commit id', None) kvs.update('pytorch_version', torch.__version__) if torch.cuda.is_available(): kvs.update('cuda', torch.version.cuda) kvs.update('gpus', torch.cuda.device_count()) else: kvs.update('cuda', None) kvs.update('gpus', None) kvs.update('snapshot_name', snapshot_name) kvs.update('args', args) return args, snapshot_name, kvs
def train_fold(pass_epoch, net, train_loader, optimizer, criterion, val_loader, scheduler, save_by='val_loss', cmp='lt', log_metrics_cb=None, img_key=None): """ A common implementation of training one fold of a neural network. Presumably, it should be called within cross-validation loop. Parameters ---------- pass_epoch : Callable Function that trains or validates one epoch net : torch.nn.Module Model to train train_loader : torch.utils.data.DataLoader Training data loader optimizer : torch.optim.Optimizer Optimizer criterion : torch.nn.Module Loss function val_loader : torch.utils.data.DataLoader Validation data loader scheduler : lr_scheduler.Scheduler Learning rate scheduler save_by: str Name of the metric used to save the snapshot. Val loss by default. Also, ReduceOnPlateau will use this metric to drop LR. cmp: str Comparator for saving the snapshots. Can be `lt` (less than) or `gt` -- (greater than). log_metrics_cb : Callable or None Callback that processes the artifacts from validation stage. img_key : str Key in the dataloader that allows to extact an image. Used in SWA. Returns ------- """ kvs = GlobalKVS() fold_id = kvs['cur_fold'] writer = SummaryWriter( os.path.join(kvs['args'].workdir, 'snapshots', kvs['snapshot_name'], 'logs', 'fold_{}'.format(fold_id), kvs['snapshot_name'])) for epoch in range(kvs['args'].n_epochs): if scheduler is not None: lrs = [param_group['lr'] for param_group in optimizer.param_groups] print( colored('==> ', 'green') + f'Training epoch [{epoch}] with LR {lrs}') else: print(colored('==> ', 'green') + f'Training epoch [{epoch}]') kvs.update('cur_epoch', epoch) train_loss, _ = pass_epoch(net, train_loader, optimizer, criterion) if isinstance(optimizer, swa.SWA): optimizer.swap_swa_sgd() assert img_key is not None bn_update_cb(net, train_loader, img_key) val_loss, val_results = pass_epoch(net, val_loader, None, criterion) log_metrics(writer, train_loss, val_loss, val_results, log_metrics_cb) save_checkpoint(net, criterion, optimizer, save_by, cmp) if scheduler is not None: if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(kvs[f'val_metrics_fold_[{kvs["cur_fold"]}]'][-1] [0][save_by]) else: scheduler.step()
cv2.ocl.setUseOpenCL(False) cv2.setNumThreads(0) if __name__ == "__main__": kvs = GlobalKVS(None) parser = argparse.ArgumentParser() parser.add_argument('--dataset_root', default='') parser.add_argument('--tta', type=bool, default=False) parser.add_argument('--bs', type=int, default=32) parser.add_argument('--n_threads', type=int, default=12) parser.add_argument('--snapshots_root', default='') parser.add_argument('--snapshot', default='') args = parser.parse_args() with open(os.path.join(args.snapshots_root, args.snapshot, 'session.pkl'), 'rb') as f: session_backup = pickle.load(f) args.model = session_backup['args'][0].model args.n_inputs = session_backup['args'][0].n_inputs args.n_classes = session_backup['args'][0].n_classes args.bw = session_backup['args'][0].bw args.depth = session_backup['args'][0].depth args.cdepth = session_backup['args'][0].cdepth args.seed = session_backup['args'][0].seed kvs.update('args', args) run_oof_binary(args=args, session_backup=session_backup, read_img=read_gs_ocv, read_mask=read_gs_binary_mask_ocv, img_group_id_colname='sample_id')
with open(os.path.join(snp_full_path, 'config.yml'), 'r') as f: cfg = yaml.load(f) print( colored('==> Experiment: ', 'red') + cfg['experiment'][0]['experiment_description']) print(colored('==> Snapshot: ', 'green') + args.snapshot) snp_args = snapshot_session['args'][0] for arg in vars(snp_args): if not hasattr(args, arg): setattr(args, arg, getattr(snp_args, arg)) args.init_model_from = '' if not os.path.isfile(os.path.join(oof_results_dir, 'oof_results.npz')): kvs = GlobalKVS() kvs.update('args', args) kvs.update('val_trf', snapshot_session['val_trf'][0]) kvs.update('train_trf', snapshot_session['train_trf'][0]) oof_inference = [] oof_gt = [] subject_ids = [] kls = [] with torch.no_grad(): for fold_id, train_split, val_split in snapshot_session[ 'cv_split'][0]: _, val_loader = init_loaders(train_split, val_split, sequential_val_sampler=True) net = init_model() snp_weigths_path = glob.glob(
def segmentation_unet(data_xy, arguments, sample): """ The newest pipeline for Unet segmentation. Model training utilizes augmentations to improve robustness. Parameters ---------- data : ndarray (3-dimensional) Input data. args : Namespace Input arguments sample : str Sample name Returns ------- Segmented calcified tissue mask. """ kvs = GlobalKVS(None) parser = ArgumentParser() parser.add_argument('--dataset_root', default='../Data/') parser.add_argument('--tta', type=bool, default=False) parser.add_argument('--bs', type=int, default=28) parser.add_argument('--n_threads', type=int, default=12) parser.add_argument('--model', type=str, default='unet') parser.add_argument('--n_inputs', type=int, default=1) parser.add_argument('--n_classes', type=int, default=2) parser.add_argument('--bw', type=int, default=24) parser.add_argument('--depth', type=int, default=6) parser.add_argument('--cdepth', type=int, default=1) parser.add_argument('--seed', type=int, default=42) # parser.add_argument('--snapshots_root', default='../workdir/snapshots/') # parser.add_argument('--snapshot', default='dios-erc-gpu_2019_12_29_13_24') args = parser.parse_args() kvs.update('args', args) # Load model models = glob(str(arguments.model_path / f'fold_[0-9]*.pth')) #models = glob(str(arguments.model_path / f'fold_3*.pth')) models.sort() # List the models device = 'cuda' model_list = [] for fold in range(len(models)): model = init_model(ignore_data_parallel=True) snp = torch.load(models[fold]) if isinstance(snp, dict): snp = snp['model'] model.load_state_dict(snp) model_list.append(model) # Merge folds into one model model = InferenceModel(model_list).to(device) # Initialize model model.eval() tmp = np.load(str(arguments.model_path.parent / 'mean_std.npy'), allow_pickle=True) mean, std = tmp[0][0], tmp[1][0] # Flip the z-dimension #data_xy = np.flip(data_xy, axis=2) # Transpose data data_xz = np.transpose(data_xy, (2, 0, 1)) # X-Z-Y data_yz = np.transpose(data_xy, (2, 1, 0)) # Y-Z-X # Y-Z-X-Ch mask_xz = np.zeros(data_xz.shape) mask_yz = np.zeros(data_yz.shape) # res_xz = int(data_xz.shape[2] % args.bs > 0) # res_yz = int(data_yz.shape[2] % args.bs > 0) with torch.no_grad(): # for idx in tqdm(range(data_xz.shape[2] // args.bs + res_xz), desc='Running inference, XZ'): for idx in tqdm(range(data_xz.shape[2]), desc='Running inference, XZ'): """ try: img = np.expand_dims(data_xz[:, :, args.bs * idx:args.bs * (idx + 1)], axis=2) mask_xz[:, :, args.bs * idx: args.bs * (idx + 1)] = inference(model, img, shape=arguments.input_shape) except IndexError: img = np.expand_dims(data_xz[:, :, args.bs * idx:], axis=2) mask_xz[:, :, args.bs * idx:] = inference(model, img, shape=arguments.input_shape) """ img = np.expand_dims(data_xz[:, :, idx], axis=2) mask_xz[:, :, idx] = inference_tiles(model, img, shape=arguments.input_shape, mean=mean, std=std) # 2nd orientation # for idx in tqdm(range(data_yz.shape[2] // args.bs + res_yz), desc='Running inference, YZ'): for idx in tqdm(range(data_yz.shape[2]), desc='Running inference, YZ'): """ try: img = np.expand_dims(data_yz[:, :, args.bs * idx: args.bs * (idx + 1)], axis=2) mask_yz[:, :, args.bs * idx: args.bs * (idx + 1)] = inference(model, img, shape=arguments.input_shape) except IndexError: img = np.expand_dims(data_yz[:, :, args.bs * idx:], axis=2) mask_yz[:, :, args.bs * idx:] = inference(model, img, shape=arguments.input_shape) """ img = np.expand_dims(data_yz[:, :, idx], axis=2) mask_yz[:, :, idx] = inference_tiles(model, img, shape=arguments.input_shape, mean=mean, std=std) # Average probability maps mask_final = ( (mask_xz + np.transpose(mask_yz, (0, 2, 1))) / 2) >= arguments.threshold mask_xz = list() mask_yz = list() data_xz = list() mask_final = np.transpose(mask_final, (1, 2, 0)) mask_final[:, :, -mask_final.shape[2] // 3:] = False largest = largest_object(mask_final) return largest