コード例 #1
0
def log_metrics(boardlogger, train_loss, val_loss, val_pred, val_gt):
    kvs = GlobalKVS()
    res = {'epoch': kvs['cur_epoch'], 'val_loss': val_loss}
    print(
        colored('==> ', 'green') +
        f'Train loss: {train_loss:.4f} / Val loss: {val_loss:.4f}')

    res.update(compute_metrics(val_pred, val_gt, no_kl=kvs['args'].no_kl))

    boardlogger.add_scalars('Losses', {
        'train': train_loss,
        'val': val_loss
    }, kvs['cur_epoch'])
    boardlogger.add_scalars(
        'Metrics',
        {metric: res[metric]
         for metric in res if metric.startswith('kappa')}, kvs['cur_epoch'])

    kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', {
        'epoch': kvs['cur_epoch'],
        'train_loss': train_loss,
        'val_loss': val_loss
    })

    kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', res)

    kvs.save_pkl(
        os.path.join(kvs['args'].snapshots, kvs['snapshot_name'],
                     'session.pkl'))
コード例 #2
0
ファイル: session.py プロジェクト: nlebang/KneeOARSIGrading
def save_checkpoint(model, optimizer):
    kvs = GlobalKVS()
    fold_id = kvs['cur_fold']
    epoch = kvs['cur_epoch']
    val_metric = kvs[f'val_metrics_fold_[{fold_id}]'][-1][0][kvs['args'].snapshot_on]
    comparator = getattr(operator, kvs['args'].snapshot_comparator)
    cur_snapshot_name = os.path.join(kvs['args'].snapshots, kvs['snapshot_name'],
                                     f'fold_{fold_id}_epoch_{epoch+1}.pth')

    if kvs['prev_model'] is None:
        print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name)
        torch.save({'epoch': epoch, 'net': net_core(model).state_dict(),
                    'optim': optimizer.state_dict()}, cur_snapshot_name)

        kvs.update('prev_model', cur_snapshot_name)
        kvs.update('best_val_metric', val_metric)

    else:
        if comparator(val_metric, kvs['best_val_metric']):
            print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name)
            if not kvs['args'].keep_snapshots:
                os.remove(kvs['prev_model'])

            torch.save({'epoch': epoch, 'net': net_core(model).state_dict(),
                        'optim': optimizer.state_dict()}, cur_snapshot_name)
            kvs.update('prev_model', cur_snapshot_name)
            kvs.update('best_val_metric', val_metric)

    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
コード例 #3
0
ファイル: session.py プロジェクト: nlebang/KneeOARSIGrading
def init_metadata():
    kvs = GlobalKVS()
    if not os.path.isfile(os.path.join(kvs['args'].snapshots, 'oai_meta.pkl')):
        print('==> Cached metadata is not found. Generating...')
        oai_meta, most_meta = build_dataset_meta(kvs['args'])
        oai_meta.to_pickle(os.path.join(kvs['args'].snapshots, 'oai_meta.pkl'), compression='infer', protocol=4)
        most_meta.to_pickle(os.path.join(kvs['args'].snapshots, 'most_meta.pkl'), compression='infer', protocol=4)
    else:
        print('==> Loading cached metadata...')
        oai_meta = pd.read_pickle(os.path.join(kvs['args'].snapshots, 'oai_meta.pkl'))

        most_meta = pd.read_pickle(os.path.join(kvs['args'].snapshots, 'most_meta.pkl'))

    most_meta = most_meta[(most_meta.XRKL >= 0) & (most_meta.XRKL <= 4)]
    oai_meta = oai_meta[(oai_meta.XRKL >= 0) & (oai_meta.XRKL <= 4)]

    print(colored('==> ', 'green') + 'Images in OAI:', oai_meta.shape[0])
    print(colored('==> ', 'green') + 'Images in MOST:', most_meta.shape[0])

    kvs.update('most_meta', most_meta)
    kvs.update('oai_meta', oai_meta)

    gkf = GroupKFold(kvs['args'].n_folds)
    cv_split = [x for x in gkf.split(kvs[kvs["args"].train_set + '_meta'],
                                     groups=kvs[kvs["args"].train_set + '_meta']['ID'].values)]

    kvs.update('cv_split_all_folds', cv_split)
    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
コード例 #4
0
ファイル: session.py プロジェクト: nlebang/KneeOARSIGrading
def init_data_processing():
    kvs = GlobalKVS()

    train_trf, val_trf = init_transforms(None, None)

    dataset = OARSIGradingDataset(kvs[f'{kvs["args"].train_set}_meta'], train_trf)

    mean_vector, std_vector = init_mean_std(snapshots_dir=kvs['args'].snapshots,
                                            dataset=dataset, batch_size=kvs['args'].bs,
                                            n_threads=kvs['args'].n_threads)

    print(colored('====> ', 'red') + 'Mean:', mean_vector)
    print(colored('====> ', 'red') + 'Std:', std_vector)

    kvs.update('mean_vector', mean_vector)
    kvs.update('std_vector', std_vector)

    train_trf, val_trf = init_transforms(mean_vector, std_vector)

    kvs.update('train_trf', train_trf)
    kvs.update('val_trf', val_trf)
    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
コード例 #5
0
ファイル: session.py プロジェクト: nlebang/KneeOARSIGrading
def init_folds():
    kvs = GlobalKVS()
    writers = {}
    cv_split_train = {}
    for fold_id, split in enumerate(kvs['cv_split_all_folds']):
        if kvs['args'].fold != -1 and fold_id != kvs['args'].fold:
            continue
        kvs.update(f'losses_fold_[{fold_id}]', None, list)
        kvs.update(f'val_metrics_fold_[{fold_id}]', None, list)
        cv_split_train[fold_id] = split
        writers[fold_id] = SummaryWriter(os.path.join(kvs['args'].snapshots,
                                                      kvs['snapshot_name'],
                                                      'logs',
                                                      'fold_{}'.format(fold_id), kvs['snapshot_name']))

    kvs.update('cv_split_train', cv_split_train)
    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
    return writers
コード例 #6
0
ファイル: session.py プロジェクト: nlebang/KneeOARSIGrading
def init_session():
    if not torch.cuda.is_available():
        raise EnvironmentError('The code must be run on GPU.')

    kvs = GlobalKVS()

    # Getting the arguments
    args = parse_args()
    # Initializing the seeds
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    # Creating the snapshot
    snapshot_name = time.strftime('%Y_%m_%d_%H_%M')
    os.makedirs(os.path.join(args.snapshots, snapshot_name), exist_ok=True)

    res = git_info()
    if res is not None:
        kvs.update('git branch name', res[0])
        kvs.update('git commit id', res[1])
    else:
        kvs.update('git branch name', None)
        kvs.update('git commit id', None)

    kvs.update('pytorch_version', torch.__version__)

    if torch.cuda.is_available():
        kvs.update('cuda', torch.version.cuda)
        kvs.update('gpus', torch.cuda.device_count())
    else:
        kvs.update('cuda', None)
        kvs.update('gpus', None)

    kvs.update('snapshot_name', snapshot_name)
    kvs.update('args', args)
    kvs.save_pkl(os.path.join(args.snapshots, snapshot_name, 'session.pkl'))

    return args, snapshot_name
コード例 #7
0
ファイル: train.py プロジェクト: nlebang/KneeOARSIGrading
cv2.setNumThreads(0)

DEBUG = sys.gettrace() is not None

if __name__ == "__main__":
    kvs = GlobalKVS()
    session.init_session()
    session.init_metadata()
    writers = session.init_folds()
    session.init_data_processing()

    for fold_id in kvs['cv_split_train']:
        if kvs['args'].fold != -1 and fold_id != kvs['args'].fold:
            continue

        kvs.update('cur_fold', fold_id)
        kvs.update('prev_model', None)
        print(colored('====> ', 'blue') + f'Training fold {fold_id}....')

        train_index, val_index = kvs['cv_split_train'][fold_id]
        train_loader, val_loader = session.init_loaders(kvs[f'{kvs["args"].train_set}_meta'].iloc[train_index],
                                                        kvs[f'{kvs["args"].train_set}_meta'].iloc[val_index])

        net, criterion = utils.init_model()

        if kvs['args'].pretrained:
            net.train(False)
            utils.net_core(net).classifier.train(True)
            optimizer = utils.init_optimizer(utils.layer_params(net, 'classifier'))
        else:
            print(colored('====> ', 'red') + 'The model will be trained from scratch!')