コード例 #1
0
def init_data_processing():
    kvs = GlobalKVS()
    train_augs = init_train_augs()

    dataset = OAProgressionDataset(dataset=kvs['args'].dataset_root, split=kvs['metadata'], trf=train_augs)

    mean_vector, std_vector = init_mean_std(snapshots_dir=kvs['args'].snapshots,
                                            dataset=dataset, batch_size=kvs['args'].bs,
                                            n_threads=kvs['args'].n_threads)

    print(colored('====> ', 'red') + 'Mean:', mean_vector)
    print(colored('====> ', 'red') + 'Std:', std_vector)

    norm_trf = tv_transforms.Normalize(mean_vector.tolist(), std_vector.tolist())
    train_trf = tv_transforms.Compose([
        train_augs,
        partial(apply_by_index, transform=norm_trf, idx=0)
    ])

    val_trf = tv_transforms.Compose([
        img_labels2solt,
        slc.Stream([
            slt.ResizeTransform((310, 310)),
            slt.CropTransform(crop_size=(300, 300), crop_mode='c'),
            slt.ImageColorTransform(mode='gs2rgb'),
        ], interpolation='bicubic'),
        unpack_solt_data,
        partial(apply_by_index, transform=tv_transforms.ToTensor(), idx=0),
        partial(apply_by_index, transform=norm_trf, idx=0)
    ])

    kvs.update('train_trf', train_trf)
    kvs.update('val_trf', val_trf)
    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
コード例 #2
0
def save_checkpoint(model, val_metric_name, comparator='lt'):
    if isinstance(model, torch.nn.DataParallel):
        model = model.module

    kvs = GlobalKVS()
    fold_id = kvs['cur_fold']
    epoch = kvs['cur_epoch']
    val_metric = kvs[f'val_metrics_fold_[{fold_id}]'][-1][0][val_metric_name]
    comparator = getattr(operator, comparator)
    cur_snapshot_name = os.path.join(kvs['args'].snapshots, kvs['snapshot_name'],
                                     f'fold_{fold_id}_epoch_{epoch + 1}.pth')

    if kvs['prev_model'] is None:
        print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name)
        torch.save(model.state_dict(), cur_snapshot_name)
        kvs.update('prev_model', cur_snapshot_name)
        kvs.update('best_val_metric', val_metric)

    else:
        if comparator(val_metric, kvs['best_val_metric']):
            print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name)
            os.remove(kvs['prev_model'])
            torch.save(model.state_dict(), cur_snapshot_name)
            kvs.update('prev_model', cur_snapshot_name)
            kvs.update('best_val_metric', val_metric)

    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
コード例 #3
0
def init_session(snapshot_name_prefix=None):
    if not torch.cuda.is_available():
        raise EnvironmentError('The code must be run on GPU.')

    kvs = GlobalKVS()

    # Getting the arguments
    args = parse_args()
    kvs.update('args', args)

    # Initializing the seeds
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    # Creating the snapshot
    init_snapshot_dir(snapshot_name_prefix)

    kvs.update('pytorch_version', torch.__version__)

    if torch.cuda.is_available():
        kvs.update('cuda', torch.version.cuda)
        kvs.update('gpus', torch.cuda.device_count())
    else:
        kvs.update('cuda', None)
        kvs.update('gpus', None)

    kvs.save_pkl(os.path.join(args.snapshots, kvs['snapshot_name'], 'session.pkl'))

    return args, kvs['snapshot_name']
コード例 #4
0
def log_metrics_prog(boardlogger, train_loss, val_loss, gt_progression, preds_progression, gt_kl, preds_kl):
    kvs = GlobalKVS()

    res = testtools.calc_metrics(gt_progression, gt_kl, preds_progression, preds_kl)
    res['val_loss'] = val_loss,
    res['epoch'] = kvs['cur_epoch']

    print(colored('====> ', 'green') + f'Train loss: {train_loss:.5f}')
    print(colored('====> ', 'green') + f'Validation loss: {val_loss:.5f}')
    print(colored('====> ', 'green') + f'Validation AUC [prog]: {res["auc_prog"]:.5f}')
    print(colored('====> ', 'green') + f'Validation F1 @ 0.3 [prog]: {res["f1_score_03_prog"]:.5f}')
    print(colored('====> ', 'green') + f'Validation F1 @ 0.4 [prog]: {res["f1_score_04_prog"]:.5f}')
    print(colored('====> ', 'green') + f'Validation F1 @ 0.5 [prog]: {res["f1_score_05_prog"]:.5f}')
    print(colored('====> ', 'green') + f'Validation AP [prog]: {res["ap_prog"]:.5f}')

    print(colored('====> ', 'green') + f'Validation AUC [oa]: {res["auc_oa"]:.5f}')
    print(colored('====> ', 'green') + f'Kappa [oa]: {res["kappa_kl"]:.5f}')

    boardlogger.add_scalars('Losses', {'train': train_loss, 'val': val_loss}, kvs['cur_epoch'])
    boardlogger.add_scalars('AUC progression', {'val': res['auc_prog']}, kvs['cur_epoch'])
    boardlogger.add_scalars('F1-score @ 0.3 progression', {'val': res['f1_score_03_prog']}, kvs['cur_epoch'])
    boardlogger.add_scalars('F1-score @ 0.4 progression', {'val': res['f1_score_04_prog']}, kvs['cur_epoch'])
    boardlogger.add_scalars('F1-score @ 0.5 progression', {'val': res['f1_score_05_prog']}, kvs['cur_epoch'])
    boardlogger.add_scalars('Average Precision progression', {'val': res['ap_prog']}, kvs['cur_epoch'])

    kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', {'epoch': kvs['cur_epoch'],
                                                    'train_loss': train_loss,
                                                    'val_loss': val_loss})

    kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', res)

    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
コード例 #5
0
ファイル: dataset.py プロジェクト: hunglethanh9/OAProgression
def init_age_sex_bmi_metadata():
    kvs = GlobalKVS()

    oai_meta = pd.read_csv(
        os.path.join(kvs['args'].metadata_root, 'OAI_progression.csv'))
    clinical_data_oai = pd.read_csv(
        os.path.join(kvs['args'].metadata_root, 'OAI_participants.csv'))
    oai_meta = pd.merge(oai_meta, clinical_data_oai, on=('ID', 'Side'))
    oai_meta = oai_meta[~oai_meta.BMI.isna() & ~oai_meta.AGE.isna()
                        & ~oai_meta.SEX.isna()]

    clinical_data_most = pd.read_csv(
        os.path.join(kvs['args'].metadata_root, 'MOST_participants.csv'))
    metadata_test = pd.read_csv(
        os.path.join(kvs['args'].metadata_root, 'MOST_progression.csv'))
    metadata_test = pd.merge(metadata_test,
                             clinical_data_most,
                             on=('ID', 'Side'))

    kvs.update('metadata', oai_meta)
    kvs.update('metadata_test', metadata_test)
    gkf = GroupKFold(n_splits=5)
    cv_split = [
        x for x in gkf.split(kvs['metadata'], kvs['metadata'][
            kvs['args'].target_var], kvs['metadata']['ID'].astype(str))
    ]

    kvs.update('cv_split_all_folds', cv_split)
    kvs.save_pkl(
        os.path.join(kvs['args'].snapshots, kvs['snapshot_name'],
                     'session.pkl'))
コード例 #6
0
def init_folds(project='OA_progression'):
    kvs = GlobalKVS()
    writers = {}
    cv_split_train = {}
    for fold_id, split in enumerate(kvs['cv_split_all_folds']):
        if kvs['args'].fold != -1 and fold_id != kvs['args'].fold:
            continue
        kvs.update(f'losses_fold_[{fold_id}]', None, list)
        kvs.update(f'val_metrics_fold_[{fold_id}]', None, list)
        cv_split_train[fold_id] = split
        writers[fold_id] = SummaryWriter(os.path.join(kvs['args'].logs,
                                                      project,
                                                      'fold_{}'.format(fold_id), kvs['snapshot_name']))

    kvs.update('cv_split_train', cv_split_train)
    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
    return writers
コード例 #7
0
def init_progression_metadata():
    # We get should rid of non-progressors from MOST because we can check
    # non-progressors only up to 84 months
    kvs = GlobalKVS()

    most_meta = pd.read_csv(os.path.join(kvs['args'].metadata_root, 'MOST_progression.csv'))
    oai_meta = pd.read_csv(os.path.join(kvs['args'].metadata_root, 'OAI_progression.csv'))

    if kvs['args'].subsample_train != -1:
        n_train = oai_meta.shape[0]
        prevalence = (oai_meta.Progressor > 0).sum() / n_train
        sample_pos = int(kvs['args'].subsample_train * prevalence)
        sample_neg = kvs['args'].subsample_train - sample_pos
        train_pos = oai_meta[oai_meta.Progressor > 0]
        train_neg = oai_meta[oai_meta.Progressor == 0]

        pos_sampled = train_pos.iloc[np.random.choice(train_pos.shape[0], sample_pos)]
        neg_sampled = train_neg.iloc[np.random.choice(train_neg.shape[0], sample_neg)]

        new_meta = pd.concat((pos_sampled, neg_sampled))
        oai_meta = new_meta.iloc[np.random.choice(new_meta.shape[0], new_meta.shape[0])]
        print(colored("==> ", 'red') + f"Train set has been sub-sampled. New # pos/neg {sample_pos}/{sample_neg}")

    kvs.update('metadata', oai_meta)
    kvs.update('metadata_test', most_meta)

    gkf = GroupKFold(n_splits=5)
    cv_split = [x for x in gkf.split(kvs['metadata'],
                                     kvs['metadata']['Progressor'],
                                     kvs['metadata']['ID'].astype(str))]

    kvs.update('cv_split_all_folds', cv_split)
    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))

    print(colored("==> ", 'green') +
          f"Train dataset has {(kvs['metadata'].Progressor == 0).sum()} non-progressed knees")

    print(colored("==> ", 'green') +
          f"Train dataset has {(kvs['metadata'].Progressor > 0).sum()} progressed knees")

    print(colored("==> ", 'green') +
          f"Test dataset has {(kvs['metadata_test'].Progressor == 0).sum()} non-progressed knees")

    print(colored("==> ", 'green') +
          f"Test dataset has {(kvs['metadata_test'].Progressor > 0).sum()} progressed knees")
コード例 #8
0
ファイル: dataset.py プロジェクト: hunglethanh9/OAProgression
def init_progression_metadata():
    # We get should rid of non-progressors from MOST because we can check
    # non-progressors only up to 84 months
    kvs = GlobalKVS()

    most_meta = pd.read_csv(
        os.path.join(kvs['args'].metadata_root, 'MOST_progression.csv'))
    oai_meta = pd.read_csv(
        os.path.join(kvs['args'].metadata_root, 'OAI_progression.csv'))

    kvs.update('metadata', oai_meta)
    kvs.update('metadata_test', most_meta)

    gkf = GroupKFold(n_splits=5)
    cv_split = [
        x for x in gkf.split(kvs['metadata'], kvs['metadata']['Progressor'],
                             kvs['metadata']['ID'].astype(str))
    ]

    kvs.update('cv_split_all_folds', cv_split)
    kvs.save_pkl(
        os.path.join(kvs['args'].snapshots, kvs['snapshot_name'],
                     'session.pkl'))

    print(
        colored("==> ", 'green') + f"Train dataset has "
        f"{(kvs['metadata'].Progressor == 0).sum()} non-progressed knees")

    print(
        colored("==> ", 'green') + f"Train dataset has "
        f"{(kvs['metadata'].Progressor > 0).sum()} progressed knees")

    print(
        colored("==> ", 'green') + f"Test dataset has "
        f"{(kvs['metadata_test'].Progressor == 0).sum()} non-progressed knees")

    print(
        colored("==> ", 'green') + f"Test dataset has "
        f"{(kvs['metadata_test'].Progressor > 0).sum()} progressed knees")
コード例 #9
0
ファイル: session.py プロジェクト: hunglethanh9/OAProgression
def init_session():
    if not torch.cuda.is_available():
        raise EnvironmentError('The code must be run on GPU.')

    kvs = GlobalKVS()

    # Getting the arguments
    args = parse_args()
    # Initializing the seeds
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    # Creating the snapshot
    snapshot_name = time.strftime('%Y_%m_%d_%H_%M')
    os.makedirs(os.path.join(args.snapshots, snapshot_name), exist_ok=True)

    res = git_info()
    if res is not None:
        kvs.update('git branch name', res[0])
        kvs.update('git commit id', res[1])
    else:
        kvs.update('git branch name', None)
        kvs.update('git commit id', None)

    kvs.update('pytorch_version', torch.__version__)

    if torch.cuda.is_available():
        kvs.update('cuda', torch.version.cuda)
        kvs.update('gpus', torch.cuda.device_count())
    else:
        kvs.update('cuda', None)
        kvs.update('gpus', None)

    kvs.update('snapshot_name', snapshot_name)
    kvs.update('args', args)
    kvs.save_pkl(os.path.join(args.snapshots, snapshot_name, 'session.pkl'))

    return args, snapshot_name
コード例 #10
0
def log_metrics_age_sex_bmi(boardlogger, train_loss, val_res):
    kvs = GlobalKVS()
    res = dict()
    val_loss = val_res[0]
    res['val_loss'] = val_loss,
    res['epoch'] = kvs['cur_epoch']
    print(colored('====> ', 'green') + f'Train loss: {train_loss:.5f}')
    print(colored('====> ', 'green') + f'Validation loss: {val_loss:.5f}')
    boardlogger.add_scalars('Losses', {'train': train_loss, 'val': val_loss}, kvs['cur_epoch'])

    if not kvs['args'].predict_age_sex_bmi:
        _, ids, gt, preds = val_res
        if kvs['args'].target_var == 'SEX':
            val_auc = roc_auc_score(gt.astype(int), preds)
            res['sex_auc'] = val_auc
            print(colored('====> ', 'green') + f'Validation AUC: {val_auc:.5f}')
            boardlogger.add_scalars('AUC sex', {'val': res['sex_auc']}, kvs['cur_epoch'])
        else:
            val_mse = mean_squared_error(gt, preds)
            val_mae = median_absolute_error(gt, preds)
            res[f"{kvs['args'].target_var}_mse"] = val_mse
            res[f"{kvs['args'].target_var}_mae"] = val_mae

            print(colored('====> ', 'green') + f'Validation mae: {val_mae:.5f}')
            print(colored('====> ', 'green') + f'Validation mse: {val_mse:.5f}')

            boardlogger.add_scalars(f"MSE [{kvs['args'].target_var}]", {'val': val_mse},
                                    kvs['cur_epoch'])
            boardlogger.add_scalars(f"MAE [{kvs['args'].target_var}]", {'val': val_mae},
                                    kvs['cur_epoch'])
    else:
        _, ids, gt_age, preds_age, gt_sex, preds_sex, gt_bmi, preds_bmi = val_res
        val_mse_age = mean_squared_error(gt_age, preds_age)
        val_mae_age = median_absolute_error(gt_age, preds_age)
        val_sex_auc = roc_auc_score(gt_sex.astype(int), preds_sex)
        val_mse_bmi = mean_squared_error(gt_bmi, preds_bmi)
        val_mae_bmi = median_absolute_error(gt_bmi, preds_bmi)

        res["AGE_mse"] = val_mse_age
        res["AGE_mae"] = val_mae_age

        res["BMI_mse"] = val_mse_bmi
        res["BMI_mae"] = val_mae_bmi

        res["SEX_auc"] = val_sex_auc

        print(colored('====> ', 'green') + f'Validation mae [Age]: {val_mae_age:.5f}')
        print(colored('====> ', 'green') + f'Validation mse [Age]: {val_mse_age:.5f}')

        print(colored('====> ', 'green') + f'Validation val_auc [Sex]: {val_sex_auc:.5f}')

        print(colored('====> ', 'green') + f'Validation mae [BMI]: {val_mae_bmi:.5f}')
        print(colored('====> ', 'green') + f'Validation mse [BMI]: {val_mse_bmi:.5f}')

    kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', {'epoch': kvs['cur_epoch'],
                                                    'train_loss': train_loss,
                                                    'val_loss': val_loss})

    kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', res)

    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))