Exemplo n.º 1
0
def init_snapshot_dir(snapshot_name_prefix=None):
    kvs = GlobalKVS()
    snapshot_name = time.strftime('%Y_%m_%d_%H_%M')
    if snapshot_name_prefix is not None:
        snapshot_name = f'{snapshot_name_prefix}_{snapshot_name}'
    os.makedirs(os.path.join(kvs['args'].snapshots, snapshot_name), exist_ok=True)
    kvs.update('snapshot_name', snapshot_name)
Exemplo n.º 2
0
def init_age_sex_bmi_metadata():
    kvs = GlobalKVS()

    oai_meta = pd.read_csv(
        os.path.join(kvs['args'].metadata_root, 'OAI_progression.csv'))
    clinical_data_oai = pd.read_csv(
        os.path.join(kvs['args'].metadata_root, 'OAI_participants.csv'))
    oai_meta = pd.merge(oai_meta, clinical_data_oai, on=('ID', 'Side'))
    oai_meta = oai_meta[~oai_meta.BMI.isna() & ~oai_meta.AGE.isna()
                        & ~oai_meta.SEX.isna()]

    clinical_data_most = pd.read_csv(
        os.path.join(kvs['args'].metadata_root, 'MOST_participants.csv'))
    metadata_test = pd.read_csv(
        os.path.join(kvs['args'].metadata_root, 'MOST_progression.csv'))
    metadata_test = pd.merge(metadata_test,
                             clinical_data_most,
                             on=('ID', 'Side'))

    kvs.update('metadata', oai_meta)
    kvs.update('metadata_test', metadata_test)
    gkf = GroupKFold(n_splits=5)
    cv_split = [
        x for x in gkf.split(kvs['metadata'], kvs['metadata'][
            kvs['args'].target_var], kvs['metadata']['ID'].astype(str))
    ]

    kvs.update('cv_split_all_folds', cv_split)
    kvs.save_pkl(
        os.path.join(kvs['args'].snapshots, kvs['snapshot_name'],
                     'session.pkl'))
Exemplo n.º 3
0
def log_metrics_prog(boardlogger, train_loss, val_loss, gt_progression, preds_progression, gt_kl, preds_kl):
    kvs = GlobalKVS()

    res = testtools.calc_metrics(gt_progression, gt_kl, preds_progression, preds_kl)
    res['val_loss'] = val_loss,
    res['epoch'] = kvs['cur_epoch']

    print(colored('====> ', 'green') + f'Train loss: {train_loss:.5f}')
    print(colored('====> ', 'green') + f'Validation loss: {val_loss:.5f}')
    print(colored('====> ', 'green') + f'Validation AUC [prog]: {res["auc_prog"]:.5f}')
    print(colored('====> ', 'green') + f'Validation F1 @ 0.3 [prog]: {res["f1_score_03_prog"]:.5f}')
    print(colored('====> ', 'green') + f'Validation F1 @ 0.4 [prog]: {res["f1_score_04_prog"]:.5f}')
    print(colored('====> ', 'green') + f'Validation F1 @ 0.5 [prog]: {res["f1_score_05_prog"]:.5f}')
    print(colored('====> ', 'green') + f'Validation AP [prog]: {res["ap_prog"]:.5f}')

    print(colored('====> ', 'green') + f'Validation AUC [oa]: {res["auc_oa"]:.5f}')
    print(colored('====> ', 'green') + f'Kappa [oa]: {res["kappa_kl"]:.5f}')

    boardlogger.add_scalars('Losses', {'train': train_loss, 'val': val_loss}, kvs['cur_epoch'])
    boardlogger.add_scalars('AUC progression', {'val': res['auc_prog']}, kvs['cur_epoch'])
    boardlogger.add_scalars('F1-score @ 0.3 progression', {'val': res['f1_score_03_prog']}, kvs['cur_epoch'])
    boardlogger.add_scalars('F1-score @ 0.4 progression', {'val': res['f1_score_04_prog']}, kvs['cur_epoch'])
    boardlogger.add_scalars('F1-score @ 0.5 progression', {'val': res['f1_score_05_prog']}, kvs['cur_epoch'])
    boardlogger.add_scalars('Average Precision progression', {'val': res['ap_prog']}, kvs['cur_epoch'])

    kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', {'epoch': kvs['cur_epoch'],
                                                    'train_loss': train_loss,
                                                    'val_loss': val_loss})

    kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', res)

    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
Exemplo n.º 4
0
def init_data_processing():
    kvs = GlobalKVS()
    train_augs = init_train_augs()

    dataset = OAProgressionDataset(dataset=kvs['args'].dataset_root, split=kvs['metadata'], trf=train_augs)

    mean_vector, std_vector = init_mean_std(snapshots_dir=kvs['args'].snapshots,
                                            dataset=dataset, batch_size=kvs['args'].bs,
                                            n_threads=kvs['args'].n_threads)

    print(colored('====> ', 'red') + 'Mean:', mean_vector)
    print(colored('====> ', 'red') + 'Std:', std_vector)

    norm_trf = tv_transforms.Normalize(mean_vector.tolist(), std_vector.tolist())
    train_trf = tv_transforms.Compose([
        train_augs,
        partial(apply_by_index, transform=norm_trf, idx=0)
    ])

    val_trf = tv_transforms.Compose([
        img_labels2solt,
        slc.Stream([
            slt.ResizeTransform((310, 310)),
            slt.CropTransform(crop_size=(300, 300), crop_mode='c'),
            slt.ImageColorTransform(mode='gs2rgb'),
        ], interpolation='bicubic'),
        unpack_solt_data,
        partial(apply_by_index, transform=tv_transforms.ToTensor(), idx=0),
        partial(apply_by_index, transform=norm_trf, idx=0)
    ])

    kvs.update('train_trf', train_trf)
    kvs.update('val_trf', val_trf)
    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
Exemplo n.º 5
0
def init_folds(project='OA_progression'):
    kvs = GlobalKVS()
    writers = {}
    cv_split_train = {}
    for fold_id, split in enumerate(kvs['cv_split_all_folds']):
        if kvs['args'].fold != -1 and fold_id != kvs['args'].fold:
            continue
        kvs.update(f'losses_fold_[{fold_id}]', None, list)
        kvs.update(f'val_metrics_fold_[{fold_id}]', None, list)
        cv_split_train[fold_id] = split
        writers[fold_id] = SummaryWriter(os.path.join(kvs['args'].logs,
                                                      project,
                                                      'fold_{}'.format(fold_id), kvs['snapshot_name']))

    kvs.update('cv_split_train', cv_split_train)
    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
    return writers
Exemplo n.º 6
0
def init_optimizer(parameters):
    kvs = GlobalKVS()
    if kvs['args'].optimizer == 'adam':
        return optim.Adam(parameters, lr=kvs['args'].lr, weight_decay=kvs['args'].wd)
    elif kvs['args'].optimizer == 'sgd':
        return optim.SGD(parameters, lr=kvs['args'].lr, weight_decay=kvs['args'].wd, momentum=0.9)
    else:
        raise NotImplementedError
Exemplo n.º 7
0
def init_progression_metadata():
    # We get should rid of non-progressors from MOST because we can check
    # non-progressors only up to 84 months
    kvs = GlobalKVS()

    most_meta = pd.read_csv(os.path.join(kvs['args'].metadata_root, 'MOST_progression.csv'))
    oai_meta = pd.read_csv(os.path.join(kvs['args'].metadata_root, 'OAI_progression.csv'))

    if kvs['args'].subsample_train != -1:
        n_train = oai_meta.shape[0]
        prevalence = (oai_meta.Progressor > 0).sum() / n_train
        sample_pos = int(kvs['args'].subsample_train * prevalence)
        sample_neg = kvs['args'].subsample_train - sample_pos
        train_pos = oai_meta[oai_meta.Progressor > 0]
        train_neg = oai_meta[oai_meta.Progressor == 0]

        pos_sampled = train_pos.iloc[np.random.choice(train_pos.shape[0], sample_pos)]
        neg_sampled = train_neg.iloc[np.random.choice(train_neg.shape[0], sample_neg)]

        new_meta = pd.concat((pos_sampled, neg_sampled))
        oai_meta = new_meta.iloc[np.random.choice(new_meta.shape[0], new_meta.shape[0])]
        print(colored("==> ", 'red') + f"Train set has been sub-sampled. New # pos/neg {sample_pos}/{sample_neg}")

    kvs.update('metadata', oai_meta)
    kvs.update('metadata_test', most_meta)

    gkf = GroupKFold(n_splits=5)
    cv_split = [x for x in gkf.split(kvs['metadata'],
                                     kvs['metadata']['Progressor'],
                                     kvs['metadata']['ID'].astype(str))]

    kvs.update('cv_split_all_folds', cv_split)
    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))

    print(colored("==> ", 'green') +
          f"Train dataset has {(kvs['metadata'].Progressor == 0).sum()} non-progressed knees")

    print(colored("==> ", 'green') +
          f"Train dataset has {(kvs['metadata'].Progressor > 0).sum()} progressed knees")

    print(colored("==> ", 'green') +
          f"Test dataset has {(kvs['metadata_test'].Progressor == 0).sum()} non-progressed knees")

    print(colored("==> ", 'green') +
          f"Test dataset has {(kvs['metadata_test'].Progressor > 0).sum()} progressed knees")
Exemplo n.º 8
0
def init_epoch_pass(net, optimizer, loader):
    kvs = GlobalKVS()
    net.train(optimizer is not None)
    running_loss = 0.0
    n_batches = len(loader)
    pbar = tqdm(total=n_batches)
    epoch = kvs['cur_epoch']
    max_epoch = kvs['args'].n_epochs
    device = next(net.parameters()).device
    return running_loss, pbar, n_batches, epoch, max_epoch, device
Exemplo n.º 9
0
def init_progression_metadata():
    # We get should rid of non-progressors from MOST because we can check
    # non-progressors only up to 84 months
    kvs = GlobalKVS()

    most_meta = pd.read_csv(
        os.path.join(kvs['args'].metadata_root, 'MOST_progression.csv'))
    oai_meta = pd.read_csv(
        os.path.join(kvs['args'].metadata_root, 'OAI_progression.csv'))

    kvs.update('metadata', oai_meta)
    kvs.update('metadata_test', most_meta)

    gkf = GroupKFold(n_splits=5)
    cv_split = [
        x for x in gkf.split(kvs['metadata'], kvs['metadata']['Progressor'],
                             kvs['metadata']['ID'].astype(str))
    ]

    kvs.update('cv_split_all_folds', cv_split)
    kvs.save_pkl(
        os.path.join(kvs['args'].snapshots, kvs['snapshot_name'],
                     'session.pkl'))

    print(
        colored("==> ", 'green') + f"Train dataset has "
        f"{(kvs['metadata'].Progressor == 0).sum()} non-progressed knees")

    print(
        colored("==> ", 'green') + f"Train dataset has "
        f"{(kvs['metadata'].Progressor > 0).sum()} progressed knees")

    print(
        colored("==> ", 'green') + f"Test dataset has "
        f"{(kvs['metadata_test'].Progressor == 0).sum()} non-progressed knees")

    print(
        colored("==> ", 'green') + f"Test dataset has "
        f"{(kvs['metadata_test'].Progressor > 0).sum()} progressed knees")
Exemplo n.º 10
0
def train_folds(writers):
    kvs = GlobalKVS()
    for fold_id in kvs['cv_split_train']:
        kvs.update('cur_fold', fold_id)
        kvs.update('prev_model', None)
        print(colored('====> ', 'blue') + f'Training fold {fold_id}....')

        train_index, val_index = kvs['cv_split_train'][fold_id]
        train_loader, val_loader = session.init_loaders(
            kvs['metadata'].iloc[train_index], kvs['metadata'].iloc[val_index])

        net = init_model()
        optimizer = init_optimizer([{
            'params':
            net.module.classifier_kl.parameters()
        }, {
            'params':
            net.module.classifier_prog.parameters()
        }])

        scheduler = MultiStepLR(optimizer,
                                milestones=kvs['args'].lr_drop,
                                gamma=0.1)

        for epoch in range(kvs['args'].n_epochs):
            kvs.update('cur_epoch', epoch)
            if epoch == kvs['args'].unfreeze_epoch:
                print(colored('====> ', 'red') + 'Unfreezing the layers!')
                new_lr_drop_milestones = list(
                    map(lambda x: x - kvs['args'].unfreeze_epoch,
                        kvs['args'].lr_drop))
                optimizer.add_param_group(
                    {'params': net.module.features.parameters()})
                scheduler = MultiStepLR(optimizer,
                                        milestones=new_lr_drop_milestones,
                                        gamma=0.1)

            print(colored('====> ', 'red') + 'LR:', scheduler.get_lr())
            train_loss = prog_epoch_pass(net, optimizer, train_loader)
            val_out = prog_epoch_pass(net, None, val_loader)
            val_loss, val_ids, gt_progression, preds_progression, gt_kl, preds_kl = val_out
            log_metrics_prog(writers[fold_id], train_loss, val_loss,
                             gt_progression, preds_progression, gt_kl,
                             preds_kl)

            session.save_checkpoint(net, 'ap_prog', 'gt')
            scheduler.step()
Exemplo n.º 11
0
def debug_augmentations(n_iter=20):
    kvs = GlobalKVS()

    ds = OAProgressionDataset(dataset=kvs['args'].dataset_root,
                              split=kvs['metadata'],
                              trf=init_train_augs())

    for ind in np.random.choice(len(ds), n_iter, replace=False):
        sample = ds[ind]
        img = np.clip(sample['img'].numpy() * 255, 0, 255).astype(np.uint8)
        img = np.swapaxes(img, 0, -1)
        img = np.swapaxes(img, 0, 1)
        plt.figure()
        plt.imshow(img)
        plt.show()
Exemplo n.º 12
0
def init_model(kneenet=True):
    kvs = GlobalKVS()
    if kneenet:
        net = KneeNet(kvs['args'].backbone, kvs['args'].dropout_rate)
    else:
        if not kvs['args'].predict_age_sex_bmi:
            net = PretrainedModel(kvs['args'].backbone, kvs['args'].dropout_rate, 1, True)
        else:
            net = PretrainedModel(kvs['args'].backbone, kvs['args'].dropout_rate, 3, True)

    if kvs['gpus'] > 1:
        net = nn.DataParallel(net).to('cuda')

    net = net.to('cuda')
    return net
Exemplo n.º 13
0
def init_loaders(x_train, x_val, progression=True):
    kvs = GlobalKVS()

    ds = OAProgressionDataset
    if not progression:
        ds = AgeSexBMIDataset

    train_dataset = ds(dataset=kvs['args'].dataset_root,
                       split=x_train,
                       trf=kvs['train_trf'])

    val_dataset = ds(dataset=kvs['args'].dataset_root,
                     split=x_val,
                     trf=kvs['val_trf'])

    train_loader = DataLoader(train_dataset, batch_size=kvs['args'].bs,
                              num_workers=kvs['args'].n_threads,
                              drop_last=True, shuffle=True,
                              worker_init_fn=lambda wid: np.random.seed(np.uint32(torch.initial_seed() + wid)))

    val_loader = DataLoader(val_dataset, batch_size=kvs['args'].val_bs,
                            num_workers=kvs['args'].n_threads)

    return train_loader, val_loader
Exemplo n.º 14
0
def save_checkpoint(model, val_metric_name, comparator='lt'):
    if isinstance(model, torch.nn.DataParallel):
        model = model.module

    kvs = GlobalKVS()
    fold_id = kvs['cur_fold']
    epoch = kvs['cur_epoch']
    val_metric = kvs[f'val_metrics_fold_[{fold_id}]'][-1][0][val_metric_name]
    comparator = getattr(operator, comparator)
    cur_snapshot_name = os.path.join(kvs['args'].snapshots, kvs['snapshot_name'],
                                     f'fold_{fold_id}_epoch_{epoch + 1}.pth')

    if kvs['prev_model'] is None:
        print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name)
        torch.save(model.state_dict(), cur_snapshot_name)
        kvs.update('prev_model', cur_snapshot_name)
        kvs.update('best_val_metric', val_metric)

    else:
        if comparator(val_metric, kvs['best_val_metric']):
            print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name)
            os.remove(kvs['prev_model'])
            torch.save(model.state_dict(), cur_snapshot_name)
            kvs.update('prev_model', cur_snapshot_name)
            kvs.update('best_val_metric', val_metric)

    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)

DEBUG = sys.gettrace() is not None

sides = [None, 'R', 'L']
JSW_features = [
    'V00JSW150', 'V00JSW175', 'V00JSW200', 'V00JSW225', 'V00JSW250',
    'V00JSW275', 'V00JSW300', 'V00LJSW700', 'V00LJSW725', 'V00LJSW750',
    'V00LJSW775', 'V00LJSW800', 'V00LJSW825', 'V00LJSW850', 'V00LJSW875',
    'V00LJSW900'
]

if __name__ == "__main__":
    kvs = GlobalKVS()
    session.init_session()
    sites, metadata = read_jsw_metadata_oai(kvs['args'].metadata_root,
                                            kvs['args'].oai_data_root)

    base_snapshot_name = kvs['snapshot_name']
    for test_site in sites:
        # Creating a sub-snapshot for every site in OAI
        os.makedirs(os.path.join(kvs['args'].snapshots, base_snapshot_name,
                                 f'site_{test_site}'),
                    exist_ok=True)
        kvs.update('snapshot_name',
                   os.path.join(base_snapshot_name, f'site_{test_site}'))
        # Splitting the data to exclude the current site from training and keeping it only
        top_subj_train = metadata[metadata.V00SITE != test_site]
        top_subj_test = metadata[metadata.V00SITE == test_site]
Exemplo n.º 16
0
def init_session():
    if not torch.cuda.is_available():
        raise EnvironmentError('The code must be run on GPU.')

    kvs = GlobalKVS()

    # Getting the arguments
    args = parse_args()
    # Initializing the seeds
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    # Creating the snapshot
    snapshot_name = time.strftime('%Y_%m_%d_%H_%M')
    os.makedirs(os.path.join(args.snapshots, snapshot_name), exist_ok=True)

    res = git_info()
    if res is not None:
        kvs.update('git branch name', res[0])
        kvs.update('git commit id', res[1])
    else:
        kvs.update('git branch name', None)
        kvs.update('git commit id', None)

    kvs.update('pytorch_version', torch.__version__)

    if torch.cuda.is_available():
        kvs.update('cuda', torch.version.cuda)
        kvs.update('gpus', torch.cuda.device_count())
    else:
        kvs.update('cuda', None)
        kvs.update('gpus', None)

    kvs.update('snapshot_name', snapshot_name)
    kvs.update('args', args)
    kvs.save_pkl(os.path.join(args.snapshots, snapshot_name, 'session.pkl'))

    return args, snapshot_name
Exemplo n.º 17
0
def prog_epoch_pass(net, optimizer, loader):
    kvs = GlobalKVS()
    running_loss, pbar, n_batches, epoch, max_epoch, device = init_epoch_pass(net, optimizer, loader)

    preds_progression = []
    gt_progression = []
    ids = []
    preds_kl = []
    gt_kl = []

    with torch.set_grad_enabled(optimizer is not None):
        for i, batch in enumerate(loader):
            if optimizer is not None:
                optimizer.zero_grad()
            # forward + backward + optimize if train
            labels_prog = batch['label'].long().to(device)
            labels_kl = batch['KL'].long().to(device)

            inputs = batch['img'].to(device)

            outputs_kl, outputs_prog = net(inputs)
            loss_kl = F.cross_entropy(outputs_kl, labels_kl)
            loss_prog = F.cross_entropy(outputs_prog, labels_prog)

            loss = loss_prog.mul(kvs['args'].loss_weight) + loss_kl.mul(1 - kvs['args'].loss_weight)

            if optimizer is not None:
                loss.backward()
                if kvs['args'].clip_grad:
                    torch.nn.utils.clip_grad_norm_(net.parameters(), kvs['args'].clip_grad_norm)
                optimizer.step()
            else:
                probs_progression_batch = F.softmax(outputs_prog, 1).data.to('cpu').numpy()
                probs_kl_batch = F.softmax(outputs_kl, 1).data.to('cpu').numpy()

                preds_progression.append(probs_progression_batch)
                gt_progression.append(batch['label'].numpy())

                preds_kl.append(probs_kl_batch)
                gt_kl.append(batch['KL'])
                ids.extend(batch['ID_SIDE'])

            running_loss += loss.item()
            if optimizer is not None:
                pbar.set_description(f'Training   [{epoch} / {max_epoch}]:: {running_loss / (i + 1):.3f}')
            else:
                pbar.set_description(f'Validating [{epoch} / {max_epoch}]:')
            pbar.update()

            gc.collect()

    if optimizer is None:
        preds_progression = np.vstack(preds_progression)
        gt_progression = np.hstack(gt_progression)

        preds_kl = np.vstack(preds_kl)
        gt_kl = np.hstack(gt_kl)

    gc.collect()
    pbar.close()

    if optimizer is not None:
        return running_loss / n_batches
    else:
        return running_loss / n_batches, ids, gt_progression, preds_progression, gt_kl, preds_kl
Exemplo n.º 18
0
def log_metrics_age_sex_bmi(boardlogger, train_loss, val_res):
    kvs = GlobalKVS()
    res = dict()
    val_loss = val_res[0]
    res['val_loss'] = val_loss,
    res['epoch'] = kvs['cur_epoch']
    print(colored('====> ', 'green') + f'Train loss: {train_loss:.5f}')
    print(colored('====> ', 'green') + f'Validation loss: {val_loss:.5f}')
    boardlogger.add_scalars('Losses', {'train': train_loss, 'val': val_loss}, kvs['cur_epoch'])

    if not kvs['args'].predict_age_sex_bmi:
        _, ids, gt, preds = val_res
        if kvs['args'].target_var == 'SEX':
            val_auc = roc_auc_score(gt.astype(int), preds)
            res['sex_auc'] = val_auc
            print(colored('====> ', 'green') + f'Validation AUC: {val_auc:.5f}')
            boardlogger.add_scalars('AUC sex', {'val': res['sex_auc']}, kvs['cur_epoch'])
        else:
            val_mse = mean_squared_error(gt, preds)
            val_mae = median_absolute_error(gt, preds)
            res[f"{kvs['args'].target_var}_mse"] = val_mse
            res[f"{kvs['args'].target_var}_mae"] = val_mae

            print(colored('====> ', 'green') + f'Validation mae: {val_mae:.5f}')
            print(colored('====> ', 'green') + f'Validation mse: {val_mse:.5f}')

            boardlogger.add_scalars(f"MSE [{kvs['args'].target_var}]", {'val': val_mse},
                                    kvs['cur_epoch'])
            boardlogger.add_scalars(f"MAE [{kvs['args'].target_var}]", {'val': val_mae},
                                    kvs['cur_epoch'])
    else:
        _, ids, gt_age, preds_age, gt_sex, preds_sex, gt_bmi, preds_bmi = val_res
        val_mse_age = mean_squared_error(gt_age, preds_age)
        val_mae_age = median_absolute_error(gt_age, preds_age)
        val_sex_auc = roc_auc_score(gt_sex.astype(int), preds_sex)
        val_mse_bmi = mean_squared_error(gt_bmi, preds_bmi)
        val_mae_bmi = median_absolute_error(gt_bmi, preds_bmi)

        res["AGE_mse"] = val_mse_age
        res["AGE_mae"] = val_mae_age

        res["BMI_mse"] = val_mse_bmi
        res["BMI_mae"] = val_mae_bmi

        res["SEX_auc"] = val_sex_auc

        print(colored('====> ', 'green') + f'Validation mae [Age]: {val_mae_age:.5f}')
        print(colored('====> ', 'green') + f'Validation mse [Age]: {val_mse_age:.5f}')

        print(colored('====> ', 'green') + f'Validation val_auc [Sex]: {val_sex_auc:.5f}')

        print(colored('====> ', 'green') + f'Validation mae [BMI]: {val_mae_bmi:.5f}')
        print(colored('====> ', 'green') + f'Validation mse [BMI]: {val_mse_bmi:.5f}')

    kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', {'epoch': kvs['cur_epoch'],
                                                    'train_loss': train_loss,
                                                    'val_loss': val_loss})

    kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', res)

    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
Exemplo n.º 19
0
def init_session(snapshot_name_prefix=None):
    if not torch.cuda.is_available():
        raise EnvironmentError('The code must be run on GPU.')

    kvs = GlobalKVS()

    # Getting the arguments
    args = parse_args()
    kvs.update('args', args)

    # Initializing the seeds
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    # Creating the snapshot
    init_snapshot_dir(snapshot_name_prefix)

    kvs.update('pytorch_version', torch.__version__)

    if torch.cuda.is_available():
        kvs.update('cuda', torch.version.cuda)
        kvs.update('gpus', torch.cuda.device_count())
    else:
        kvs.update('cuda', None)
        kvs.update('gpus', None)

    kvs.save_pkl(os.path.join(args.snapshots, kvs['snapshot_name'], 'session.pkl'))

    return args, kvs['snapshot_name']
Exemplo n.º 20
0
def epoch_pass(net, optimizer, loader):
    kvs = GlobalKVS()
    running_loss, pbar, n_batches, epoch, max_epoch, device = init_epoch_pass(net, optimizer, loader)
    if kvs['args'].target_var == 'SEX':
        criterion = F.binary_cross_entropy_with_logits
    else:
        criterion = F.mse_loss
    # Individual factors prediction
    preds = list()
    gt = list()
    # Predicting Age, Sex, BMI
    preds_age = list()
    preds_sex = list()
    preds_bmi = list()

    gt_age = list()
    gt_sex = list()
    gt_bmi = list()

    ids = list()

    with torch.set_grad_enabled(optimizer is not None):
        for i, batch in enumerate(loader):
            if optimizer is not None:
                optimizer.zero_grad()
            inp = batch['img'].to(device)
            output = net(inp).squeeze()
            if not kvs['args'].predict_age_sex_bmi:
                target = batch[kvs['args'].target_var].float().to(device)
                loss = criterion(output, target)
            else:
                target_age = batch['AGE'].float().to(device)
                target_sex = batch['SEX'].float().to(device)
                target_bmi = batch['BMI'].float().to(device)
                loss_age = F.mse_loss(output[:, 0].squeeze(), target_age)
                loss_sex = F.binary_cross_entropy_with_logits(output[:, 1].squeeze(), target_sex)
                loss_bmi = F.mse_loss(output[:, 2].squeeze(), target_bmi)
                loss = loss_age + loss_sex + loss_bmi

            if optimizer is not None:
                loss.backward()
                if kvs['args'].clip_grad:
                    torch.nn.utils.clip_grad_norm_(net.parameters(), kvs['args'].clip_grad_norm)
                optimizer.step()
            else:
                if not kvs['args'].predict_age_sex_bmi:
                    if kvs['args'].target_var == 'SEX':
                        pred_batch = torch.sigmoid(output).data.to('cpu').numpy().squeeze()
                    else:
                        pred_batch = output.data.to('cpu').numpy().squeeze()
                    preds.append(pred_batch)
                    gt.append(batch[kvs['args'].target_var].numpy().squeeze())
                else:
                    preds_age_batch = output[:, 0].data.to('cpu').numpy().squeeze()
                    preds_sex_batch = torch.sigmoid(output[:, 1]).data.to('cpu').numpy().squeeze()
                    preds_bmi_batch = output[:, 2].data.to('cpu').numpy().squeeze()

                    preds_age.append(preds_age_batch)
                    preds_sex.append(preds_sex_batch)
                    preds_bmi.append(preds_bmi_batch)

                    gt_age.append(batch['AGE'].numpy().squeeze())
                    gt_sex.append(batch['SEX'].numpy().squeeze())
                    gt_bmi.append(batch['BMI'].numpy().squeeze())

                ids.extend(batch['ID_SIDE'])

            running_loss += loss.item()
            if optimizer is not None:
                pbar.set_description(f'Training   [{epoch} / {max_epoch}]:: {running_loss / (i + 1):.3f}')
            else:
                pbar.set_description(f'Validating [{epoch} / {max_epoch}]:')
            pbar.update()

            gc.collect()

    gc.collect()
    pbar.close()

    if optimizer is not None:
        return running_loss / n_batches
    else:
        if not kvs['args'].predict_age_sex_bmi:
            preds = np.hstack(preds)
            gt = np.hstack(gt)
            return running_loss / n_batches, ids, gt, preds
        else:
            preds_age = np.hstack(preds_age)
            gt_age = np.hstack(gt_age)

            preds_sex = np.hstack(preds_sex)
            gt_sex = np.hstack(gt_sex)

            preds_bmi = np.hstack(preds_bmi)
            gt_bmi = np.hstack(gt_bmi)

            return running_loss / n_batches, ids, gt_age, preds_age, gt_sex, preds_sex, gt_bmi, preds_bmi
Exemplo n.º 21
0
import cv2

from oaprogression.kvs import GlobalKVS
from oaprogression.training import dataset
from oaprogression.training import session
from oaprogression.training import train_utils

cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)

if __name__ == "__main__":
    kvs = GlobalKVS()
    session.init_session()
    dataset.init_progression_metadata()
    session.init_data_processing()
    writers = session.init_folds()
    train_utils.train_folds(writers)
Exemplo n.º 22
0
import cv2
from termcolor import colored
from torch.optim.lr_scheduler import MultiStepLR

from oaprogression.kvs import GlobalKVS
from oaprogression.training import dataset
from oaprogression.training import session
from oaprogression.training import train_utils

cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)

DEBUG = sys.gettrace() is not None

if __name__ == "__main__":
    kvs = GlobalKVS()
    session.init_session()
    dataset.init_progression_metadata()
    session.init_data_processing()
    writers = session.init_folds()

    if DEBUG:
        dataset.debug_augmentations()

    for fold_id in kvs['cv_split_train']:
        kvs.update('cur_fold', fold_id)
        kvs.update('prev_model', None)
        print(colored('====> ', 'blue') + f'Training fold {fold_id}....')

        train_index, val_index = kvs['cv_split_train'][fold_id]
        train_loader, val_loader = session.init_loaders(kvs['metadata'].iloc[train_index],