def init_data_processing(): kvs = GlobalKVS() train_augs = init_train_augmentation_pipeline() dataset = SegmentationDataset(split=kvs['metadata_train'], trf=train_augs, read_img=read_gs_ocv, read_mask=read_gs_mask_ocv) mean_vector, std_vector, class_weights = init_mean_std( snapshots_dir=kvs['args'].snapshots, dataset=dataset, batch_size=kvs['args'].bs, n_threads=kvs['args'].n_threads, n_classes=kvs['args'].n_classes) norm_trf = transforms.Normalize( torch.from_numpy(mean_vector).float(), torch.from_numpy(std_vector).float()) train_trf = transforms.Compose( [train_augs, partial(apply_by_index, transform=norm_trf, idx=0)]) val_trf = transforms.Compose([ partial(apply_by_index, transform=gs2tens, idx=[0, 1]), partial(apply_by_index, transform=norm_trf, idx=0) ]) kvs.update('class_weights', class_weights) kvs.update('train_trf', train_trf) kvs.update('val_trf', val_trf) kvs.save_pkl( os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def log_metrics(writer, train_loss, val_loss, conf_matrix): kvs = GlobalKVS() dices = { 'dice_{}'.format(cls): dice for cls, dice in enumerate(calculate_dice(conf_matrix)) } print(colored('==> ', 'green') + 'Metrics:') print(colored('====> ', 'green') + 'Train loss:', train_loss) print(colored('====> ', 'green') + 'Val loss:', val_loss) print(colored('====> ', 'green') + f'Val Dices: {dices}') dices_tb = {} for cls in range(1, len(dices)): dices_tb[f"Dice [{cls}]"] = dices[f"dice_{cls}"] to_log = {'train_loss': train_loss, 'val_loss': val_loss} # Tensorboard logging writer.add_scalars(f"Losses_{kvs['args'].model}", to_log, kvs['cur_epoch']) writer.add_scalars('Metrics', dices_tb, kvs['cur_epoch']) # KVS logging to_log.update({'epoch': kvs['cur_epoch']}) val_metrics = {'epoch': kvs['cur_epoch']} val_metrics.update(to_log) val_metrics.update(dices) val_metrics.update({'conf_matrix': conf_matrix}) kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', to_log) kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', val_metrics) kvs.save_pkl( os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def validate_epoch(net, val_loader, criterion): kvs = GlobalKVS() net.train(False) epoch = kvs['cur_epoch'] max_epoch = kvs['args'].n_epochs n_classes = kvs['args'].n_classes device = next(net.parameters()).device confusion_matrix = np.zeros((n_classes, n_classes), dtype=np.uint64) val_loss = 0 with torch.no_grad(): for entry in tqdm(val_loader, total=len(val_loader), desc=f"[{epoch} / {max_epoch}] Val: "): img = entry['img'].to(device) mask = entry['mask'].to(device).squeeze() preds = net(img) val_loss += criterion(preds, mask).item() mask = mask.to('cpu').numpy() if n_classes == 2: preds = (preds.to('cpu').numpy() > 0.5).astype(float) elif n_classes > 2: preds = preds.to('cpu').numpy().argmax(axis=1) else: raise ValueError confusion_matrix += metrics.calculate_confusion_matrix_from_arrays( preds, mask, n_classes) val_loss /= len(val_loader) return val_loss, confusion_matrix
def train_epoch(net, train_loader, optimizer, criterion): kvs = GlobalKVS() net.train(True) fold_id = kvs['cur_fold'] epoch = kvs['cur_epoch'] max_ep = kvs['args'].n_epochs running_loss = 0.0 n_batches = len(train_loader) device = next(net.parameters()).device pbar = tqdm(total=n_batches, ncols=200) for i, entry in enumerate(train_loader): inputs = entry['img'].to(device) mask = entry['mask'].to(device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, mask) loss.backward() optimizer.step() running_loss += loss.item() pbar.set_description( f"Fold [{fold_id}] [{epoch} | {max_ep}] | " f"Running loss {running_loss / (i + 1):.5f} / {loss.item():.5f}") pbar.update() gc.collect() gc.collect() pbar.close() return running_loss / n_batches
def init_loaders(x_train, x_val): kvs = GlobalKVS() train_dataset = SegmentationDataset(split=x_train, trf=kvs['train_trf'], read_img=read_gs_ocv, read_mask=read_gs_mask_ocv) val_dataset = SegmentationDataset(split=x_val, trf=kvs['val_trf'], read_img=read_gs_ocv, read_mask=read_gs_mask_ocv) train_loader = DataLoader(train_dataset, batch_size=kvs['args'].bs, num_workers=kvs['args'].n_threads, shuffle=True, drop_last=True, worker_init_fn=lambda wid: np.random.seed( np.uint32(torch.initial_seed() + wid))) val_loader = DataLoader(val_dataset, batch_size=kvs['args'].val_bs, num_workers=kvs['args'].n_threads) return train_loader, val_loader
def init_optimizer(net): kvs = GlobalKVS() if kvs['args'].optimizer == 'adam': return optim.Adam(net.parameters(), lr=kvs['args'].lr, weight_decay=kvs['args'].wd) elif kvs['args'].optimizer == 'sgd': return optim.SGD(net.parameters(), lr=kvs['args'].lr, weight_decay=kvs['args'].wd, momentum=0.9) else: raise NotImplementedError
def init_loss(): kvs = GlobalKVS() class_weights = kvs['class_weights'] if kvs['args'].n_classes == 2: if kvs['args'].loss == 'combined': return CombinedLoss([BCEWithLogitsLoss2d(), BinaryDiceLoss()]) elif kvs['args'].loss == 'bce': return BCEWithLogitsLoss2d() elif kvs['args'].loss == 'dice': return BinaryDiceLoss() elif kvs['args'].loss == 'wbce': return BCEWithLogitsLoss2d(weight=class_weights) else: raise NotImplementedError else: raise NotImplementedError
def init_model(): kvs = GlobalKVS() if kvs['args'].model == 'unet': net = UNet(bw=kvs['args'].bw, depth=kvs['args'].depth, center_depth=kvs['args'].cdepth, n_inputs=kvs['args'].n_inputs, n_classes=kvs['args'].n_classes - 1, activation='relu') if kvs['gpus'] > 1: net = nn.DataParallel(net).to('cuda') net = net.to('cuda') else: raise NotImplementedError return net
def init_train_augmentation_pipeline(): kvs = GlobalKVS() ppl = transforms.Compose([ img_mask2solt, slc.Stream([ slt.RandomFlip(axis=1, p=0.5), slt.ImageGammaCorrection(gamma_range=(0.5, 2), p=0.5), slt.PadTransform(pad_to=(kvs['args'].crop_x + 1, kvs['args'].crop_y + 1)), slt.CropTransform(crop_size=(kvs['args'].crop_x, kvs['args'].crop_y), crop_mode='r') ]), solt2img_mask, partial(apply_by_index, transform=gs2tens, idx=[0, 1]), ]) return ppl
def init_folds(): kvs = GlobalKVS() gkf = GroupKFold(kvs['args'].n_folds) cv_split = [] for fold_id, (train_ind, val_ind) in enumerate( gkf.split(X=kvs['metadata_train'], y=kvs['metadata_train'].grade, groups=kvs['metadata_train'].subject_id)): if kvs['args'].fold != -1 and fold_id != kvs['args'].fold: continue cv_split.append((fold_id, kvs['metadata_train'].iloc[train_ind], kvs['metadata_train'].iloc[val_ind])) kvs.update(f'losses_fold_[{fold_id}]', None, list) kvs.update(f'val_metrics_fold_[{fold_id}]', None, list) kvs.update('cv_split', cv_split)
def init_metadata(): kvs = GlobalKVS() imgs = glob.glob(os.path.join(kvs['args'].dataset, '*', 'imgs', '*.png')) imgs.sort(key=lambda x: x.split('/')[-1]) masks = glob.glob(os.path.join(kvs['args'].dataset, '*', 'masks', '*.png')) masks.sort(key=lambda x: x.split('/')[-1]) sample_id = list(map(lambda x: x.split('/')[-3], imgs)) subject_id = list(map(lambda x: x.split('/')[-3].split('_')[0], imgs)) metadata = pd.DataFrame( data={ 'img_fname': imgs, 'mask_fname': masks, 'sample_id': sample_id, 'subject_id': subject_id }) grades = pd.read_csv(kvs['args'].grades) n_subj = np.unique(metadata.subject_id.values).shape[0] if n_subj < kvs['args'].train_size: raise ValueError metadata = pd.merge(metadata, grades, on='sample_id') #gss = GroupShuffleSplit(n_splits=0, # train_size=kvs['args'].train_size, # test_size=n_subj-kvs['args'].train_size, # random_state=kvs['args'].seed) #train_ind, test_ind = next(gss.split(metadata, y=metadata.grade, groups=metadata.subject_id)) metadata_train = metadata.iloc[:] #metadata_test = metadata.iloc[test_ind] kvs.update('metadata_train', metadata_train) #kvs.update('metadata_test', metadata_test) kvs.save_pkl( os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl')) return metadata
def save_checkpoint(net, val_metric_name, comparator='lt'): if isinstance(net, torch.nn.DataParallel): net = net.module kvs = GlobalKVS() fold_id = kvs['cur_fold'] epoch = kvs['cur_epoch'] val_metric = kvs[f'val_metrics_fold_[{fold_id}]'][-1][0][val_metric_name] comparator = getattr(operator, comparator) cur_snapshot_name = os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], f'fold_{fold_id}_epoch_{epoch}.pth') if kvs['prev_model'] is None: print( colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name) torch.save(net.state_dict(), cur_snapshot_name) kvs.update('prev_model', cur_snapshot_name) kvs.update('best_val_metric', val_metric) else: if comparator(val_metric, kvs['best_val_metric']): print( colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name) os.remove(kvs['prev_model']) torch.save(net.state_dict(), cur_snapshot_name) kvs.update('prev_model', cur_snapshot_name) kvs.update('best_val_metric', val_metric) kvs.save_pkl( os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def init_session(): kvs = GlobalKVS() # Getting the arguments args = parse_args_train() # Initializing the seeds torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) # Creating the snapshot snapshot_name = time.strftime('%Y_%m_%d_%H_%M') os.makedirs(os.path.join(args.snapshots, snapshot_name), exist_ok=True) res = git_info() if res is not None: kvs.update('git branch name', res[0]) kvs.update('git commit id', res[1]) else: kvs.update('git branch name', None) kvs.update('git commit id', None) kvs.update('pytorch_version', torch.__version__) if torch.cuda.is_available(): kvs.update('cuda', torch.version.cuda) kvs.update('gpus', torch.cuda.device_count()) else: kvs.update('cuda', None) kvs.update('gpus', None) kvs.update('snapshot_name', snapshot_name) kvs.update('args', args) kvs.save_pkl(os.path.join(args.snapshots, snapshot_name, 'session.pkl')) return args, snapshot_name