def init_snapshot_dir(snapshot_name_prefix=None): kvs = GlobalKVS() snapshot_name = time.strftime('%Y_%m_%d_%H_%M') if snapshot_name_prefix is not None: snapshot_name = f'{snapshot_name_prefix}_{snapshot_name}' os.makedirs(os.path.join(kvs['args'].snapshots, snapshot_name), exist_ok=True) kvs.update('snapshot_name', snapshot_name)
def init_age_sex_bmi_metadata(): kvs = GlobalKVS() oai_meta = pd.read_csv( os.path.join(kvs['args'].metadata_root, 'OAI_progression.csv')) clinical_data_oai = pd.read_csv( os.path.join(kvs['args'].metadata_root, 'OAI_participants.csv')) oai_meta = pd.merge(oai_meta, clinical_data_oai, on=('ID', 'Side')) oai_meta = oai_meta[~oai_meta.BMI.isna() & ~oai_meta.AGE.isna() & ~oai_meta.SEX.isna()] clinical_data_most = pd.read_csv( os.path.join(kvs['args'].metadata_root, 'MOST_participants.csv')) metadata_test = pd.read_csv( os.path.join(kvs['args'].metadata_root, 'MOST_progression.csv')) metadata_test = pd.merge(metadata_test, clinical_data_most, on=('ID', 'Side')) kvs.update('metadata', oai_meta) kvs.update('metadata_test', metadata_test) gkf = GroupKFold(n_splits=5) cv_split = [ x for x in gkf.split(kvs['metadata'], kvs['metadata'][ kvs['args'].target_var], kvs['metadata']['ID'].astype(str)) ] kvs.update('cv_split_all_folds', cv_split) kvs.save_pkl( os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def log_metrics_prog(boardlogger, train_loss, val_loss, gt_progression, preds_progression, gt_kl, preds_kl): kvs = GlobalKVS() res = testtools.calc_metrics(gt_progression, gt_kl, preds_progression, preds_kl) res['val_loss'] = val_loss, res['epoch'] = kvs['cur_epoch'] print(colored('====> ', 'green') + f'Train loss: {train_loss:.5f}') print(colored('====> ', 'green') + f'Validation loss: {val_loss:.5f}') print(colored('====> ', 'green') + f'Validation AUC [prog]: {res["auc_prog"]:.5f}') print(colored('====> ', 'green') + f'Validation F1 @ 0.3 [prog]: {res["f1_score_03_prog"]:.5f}') print(colored('====> ', 'green') + f'Validation F1 @ 0.4 [prog]: {res["f1_score_04_prog"]:.5f}') print(colored('====> ', 'green') + f'Validation F1 @ 0.5 [prog]: {res["f1_score_05_prog"]:.5f}') print(colored('====> ', 'green') + f'Validation AP [prog]: {res["ap_prog"]:.5f}') print(colored('====> ', 'green') + f'Validation AUC [oa]: {res["auc_oa"]:.5f}') print(colored('====> ', 'green') + f'Kappa [oa]: {res["kappa_kl"]:.5f}') boardlogger.add_scalars('Losses', {'train': train_loss, 'val': val_loss}, kvs['cur_epoch']) boardlogger.add_scalars('AUC progression', {'val': res['auc_prog']}, kvs['cur_epoch']) boardlogger.add_scalars('F1-score @ 0.3 progression', {'val': res['f1_score_03_prog']}, kvs['cur_epoch']) boardlogger.add_scalars('F1-score @ 0.4 progression', {'val': res['f1_score_04_prog']}, kvs['cur_epoch']) boardlogger.add_scalars('F1-score @ 0.5 progression', {'val': res['f1_score_05_prog']}, kvs['cur_epoch']) boardlogger.add_scalars('Average Precision progression', {'val': res['ap_prog']}, kvs['cur_epoch']) kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', {'epoch': kvs['cur_epoch'], 'train_loss': train_loss, 'val_loss': val_loss}) kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', res) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def init_data_processing(): kvs = GlobalKVS() train_augs = init_train_augs() dataset = OAProgressionDataset(dataset=kvs['args'].dataset_root, split=kvs['metadata'], trf=train_augs) mean_vector, std_vector = init_mean_std(snapshots_dir=kvs['args'].snapshots, dataset=dataset, batch_size=kvs['args'].bs, n_threads=kvs['args'].n_threads) print(colored('====> ', 'red') + 'Mean:', mean_vector) print(colored('====> ', 'red') + 'Std:', std_vector) norm_trf = tv_transforms.Normalize(mean_vector.tolist(), std_vector.tolist()) train_trf = tv_transforms.Compose([ train_augs, partial(apply_by_index, transform=norm_trf, idx=0) ]) val_trf = tv_transforms.Compose([ img_labels2solt, slc.Stream([ slt.ResizeTransform((310, 310)), slt.CropTransform(crop_size=(300, 300), crop_mode='c'), slt.ImageColorTransform(mode='gs2rgb'), ], interpolation='bicubic'), unpack_solt_data, partial(apply_by_index, transform=tv_transforms.ToTensor(), idx=0), partial(apply_by_index, transform=norm_trf, idx=0) ]) kvs.update('train_trf', train_trf) kvs.update('val_trf', val_trf) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def init_folds(project='OA_progression'): kvs = GlobalKVS() writers = {} cv_split_train = {} for fold_id, split in enumerate(kvs['cv_split_all_folds']): if kvs['args'].fold != -1 and fold_id != kvs['args'].fold: continue kvs.update(f'losses_fold_[{fold_id}]', None, list) kvs.update(f'val_metrics_fold_[{fold_id}]', None, list) cv_split_train[fold_id] = split writers[fold_id] = SummaryWriter(os.path.join(kvs['args'].logs, project, 'fold_{}'.format(fold_id), kvs['snapshot_name'])) kvs.update('cv_split_train', cv_split_train) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl')) return writers
def init_optimizer(parameters): kvs = GlobalKVS() if kvs['args'].optimizer == 'adam': return optim.Adam(parameters, lr=kvs['args'].lr, weight_decay=kvs['args'].wd) elif kvs['args'].optimizer == 'sgd': return optim.SGD(parameters, lr=kvs['args'].lr, weight_decay=kvs['args'].wd, momentum=0.9) else: raise NotImplementedError
def init_progression_metadata(): # We get should rid of non-progressors from MOST because we can check # non-progressors only up to 84 months kvs = GlobalKVS() most_meta = pd.read_csv(os.path.join(kvs['args'].metadata_root, 'MOST_progression.csv')) oai_meta = pd.read_csv(os.path.join(kvs['args'].metadata_root, 'OAI_progression.csv')) if kvs['args'].subsample_train != -1: n_train = oai_meta.shape[0] prevalence = (oai_meta.Progressor > 0).sum() / n_train sample_pos = int(kvs['args'].subsample_train * prevalence) sample_neg = kvs['args'].subsample_train - sample_pos train_pos = oai_meta[oai_meta.Progressor > 0] train_neg = oai_meta[oai_meta.Progressor == 0] pos_sampled = train_pos.iloc[np.random.choice(train_pos.shape[0], sample_pos)] neg_sampled = train_neg.iloc[np.random.choice(train_neg.shape[0], sample_neg)] new_meta = pd.concat((pos_sampled, neg_sampled)) oai_meta = new_meta.iloc[np.random.choice(new_meta.shape[0], new_meta.shape[0])] print(colored("==> ", 'red') + f"Train set has been sub-sampled. New # pos/neg {sample_pos}/{sample_neg}") kvs.update('metadata', oai_meta) kvs.update('metadata_test', most_meta) gkf = GroupKFold(n_splits=5) cv_split = [x for x in gkf.split(kvs['metadata'], kvs['metadata']['Progressor'], kvs['metadata']['ID'].astype(str))] kvs.update('cv_split_all_folds', cv_split) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl')) print(colored("==> ", 'green') + f"Train dataset has {(kvs['metadata'].Progressor == 0).sum()} non-progressed knees") print(colored("==> ", 'green') + f"Train dataset has {(kvs['metadata'].Progressor > 0).sum()} progressed knees") print(colored("==> ", 'green') + f"Test dataset has {(kvs['metadata_test'].Progressor == 0).sum()} non-progressed knees") print(colored("==> ", 'green') + f"Test dataset has {(kvs['metadata_test'].Progressor > 0).sum()} progressed knees")
def init_epoch_pass(net, optimizer, loader): kvs = GlobalKVS() net.train(optimizer is not None) running_loss = 0.0 n_batches = len(loader) pbar = tqdm(total=n_batches) epoch = kvs['cur_epoch'] max_epoch = kvs['args'].n_epochs device = next(net.parameters()).device return running_loss, pbar, n_batches, epoch, max_epoch, device
def init_progression_metadata(): # We get should rid of non-progressors from MOST because we can check # non-progressors only up to 84 months kvs = GlobalKVS() most_meta = pd.read_csv( os.path.join(kvs['args'].metadata_root, 'MOST_progression.csv')) oai_meta = pd.read_csv( os.path.join(kvs['args'].metadata_root, 'OAI_progression.csv')) kvs.update('metadata', oai_meta) kvs.update('metadata_test', most_meta) gkf = GroupKFold(n_splits=5) cv_split = [ x for x in gkf.split(kvs['metadata'], kvs['metadata']['Progressor'], kvs['metadata']['ID'].astype(str)) ] kvs.update('cv_split_all_folds', cv_split) kvs.save_pkl( os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl')) print( colored("==> ", 'green') + f"Train dataset has " f"{(kvs['metadata'].Progressor == 0).sum()} non-progressed knees") print( colored("==> ", 'green') + f"Train dataset has " f"{(kvs['metadata'].Progressor > 0).sum()} progressed knees") print( colored("==> ", 'green') + f"Test dataset has " f"{(kvs['metadata_test'].Progressor == 0).sum()} non-progressed knees") print( colored("==> ", 'green') + f"Test dataset has " f"{(kvs['metadata_test'].Progressor > 0).sum()} progressed knees")
def train_folds(writers): kvs = GlobalKVS() for fold_id in kvs['cv_split_train']: kvs.update('cur_fold', fold_id) kvs.update('prev_model', None) print(colored('====> ', 'blue') + f'Training fold {fold_id}....') train_index, val_index = kvs['cv_split_train'][fold_id] train_loader, val_loader = session.init_loaders( kvs['metadata'].iloc[train_index], kvs['metadata'].iloc[val_index]) net = init_model() optimizer = init_optimizer([{ 'params': net.module.classifier_kl.parameters() }, { 'params': net.module.classifier_prog.parameters() }]) scheduler = MultiStepLR(optimizer, milestones=kvs['args'].lr_drop, gamma=0.1) for epoch in range(kvs['args'].n_epochs): kvs.update('cur_epoch', epoch) if epoch == kvs['args'].unfreeze_epoch: print(colored('====> ', 'red') + 'Unfreezing the layers!') new_lr_drop_milestones = list( map(lambda x: x - kvs['args'].unfreeze_epoch, kvs['args'].lr_drop)) optimizer.add_param_group( {'params': net.module.features.parameters()}) scheduler = MultiStepLR(optimizer, milestones=new_lr_drop_milestones, gamma=0.1) print(colored('====> ', 'red') + 'LR:', scheduler.get_lr()) train_loss = prog_epoch_pass(net, optimizer, train_loader) val_out = prog_epoch_pass(net, None, val_loader) val_loss, val_ids, gt_progression, preds_progression, gt_kl, preds_kl = val_out log_metrics_prog(writers[fold_id], train_loss, val_loss, gt_progression, preds_progression, gt_kl, preds_kl) session.save_checkpoint(net, 'ap_prog', 'gt') scheduler.step()
def debug_augmentations(n_iter=20): kvs = GlobalKVS() ds = OAProgressionDataset(dataset=kvs['args'].dataset_root, split=kvs['metadata'], trf=init_train_augs()) for ind in np.random.choice(len(ds), n_iter, replace=False): sample = ds[ind] img = np.clip(sample['img'].numpy() * 255, 0, 255).astype(np.uint8) img = np.swapaxes(img, 0, -1) img = np.swapaxes(img, 0, 1) plt.figure() plt.imshow(img) plt.show()
def init_model(kneenet=True): kvs = GlobalKVS() if kneenet: net = KneeNet(kvs['args'].backbone, kvs['args'].dropout_rate) else: if not kvs['args'].predict_age_sex_bmi: net = PretrainedModel(kvs['args'].backbone, kvs['args'].dropout_rate, 1, True) else: net = PretrainedModel(kvs['args'].backbone, kvs['args'].dropout_rate, 3, True) if kvs['gpus'] > 1: net = nn.DataParallel(net).to('cuda') net = net.to('cuda') return net
def init_loaders(x_train, x_val, progression=True): kvs = GlobalKVS() ds = OAProgressionDataset if not progression: ds = AgeSexBMIDataset train_dataset = ds(dataset=kvs['args'].dataset_root, split=x_train, trf=kvs['train_trf']) val_dataset = ds(dataset=kvs['args'].dataset_root, split=x_val, trf=kvs['val_trf']) train_loader = DataLoader(train_dataset, batch_size=kvs['args'].bs, num_workers=kvs['args'].n_threads, drop_last=True, shuffle=True, worker_init_fn=lambda wid: np.random.seed(np.uint32(torch.initial_seed() + wid))) val_loader = DataLoader(val_dataset, batch_size=kvs['args'].val_bs, num_workers=kvs['args'].n_threads) return train_loader, val_loader
def save_checkpoint(model, val_metric_name, comparator='lt'): if isinstance(model, torch.nn.DataParallel): model = model.module kvs = GlobalKVS() fold_id = kvs['cur_fold'] epoch = kvs['cur_epoch'] val_metric = kvs[f'val_metrics_fold_[{fold_id}]'][-1][0][val_metric_name] comparator = getattr(operator, comparator) cur_snapshot_name = os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], f'fold_{fold_id}_epoch_{epoch + 1}.pth') if kvs['prev_model'] is None: print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name) torch.save(model.state_dict(), cur_snapshot_name) kvs.update('prev_model', cur_snapshot_name) kvs.update('best_val_metric', val_metric) else: if comparator(val_metric, kvs['best_val_metric']): print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name) os.remove(kvs['prev_model']) torch.save(model.state_dict(), cur_snapshot_name) kvs.update('prev_model', cur_snapshot_name) kvs.update('best_val_metric', val_metric) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
cv2.ocl.setUseOpenCL(False) cv2.setNumThreads(0) DEBUG = sys.gettrace() is not None sides = [None, 'R', 'L'] JSW_features = [ 'V00JSW150', 'V00JSW175', 'V00JSW200', 'V00JSW225', 'V00JSW250', 'V00JSW275', 'V00JSW300', 'V00LJSW700', 'V00LJSW725', 'V00LJSW750', 'V00LJSW775', 'V00LJSW800', 'V00LJSW825', 'V00LJSW850', 'V00LJSW875', 'V00LJSW900' ] if __name__ == "__main__": kvs = GlobalKVS() session.init_session() sites, metadata = read_jsw_metadata_oai(kvs['args'].metadata_root, kvs['args'].oai_data_root) base_snapshot_name = kvs['snapshot_name'] for test_site in sites: # Creating a sub-snapshot for every site in OAI os.makedirs(os.path.join(kvs['args'].snapshots, base_snapshot_name, f'site_{test_site}'), exist_ok=True) kvs.update('snapshot_name', os.path.join(base_snapshot_name, f'site_{test_site}')) # Splitting the data to exclude the current site from training and keeping it only top_subj_train = metadata[metadata.V00SITE != test_site] top_subj_test = metadata[metadata.V00SITE == test_site]
def init_session(): if not torch.cuda.is_available(): raise EnvironmentError('The code must be run on GPU.') kvs = GlobalKVS() # Getting the arguments args = parse_args() # Initializing the seeds torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) # Creating the snapshot snapshot_name = time.strftime('%Y_%m_%d_%H_%M') os.makedirs(os.path.join(args.snapshots, snapshot_name), exist_ok=True) res = git_info() if res is not None: kvs.update('git branch name', res[0]) kvs.update('git commit id', res[1]) else: kvs.update('git branch name', None) kvs.update('git commit id', None) kvs.update('pytorch_version', torch.__version__) if torch.cuda.is_available(): kvs.update('cuda', torch.version.cuda) kvs.update('gpus', torch.cuda.device_count()) else: kvs.update('cuda', None) kvs.update('gpus', None) kvs.update('snapshot_name', snapshot_name) kvs.update('args', args) kvs.save_pkl(os.path.join(args.snapshots, snapshot_name, 'session.pkl')) return args, snapshot_name
def prog_epoch_pass(net, optimizer, loader): kvs = GlobalKVS() running_loss, pbar, n_batches, epoch, max_epoch, device = init_epoch_pass(net, optimizer, loader) preds_progression = [] gt_progression = [] ids = [] preds_kl = [] gt_kl = [] with torch.set_grad_enabled(optimizer is not None): for i, batch in enumerate(loader): if optimizer is not None: optimizer.zero_grad() # forward + backward + optimize if train labels_prog = batch['label'].long().to(device) labels_kl = batch['KL'].long().to(device) inputs = batch['img'].to(device) outputs_kl, outputs_prog = net(inputs) loss_kl = F.cross_entropy(outputs_kl, labels_kl) loss_prog = F.cross_entropy(outputs_prog, labels_prog) loss = loss_prog.mul(kvs['args'].loss_weight) + loss_kl.mul(1 - kvs['args'].loss_weight) if optimizer is not None: loss.backward() if kvs['args'].clip_grad: torch.nn.utils.clip_grad_norm_(net.parameters(), kvs['args'].clip_grad_norm) optimizer.step() else: probs_progression_batch = F.softmax(outputs_prog, 1).data.to('cpu').numpy() probs_kl_batch = F.softmax(outputs_kl, 1).data.to('cpu').numpy() preds_progression.append(probs_progression_batch) gt_progression.append(batch['label'].numpy()) preds_kl.append(probs_kl_batch) gt_kl.append(batch['KL']) ids.extend(batch['ID_SIDE']) running_loss += loss.item() if optimizer is not None: pbar.set_description(f'Training [{epoch} / {max_epoch}]:: {running_loss / (i + 1):.3f}') else: pbar.set_description(f'Validating [{epoch} / {max_epoch}]:') pbar.update() gc.collect() if optimizer is None: preds_progression = np.vstack(preds_progression) gt_progression = np.hstack(gt_progression) preds_kl = np.vstack(preds_kl) gt_kl = np.hstack(gt_kl) gc.collect() pbar.close() if optimizer is not None: return running_loss / n_batches else: return running_loss / n_batches, ids, gt_progression, preds_progression, gt_kl, preds_kl
def log_metrics_age_sex_bmi(boardlogger, train_loss, val_res): kvs = GlobalKVS() res = dict() val_loss = val_res[0] res['val_loss'] = val_loss, res['epoch'] = kvs['cur_epoch'] print(colored('====> ', 'green') + f'Train loss: {train_loss:.5f}') print(colored('====> ', 'green') + f'Validation loss: {val_loss:.5f}') boardlogger.add_scalars('Losses', {'train': train_loss, 'val': val_loss}, kvs['cur_epoch']) if not kvs['args'].predict_age_sex_bmi: _, ids, gt, preds = val_res if kvs['args'].target_var == 'SEX': val_auc = roc_auc_score(gt.astype(int), preds) res['sex_auc'] = val_auc print(colored('====> ', 'green') + f'Validation AUC: {val_auc:.5f}') boardlogger.add_scalars('AUC sex', {'val': res['sex_auc']}, kvs['cur_epoch']) else: val_mse = mean_squared_error(gt, preds) val_mae = median_absolute_error(gt, preds) res[f"{kvs['args'].target_var}_mse"] = val_mse res[f"{kvs['args'].target_var}_mae"] = val_mae print(colored('====> ', 'green') + f'Validation mae: {val_mae:.5f}') print(colored('====> ', 'green') + f'Validation mse: {val_mse:.5f}') boardlogger.add_scalars(f"MSE [{kvs['args'].target_var}]", {'val': val_mse}, kvs['cur_epoch']) boardlogger.add_scalars(f"MAE [{kvs['args'].target_var}]", {'val': val_mae}, kvs['cur_epoch']) else: _, ids, gt_age, preds_age, gt_sex, preds_sex, gt_bmi, preds_bmi = val_res val_mse_age = mean_squared_error(gt_age, preds_age) val_mae_age = median_absolute_error(gt_age, preds_age) val_sex_auc = roc_auc_score(gt_sex.astype(int), preds_sex) val_mse_bmi = mean_squared_error(gt_bmi, preds_bmi) val_mae_bmi = median_absolute_error(gt_bmi, preds_bmi) res["AGE_mse"] = val_mse_age res["AGE_mae"] = val_mae_age res["BMI_mse"] = val_mse_bmi res["BMI_mae"] = val_mae_bmi res["SEX_auc"] = val_sex_auc print(colored('====> ', 'green') + f'Validation mae [Age]: {val_mae_age:.5f}') print(colored('====> ', 'green') + f'Validation mse [Age]: {val_mse_age:.5f}') print(colored('====> ', 'green') + f'Validation val_auc [Sex]: {val_sex_auc:.5f}') print(colored('====> ', 'green') + f'Validation mae [BMI]: {val_mae_bmi:.5f}') print(colored('====> ', 'green') + f'Validation mse [BMI]: {val_mse_bmi:.5f}') kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', {'epoch': kvs['cur_epoch'], 'train_loss': train_loss, 'val_loss': val_loss}) kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', res) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def init_session(snapshot_name_prefix=None): if not torch.cuda.is_available(): raise EnvironmentError('The code must be run on GPU.') kvs = GlobalKVS() # Getting the arguments args = parse_args() kvs.update('args', args) # Initializing the seeds torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) # Creating the snapshot init_snapshot_dir(snapshot_name_prefix) kvs.update('pytorch_version', torch.__version__) if torch.cuda.is_available(): kvs.update('cuda', torch.version.cuda) kvs.update('gpus', torch.cuda.device_count()) else: kvs.update('cuda', None) kvs.update('gpus', None) kvs.save_pkl(os.path.join(args.snapshots, kvs['snapshot_name'], 'session.pkl')) return args, kvs['snapshot_name']
def epoch_pass(net, optimizer, loader): kvs = GlobalKVS() running_loss, pbar, n_batches, epoch, max_epoch, device = init_epoch_pass(net, optimizer, loader) if kvs['args'].target_var == 'SEX': criterion = F.binary_cross_entropy_with_logits else: criterion = F.mse_loss # Individual factors prediction preds = list() gt = list() # Predicting Age, Sex, BMI preds_age = list() preds_sex = list() preds_bmi = list() gt_age = list() gt_sex = list() gt_bmi = list() ids = list() with torch.set_grad_enabled(optimizer is not None): for i, batch in enumerate(loader): if optimizer is not None: optimizer.zero_grad() inp = batch['img'].to(device) output = net(inp).squeeze() if not kvs['args'].predict_age_sex_bmi: target = batch[kvs['args'].target_var].float().to(device) loss = criterion(output, target) else: target_age = batch['AGE'].float().to(device) target_sex = batch['SEX'].float().to(device) target_bmi = batch['BMI'].float().to(device) loss_age = F.mse_loss(output[:, 0].squeeze(), target_age) loss_sex = F.binary_cross_entropy_with_logits(output[:, 1].squeeze(), target_sex) loss_bmi = F.mse_loss(output[:, 2].squeeze(), target_bmi) loss = loss_age + loss_sex + loss_bmi if optimizer is not None: loss.backward() if kvs['args'].clip_grad: torch.nn.utils.clip_grad_norm_(net.parameters(), kvs['args'].clip_grad_norm) optimizer.step() else: if not kvs['args'].predict_age_sex_bmi: if kvs['args'].target_var == 'SEX': pred_batch = torch.sigmoid(output).data.to('cpu').numpy().squeeze() else: pred_batch = output.data.to('cpu').numpy().squeeze() preds.append(pred_batch) gt.append(batch[kvs['args'].target_var].numpy().squeeze()) else: preds_age_batch = output[:, 0].data.to('cpu').numpy().squeeze() preds_sex_batch = torch.sigmoid(output[:, 1]).data.to('cpu').numpy().squeeze() preds_bmi_batch = output[:, 2].data.to('cpu').numpy().squeeze() preds_age.append(preds_age_batch) preds_sex.append(preds_sex_batch) preds_bmi.append(preds_bmi_batch) gt_age.append(batch['AGE'].numpy().squeeze()) gt_sex.append(batch['SEX'].numpy().squeeze()) gt_bmi.append(batch['BMI'].numpy().squeeze()) ids.extend(batch['ID_SIDE']) running_loss += loss.item() if optimizer is not None: pbar.set_description(f'Training [{epoch} / {max_epoch}]:: {running_loss / (i + 1):.3f}') else: pbar.set_description(f'Validating [{epoch} / {max_epoch}]:') pbar.update() gc.collect() gc.collect() pbar.close() if optimizer is not None: return running_loss / n_batches else: if not kvs['args'].predict_age_sex_bmi: preds = np.hstack(preds) gt = np.hstack(gt) return running_loss / n_batches, ids, gt, preds else: preds_age = np.hstack(preds_age) gt_age = np.hstack(gt_age) preds_sex = np.hstack(preds_sex) gt_sex = np.hstack(gt_sex) preds_bmi = np.hstack(preds_bmi) gt_bmi = np.hstack(gt_bmi) return running_loss / n_batches, ids, gt_age, preds_age, gt_sex, preds_sex, gt_bmi, preds_bmi
import cv2 from oaprogression.kvs import GlobalKVS from oaprogression.training import dataset from oaprogression.training import session from oaprogression.training import train_utils cv2.ocl.setUseOpenCL(False) cv2.setNumThreads(0) if __name__ == "__main__": kvs = GlobalKVS() session.init_session() dataset.init_progression_metadata() session.init_data_processing() writers = session.init_folds() train_utils.train_folds(writers)
import cv2 from termcolor import colored from torch.optim.lr_scheduler import MultiStepLR from oaprogression.kvs import GlobalKVS from oaprogression.training import dataset from oaprogression.training import session from oaprogression.training import train_utils cv2.ocl.setUseOpenCL(False) cv2.setNumThreads(0) DEBUG = sys.gettrace() is not None if __name__ == "__main__": kvs = GlobalKVS() session.init_session() dataset.init_progression_metadata() session.init_data_processing() writers = session.init_folds() if DEBUG: dataset.debug_augmentations() for fold_id in kvs['cv_split_train']: kvs.update('cur_fold', fold_id) kvs.update('prev_model', None) print(colored('====> ', 'blue') + f'Training fold {fold_id}....') train_index, val_index = kvs['cv_split_train'][fold_id] train_loader, val_loader = session.init_loaders(kvs['metadata'].iloc[train_index],