Пример #1
0
def init_metadata():
    kvs = GlobalKVS()
    if not os.path.isfile(os.path.join(kvs['args'].snapshots, 'oai_meta.pkl')):
        print('==> Cached metadata is not found. Generating...')
        oai_meta, most_meta = build_dataset_meta(kvs['args'])
        oai_meta.to_pickle(os.path.join(kvs['args'].snapshots, 'oai_meta.pkl'), compression='infer', protocol=4)
        most_meta.to_pickle(os.path.join(kvs['args'].snapshots, 'most_meta.pkl'), compression='infer', protocol=4)
    else:
        print('==> Loading cached metadata...')
        oai_meta = pd.read_pickle(os.path.join(kvs['args'].snapshots, 'oai_meta.pkl'))

        most_meta = pd.read_pickle(os.path.join(kvs['args'].snapshots, 'most_meta.pkl'))

    most_meta = most_meta[(most_meta.XRKL >= 0) & (most_meta.XRKL <= 4)]
    oai_meta = oai_meta[(oai_meta.XRKL >= 0) & (oai_meta.XRKL <= 4)]

    print(colored('==> ', 'green') + 'Images in OAI:', oai_meta.shape[0])
    print(colored('==> ', 'green') + 'Images in MOST:', most_meta.shape[0])

    kvs.update('most_meta', most_meta)
    kvs.update('oai_meta', oai_meta)

    gkf = GroupKFold(kvs['args'].n_folds)
    cv_split = [x for x in gkf.split(kvs[kvs["args"].train_set + '_meta'],
                                     groups=kvs[kvs["args"].train_set + '_meta']['ID'].values)]

    kvs.update('cv_split_all_folds', cv_split)
    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
Пример #2
0
def log_metrics(boardlogger, train_loss, val_loss, val_pred, val_gt):
    kvs = GlobalKVS()
    res = {'epoch': kvs['cur_epoch'], 'val_loss': val_loss}
    print(
        colored('==> ', 'green') +
        f'Train loss: {train_loss:.4f} / Val loss: {val_loss:.4f}')

    res.update(compute_metrics(val_pred, val_gt, no_kl=kvs['args'].no_kl))

    boardlogger.add_scalars('Losses', {
        'train': train_loss,
        'val': val_loss
    }, kvs['cur_epoch'])
    boardlogger.add_scalars(
        'Metrics',
        {metric: res[metric]
         for metric in res if metric.startswith('kappa')}, kvs['cur_epoch'])

    kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', {
        'epoch': kvs['cur_epoch'],
        'train_loss': train_loss,
        'val_loss': val_loss
    })

    kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', res)

    kvs.save_pkl(
        os.path.join(kvs['args'].snapshots, kvs['snapshot_name'],
                     'session.pkl'))
Пример #3
0
def init_loaders(x_train, x_val):
    kvs = GlobalKVS()
    train_dataset, val_dataset = init_datasets(x_train, x_val)

    if kvs['args'].weighted_sampling:
        if not kvs['args'].mtw:
            print(colored('====> ', 'red') + 'Using weighted sampling (KL)')
            _, weights = make_weights_for_multiclass(x_train.XRKL.values.astype(int))
        else:
            print(colored('====> ', 'red') + 'Using weighted sampling (MTW)')
            cols = ['XROSTL', 'XROSFL', 'XRJSL', 'XROSTM', 'XROSFM', 'XRJSM']
            weights = torch.stack([make_weights_for_multiclass(x_train[col].values.astype(int))[1]
                                   for col in cols], 1).max(1)[0]

        sampler = WeightedRandomSampler(weights, x_train.shape[0], True)

        train_loader = DataLoader(train_dataset, batch_size=kvs['args'].bs,
                                  num_workers=kvs['args'].n_threads,
                                  drop_last=True, sampler=sampler)

    else:
        train_loader = DataLoader(train_dataset, batch_size=kvs['args'].bs,
                                  num_workers=kvs['args'].n_threads,
                                  drop_last=True, shuffle=True)

    val_loader = DataLoader(val_dataset, batch_size=kvs['args'].val_bs,
                            num_workers=kvs['args'].n_threads)

    return train_loader, val_loader
Пример #4
0
def init_mean_std(snapshots_dir, dataset, batch_size, n_threads):
    kvs = GlobalKVS()
    if os.path.isfile(os.path.join(snapshots_dir, f'mean_std_{kvs["args"].train_set}.npy')):
        tmp = np.load(os.path.join(snapshots_dir, f'mean_std_{kvs["args"].train_set}.npy'))
        mean_vector, std_vector = tmp
    else:
        tmp_loader = DataLoader(dataset, batch_size=batch_size, num_workers=n_threads)
        mean_vector = None
        std_vector = None
        print(colored('==> ', 'green') + 'Calculating mean and std')
        for batch in tqdm(tmp_loader, total=len(tmp_loader)):
            if kvs['args'].siamese:
                imgs = torch.cat((batch['img_med'], batch['img_lat']))
            else:
                imgs = batch['img']

            if mean_vector is None:
                mean_vector = np.zeros(imgs.size(1))
                std_vector = np.zeros(imgs.size(1))
            for j in range(mean_vector.shape[0]):
                mean_vector[j] += imgs[:, j, :, :].mean()
                std_vector[j] += imgs[:, j, :, :].std()

        mean_vector /= len(tmp_loader)
        std_vector /= len(tmp_loader)
        np.save(os.path.join(snapshots_dir, f'mean_std_{kvs["args"].train_set}.npy'),
                [mean_vector.astype(np.float32), std_vector.astype(np.float32)])

    return mean_vector, std_vector
Пример #5
0
def init_folds():
    kvs = GlobalKVS()
    writers = {}
    cv_split_train = {}
    for fold_id, split in enumerate(kvs['cv_split_all_folds']):
        if kvs['args'].fold != -1 and fold_id != kvs['args'].fold:
            continue
        kvs.update(f'losses_fold_[{fold_id}]', None, list)
        kvs.update(f'val_metrics_fold_[{fold_id}]', None, list)
        cv_split_train[fold_id] = split
        writers[fold_id] = SummaryWriter(os.path.join(kvs['args'].snapshots,
                                                      kvs['snapshot_name'],
                                                      'logs',
                                                      'fold_{}'.format(fold_id), kvs['snapshot_name']))

    kvs.update('cv_split_train', cv_split_train)
    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
    return writers
Пример #6
0
def init_scheduler(optimizer: Optimizer, epoch_start: int) -> MultiStepLR:
    kvs = GlobalKVS()
    scheduler = MultiStepLR(optimizer,
                            milestones=list(
                                map(lambda x: x - epoch_start,
                                    kvs['args'].lr_drop)),
                            gamma=kvs['args'].lr_drop_gamma)

    return scheduler
Пример #7
0
def init_optimizer(params) -> Optimizer:
    kvs = GlobalKVS()
    if kvs['args'].optimizer == 'adam':
        return optim.Adam(params,
                          lr=kvs['args'].lr,
                          weight_decay=kvs['args'].wd)
    elif kvs['args'].optimizer == 'sgd':
        return optim.SGD(params,
                         lr=kvs['args'].lr,
                         weight_decay=kvs['args'].wd,
                         momentum=kvs['args'].momentum,
                         nesterov=kvs['args'].nesterov)
    else:
        raise NotImplementedError
Пример #8
0
def init_model() -> Tuple[nn.Module, nn.Module]:
    kvs = GlobalKVS()

    if kvs['args'].siamese:
        net = OARSIGradingNetSiamese(backbone=kvs['args'].siamese_bb,
                                     dropout=kvs['args'].dropout_rate)
    else:
        net = OARSIGradingNet(bb_depth=kvs['args'].backbone_depth,
                              dropout=kvs['args'].dropout_rate,
                              cls_bnorm=kvs['args'].use_bnorm,
                              se=kvs['args'].se,
                              dw=kvs['args'].dw,
                              use_gwap=kvs['args'].use_gwap,
                              use_gwap_hidden=kvs['args'].use_gwap_hidden,
                              pretrained=kvs['args'].pretrained,
                              no_kl=kvs['args'].no_kl)

    if kvs['gpus'] > 1:
        net = nn.DataParallel(net).to('cuda')

    return net.to('cuda'), init_loss().to('cuda')
Пример #9
0
def init_data_processing():
    kvs = GlobalKVS()

    train_trf, val_trf = init_transforms(None, None)

    dataset = OARSIGradingDataset(kvs[f'{kvs["args"].train_set}_meta'], train_trf)

    mean_vector, std_vector = init_mean_std(snapshots_dir=kvs['args'].snapshots,
                                            dataset=dataset, batch_size=kvs['args'].bs,
                                            n_threads=kvs['args'].n_threads)

    print(colored('====> ', 'red') + 'Mean:', mean_vector)
    print(colored('====> ', 'red') + 'Std:', std_vector)

    kvs.update('mean_vector', mean_vector)
    kvs.update('std_vector', std_vector)

    train_trf, val_trf = init_transforms(mean_vector, std_vector)

    kvs.update('train_trf', train_trf)
    kvs.update('val_trf', val_trf)
    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
Пример #10
0
def init_session():
    if not torch.cuda.is_available():
        raise EnvironmentError('The code must be run on GPU.')

    kvs = GlobalKVS()

    # Getting the arguments
    args = parse_args()
    # Initializing the seeds
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    # Creating the snapshot
    snapshot_name = time.strftime('%Y_%m_%d_%H_%M')
    os.makedirs(os.path.join(args.snapshots, snapshot_name), exist_ok=True)

    res = git_info()
    if res is not None:
        kvs.update('git branch name', res[0])
        kvs.update('git commit id', res[1])
    else:
        kvs.update('git branch name', None)
        kvs.update('git commit id', None)

    kvs.update('pytorch_version', torch.__version__)

    if torch.cuda.is_available():
        kvs.update('cuda', torch.version.cuda)
        kvs.update('gpus', torch.cuda.device_count())
    else:
        kvs.update('cuda', None)
        kvs.update('gpus', None)

    kvs.update('snapshot_name', snapshot_name)
    kvs.update('args', args)
    kvs.save_pkl(os.path.join(args.snapshots, snapshot_name, 'session.pkl'))

    return args, snapshot_name
Пример #11
0
def save_checkpoint(model, optimizer):
    kvs = GlobalKVS()
    fold_id = kvs['cur_fold']
    epoch = kvs['cur_epoch']
    val_metric = kvs[f'val_metrics_fold_[{fold_id}]'][-1][0][kvs['args'].snapshot_on]
    comparator = getattr(operator, kvs['args'].snapshot_comparator)
    cur_snapshot_name = os.path.join(kvs['args'].snapshots, kvs['snapshot_name'],
                                     f'fold_{fold_id}_epoch_{epoch+1}.pth')

    if kvs['prev_model'] is None:
        print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name)
        torch.save({'epoch': epoch, 'net': net_core(model).state_dict(),
                    'optim': optimizer.state_dict()}, cur_snapshot_name)

        kvs.update('prev_model', cur_snapshot_name)
        kvs.update('best_val_metric', val_metric)

    else:
        if comparator(val_metric, kvs['best_val_metric']):
            print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name)
            if not kvs['args'].keep_snapshots:
                os.remove(kvs['prev_model'])

            torch.save({'epoch': epoch, 'net': net_core(model).state_dict(),
                        'optim': optimizer.state_dict()}, cur_snapshot_name)
            kvs.update('prev_model', cur_snapshot_name)
            kvs.update('best_val_metric', val_metric)

    kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
Пример #12
0
def init_datasets(x_train, x_val):
    kvs = GlobalKVS()
    train_dataset = OARSIGradingDataset(x_train, kvs['train_trf'])
    val_dataset = OARSIGradingDataset(x_val, kvs['val_trf'])

    return train_dataset, val_dataset
Пример #13
0
from oarsigrading.kvs import GlobalKVS
from oarsigrading.training.dataset import OARSIGradingDataset
from oarsigrading.evaluation import metrics
from oarsigrading.training.model_zoo import backbone_name
from oarsigrading.training.model import OARSIGradingNet, OARSIGradingNetSiamese
import oarsigrading.evaluation.tta as tta
from oarsigrading.training.transforms import apply_by_index
import pandas as pd

cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)

DEBUG = sys.gettrace() is not None

if __name__ == "__main__":
    kvs = GlobalKVS()
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_root', default='')
    parser.add_argument('--meta_root', default='')
    parser.add_argument('--tta', type=bool, default=False)
    parser.add_argument('--bs', type=int, default=32)
    parser.add_argument('--n_threads', type=int, default=12)
    parser.add_argument('--snapshots_root', default='')
    parser.add_argument('--snapshot', default='')
    parser.add_argument('--save_dir', default='')
    args = parser.parse_args()

    with open(os.path.join(args.snapshots_root, args.snapshot, 'session.pkl'),
              'rb') as f:
        session_backup = pickle.load(f)
Пример #14
0
def init_transforms(mean_vector, std_vector):
    kvs = GlobalKVS()

    if mean_vector is not None:
        mean_vector = torch.from_numpy(mean_vector).float()
        std_vector = torch.from_numpy(std_vector).float()
        norm_trf = partial(normalize_channel_wise,
                           mean=mean_vector,
                           std=std_vector)
        norm_trf = partial(apply_by_index, transform=norm_trf, idx=[0, 1, 2])
    else:
        norm_trf = None

    if kvs['args'].siamese:
        resize_train = slc.Stream()
        crop_train = slt.CropTransform(crop_size=(kvs['args'].imsize,
                                                  kvs['args'].imsize),
                                       crop_mode='c')
    else:
        resize_train = slt.ResizeTransform(
            (kvs['args'].inp_size, kvs['args'].inp_size))
        crop_train = slt.CropTransform(crop_size=(kvs['args'].crop_size,
                                                  kvs['args'].crop_size),
                                       crop_mode='r')

    train_trf = [
        wrap2solt,
        slc.Stream([
            slt.PadTransform(pad_to=(kvs['args'].imsize, kvs['args'].imsize)),
            slt.CropTransform(crop_size=(kvs['args'].imsize,
                                         kvs['args'].imsize),
                              crop_mode='c'),
            resize_train,
            slt.ImageAdditiveGaussianNoise(p=0.5, gain_range=0.3),
            slt.RandomRotate(p=1, rotation_range=(-10, 10)),
            crop_train,
            slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)),
        ]),
        unpack_solt_data,
        partial(pack_tensors, no_kl=kvs['args'].no_kl),
    ]

    if not kvs['args'].siamese:
        resize_val = slc.Stream([
            slt.ResizeTransform((kvs['args'].inp_size, kvs['args'].inp_size)),
            slt.CropTransform(crop_size=(kvs['args'].crop_size,
                                         kvs['args'].crop_size),
                              crop_mode='c'),
        ])
    else:
        resize_val = slc.Stream()

    val_trf = [
        wrap2solt,
        slc.Stream([
            slt.PadTransform(pad_to=(kvs['args'].imsize, kvs['args'].imsize)),
            slt.CropTransform(crop_size=(kvs['args'].imsize,
                                         kvs['args'].imsize),
                              crop_mode='c'),
            resize_val,
        ]),
        unpack_solt_data,
        partial(pack_tensors, no_kl=kvs['args'].no_kl),
    ]

    if norm_trf is not None:
        train_trf.append(norm_trf)
        val_trf.append(norm_trf)

    train_trf = transforms.Compose(train_trf)
    val_trf = transforms.Compose(val_trf)

    return train_trf, val_trf
Пример #15
0
import cv2
import sys
from termcolor import colored

from oarsigrading.kvs import GlobalKVS
from oarsigrading.training import session
from oarsigrading.training import utils
from oarsigrading.evaluation import metrics

cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)

DEBUG = sys.gettrace() is not None

if __name__ == "__main__":
    kvs = GlobalKVS()
    session.init_session()
    session.init_metadata()
    writers = session.init_folds()
    session.init_data_processing()

    for fold_id in kvs['cv_split_train']:
        if kvs['args'].fold != -1 and fold_id != kvs['args'].fold:
            continue

        kvs.update('cur_fold', fold_id)
        kvs.update('prev_model', None)
        print(colored('====> ', 'blue') + f'Training fold {fold_id}....')

        train_index, val_index = kvs['cv_split_train'][fold_id]
        train_loader, val_loader = session.init_loaders(kvs[f'{kvs["args"].train_set}_meta'].iloc[train_index],
Пример #16
0
def epoch_pass(
    net: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    optimizer: Optimizer or None,
    writer: SummaryWriter or None = None
) -> float or Tuple[float, List[str], np.ndarray, np.ndarray]:

    kvs = GlobalKVS()
    if optimizer is not None:
        net.train(True)
    else:
        net.train(False)

    running_loss = 0.0
    n_batches = len(loader)
    pbar = tqdm(total=len(loader))
    epoch = kvs['cur_epoch']
    max_epoch = kvs['args'].n_epochs
    fold_id = kvs['cur_fold']

    device = next(net.parameters()).device
    predicts = []
    fnames = []
    gt = []
    with torch.set_grad_enabled(optimizer is not None):
        for i, batch in enumerate(loader):
            if optimizer is not None:
                optimizer.zero_grad()

            # forward + backward + optimize
            labels = batch['target'].squeeze().to(device)
            if kvs['args'].siamese:
                inp_med = batch['img_med'].squeeze().to(device)
                inp_lat = batch['img_lat'].squeeze().to(device)
                outputs = net(inp_med, inp_lat)
            else:
                inputs = batch['img'].squeeze().to(device)
                outputs = net(inputs)

            loss = criterion(outputs, labels)

            if optimizer is not None:
                loss.backward()
                optimizer.step()
                pbar.set_description(
                    f'[{fold_id}] Train:: [{epoch} / {max_epoch}]:: '
                    f'{running_loss / (i + 1):.3f} | {loss.item():.3f}')
            else:
                tmp_preds = np.zeros(batch['target'].squeeze().size(),
                                     dtype=np.int64)
                for task_id, o in enumerate(outputs):
                    tmp_preds[:, task_id] = outputs[task_id].to(
                        'cpu').squeeze().argmax(1)

                predicts.append(tmp_preds)
                gt.append(batch['target'].to('cpu').numpy().squeeze())
                fnames.extend(batch['ID'])

                pbar.set_description(
                    f'[{fold_id}] Validating [{epoch} / {max_epoch}]:')
            if writer is not None and optimizer is not None:
                writer.add_scalar('train_logs/loss', loss.item(),
                                  kvs['cur_epoch'] * len(loader) + i)
            running_loss += loss.item()
            pbar.update()

            gc.collect()
    gc.collect()
    pbar.close()
    if optimizer is not None:
        return running_loss / n_batches

    return running_loss / n_batches, fnames, np.vstack(
        gt).squeeze(), np.vstack(predicts)