Ejemplo n.º 1
0
def load_data(datadir, img_size=416, crop_pct=0.875):
    # Data loading code
    print("Loading data")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    scale_size = int(math.floor(img_size / crop_pct))

    print("Loading training data")
    st = time.time()
    dataset = VOCDetection(datadir, image_set='train', download=True,
                           transforms=Compose([VOCTargetTransform(classes),
                                               RandomResizedCrop((img_size, img_size), scale=(0.3, 1.0)),
                                               RandomHorizontalFlip(),
                                               convert_to_relative,
                                               ImageTransform(transforms.ColorJitter(brightness=0.3, contrast=0.3,
                                                                                     saturation=0.1, hue=0.02)),
                                               ImageTransform(transforms.ToTensor()), ImageTransform(normalize)]))

    print("Took", time.time() - st)

    print("Loading validation data")
    st = time.time()
    dataset_test = VOCDetection(datadir, image_set='val', download=True,
                                transforms=Compose([VOCTargetTransform(classes),
                                                    Resize(scale_size), CenterCrop(img_size),
                                                    convert_to_relative,
                                                    ImageTransform(transforms.ToTensor()), ImageTransform(normalize)]))

    print("Took", time.time() - st)
    print("Creating data loaders")
    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler
Ejemplo n.º 2
0
def main():
    '''
    Runs if the module is called as a script (eg: python3 dataset.py <dataset_name> <frametime>)
    Executes self tests
    '''
    dataset_name = argv[1] if len(argv) > 1 else "clarissa"
    wait = argv[2] if len(argv) > 2 else 10
    print(
        "Dataset module running as script, executing dataset unit test in {}".
        format(dataset_name))

    if dataset_name == "adni_slices":
        unit_test(image_dataset=False,
                  adni=True,
                  hiponly=True,
                  plt_show=True,
                  nworkers=4,
                  e2d=True)
    elif dataset_name == "clarissa_slices":
        unit_test(image_dataset=False,
                  adni=False,
                  hiponly=True,
                  plt_show=True,
                  nworkers=4,
                  e2d=True)
    elif dataset_name == "concat":
        from transforms import ReturnPatch, Intensity, RandomFlip, Noisify, ToTensor, CenterCrop, RandomAffine
        train_transforms = Compose([
            ReturnPatch(patch_size=(32, 32)),
            RandomAffine(),
            Intensity(),
            RandomFlip(modes=['horflip']),
            Noisify(),
            ToTensor()
        ])  #default is 32 32 patch
        data_transforms = {
            'train': train_transforms,
            'validation': train_transforms,
            'test': Compose([CenterCrop(160, 160),
                             ToTensor()])
        }
        mode = "train"
        data, dsizes = get_data(data_transforms=data_transforms,
                                db="concat",
                                e2d=True,
                                batch_size=50 + 150 * (mode != "test"))
        print("Dataset sizes: {}".format(dsizes))
        for o in orientations:
            batch = next(iter(data[o][mode]))
            display_batch(batch, o + " concat " + mode + " data")
        plt.show()
    else:
        view_volumes(dataset_name, wait=10)
    print("All tests completed!")
Ejemplo n.º 3
0
def load_data(datadir, img_size=416, crop_pct=0.875):
    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    scale_size = int(math.floor(img_size / crop_pct))

    print("Loading training data")
    st = time.time()
    train_set = VOCDetection(datadir,
                             image_set='train',
                             download=True,
                             transforms=Compose([
                                 VOCTargetTransform(VOC_CLASSES),
                                 RandomResizedCrop((img_size, img_size),
                                                   scale=(0.3, 1.0)),
                                 RandomHorizontalFlip(), convert_to_relative,
                                 ImageTransform(
                                     transforms.ColorJitter(brightness=0.3,
                                                            contrast=0.3,
                                                            saturation=0.1,
                                                            hue=0.02)),
                                 ImageTransform(transforms.ToTensor()),
                                 ImageTransform(normalize)
                             ]))

    print("Took", time.time() - st)

    print("Loading validation data")
    st = time.time()
    val_set = VOCDetection(datadir,
                           image_set='val',
                           download=True,
                           transforms=Compose([
                               VOCTargetTransform(VOC_CLASSES),
                               Resize(scale_size),
                               CenterCrop(img_size), convert_to_relative,
                               ImageTransform(transforms.ToTensor()),
                               ImageTransform(normalize)
                           ]))

    print("Took", time.time() - st)

    return train_set, val_set
Ejemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--jaccard-weight', default=0.3, type=float)
    arg('--device-ids',
        type=str,
        default='0',
        help='For example 0,1 to run on two GPUs')
    arg('--fold', type=int, help='fold', default=0)
    arg('--root', default='runs/debug', help='checkpoint root')
    arg('--batch-size', type=int, default=1)
    arg('--limit', type=int, default=10000, help='number of images in epoch')
    arg('--n-epochs', type=int, default=100)
    arg('--lr', type=float, default=0.0001)
    arg('--workers', type=int, default=12)
    arg('--model',
        type=str,
        default='UNet',
        choices=['UNet', 'UNet11', 'UNet16', 'AlbuNet34'])

    args = parser.parse_args()

    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    num_classes = 1
    if args.model == 'UNet':
        model = UNet(num_classes=num_classes)
    elif args.model == 'UNet11':
        model = UNet11(num_classes=num_classes, pretrained=True)
    elif args.model == 'UNet16':
        model = UNet16(num_classes=num_classes, pretrained=True)
    elif args.model == 'LinkNet34':
        model = LinkNet34(num_classes=num_classes, pretrained=True)
    elif args.model == 'AlbuNet':
        model = AlbuNet34(num_classes=num_classes, pretrained=True)
    else:
        model = UNet(num_classes=num_classes, input_channels=3)

    if torch.cuda.is_available():
        if args.device_ids:
            device_ids = list(map(int, args.device_ids.split(',')))
        else:
            device_ids = None
        model = nn.DataParallel(model, device_ids=device_ids).cuda()

    loss = LossBinary(jaccard_weight=args.jaccard_weight)

    cudnn.benchmark = True

    def make_loader(file_names, shuffle=False, transform=None, limit=None):
        return DataLoader(dataset=AngyodysplasiaDataset(file_names,
                                                        transform=transform,
                                                        limit=limit),
                          shuffle=shuffle,
                          num_workers=args.workers,
                          batch_size=args.batch_size,
                          pin_memory=torch.cuda.is_available())

    train_file_names, val_file_names = get_split(args.fold)

    print('num train = {}, num_val = {}'.format(len(train_file_names),
                                                len(val_file_names)))

    train_transform = DualCompose([
        SquarePaddingTraining(),
        CenterCrop([574, 574]),
        HorizontalFlip(),
        VerticalFlip(),
        Rotate(),
        ImageOnly(RandomHueSaturationValue()),
        ImageOnly(Normalize())
    ])

    val_transform = DualCompose([
        SquarePaddingTraining(),
        CenterCrop([574, 574]),
        ImageOnly(Normalize())
    ])

    train_loader = make_loader(train_file_names,
                               shuffle=True,
                               transform=train_transform,
                               limit=args.limit)
    valid_loader = make_loader(val_file_names, transform=val_transform)

    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    utils.train(init_optimizer=lambda lr: Adam(model.parameters(), lr=lr),
                args=args,
                model=model,
                criterion=loss,
                train_loader=train_loader,
                valid_loader=valid_loader,
                validation=validation_binary,
                fold=args.fold)
Ejemplo n.º 5
0
def main(args):

    print(args)

    torch.backends.cudnn.benchmark = True

    # Data loading
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    crop_pct = 0.875
    scale_size = int(math.floor(args.img_size / crop_pct))

    train_loader, val_loader = None, None

    if not args.test_only:
        st = time.time()
        train_set = VOCDetection(datadir,
                                 image_set='train',
                                 download=True,
                                 transforms=Compose([
                                     VOCTargetTransform(VOC_CLASSES),
                                     RandomResizedCrop(
                                         (args.img_size, args.img_size),
                                         scale=(0.3, 1.0)),
                                     RandomHorizontalFlip(),
                                     convert_to_relative,
                                     ImageTransform(
                                         T.ColorJitter(brightness=0.3,
                                                       contrast=0.3,
                                                       saturation=0.1,
                                                       hue=0.02)),
                                     ImageTransform(T.ToTensor()),
                                     ImageTransform(normalize)
                                 ]))

        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            drop_last=True,
            collate_fn=collate_fn,
            sampler=RandomSampler(train_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(f"Training set loaded in {time.time() - st:.2f}s "
              f"({len(train_set)} samples in {len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target)
        return

    if not (args.lr_finder or args.check_setup):
        st = time.time()
        val_set = VOCDetection(datadir,
                               image_set='val',
                               download=True,
                               transforms=Compose([
                                   VOCTargetTransform(VOC_CLASSES),
                                   Resize(scale_size),
                                   CenterCrop(args.img_size),
                                   convert_to_relative,
                                   ImageTransform(T.ToTensor()),
                                   ImageTransform(normalize)
                               ]))

        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            drop_last=False,
            collate_fn=collate_fn,
            sampler=SequentialSampler(val_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(
            f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)"
        )

    model = detection.__dict__[args.model](args.pretrained,
                                           num_classes=len(VOC_CLASSES),
                                           pretrained_backbone=True)

    model_params = [p for p in model.parameters() if p.requires_grad]
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model_params,
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model_params,
                                     args.lr,
                                     betas=(0.95, 0.99),
                                     eps=1e-6,
                                     weight_decay=args.weight_decay)
    elif args.opt == 'radam':
        optimizer = holocron.optim.RAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'ranger':
        optimizer = Lookahead(
            holocron.optim.RAdam(model_params,
                                 args.lr,
                                 betas=(0.95, 0.99),
                                 eps=1e-6,
                                 weight_decay=args.weight_decay))
    elif args.opt == 'tadam':
        optimizer = holocron.optim.TAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)

    trainer = DetectionTrainer(model, train_loader, val_loader, None,
                               optimizer, args.device, args.output_file)

    if args.resume:
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')
        trainer.load(checkpoint)

    if args.test_only:
        print("Running evaluation")
        eval_metrics = trainer.evaluate()
        print(
            f"Loc error: {eval_metrics['loc_err']:.2%} | Clf error: {eval_metrics['clf_err']:.2%} | "
            f"Det error: {eval_metrics['det_err']:.2%}")
        return

    if args.lr_finder:
        print("Looking for optimal LR")
        trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100))
        trainer.plot_recorder()
        return

    if args.check_setup:
        print("Checking batch overfitting")
        is_ok = trainer.check_setup(args.freeze_until,
                                    args.lr,
                                    num_it=min(len(train_loader), 100))
        print(is_ok)
        return

    print("Start training")
    start_time = time.time()
    trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched)
    total_time_str = str(
        datetime.timedelta(seconds=int(time.time() - start_time)))
    print(f"Training time {total_time_str}")
from prepare_train_val import get_split
from dataset import Polyp
import cv2
from models import UNet, UNet11, UNet16, AlbuNet34, MDeNet, EncDec, hourglass, MDeNetplus
import torch
from pathlib import Path
from tqdm import tqdm
import numpy as np
import utils
# import prepare_data
from torch.utils.data import DataLoader
from torch.nn import functional as F

from transforms import (ImageOnly, Normalize, CenterCrop, DualCompose)

img_transform = DualCompose([CenterCrop(512), ImageOnly(Normalize())])


def get_model(model_path, model_type):
    """

    :param model_path:
    :param model_type: 'UNet', 'UNet11', 'UNet16', 'AlbuNet34'
    :return:
    """

    num_classes = 1

    if model_type == 'UNet11':
        model = UNet11(num_classes=num_classes)
    elif model_type == 'UNet16':
from pathlib import Path
from tqdm import tqdm
import numpy as np
import utils
# import prepare_data
from torch.utils.data import DataLoader
from torch.nn import functional as F
import torch

from transforms import (ImageOnly,
                        Normalize,
                        CenterCrop,
                        DualCompose)

img_transform = DualCompose([
    CenterCrop(512),
    ImageOnly(Normalize())
])


def get_model(model_path, model_type):
    """

    :param model_path:
    :param model_type: 'UNet', 'UNet11', 'UNet16', 'AlbuNet34'
    :return:
    """

    num_classes = 1

    if model_type == 'UNet11':
Ejemplo n.º 8
0
import logging

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


if __name__ == '__main__':


    tbwiter = SummaryWriter(log_dir= config.LOG_DIR_PATH)

    print("Initializing AlexNet model")
    alexnet = alexnetmodel.AlexNet(num_classes = config.PARAMETERS['NUM_CLASSES'])
    alexnet = torch.nn.parallel.DataParallel(alexnet, device_ids=config.DEVICE_IDS)
    print("Initilaizing transfrmations to apply on the image")
    transformations = tf.Compose([
        CenterCrop(227),
        # HorizontalFlip(),
        ToTensor(),
        Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
        ]
    )

    print("Generating train images and labels")
    train_X, train_Y, enc = separate_dataset_and_classes()
    print("Generating Validation images and labels")
    val_X, val_Y = separate_classes_val_test(enc = enc)
    
    print("Initilazing data loader for Alexnet")
    train_dataset = alexnetdataloader.AlexNetDataLoader(train_X, train_Y, transform = transformations)
    val_dataset = alexnetdataloader.AlexNetDataLoader(val_X, val_Y, None)
Ejemplo n.º 9
0
from train_results import TrainResults
from transforms import CenterCrop, ToTensor, Compose, CenterCrop, Resize, ToNumpy, ReturnPatch, RandomFlip, Intensity, Noisify, RandomAffine
import multiprocessing as mp
from utils import check_name, parse_argv, plots

plt.rcParams.update({'font.size': 16})

display_volume, train, volume, db_name, notest, study_ths, wait, finetune = parse_argv(argv)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

loss = DICELoss(apply_sigmoid=False)
metric = DICEMetric(apply_sigmoid=False)
#train_transforms = Compose([ReturnPatch(patch_size=(32, 32)), Intensity(), RandomFlip(modes=['horflip']), Noisify(), ToTensor()]) #default is 32 32 patch, arg* in sheet
train_transforms = Compose([ReturnPatch(patch_size=(32, 32)), RandomAffine(), Intensity(), RandomFlip(modes=['horflip']), Noisify(), ToTensor()])
data_transforms = {'train': train_transforms, 'validation': train_transforms, 'test': Compose([CenterCrop(160, 160), ToTensor()])}

print("Train transforms: {}".format(train_transforms))

bnames = []
'''
basename = "FLOAT0.01-DICE-120"
bnames.append(basename)
basename = "full-newdice-test"  # trained in full data
bnames.append(basename)
basename = "tPATCH320.001-VOLDICE-400" # patch approach, 32, flip only, only positive patch
bnames.append(basename)
basename = "INTFLIP-PATCH320.001-400"
bnames.append(basename)
basename = "RES-INTFLIP-PATCH320.001-400"
bnames.append(basename)
Ejemplo n.º 10
0
def test(FLG):
    device = torch.device('cuda:{}'.format(FLG.devices[0]))
    torch.set_grad_enabled(False)
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(FLG.devices[0])
    report = [ScoreReport() for _ in range(FLG.fold)]
    overall_report = ScoreReport()
    target_dict = np.load(pjoin(FLG.data_root, 'target_dict.pkl'))

    with open(FLG.model + '_stat.pkl', 'rb') as f:
        stat = pickle.load(f)
    summary = Summary(port=10001, env=str(FLG.model) + 'CAM')

    class Feature(object):
        def __init__(self):
            self.blob = None

        def capture(self, blob):
            self.blob = blob

    if 'plane' in FLG.model:
        model = Plane(len(FLG.labels), name=FLG.model)
    elif 'resnet11' in FLG.model:
        model = resnet11(len(FLG.labels), FLG.model)
    elif 'resnet19' in FLG.model:
        model = resnet19(len(FLG.labels), FLG.model)
    elif 'resnet35' in FLG.model:
        model = resnet35(len(FLG.labels), FLG.model)
    elif 'resnet51' in FLG.model:
        model = resnet51(len(FLG.labels), FLG.model)
    else:
        raise NotImplementedError(FLG.model)
    model.to(device)

    ad_h = []
    nl_h = []
    adcams = np.zeros((4, 3, 112, 144, 112), dtype="f8")
    nlcams = np.zeros((4, 3, 112, 144, 112), dtype="f8")
    sb = [9.996e-01, 6.3e-01, 1.001e-01]
    for running_fold in range(FLG.fold):
        _, validblock, _ = fold_split(
            FLG.fold, running_fold, FLG.labels,
            np.load(pjoin(FLG.data_root, 'subject_indices.npy')), target_dict)
        validset = ADNIDataset(FLG.labels,
                               pjoin(FLG.data_root, FLG.modal),
                               validblock,
                               target_dict,
                               transform=transform_presets(FLG.augmentation))
        validloader = DataLoader(validset, pin_memory=True)

        epoch, _ = load_checkpoint(model, FLG.checkpoint_root, running_fold,
                                   FLG.model, None, True)
        model.eval()
        feature = Feature()

        def hook(mod, inp, oup):
            return feature.capture(oup.data.cpu().numpy())

        _ = model.layer4.register_forward_hook(hook)
        fc_weights = model.fc.weight.data.cpu().numpy()

        transformer = Compose([CenterCrop((112, 144, 112)), ToFloatTensor()])
        im, _ = original_load(validblock, target_dict, transformer, device)

        for image, target in validloader:
            true = target
            npatches = 1
            if len(image.shape) == 6:
                _, npatches, c, x, y, z = image.shape
                image = image.view(-1, c, x, y, z)
                target = torch.stack([target
                                      for _ in range(npatches)]).squeeze()
            image = image.cuda(device, non_blocking=True)
            target = target.cuda(device, non_blocking=True)

            output = model(image)

            if npatches == 1:
                score = F.softmax(output, dim=1)
            else:
                score = torch.mean(F.softmax(output, dim=1),
                                   dim=0,
                                   keepdim=True)

            report[running_fold].update_true(true)
            report[running_fold].update_score(score)

            overall_report.update_true(true)
            overall_report.update_score(score)

            print(target)
            if FLG.cam:
                s = 0
                cams = []
                if target[0] == 0:
                    s = score[0][0]
                    #s = s.cpu().numpy()[()]
                    cams = adcams
                else:
                    sn = score[0][1]
                    #s = s.cpu().numpy()[()]
                    cams = nlcams
                if s > sb[0]:
                    cams[0] = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            cams[0],
                                            s,
                                            num_images=5)
                elif s > sb[1]:
                    cams[1] = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            cams[1],
                                            s,
                                            num_images=5)
                elif s > sb[2]:
                    cams[2] = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            cams[2],
                                            s,
                                            num_images=5)
                else:
                    cams[3] = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            cams[3],
                                            s,
                                            num_images=5)
                #ad_h += [s]
                #nl_h += [sn]

        print('At {}'.format(epoch))
        print(
            metrics.classification_report(report[running_fold].y_true,
                                          report[running_fold].y_pred,
                                          target_names=FLG.labels,
                                          digits=4))
        print('accuracy {}'.format(report[running_fold].accuracy))

    #print(np.histogram(ad_h))
    #print(np.histogram(nl_h))

    print('over all')
    print(
        metrics.classification_report(overall_report.y_true,
                                      overall_report.y_pred,
                                      target_names=FLG.labels,
                                      digits=4))
    print('accuracy {}'.format(overall_report.accuracy))

    with open(FLG.model + '_stat.pkl', 'wb') as f:
        pickle.dump(report, f, pickle.HIGHEST_PROTOCOL)
Ejemplo n.º 11
0
def test_cam():
    with open(FLG.model + '_stat.pkl', 'rb') as f:
        stat = pickle.load(f)
    summary = Summary(port=10001, env=str(FLG.model) + 'CAM')

    class Feature(object):
        def __init__(self):
            self.blob = None

        def capture(self, blob):
            self.blob = blob

    # TODO: create model
    device = torch.device('cuda:{}'.format(FLG.devices[0]))
    torch.set_grad_enabled(False)
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(FLG.devices[0])
    report = [ScoreReport() for _ in range(FLG.fold)]
    target_dict = np.load(pjoin(FLG.data_root, 'target_dict.pkl'))

    model = Plane(len(FLG.labels), name=FLG.model)
    model.to(device)

    transformer = Compose([CenterCrop((112, 144, 112)), ToFloatTensor()])

    def original_load(validblock):
        originalset = ADNIDataset(FLG.labels,
                                  pjoin(FLG.data_root, 'spm_normalized'),
                                  validblock,
                                  target_dict,
                                  transform=transformer)
        originloader = DataLoader(originalset, pin_memory=True)
        for image, target in originloader:
            if len(image.shape) == 6:
                _, npatches, c, x, y, z = image.shape
                image = image.view(-1, c, x, y, z)
                target = torch.stack([target
                                      for _ in range(npatches)]).squeeze()
            image = image.cuda(device, non_blocking=True)
            target = target.cuda(device, non_blocking=True)
            break
        return image, target

    hadcams = np.zeros((3, 112, 144, 112), dtype="f8")
    madcams = np.zeros((3, 112, 144, 112), dtype="f8")
    sadcams = np.zeros((3, 112, 144, 112), dtype="f8")
    zadcams = np.zeros((3, 112, 144, 112), dtype="f8")

    nlcams = np.zeros((4, 3, 112, 144, 112), dtype="f8")
    sb = [4.34444371e-16, 1.67179015e-18, 4.08813312e-23]
    #im, _ = original_load(validblock)
    for running_fold in range(FLG.fold):
        # validset
        _, validblock, _ = fold_split(
            FLG.fold, running_fold, FLG.labels,
            np.load(pjoin(FLG.data_root, 'subject_indices.npy')), target_dict)
        validset = ADNIDataset(FLG.labels,
                               pjoin(FLG.data_root, FLG.modal),
                               validblock,
                               target_dict,
                               transform=transformer)
        validloader = DataLoader(validset, pin_memory=True)

        load_checkpoint(model,
                        FLG.checkpoint_root,
                        running_fold,
                        FLG.model,
                        epoch=None,
                        is_best=True)
        model.eval()
        feature = Feature()

        def hook(mod, inp, oup):
            return feature.capture(oup.data.cpu().numpy())

        _ = model.layer4.register_forward_hook(hook)
        fc_weights = model.fc.weight.data.cpu().numpy()

        im, _ = original_load(validblock)
        ad_s = []
        for image, target in validloader:
            true = target
            npatches = 1
            if len(image.shape) == 6:
                _, npatches, c, x, y, z = image.shape
                image = image.view(-1, c, x, y, z)
                target = torch.stack([target
                                      for _ in range(npatches)]).squeeze()
            image = image.cuda(device, non_blocking=True)
            target = target.cuda(device, non_blocking=True)

            #_ = model(image.view(*image.shape))
            output = model(image)

            if npatches == 1:
                score = F.softmax(output, dim=1)
            else:
                score = torch.mean(F.softmax(output, dim=1),
                                   dim=0,
                                   keepdim=True)

            sa = score[0][1]
            #name = 'k'+str(running_fold)
            sa = sa.cpu().numpy()[()]
            print(score, score.shape)
            if true == torch.tensor([0]):
                if sa > sb[0]:
                    hadcams = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            hadcams,
                                            sa,
                                            num_images=5)
                elif sa > sb[1]:
                    madcams = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            madcams,
                                            sa,
                                            num_images=5)
                elif sa > sb[2]:
                    sadcams = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            sadcams,
                                            sa,
                                            num_images=5)
                else:
                    zadcams = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            zadcams,
                                            sa,
                                            num_images=5)
            else:
                if s > sb[0]:
                    nlcams[0] = summary.cam3d(FLG.labels[target],
                                              im,
                                              feature.blob,
                                              fc_weights,
                                              target,
                                              nlcams[0],
                                              sr,
                                              num_images=5)
                elif sr > sb[1]:
                    nlcams[1] = summary.cam3d(FLG.labels[target],
                                              im,
                                              feature.blob,
                                              fc_weights,
                                              target,
                                              nlcams[1],
                                              sr,
                                              num_images=5)
                elif sr > sb[2]:
                    nlcams[2] = summary.cam3d(FLG.labels[target],
                                              im,
                                              feature.blob,
                                              fc_weights,
                                              target,
                                              nlcams[2],
                                              sr,
                                              num_images=5)
                else:
                    nlcams[3] = summary.cam3d(FLG.labels[target],
                                              im,
                                              feature.blob,
                                              fc_weights,
                                              target,
                                              nlcams[3],
                                              sr,
                                              num_images=5)
            ad_s += [sr]
        print('histogram', np.histogram(ad_s))
Ejemplo n.º 12
0
def run_once(volpath, models):
    '''
    Runs our best model in a provided volume and saves mask,
    In a self contained matter
    '''
    print(
        "\nALPHA VERSION: For this version of this code, the provided volume should return slices on the following way for optimal performance:"
    )
    print("volume[0, :, :] sagital, eyes facing down")
    print("volume[:, 0, :] coronal")
    print("volume[:, :, 0] axial, with eyes facing right\n")
    begin = time.time()
    save_path = volpath + "_e2dhipmask.nii.gz"
    device = get_device()
    orientations = ["sagital", "coronal", "axial"]
    CROP_SHAPE = 160
    slice_transform = Compose([CenterCrop(CROP_SHAPE, CROP_SHAPE), ToTensor()])

    sample_v = normalizeMri(nib.load(volpath).get_fdata().astype(np.float32))
    shape = sample_v.shape
    sum_vol_total = torch.zeros(shape)

    for o, model in models.items():
        model.eval()
        model.to(device)

    print("Performing segmentation...")
    for i, o in enumerate(orientations):
        try:
            slice_shape = myrotate(get_slice(sample_v, 0, o), 90).shape
            for j in range(shape[i]):
                # E2D
                ts = np.zeros((3, slice_shape[0], slice_shape[1]),
                              dtype=np.float32)
                for ii, jj in enumerate(range(j - 1, j + 2)):
                    if jj < 0:
                        jj = 0
                    elif jj == shape[i]:
                        jj = shape[i] - 1

                    if i == 0:
                        ts[ii] = myrotate(sample_v[jj, :, :], 90)
                    elif i == 1:
                        ts[ii] = myrotate(sample_v[:, jj, :], 90)
                    elif i == 2:
                        ts[ii] = myrotate(sample_v[:, :, jj], 90)

                s, _ = slice_transform(ts, ts[1])  # work around, no mask
                s = s.to(device)

                probs = models[o](s.unsqueeze(0))

                cpup = probs.squeeze().detach().cpu()
                finalp = torch.from_numpy(myrotate(
                    cpup.numpy(), -90)).float()  # back to volume orientation

                # Add to final consensus volume, uses original orientation/shape
                if i == 0:
                    toppad = shape[1] // 2 - CROP_SHAPE // 2
                    sidepad = shape[2] // 2 - CROP_SHAPE // 2

                    tf = 1 if shape[1] % 2 == 1 else 0
                    sf = 1 if shape[2] % 2 == 1 else 0
                    pad = F.pad(
                        finalp,
                        (sidepad + sf, sidepad, toppad, toppad + tf)) / 3

                    sum_vol_total[j, :, :] += pad
                elif i == 1:
                    toppad = shape[0] // 2 - CROP_SHAPE // 2
                    sidepad = shape[2] // 2 - CROP_SHAPE // 2

                    tf = 1 if shape[0] % 2 == 1 else 0
                    sf = 1 if shape[2] % 2 == 1 else 0
                    pad = F.pad(
                        finalp,
                        (sidepad + sf, sidepad, toppad, toppad + tf)) / 3

                    sum_vol_total[:, j, :] += pad
                elif i == 2:
                    toppad = shape[0] // 2 - CROP_SHAPE // 2
                    sidepad = shape[1] // 2 - CROP_SHAPE // 2

                    tf = 1 if shape[0] % 2 == 1 else 0
                    sf = 1 if shape[1] % 2 == 1 else 0
                    pad = F.pad(
                        finalp,
                        (sidepad + sf, sidepad, toppad, toppad + tf)) / 3

                    sum_vol_total[:, :, j] += pad

        except Exception as e:
            print(
                "Error: {}, make sure your data is ok, please contact author https://github.com/dscarmo"
                .format(e))
            traceback.print_exc()
            quit()

    final_nppred = get_largest_components(sum_vol_total.numpy(), mask_ths=0.5)

    print("Processing took {}s".format(time.time() - begin))
    print("Saving to {}".format(save_path))
    nib.save(nib.nifti1.Nifti1Image(final_nppred, None), save_path)
    return sample_v, final_nppred
Ejemplo n.º 13
0
import network, dataset, criterion, transform
from dataset import VOC12
from network import PSPNet, FCN8, SegNet, FCN8s
from criterion import CrossEntropyLoss2d
from transform import Relabel, ToLabel, Colorize
import deeplab_resnet
import torch.nn.functional as F
from accuracy_metrics import pixel_accuracy, mean_accuracy, mean_IU

NUM_CHANNELS = 3
NUM_CLASSES = 22  #6 for brats

color_transform = Colorize()
image_transform = ToPILImage()
input_transform = Compose([
    CenterCrop(256),
    #Scale(240),
    ToTensor(),
    Normalize([.485, .456, .406], [.229, .224, .225]),
])

input_transform1 = Compose([
    CenterCrop(256),
    #Scale(136),
    #ToTensor(),
])

target_transform = Compose([
    CenterCrop(256),
    #Scale(240),
    ToLabel(),