Exemple #1
0
def init_data_processing():
    kvs = GlobalKVS()
    train_augs = init_train_augmentation_pipeline()

    dataset = SegmentationDataset(split=kvs['metadata_train'],
                                  trf=train_augs,
                                  read_img=read_gs_ocv,
                                  read_mask=read_gs_mask_ocv)

    mean_vector, std_vector, class_weights = init_mean_std(
        snapshots_dir=kvs['args'].snapshots,
        dataset=dataset,
        batch_size=kvs['args'].bs,
        n_threads=kvs['args'].n_threads,
        n_classes=kvs['args'].n_classes)

    norm_trf = transforms.Normalize(
        torch.from_numpy(mean_vector).float(),
        torch.from_numpy(std_vector).float())
    train_trf = transforms.Compose(
        [train_augs,
         partial(apply_by_index, transform=norm_trf, idx=0)])

    val_trf = transforms.Compose([
        partial(apply_by_index, transform=gs2tens, idx=[0, 1]),
        partial(apply_by_index, transform=norm_trf, idx=0)
    ])
    kvs.update('class_weights', class_weights)
    kvs.update('train_trf', train_trf)
    kvs.update('val_trf', val_trf)
    kvs.save_pkl(
        os.path.join(kvs['args'].snapshots, kvs['snapshot_name'],
                     'session.pkl'))
Exemple #2
0
def log_metrics(writer, train_loss, val_loss, conf_matrix):
    kvs = GlobalKVS()

    dices = {
        'dice_{}'.format(cls): dice
        for cls, dice in enumerate(calculate_dice(conf_matrix))
    }

    print(colored('==> ', 'green') + 'Metrics:')
    print(colored('====> ', 'green') + 'Train loss:', train_loss)
    print(colored('====> ', 'green') + 'Val loss:', val_loss)
    print(colored('====> ', 'green') + f'Val Dices: {dices}')

    dices_tb = {}
    for cls in range(1, len(dices)):
        dices_tb[f"Dice [{cls}]"] = dices[f"dice_{cls}"]

    to_log = {'train_loss': train_loss, 'val_loss': val_loss}
    # Tensorboard logging
    writer.add_scalars(f"Losses_{kvs['args'].model}", to_log, kvs['cur_epoch'])
    writer.add_scalars('Metrics', dices_tb, kvs['cur_epoch'])
    # KVS logging
    to_log.update({'epoch': kvs['cur_epoch']})
    val_metrics = {'epoch': kvs['cur_epoch']}
    val_metrics.update(to_log)
    val_metrics.update(dices)
    val_metrics.update({'conf_matrix': conf_matrix})

    kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', to_log)
    kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', val_metrics)

    kvs.save_pkl(
        os.path.join(kvs['args'].snapshots, kvs['snapshot_name'],
                     'session.pkl'))
Exemple #3
0
def validate_epoch(net, val_loader, criterion):
    kvs = GlobalKVS()
    net.train(False)

    epoch = kvs['cur_epoch']
    max_epoch = kvs['args'].n_epochs
    n_classes = kvs['args'].n_classes

    device = next(net.parameters()).device
    confusion_matrix = np.zeros((n_classes, n_classes), dtype=np.uint64)
    val_loss = 0
    with torch.no_grad():
        for entry in tqdm(val_loader,
                          total=len(val_loader),
                          desc=f"[{epoch} / {max_epoch}] Val: "):
            img = entry['img'].to(device)
            mask = entry['mask'].to(device).squeeze()
            preds = net(img)
            val_loss += criterion(preds, mask).item()

            mask = mask.to('cpu').numpy()
            if n_classes == 2:
                preds = (preds.to('cpu').numpy() > 0.5).astype(float)
            elif n_classes > 2:
                preds = preds.to('cpu').numpy().argmax(axis=1)
            else:
                raise ValueError

            confusion_matrix += metrics.calculate_confusion_matrix_from_arrays(
                preds, mask, n_classes)

    val_loss /= len(val_loader)

    return val_loss, confusion_matrix
Exemple #4
0
def train_epoch(net, train_loader, optimizer, criterion):
    kvs = GlobalKVS()
    net.train(True)

    fold_id = kvs['cur_fold']
    epoch = kvs['cur_epoch']
    max_ep = kvs['args'].n_epochs

    running_loss = 0.0
    n_batches = len(train_loader)
    device = next(net.parameters()).device
    pbar = tqdm(total=n_batches, ncols=200)
    for i, entry in enumerate(train_loader):
        inputs = entry['img'].to(device)
        mask = entry['mask'].to(device)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, mask)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        pbar.set_description(
            f"Fold [{fold_id}] [{epoch} | {max_ep}] | "
            f"Running loss {running_loss / (i + 1):.5f} / {loss.item():.5f}")
        pbar.update()

        gc.collect()
    gc.collect()
    pbar.close()

    return running_loss / n_batches
Exemple #5
0
def init_loaders(x_train, x_val):
    kvs = GlobalKVS()
    train_dataset = SegmentationDataset(split=x_train,
                                        trf=kvs['train_trf'],
                                        read_img=read_gs_ocv,
                                        read_mask=read_gs_mask_ocv)

    val_dataset = SegmentationDataset(split=x_val,
                                      trf=kvs['val_trf'],
                                      read_img=read_gs_ocv,
                                      read_mask=read_gs_mask_ocv)

    train_loader = DataLoader(train_dataset,
                              batch_size=kvs['args'].bs,
                              num_workers=kvs['args'].n_threads,
                              shuffle=True,
                              drop_last=True,
                              worker_init_fn=lambda wid: np.random.seed(
                                  np.uint32(torch.initial_seed() + wid)))

    val_loader = DataLoader(val_dataset,
                            batch_size=kvs['args'].val_bs,
                            num_workers=kvs['args'].n_threads)

    return train_loader, val_loader
Exemple #6
0
def init_optimizer(net):
    kvs = GlobalKVS()
    if kvs['args'].optimizer == 'adam':
        return optim.Adam(net.parameters(),
                          lr=kvs['args'].lr,
                          weight_decay=kvs['args'].wd)
    elif kvs['args'].optimizer == 'sgd':
        return optim.SGD(net.parameters(),
                         lr=kvs['args'].lr,
                         weight_decay=kvs['args'].wd,
                         momentum=0.9)
    else:
        raise NotImplementedError
Exemple #7
0
def init_loss():
    kvs = GlobalKVS()
    class_weights = kvs['class_weights']
    if kvs['args'].n_classes == 2:
        if kvs['args'].loss == 'combined':
            return CombinedLoss([BCEWithLogitsLoss2d(), BinaryDiceLoss()])
        elif kvs['args'].loss == 'bce':
            return BCEWithLogitsLoss2d()
        elif kvs['args'].loss == 'dice':
            return BinaryDiceLoss()
        elif kvs['args'].loss == 'wbce':
            return BCEWithLogitsLoss2d(weight=class_weights)
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError
Exemple #8
0
def init_model():
    kvs = GlobalKVS()
    if kvs['args'].model == 'unet':
        net = UNet(bw=kvs['args'].bw,
                   depth=kvs['args'].depth,
                   center_depth=kvs['args'].cdepth,
                   n_inputs=kvs['args'].n_inputs,
                   n_classes=kvs['args'].n_classes - 1,
                   activation='relu')
        if kvs['gpus'] > 1:
            net = nn.DataParallel(net).to('cuda')

        net = net.to('cuda')
    else:
        raise NotImplementedError

    return net
Exemple #9
0
def init_train_augmentation_pipeline():
    kvs = GlobalKVS()
    ppl = transforms.Compose([
        img_mask2solt,
        slc.Stream([
            slt.RandomFlip(axis=1, p=0.5),
            slt.ImageGammaCorrection(gamma_range=(0.5, 2), p=0.5),
            slt.PadTransform(pad_to=(kvs['args'].crop_x + 1,
                                     kvs['args'].crop_y + 1)),
            slt.CropTransform(crop_size=(kvs['args'].crop_x,
                                         kvs['args'].crop_y),
                              crop_mode='r')
        ]),
        solt2img_mask,
        partial(apply_by_index, transform=gs2tens, idx=[0, 1]),
    ])
    return ppl
Exemple #10
0
def init_folds():
    kvs = GlobalKVS()
    gkf = GroupKFold(kvs['args'].n_folds)
    cv_split = []
    for fold_id, (train_ind, val_ind) in enumerate(
            gkf.split(X=kvs['metadata_train'],
                      y=kvs['metadata_train'].grade,
                      groups=kvs['metadata_train'].subject_id)):

        if kvs['args'].fold != -1 and fold_id != kvs['args'].fold:
            continue

        cv_split.append((fold_id, kvs['metadata_train'].iloc[train_ind],
                         kvs['metadata_train'].iloc[val_ind]))

        kvs.update(f'losses_fold_[{fold_id}]', None, list)
        kvs.update(f'val_metrics_fold_[{fold_id}]', None, list)

    kvs.update('cv_split', cv_split)
Exemple #11
0
def init_metadata():
    kvs = GlobalKVS()

    imgs = glob.glob(os.path.join(kvs['args'].dataset, '*', 'imgs', '*.png'))
    imgs.sort(key=lambda x: x.split('/')[-1])

    masks = glob.glob(os.path.join(kvs['args'].dataset, '*', 'masks', '*.png'))
    masks.sort(key=lambda x: x.split('/')[-1])

    sample_id = list(map(lambda x: x.split('/')[-3], imgs))
    subject_id = list(map(lambda x: x.split('/')[-3].split('_')[0], imgs))

    metadata = pd.DataFrame(
        data={
            'img_fname': imgs,
            'mask_fname': masks,
            'sample_id': sample_id,
            'subject_id': subject_id
        })

    grades = pd.read_csv(kvs['args'].grades)

    n_subj = np.unique(metadata.subject_id.values).shape[0]

    if n_subj < kvs['args'].train_size:
        raise ValueError

    metadata = pd.merge(metadata, grades, on='sample_id')
    #gss = GroupShuffleSplit(n_splits=0,
    #                        train_size=kvs['args'].train_size,
    #                        test_size=n_subj-kvs['args'].train_size,
    #                        random_state=kvs['args'].seed)

    #train_ind, test_ind = next(gss.split(metadata, y=metadata.grade, groups=metadata.subject_id))

    metadata_train = metadata.iloc[:]
    #metadata_test = metadata.iloc[test_ind]

    kvs.update('metadata_train', metadata_train)
    #kvs.update('metadata_test', metadata_test)
    kvs.save_pkl(
        os.path.join(kvs['args'].snapshots, kvs['snapshot_name'],
                     'session.pkl'))

    return metadata
Exemple #12
0
def save_checkpoint(net, val_metric_name, comparator='lt'):
    if isinstance(net, torch.nn.DataParallel):
        net = net.module

    kvs = GlobalKVS()
    fold_id = kvs['cur_fold']
    epoch = kvs['cur_epoch']
    val_metric = kvs[f'val_metrics_fold_[{fold_id}]'][-1][0][val_metric_name]
    comparator = getattr(operator, comparator)
    cur_snapshot_name = os.path.join(kvs['args'].snapshots,
                                     kvs['snapshot_name'],
                                     f'fold_{fold_id}_epoch_{epoch}.pth')

    if kvs['prev_model'] is None:
        print(
            colored('====> ', 'red') + 'Snapshot was saved to',
            cur_snapshot_name)
        torch.save(net.state_dict(), cur_snapshot_name)
        kvs.update('prev_model', cur_snapshot_name)
        kvs.update('best_val_metric', val_metric)

    else:
        if comparator(val_metric, kvs['best_val_metric']):
            print(
                colored('====> ', 'red') + 'Snapshot was saved to',
                cur_snapshot_name)
            os.remove(kvs['prev_model'])
            torch.save(net.state_dict(), cur_snapshot_name)
            kvs.update('prev_model', cur_snapshot_name)
            kvs.update('best_val_metric', val_metric)

    kvs.save_pkl(
        os.path.join(kvs['args'].snapshots, kvs['snapshot_name'],
                     'session.pkl'))
Exemple #13
0
def init_session():
    kvs = GlobalKVS()

    # Getting the arguments
    args = parse_args_train()
    # Initializing the seeds
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    # Creating the snapshot
    snapshot_name = time.strftime('%Y_%m_%d_%H_%M')
    os.makedirs(os.path.join(args.snapshots, snapshot_name), exist_ok=True)

    res = git_info()
    if res is not None:
        kvs.update('git branch name', res[0])
        kvs.update('git commit id', res[1])
    else:
        kvs.update('git branch name', None)
        kvs.update('git commit id', None)

    kvs.update('pytorch_version', torch.__version__)

    if torch.cuda.is_available():
        kvs.update('cuda', torch.version.cuda)
        kvs.update('gpus', torch.cuda.device_count())
    else:
        kvs.update('cuda', None)
        kvs.update('gpus', None)

    kvs.update('snapshot_name', snapshot_name)
    kvs.update('args', args)
    kvs.save_pkl(os.path.join(args.snapshots, snapshot_name, 'session.pkl'))

    return args, snapshot_name