Ejemplo n.º 1
0
def init_metadata():
    kvs = GlobalKVS()

    imgs = glob.glob(os.path.join(kvs['args'].dataset, '*', 'imgs', '*.png'))
    imgs.sort(key=lambda x: x.split('/')[-1])

    masks = glob.glob(os.path.join(kvs['args'].dataset, '*', 'masks', '*.png'))
    masks.sort(key=lambda x: x.split('/')[-1])

    sample_id = list(map(lambda x: x.split('/')[-3], imgs))
    subject_id = list(map(lambda x: x.split('/')[-3].split('_')[0], imgs))

    metadata = pd.DataFrame(
        data={
            'img_fname': imgs,
            'mask_fname': masks,
            'sample_id': sample_id,
            'subject_id': subject_id
        })

    metadata['sample_subject_proj'] = metadata.apply(
        lambda x: gen_image_id(x.img_fname, x.sample_id), 1)
    grades = pd.read_csv(kvs['args'].grades)
    metadata = pd.merge(metadata, grades, on='sample_id')
    kvs.update('metadata', metadata)
    return metadata
Ejemplo n.º 2
0
def save_checkpoint(net, loss, optimizer, val_metric_name, comparator='lt'):
    """
    Flexible function that saves the model and the optimizer states using a metric and a comparator.

    Parameters
    ----------
    net : torch.nn.Module
        Model
    optimizer : torch.optim.Optimizer
        Optimizer
    val_metric_name : str
        Name of the metric that needs to be used for snapshot comparison.
        This name needs match the once that were created in the callback function passed to
        `log_metrics`.
    comparator : str
        How to compare the previous and the current metric values - `lt` is less than, and `gt` is greater than.

    Returns
    -------
    out : None

    """
    if isinstance(net, torch.nn.DataParallel):
        net = net.module

    kvs = GlobalKVS()
    fold_id = kvs['cur_fold']
    epoch = kvs['cur_epoch']
    val_metric = kvs[f'val_metrics_fold_[{fold_id}]'][-1][0][val_metric_name]
    comparator = getattr(operator, comparator)
    cur_snapshot_name = os.path.join(
        os.path.join(kvs['args'].workdir, 'snapshots', kvs['snapshot_name'],
                     f'fold_{fold_id}_epoch_{epoch}.pth'))

    state = {
        'model': net.state_dict(),
        'optimizer': optimizer.state_dict(),
        'loss': loss.state_dict()
    }
    if kvs['prev_model'] is None:
        print(
            colored('====> ', 'red') + 'Snapshot was saved to',
            cur_snapshot_name)
        torch.save(state, cur_snapshot_name)
        kvs.update('prev_model', cur_snapshot_name)
        kvs.update('best_val_metric', val_metric)

    else:
        if comparator(val_metric, kvs['best_val_metric']):
            print(
                colored('====> ', 'red') + 'Snapshot was saved to',
                cur_snapshot_name)
            os.remove(kvs['prev_model'])
            torch.save(state, cur_snapshot_name)
            kvs.update('prev_model', cur_snapshot_name)
            kvs.update('best_val_metric', val_metric)
Ejemplo n.º 3
0
def init_pd_meta():
    """
    Basic implementation of metadata loading. Loads the pandas data frame and stores
    it in global KVS under the `metadata` tag.

    Returns
    -------
    out : None
    """
    kvs = GlobalKVS()
    metadata = pd.read_csv(os.path.join(kvs['args'].workdir, kvs['args'].metadata))
    kvs.update('metadata', metadata)
Ejemplo n.º 4
0
def init_augs():
    kvs = GlobalKVS()
    args = kvs['args']
    cutout = slt.ImageCutOut(cutout_size=(int(args.cutout * args.crop_x),
                                          int(args.cutout * args.crop_y)),
                             p=0.5)
    # plus-minus 1.3 pixels
    jitter = slt.KeypointsJitter(dx_range=(-0.003, 0.003),
                                 dy_range=(-0.003, 0.003))
    ppl = tvt.Compose([
        jitter if args.use_target_jitter else slc.Stream(),
        slc.SelectiveStream([
            slc.Stream([
                slt.RandomFlip(p=0.5, axis=1),
                slt.RandomProjection(affine_transforms=slc.Stream([
                    slt.RandomScale(range_x=(0.8, 1.3), p=1),
                    slt.RandomRotate(rotation_range=(-90, 90), p=1),
                    slt.RandomShear(
                        range_x=(-0.1, 0.1), range_y=(-0.1, 0.1), p=0.5),
                ]),
                                     v_range=(1e-5, 2e-3),
                                     p=0.5),
                slt.RandomScale(range_x=(0.5, 2.5), p=0.5),
            ]),
            slc.Stream()
        ],
                            probs=[0.7, 0.3]),
        slc.Stream([
            slt.PadTransform((args.pad_x, args.pad_y), padding='z'),
            slt.CropTransform((args.crop_x, args.crop_y), crop_mode='r'),
        ]),
        slc.SelectiveStream([
            slt.ImageSaltAndPepper(p=1, gain_range=0.01),
            slt.ImageBlur(p=1, blur_type='g', k_size=(3, 5)),
            slt.ImageBlur(p=1, blur_type='m', k_size=(3, 5)),
            slt.ImageAdditiveGaussianNoise(p=1, gain_range=0.5),
            slc.Stream([
                slt.ImageSaltAndPepper(p=1, gain_range=0.05),
                slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)),
            ]),
            slc.Stream([
                slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)),
                slt.ImageSaltAndPepper(p=1, gain_range=0.01),
            ]),
            slc.Stream()
        ],
                            n=1),
        slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)),
        cutout if args.use_cutout else slc.Stream(),
        partial(solt2torchhm, downsample=None, sigma=None),
    ])
    kvs.update('train_trf', ppl)
Ejemplo n.º 5
0
def train_fold(net, train_loader, optimizer, criterion, val_loader, scheduler):
    kvs = GlobalKVS()
    fold_id = kvs['cur_fold']
    writer = SummaryWriter(
        os.path.join(kvs['args'].workdir, 'snapshots', kvs['snapshot_name'],
                     'logs', 'fold_{}'.format(fold_id), kvs['snapshot_name']))

    for epoch in range(kvs['args'].n_epochs):
        print(
            colored('==> ', 'green') +
            f'Training epoch [{epoch}] with LR {scheduler.get_lr()}')
        kvs.update('cur_epoch', epoch)
        train_loss, _ = pass_epoch(net, train_loader, optimizer, criterion)
        val_loss, conf_matrix = pass_epoch(net, val_loader, None, criterion)
        log_metrics(writer, train_loss, val_loss, conf_matrix)
        save_checkpoint(net, optimizer, 'val_loss', 'lt')
        scheduler.step()
Ejemplo n.º 6
0
def pass_epoch(net, loader, optimizer, criterion):
    kvs = GlobalKVS()
    net.train(optimizer is not None)

    fold_id = kvs['cur_fold']
    epoch = kvs['cur_epoch']
    max_ep = kvs['args'].n_epochs
    n_classes = kvs['args'].n_classes

    running_loss = 0.0
    n_batches = len(loader)
    confusion_matrix = np.zeros((n_classes, n_classes), dtype=np.uint64)
    device = next(net.parameters()).device
    pbar = tqdm(total=n_batches, ncols=200)
    with torch.set_grad_enabled(optimizer is not None):
        for i, entry in enumerate(loader):
            if optimizer is not None:
                optimizer.zero_grad()

            inputs = entry['img'].to(device)
            mask = entry['mask'].to(device)
            outputs = net(inputs)
            loss = criterion(outputs, mask)

            if optimizer is not None:
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                pbar.set_description(
                    f"Fold [{fold_id}] [{epoch} | {max_ep}] | "
                    f"Running loss {running_loss / (i + 1):.5f} / {loss.item():.5f}"
                )
            else:
                running_loss += loss.item()
                pbar.set_description(
                    desc=
                    f"Fold [{fold_id}] [{epoch} | {max_ep}] | Validation progress"
                )

                if n_classes == 2:
                    preds = outputs.gt(kvs['args'].binary_threshold)
                elif n_classes > 2:
                    preds = outputs.argmax(axis=1)
                else:
                    raise ValueError

                preds = preds.float().to('cpu').numpy()
                mask = mask.float().to('cpu').numpy()
                confusion_matrix += metrics.calculate_confusion_matrix_from_arrays(
                    preds, mask, n_classes)

            pbar.update()
            gc.collect()
        gc.collect()
        pbar.close()

    return running_loss / n_batches, confusion_matrix
Ejemplo n.º 7
0
def init_folds(img_group_id_colname=None, img_class_colname=None):
    """
    Initialzies the cross-validation splits.

    Parameters
    ----------
    img_group_id_colname : str or None
        Column in `metadata` that is used to create cross-validation splits.
        If not None, then images that have the same group_id are never in train and validation.
    img_class_colname : str or None
        Column in `metadata` that is used to create cross-validation splits. If not none,
        splits are stratifed to ensure the same distribution of `img_class_colname` in train and validation.

    Returns
    -------

    """
    kvs = GlobalKVS()

    if img_group_id_colname is not None:
        gkf = GroupKFold(kvs['args'].n_folds)
        if img_class_colname is not None:
            class_col_name = getattr(kvs['metadata'], img_class_colname, None)
        else:
            class_col_name = None
        splitter = gkf.split(X=kvs['metadata'],
                             y=class_col_name,
                             groups=getattr(kvs['metadata'], img_group_id_colname))
    else:
        if img_class_colname is not None:
            skf = StratifiedKFold(kvs['args'].n_folds)
            splitter = skf.split(X=kvs['metadata'],
                                 y=getattr(kvs['metadata'], img_class_colname, None))
        else:
            kf = KFold(kvs['args'].n_folds)
            splitter = kf.split(X=kvs['metadata'])

    cv_split = []
    for fold_id, (train_ind, val_ind) in enumerate(splitter):

        if kvs['args'].fold != -1 and fold_id != kvs['args'].fold:
            continue

        np.random.shuffle(train_ind)
        train_ind = train_ind[::kvs['args'].skip_train]

        cv_split.append((fold_id,
                         kvs['metadata'].iloc[train_ind],
                         kvs['metadata'].iloc[val_ind]))

        kvs.update(f'losses_fold_[{fold_id}]', None, list)
        kvs.update(f'val_metrics_fold_[{fold_id}]', None, list)

    kvs.update('cv_split', cv_split)
Ejemplo n.º 8
0
def init_binary_segmentation_augs():
    kvs = GlobalKVS()
    ppl = tvt.Compose([
        img_binary_mask2solt,
        slc.Stream([
            slt.PadTransform(pad_to=(kvs['args'].pad_x, kvs['args'].pad_y)),
            slt.RandomFlip(axis=1, p=0.5),
            slt.CropTransform(crop_size=(kvs['args'].crop_x,
                                         kvs['args'].crop_y),
                              crop_mode='r'),
            slt.ImageGammaCorrection(gamma_range=(kvs['args'].gamma_min,
                                                  kvs['args'].gamma_max),
                                     p=0.5),
        ]),
        solt2img_binary_mask,
        partial(apply_by_index, transform=numpy2tens, idx=[0, 1]),
    ])

    kvs.update('train_trf', ppl)

    return ppl
Ejemplo n.º 9
0
def init_ms_scheduler(optimizer):
    """
    Initializes a simple multi-step learning rate scheduler.
    The scheduling is done according to the scheduling parameters specified in the arguments.
    The parameter responsible for this is `lr_drop`.

    Parameters
    ----------
    optimizer : torch.optim.Optimizer
        Optimizer for which the scheduler need to be created
    Returns
    -------
    out : lr_scheduler.Scheduler
        Created Scheduler

    """
    kvs = GlobalKVS()
    return lr_scheduler.MultiStepLR(optimizer, kvs['args'].lr_drop)
Ejemplo n.º 10
0
def init_model():
    kvs = GlobalKVS()
    net = init_model_from_args(kvs['args'])

    if kvs['args'].init_model_from != '':
        cur_fold = kvs['cur_fold']
        pattern_snp = os.path.join(kvs['args'].init_model_from,
                                   f'fold_{cur_fold}_*.pth')
        state_dict = torch.load(glob.glob(pattern_snp)[0])['model']
        pretrained_dict = {
            k: v
            for k, v in state_dict.items() if 'out_block' not in k
        }
        net_state_dict = net.state_dict()
        net_state_dict.update(pretrained_dict)
        net.load_state_dict(net_state_dict)

    return net.to('cuda')
Ejemplo n.º 11
0
def init_binary_loss():
    kvs = GlobalKVS()
    if kvs['args'].n_classes == 2:
        if kvs['args'].loss == 'combined':
            return CombinedLoss(
                [
                    BCEWithLogitsLoss2d(),
                    SoftJaccardLoss(use_log=kvs['args'].log_jaccard)
                ],
                weights=[1 - kvs['args'].loss_weight, kvs['args'].loss_weight])
        elif kvs['args'].loss == 'bce':
            return BCEWithLogitsLoss2d()
        elif kvs['args'].loss == 'jaccard':
            return SoftJaccardLoss(use_log=kvs['args'].log_jaccard)
        elif kvs['args'].loss == 'focal':
            return FocalLoss()
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError
Ejemplo n.º 12
0
def init_data_processing(img_reader=read_rgb_ocv,
                         mask_reader=read_gs_binary_mask_ocv):
    kvs = GlobalKVS()

    dataset = SegmentationDataset(split=kvs['metadata'],
                                  trf=kvs['train_trf'],
                                  read_img=img_reader,
                                  read_mask=mask_reader)

    tmp = init_mean_std(snapshots_dir=os.path.join(kvs['args'].workdir,
                                                   'snapshots'),
                        dataset=dataset,
                        batch_size=kvs['args'].bs,
                        n_threads=kvs['args'].n_threads,
                        n_classes=kvs['args'].n_classes)

    if len(tmp) == 3:
        mean_vector, std_vector, class_weights = tmp
    elif len(tmp) == 2:
        mean_vector, std_vector = tmp
    else:

        raise ValueError('Incorrect format of mean/std/class-weights')

    norm_trf = partial(normalize_channel_wise,
                       mean=mean_vector,
                       std=std_vector)

    train_trf = tvt.Compose(
        [kvs['train_trf'],
         partial(apply_by_index, transform=norm_trf, idx=0)])

    val_trf = tvt.Compose([
        partial(apply_by_index, transform=numpy2tens, idx=[0, 1]),
        partial(apply_by_index, transform=norm_trf, idx=0)
    ])
    kvs.update('class_weights', class_weights)
    kvs.update('train_trf', train_trf)
    kvs.update('val_trf', val_trf)
Ejemplo n.º 13
0
def init_loss():
    kvs = GlobalKVS()
    if kvs['args'].loss_type == 'elastic':
        loss = ElasticLoss(w=kvs['args'].loss_weight)
    elif kvs['args'].loss_type == 'l2':
        loss = LNLoss(space='l2')
    elif kvs['args'].loss_type == 'l1':
        loss = LNLoss(space='l1')
    elif kvs['args'].loss_type == 'wing':
        loss = WingLoss(width=kvs['args'].wing_w, curvature=kvs['args'].wing_c)
    elif kvs['args'].loss_type == 'robust':
        loss = GeneralizedRobustLoss(num_dims=16 * 2 if kvs['args'].annotations == 'hc' else 2,
                                     alpha_init=kvs['args'].alpha_robust,
                                     scale_init=kvs['args'].c_robust,
                                     alpha_lo=kvs['args'].alpha_robust_min,
                                     alpha_hi=kvs['args'].alpha_robust_max)
    else:
        raise NotImplementedError

    return loss.to('cuda')
Ejemplo n.º 14
0
def log_metrics(writer,
                train_loss,
                val_loss,
                val_results,
                val_results_callback=None):
    """
    Basic function to log the results from the validation stage.
    takes Tensorboard writer, train loss, validation loss, the artifacts produced during the validation phase,
    and also additional callback that can process these data, e.g. compute the metrics and
    visualize them in Tensorboard. By default, train and validation losses are visualized outside of the callback.
    If any metric is computed in the callback, it is useful to log it into a dictionary `to_log`.



    Parameters
    ----------
    writer : SummaryWriter
        Tensorboard summary writer
    train_loss : float
        Training loss
    val_loss : float
        Validation loss
    val_results : object
        Artifacts produced during teh validation
    val_results_callback : Callable or None
        A callback function that can process the artifacts and e.g. display those in Tensorboard.

    Returns
    -------
    out : None

    """
    kvs = GlobalKVS()

    print(colored('==> ', 'green') + 'Metrics:')
    print(colored('====> ', 'green') + 'Train loss:', train_loss)
    print(colored('====> ', 'green') + 'Val loss:', val_loss)

    to_log = {'train_loss': train_loss, 'val_loss': val_loss}
    val_metrics = {'epoch': kvs['cur_epoch']}
    val_metrics.update(to_log)
    writer.add_scalars(f"Losses_{kvs['args'].experiment_tag}", to_log,
                       kvs['cur_epoch'])
    if val_results_callback is not None:
        val_results_callback(writer, val_metrics, to_log, val_results)

    kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', to_log)
    kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', val_metrics)
Ejemplo n.º 15
0
def init_data_processing():
    kvs = GlobalKVS()

    dataset = LandmarkDataset(data_root=kvs['args'].dataset_root,
                              split=kvs['metadata'],
                              hc_spacing=kvs['args'].hc_spacing,
                              lc_spacing=kvs['args'].lc_spacing,
                              transform=kvs['train_trf'],
                              ann_type=kvs['args'].annotations,
                              image_pad=kvs['args'].img_pad)

    tmp = init_mean_std(snapshots_dir=os.path.join(kvs['args'].workdir,
                                                   'snapshots'),
                        dataset=dataset,
                        batch_size=kvs['args'].bs,
                        n_threads=kvs['args'].n_threads,
                        n_classes=-1)

    if len(tmp) == 3:
        mean_vector, std_vector, class_weights = tmp
    elif len(tmp) == 2:
        mean_vector, std_vector = tmp
    else:
        raise ValueError('Incorrect format of mean/std/class-weights')

    norm_trf = partial(normalize_channel_wise,
                       mean=mean_vector,
                       std=std_vector)

    train_trf = tvt.Compose(
        [kvs['train_trf'],
         partial(apply_by_index, transform=norm_trf, idx=0)])

    val_trf = tvt.Compose([
        slc.Stream([
            slt.PadTransform((kvs['args'].pad_x, kvs['args'].pad_y),
                             padding='z'),
            slt.CropTransform((kvs['args'].crop_x, kvs['args'].crop_y),
                              crop_mode='c'),
        ]),
        partial(solt2torchhm, downsample=None, sigma=None),
        partial(apply_by_index, transform=norm_trf, idx=0)
    ])

    kvs.update('train_trf', train_trf)
    kvs.update('val_trf', val_trf)
Ejemplo n.º 16
0
def init_segmentation_loaders(x_train,
                              x_val,
                              img_reader=read_rgb_ocv,
                              mask_reader=read_gs_binary_mask_ocv,
                              img_id_colname=None,
                              img_group_id_colname=None):
    kvs = GlobalKVS()

    train_dataset = SegmentationDataset(
        split=x_train,
        trf=kvs['train_trf'],
        read_img=img_reader,
        read_mask=mask_reader,
        img_id_colname=img_id_colname,
        img_group_id_colname=img_group_id_colname)

    val_dataset = SegmentationDataset(
        split=x_val,
        trf=kvs['val_trf'],
        read_img=img_reader,
        read_mask=mask_reader,
        img_id_colname=img_id_colname,
        img_group_id_colname=img_group_id_colname)

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=kvs['args'].bs,
                                   num_workers=kvs['args'].n_threads,
                                   shuffle=True,
                                   drop_last=True,
                                   worker_init_fn=lambda wid: np.random.seed(
                                       np.uint32(torch.initial_seed() + wid)))

    val_loader = data.DataLoader(val_dataset,
                                 batch_size=kvs['args'].val_bs,
                                 num_workers=kvs['args'].n_threads)

    return train_loader, val_loader
Ejemplo n.º 17
0
def init_loaders(x_train, x_val, sequential_val_sampler=False):
    kvs = GlobalKVS()
    train_ds = LandmarkDataset(data_root=kvs['args'].dataset_root,
                               split=x_train,
                               hc_spacing=kvs['args'].hc_spacing,
                               lc_spacing=kvs['args'].lc_spacing,
                               transform=kvs['train_trf'],
                               ann_type=kvs['args'].annotations,
                               image_pad=kvs['args'].img_pad)

    val_ds = LandmarkDataset(data_root=kvs['args'].dataset_root,
                             split=x_val,
                             hc_spacing=kvs['args'].hc_spacing,
                             lc_spacing=kvs['args'].lc_spacing,
                             transform=kvs['val_trf'],
                             ann_type=kvs['args'].annotations,
                             image_pad=kvs['args'].img_pad)

    train_loader = DataLoader(train_ds,
                              batch_size=kvs['args'].bs,
                              num_workers=kvs['args'].n_threads,
                              shuffle=True,
                              drop_last=True,
                              worker_init_fn=lambda wid: np.random.seed(
                                  np.uint32(torch.initial_seed() + wid)))

    if sequential_val_sampler:
        sampler = torch.utils.data.sampler.SequentialSampler(
            data_source=val_ds)
    else:
        sampler = None
    val_loader = DataLoader(val_ds,
                            batch_size=kvs['args'].val_bs,
                            num_workers=kvs['args'].n_threads,
                            sampler=sampler)

    return train_loader, val_loader
Ejemplo n.º 18
0
def init_optimizer_default(net, loss):
    """
    Initializes the optimizer for a given model.
    Currently supported optimizers are Adam and SGD with default parameters.
    Learning rate (LR) and weight decay (WD) must be specified in the arguments as `lr` and `wd`, respectively.
    LR and WD are retrieved automatically from global KVS.
    Parameters
    ----------
    net : torch.Module

    Returns
    -------
    out : torch.optim.Optimizer
        Initialized optimizer.

    """
    kvs = GlobalKVS()
    if kvs['args'].optimizer == 'adam':
        return optim.Adam([{
            'params': net.parameters()
        }, {
            'params': loss.parameters()
        }],
                          lr=kvs['args'].lr,
                          weight_decay=kvs['args'].wd)
    elif kvs['args'].optimizer == 'sgd':
        return optim.SGD([{
            'params': net.parameters()
        }, {
            'params': loss.parameters()
        }],
                         lr=kvs['args'].lr,
                         weight_decay=kvs['args'].wd,
                         momentum=0.9)
    else:
        raise NotImplementedError
Ejemplo n.º 19
0
def log_metrics(writer, train_loss, val_loss, conf_matrix):
    kvs = GlobalKVS()

    dices = {
        'dice_{}'.format(cls): dice
        for cls, dice in enumerate(calculate_dice(conf_matrix))
    }
    ious = {
        'iou_{}'.format(cls): iou
        for cls, iou in enumerate(calculate_iou(conf_matrix))
    }
    print(colored('==> ', 'green') + 'Metrics:')
    print(colored('====> ', 'green') + 'Train loss:', train_loss)
    print(colored('====> ', 'green') + 'Val loss:', val_loss)
    print(colored('====> ', 'green') + f'Val Dice: {dices}')
    print(colored('====> ', 'green') + f'Val IoU: {ious}')
    dices_tb = {}
    for cls in range(1, len(dices)):
        dices_tb[f"Dice [{cls}]"] = dices[f"dice_{cls}"]

    ious_tb = {}
    for cls in range(1, len(ious)):
        ious_tb[f"IoU [{cls}]"] = ious[f"iou_{cls}"]

    to_log = {'train_loss': train_loss, 'val_loss': val_loss}
    # Tensorboard logging
    writer.add_scalars(f"Losses_{kvs['args'].model}", to_log, kvs['cur_epoch'])
    writer.add_scalars('Metrics/Dice', dices_tb, kvs['cur_epoch'])
    writer.add_scalars('Metrics/IoU', ious_tb, kvs['cur_epoch'])
    # KVS logging
    to_log.update({'epoch': kvs['cur_epoch']})
    val_metrics = {'epoch': kvs['cur_epoch']}
    val_metrics.update(to_log)
    val_metrics.update(dices)
    val_metrics.update({'conf_matrix': conf_matrix})

    kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', to_log)
    kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', val_metrics)
Ejemplo n.º 20
0
def init_session(args):
    """
    Basic function that initializes each training loop.
    Sets the seed based on the parsed args, creates the snapshots dir and initializes global KVS.

    Parameters
    ----------
    args : Namespace
        Arguments from argparse.

    Returns
    -------
    out : tuple
        Args, snapshot name and global KVS.
    """
    # Initializing the seeds
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    if args.experiment_config != '':
        with open(args.experiment_config, 'r') as f:
            conf = yaml.safe_load(f)
    else:
        conf = None
        raise Warning('No experiment config has has been provided')

    # Creating the snapshot
    snapshot_name = time.strftime(f'{socket.gethostname()}_%Y_%m_%d_%H_%M_%S')
    os.makedirs(os.path.join(args.workdir, 'snapshots', snapshot_name),
                exist_ok=True)

    kvs = GlobalKVS(
        os.path.join(args.workdir, 'snapshots', snapshot_name, 'session.pkl'))
    if conf is not None:
        kvs.update('config', conf)
        with open(
                os.path.join(args.workdir, 'snapshots', snapshot_name,
                             'config.yml'), 'w') as conf_file:
            yaml.dump(conf, conf_file)

    res = git_info()
    if res is not None:
        kvs.update('git branch name', res[0])
        kvs.update('git commit id', res[1])
    else:
        kvs.update('git branch name', None)
        kvs.update('git commit id', None)

    kvs.update('pytorch_version', torch.__version__)

    if torch.cuda.is_available():
        kvs.update('cuda', torch.version.cuda)
        kvs.update('gpus', torch.cuda.device_count())
    else:
        kvs.update('cuda', None)
        kvs.update('gpus', None)

    kvs.update('snapshot_name', snapshot_name)
    kvs.update('args', args)

    return args, snapshot_name, kvs
Ejemplo n.º 21
0
def train_fold(pass_epoch,
               net,
               train_loader,
               optimizer,
               criterion,
               val_loader,
               scheduler,
               save_by='val_loss',
               cmp='lt',
               log_metrics_cb=None,
               img_key=None):
    """
    A common implementation of training one fold of a neural network. Presumably, it should be called
    within cross-validation loop.

    Parameters
    ----------
    pass_epoch : Callable
        Function that trains or validates one epoch
    net : torch.nn.Module
        Model to train
    train_loader : torch.utils.data.DataLoader
        Training data loader
    optimizer : torch.optim.Optimizer
        Optimizer
    criterion : torch.nn.Module
        Loss function
    val_loader : torch.utils.data.DataLoader
        Validation data loader
    scheduler : lr_scheduler.Scheduler
        Learning rate scheduler
    save_by: str
        Name of the metric used to save the snapshot. Val loss by default.
        Also, ReduceOnPlateau will use this metric to drop LR.
    cmp: str
        Comparator for saving the snapshots. Can be `lt` (less than) or `gt` -- (greater than).
    log_metrics_cb : Callable or None
        Callback that processes the artifacts from validation stage.
    img_key : str
        Key in the dataloader that allows to extact an image. Used in SWA.

    Returns
    -------

    """
    kvs = GlobalKVS()
    fold_id = kvs['cur_fold']
    writer = SummaryWriter(
        os.path.join(kvs['args'].workdir, 'snapshots', kvs['snapshot_name'],
                     'logs', 'fold_{}'.format(fold_id), kvs['snapshot_name']))

    for epoch in range(kvs['args'].n_epochs):
        if scheduler is not None:
            lrs = [param_group['lr'] for param_group in optimizer.param_groups]
            print(
                colored('==> ', 'green') +
                f'Training epoch [{epoch}] with LR {lrs}')
        else:
            print(colored('==> ', 'green') + f'Training epoch [{epoch}]')
        kvs.update('cur_epoch', epoch)
        train_loss, _ = pass_epoch(net, train_loader, optimizer, criterion)
        if isinstance(optimizer, swa.SWA):
            optimizer.swap_swa_sgd()
            assert img_key is not None
            bn_update_cb(net, train_loader, img_key)

        val_loss, val_results = pass_epoch(net, val_loader, None, criterion)
        log_metrics(writer, train_loss, val_loss, val_results, log_metrics_cb)
        save_checkpoint(net, criterion, optimizer, save_by, cmp)
        if scheduler is not None:
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(kvs[f'val_metrics_fold_[{kvs["cur_fold"]}]'][-1]
                               [0][save_by])
            else:
                scheduler.step()
Ejemplo n.º 22
0
def pass_epoch(net, loader, optimizer, criterion):
    kvs = GlobalKVS()
    net.train(optimizer is not None)

    fold_id = kvs['cur_fold']
    epoch = kvs['cur_epoch']
    max_ep = kvs['args'].n_epochs

    running_loss = 0.0
    n_batches = len(loader)
    landmark_errors = {}
    device = next(net.parameters()).device
    pbar = tqdm(total=n_batches, ncols=200)

    with torch.set_grad_enabled(optimizer is not None):
        for i, entry in enumerate(loader):
            if optimizer is not None:
                optimizer.zero_grad()

            inputs = entry['img'].to(device)
            target = entry['kp_gt'].to(device).float()

            if kvs['args'].use_mixup and optimizer is not None:
                loss = mixup_pass(net, criterion, inputs, target,
                                  kvs['args'].mixup_alpha)
            else:
                outputs = net(inputs)
                loss = criterion(outputs, target)

            if optimizer is not None:
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                pbar.set_description(
                    f"Fold [{fold_id}] [{epoch} | {max_ep}] | "
                    f"Running loss {running_loss / (i + 1):.5f} / {loss.item():.5f}"
                )
            else:
                running_loss += loss.item()
                pbar.set_description(
                    desc=
                    f"Fold [{fold_id}] [{epoch} | {max_ep}] | Validation progress"
                )
            if optimizer is None:
                target_kp = entry['kp_gt'].numpy()
                h, w = inputs.size(2), inputs.size(3)
                if isinstance(outputs, tuple):
                    predicts = outputs[-1].to('cpu').numpy()
                else:
                    predicts = outputs.to('cpu').numpy()

                xy_batch = predicts
                xy_batch[:, :, 0] *= (w - 1)
                xy_batch[:, :, 1] *= (h - 1)

                target_kp = target_kp
                xy_batch = xy_batch

                target_kp[:, :, 0] *= (w - 1)
                target_kp[:, :, 1] *= (h - 1)

                for kp_id in range(target_kp.shape[1]):
                    spacing = getattr(kvs['args'],
                                      f"{kvs['args'].annotations}_spacing")
                    d = target_kp[:, kp_id] - xy_batch[:, kp_id]
                    err = np.sqrt(np.sum(d**2, 1)) * spacing
                    if kp_id not in landmark_errors:
                        landmark_errors[kp_id] = list()

                    landmark_errors[kp_id].append(err)

            pbar.update()
            gc.collect()
        gc.collect()
        pbar.close()

    if len(landmark_errors) > 0:
        for kp_id in landmark_errors:
            landmark_errors[kp_id] = np.hstack(landmark_errors[kp_id])
    else:
        landmark_errors = None

    return running_loss / n_batches, landmark_errors
Ejemplo n.º 23
0
import cv2
import pickle
import argparse
import os


from deeppipeline.kvs import GlobalKVS
from deeppipeline.io import read_gs_binary_mask_ocv, read_gs_ocv
from deeppipeline.segmentation.evaluation import run_oof_binary

cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)


if __name__ == "__main__":
    kvs = GlobalKVS(None)
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_root', default='')
    parser.add_argument('--tta', type=bool, default=False)
    parser.add_argument('--bs', type=int, default=32)
    parser.add_argument('--n_threads', type=int, default=12)
    parser.add_argument('--snapshots_root', default='')
    parser.add_argument('--snapshot', default='')
    args = parser.parse_args()

    with open(os.path.join(args.snapshots_root, args.snapshot, 'session.pkl'), 'rb') as f:
        session_backup = pickle.load(f)

        args.model = session_backup['args'][0].model
        args.n_inputs = session_backup['args'][0].n_inputs
        args.n_classes = session_backup['args'][0].n_classes
Ejemplo n.º 24
0
    with open(os.path.join(snp_full_path, 'config.yml'), 'r') as f:
        cfg = yaml.load(f)
    print(
        colored('==> Experiment: ', 'red') +
        cfg['experiment'][0]['experiment_description'])
    print(colored('==> Snapshot: ', 'green') + args.snapshot)

    snp_args = snapshot_session['args'][0]
    for arg in vars(snp_args):
        if not hasattr(args, arg):
            setattr(args, arg, getattr(snp_args, arg))
    args.init_model_from = ''

    if not os.path.isfile(os.path.join(oof_results_dir, 'oof_results.npz')):
        kvs = GlobalKVS()
        kvs.update('args', args)
        kvs.update('val_trf', snapshot_session['val_trf'][0])
        kvs.update('train_trf', snapshot_session['train_trf'][0])

        oof_inference = []
        oof_gt = []
        subject_ids = []
        kls = []
        with torch.no_grad():
            for fold_id, train_split, val_split in snapshot_session[
                    'cv_split'][0]:
                _, val_loader = init_loaders(train_split,
                                             val_split,
                                             sequential_val_sampler=True)
                net = init_model()
Ejemplo n.º 25
0
def segmentation_unet(data_xy, arguments, sample):
    """
    The newest pipeline for Unet segmentation. Model training utilizes augmentations to improve robustness.

    Parameters
    ----------
    data : ndarray (3-dimensional)
        Input data.
    args : Namespace
        Input arguments
    sample : str
        Sample name

    Returns
    -------
    Segmented calcified tissue mask.
    """
    kvs = GlobalKVS(None)

    parser = ArgumentParser()
    parser.add_argument('--dataset_root', default='../Data/')
    parser.add_argument('--tta', type=bool, default=False)
    parser.add_argument('--bs', type=int, default=28)
    parser.add_argument('--n_threads', type=int, default=12)
    parser.add_argument('--model', type=str, default='unet')
    parser.add_argument('--n_inputs', type=int, default=1)
    parser.add_argument('--n_classes', type=int, default=2)
    parser.add_argument('--bw', type=int, default=24)
    parser.add_argument('--depth', type=int, default=6)
    parser.add_argument('--cdepth', type=int, default=1)
    parser.add_argument('--seed', type=int, default=42)
    # parser.add_argument('--snapshots_root', default='../workdir/snapshots/')
    # parser.add_argument('--snapshot', default='dios-erc-gpu_2019_12_29_13_24')
    args = parser.parse_args()

    kvs.update('args', args)

    # Load model
    models = glob(str(arguments.model_path / f'fold_[0-9]*.pth'))
    #models = glob(str(arguments.model_path / f'fold_3*.pth'))
    models.sort()

    # List the models
    device = 'cuda'
    model_list = []

    for fold in range(len(models)):
        model = init_model(ignore_data_parallel=True)
        snp = torch.load(models[fold])
        if isinstance(snp, dict):
            snp = snp['model']
        model.load_state_dict(snp)
        model_list.append(model)

    # Merge folds into one model
    model = InferenceModel(model_list).to(device)
    # Initialize model
    model.eval()

    tmp = np.load(str(arguments.model_path.parent / 'mean_std.npy'),
                  allow_pickle=True)
    mean, std = tmp[0][0], tmp[1][0]

    # Flip the z-dimension
    #data_xy = np.flip(data_xy, axis=2)
    # Transpose data
    data_xz = np.transpose(data_xy, (2, 0, 1))  # X-Z-Y
    data_yz = np.transpose(data_xy, (2, 1, 0))  # Y-Z-X  # Y-Z-X-Ch
    mask_xz = np.zeros(data_xz.shape)
    mask_yz = np.zeros(data_yz.shape)
    # res_xz = int(data_xz.shape[2] % args.bs > 0)
    # res_yz = int(data_yz.shape[2] % args.bs > 0)

    with torch.no_grad():
        # for idx in tqdm(range(data_xz.shape[2] // args.bs + res_xz), desc='Running inference, XZ'):
        for idx in tqdm(range(data_xz.shape[2]), desc='Running inference, XZ'):
            """
            try:
                img = np.expand_dims(data_xz[:, :, args.bs * idx:args.bs * (idx + 1)], axis=2)
                mask_xz[:, :, args.bs * idx: args.bs * (idx + 1)] = inference(model, img, shape=arguments.input_shape)
            except IndexError:
                img = np.expand_dims(data_xz[:, :, args.bs * idx:], axis=2)
                mask_xz[:, :, args.bs * idx:] = inference(model, img, shape=arguments.input_shape)
            """
            img = np.expand_dims(data_xz[:, :, idx], axis=2)
            mask_xz[:, :, idx] = inference_tiles(model,
                                                 img,
                                                 shape=arguments.input_shape,
                                                 mean=mean,
                                                 std=std)
        # 2nd orientation
        # for idx in tqdm(range(data_yz.shape[2] // args.bs + res_yz), desc='Running inference, YZ'):
        for idx in tqdm(range(data_yz.shape[2]), desc='Running inference, YZ'):
            """
            try:
                img = np.expand_dims(data_yz[:, :, args.bs * idx: args.bs * (idx + 1)], axis=2)
                mask_yz[:, :, args.bs * idx: args.bs * (idx + 1)] = inference(model, img, shape=arguments.input_shape)
            except IndexError:
                img = np.expand_dims(data_yz[:, :, args.bs * idx:], axis=2)
                mask_yz[:, :, args.bs * idx:] = inference(model, img, shape=arguments.input_shape)
            """
            img = np.expand_dims(data_yz[:, :, idx], axis=2)
            mask_yz[:, :, idx] = inference_tiles(model,
                                                 img,
                                                 shape=arguments.input_shape,
                                                 mean=mean,
                                                 std=std)
    # Average probability maps
    mask_final = (
        (mask_xz + np.transpose(mask_yz,
                                (0, 2, 1))) / 2) >= arguments.threshold
    mask_xz = list()
    mask_yz = list()
    data_xz = list()

    mask_final = np.transpose(mask_final, (1, 2, 0))
    mask_final[:, :, -mask_final.shape[2] // 3:] = False

    largest = largest_object(mask_final)

    return largest
Ejemplo n.º 26
0
def init_scheduler(optimizer):
    kvs = GlobalKVS()
    return lr_scheduler.MultiStepLR(optimizer, kvs['args'].lr_drop)