Пример #1
0
def getDataset(data_path, subsets, aug_t):

    # Get path to image files
    image_path_file = data_path + "imageset_path.txt"
    with open(image_path_file) as f:
        entries = f.read().splitlines()
    images_path = entries[0]

    img_names = []
    img_labels = []

    # Define img_paths
    for k in subsets:

        # Load image names for subset
        img_names_file_k = data_path + "subsets/subset_{}_ids.txt".format(k)
        with open(img_names_file_k) as f:
            img_names_k = f.read().splitlines()

        img_label_file_k = data_path + "subsets/subset_{}_labels.txt".format(k)
        with open(img_label_file_k) as f:
            img_labels_k = f.read().splitlines()

        img_names.extend(img_names_k)
        img_labels.extend(img_labels_k)

    img_paths = [images_path + f + ".npy" for f in img_names]
    img_labels = np.array(img_labels)

    dataset = VolumeDataset(img_paths, img_labels, aug_t)

    return (dataset, img_names)
Пример #2
0
    def __init__(self, dspec_path, net_spec, params, auto_mask=True):
        """
        Initialize DataProvider.

        Args:
            dspec_path: Path to the dataset specification file.
            net_spec:   Net specification.
            params:     Various options.
            auto_mask:  Whether to automatically generate mask from
                        corresponding label.
        """
        # Params.
        drange = params['drange']  # Required.
        dprior = params.get('dprior', None)  # Optional.

        # Build Datasets.
        print '\n[VolumeDataProvider]'
        p = parser.Parser(dspec_path, net_spec, params, auto_mask=auto_mask)
        self.datasets = list()
        for d in drange:
            print 'constructing dataset %d...' % d
            config, dparams = p.parse_dataset(d)
            dataset = VolumeDataset(config, **dparams)
            self.datasets.append(dataset)

        # Sampling weight.
        self.set_sampling_weights(dprior)

        # Setup data augmentation.
        aug_spec = params.get('augment', [])  # Default is an empty list.
        self._data_aug = DataAugmentor(aug_spec)
Пример #3
0
def getDataset(path_gt, path_img, path_in):

    # Get image files
    img_files = [
        f for f in os.listdir(path_img)
        if os.path.isfile(os.path.join(path_img, f))
    ]
    img_files = [f for f in img_files if ".npy" in f]

    ids_img = [f.split(".")[0] for f in img_files]
    ids_img = np.array(ids_img).astype("int")

    # Get ground truth values
    with open(path_gt + "gt.txt") as f:
        entries = f.readlines()
    entries.pop(0)

    ids = [f.split(",")[0] for f in entries]
    values = [f.split(",")[1] for f in entries]

    ids = np.array(ids).astype("int")

    # Check if images exist for all ids from gt
    mask = np.invert(np.in1d(ids, ids_img))
    ids_missing = ids[mask]

    if len(ids_missing) > 0:
        print("ERROR: No image found for ids: {}".format(ids_missing))
        sys.exit()

    #
    img_paths = ["{}{}.npy".format(path_img, f) for f in ids]

    path_std = path_in + "standardization.txt"
    if os.path.exists(path_std):
        print("Applying label standardization...")
        values = np.array(values).astype("float")
        (std_mean, std_stdev) = readStandardization(path_std)
        values = (values - std_mean) / std_stdev

    aug_t = np.array((0, 0, 0, 0))
    dataset = VolumeDataset(img_paths, values, aug_t)

    return (dataset, img_paths)
Пример #4
0
     checkpoint=torch.load(args.init_model, map_location={'cuda:0':'cpu'})
     model.load_state_dict(checkpoint['state_dict'])
 
 if use_gpu:
     model.cuda()
     cudnn.benchmark=True
 
 # optimizer
 optimizerSs=optim.Adam(model.parameters(), lr=args.learning_rate)
 
 # loss function
 criterionSs=nn.CrossEntropyLoss()
 if use_gpu:
     criterionSs.cuda()
 
 volume_dataset=VolumeDataset(rimg_in=None, cimg_in=args.train_t1w, bmsk_in=args.train_msk)
 volume_loader=DataLoader(dataset=volume_dataset, batch_size=1, shuffle=True, num_workers=0)
 
 blk_batch_size=20
 
 if not os.path.exists(args.out_dir):
     os.mkdir(args.out_dir)
 
 # Init Dice and Loss Dict
 DL_Dict=dict()
 dice_list=list()
 loss_list=list()
 
 if use_validate:
     valid_model=nn.Sequential(model, nn.Softmax2d())
     dice_dict=predict_volumes(valid_model, rimg_in=None, cimg_in=args.validate_t1w, bmsk_in=args.validate_msk, 
Пример #5
0
def predict_volumes(model,
                    rimg_in=None,
                    cimg_in=None,
                    bmsk_in=None,
                    suffix="pre_mask",
                    save_dice=False,
                    save_nii=False,
                    nii_outdir=None,
                    verbose=False,
                    rescale_dim=256,
                    num_slice=3):
    use_gpu = torch.cuda.is_available()
    model_on_gpu = next(model.parameters()).is_cuda
    use_bn = True
    if use_gpu:
        if not model_on_gpu:
            model.cuda()
    else:
        if model_on_gpu:
            model.cpu()

    NoneType = type(None)
    if isinstance(rimg_in, NoneType) and isinstance(cimg_in, NoneType):
        print("Input rimg_in or cimg_in")
        sys.exit(1)

    if save_dice:
        dice_dict = dict()

    volume_dataset = VolumeDataset(rimg_in=rimg_in,
                                   cimg_in=cimg_in,
                                   bmsk_in=bmsk_in)
    volume_loader = DataLoader(dataset=volume_dataset, batch_size=1)

    for idx, vol in enumerate(volume_loader):
        if len(vol) == 1:  # just img
            ptype = 1  # Predict
            cimg = vol
            bmsk = None
            block_dataset = BlockDataset(rimg=cimg,
                                         bfld=None,
                                         bmsk=None,
                                         num_slice=num_slice,
                                         rescale_dim=rescale_dim)
        elif len(vol) == 2:  # img & msk
            ptype = 2  # image test
            cimg = vol[0]
            bmsk = vol[1]
            block_dataset = BlockDataset(rimg=cimg,
                                         bfld=None,
                                         bmsk=bmsk,
                                         num_slice=num_slice,
                                         rescale_dim=rescale_dim)
        elif len(vol == 3):  # img bias_field & msk
            ptype = 3  # image bias correction test
            cimg = vol[0]
            bfld = vol[1]
            bmsk = vol[2]
            block_dataset = BlockDataset(rimg=cimg,
                                         bfld=bfld,
                                         bmsk=bmsk,
                                         num_slice=num_slice,
                                         rescale_dim=rescale_dim)
        else:
            print("Invalid Volume Dataset!")
            sys.exit(2)

        rescale_shape = block_dataset.get_rescale_shape()
        raw_shape = block_dataset.get_raw_shape()

        for od in range(3):
            backard_ind = np.arange(3)
            backard_ind = np.insert(np.delete(backard_ind, 0), od, 0)

            block_data, slice_list, slice_weight = block_dataset.get_one_directory(
                axis=od)
            pr_bmsk = torch.zeros(
                [len(slice_weight), rescale_dim, rescale_dim])
            if use_gpu:
                pr_bmsk = pr_bmsk.cuda()
            for (i, ind) in enumerate(slice_list):
                if ptype == 1:
                    rimg_blk = block_data[i]
                    if use_gpu:
                        rimg_blk = rimg_blk.cuda()
                elif ptype == 2:
                    rimg_blk, bmsk_blk = block_data[i]
                    if use_gpu:
                        rimg_blk = rimg_blk.cuda()
                        bmsk_blk = bmsk_blk.cuda()
                else:
                    rimg_blk, bfld_blk, bmsk_blk = block_data[i]
                    if use_gpu:
                        rimg_blk = rimg_blk.cuda()
                        bfld_blk = bfld_blk.cuda()
                        bmsk_blk = bmsk_blk.cuda()
                pr_bmsk_blk = model(torch.unsqueeze(Variable(rimg_blk), 0))
                pr_bmsk[ind[1], :, :] = pr_bmsk_blk.data[0][1, :, :]

            if use_gpu:
                pr_bmsk = pr_bmsk.cpu()

            pr_bmsk = pr_bmsk.permute(backard_ind[0], backard_ind[1],
                                      backard_ind[2])
            pr_bmsk = pr_bmsk[:rescale_shape[0], :rescale_shape[1], :
                              rescale_shape[2]]
            uns_pr_bmsk = torch.unsqueeze(pr_bmsk, 0)
            uns_pr_bmsk = torch.unsqueeze(uns_pr_bmsk, 0)
            uns_pr_bmsk = nn.functional.interpolate(uns_pr_bmsk,
                                                    size=raw_shape,
                                                    mode="trilinear",
                                                    align_corners=False)
            pr_bmsk = torch.squeeze(uns_pr_bmsk)

            if od == 0:
                pr_3_bmsk = torch.unsqueeze(pr_bmsk, 3)
            else:
                pr_3_bmsk = torch.cat((pr_3_bmsk, torch.unsqueeze(pr_bmsk, 3)),
                                      dim=3)

        pr_bmsk = pr_3_bmsk.mean(dim=3)

        pr_bmsk = pr_bmsk.numpy()
        pr_bmsk_final = extract_large_comp(pr_bmsk > 0.5)

        if isinstance(bmsk, torch.Tensor):
            bmsk = bmsk.data[0].numpy()
            dice = estimate_dice(bmsk, pr_bmsk_final)
            if verbose:
                print(dice)

        t1w_nii = volume_dataset.getCurCimgNii()
        t1w_path = t1w_nii.get_filename()
        t1w_dir, t1w_file = os.path.split(t1w_path)
        t1w_name = os.path.splitext(t1w_file)[0]
        t1w_name = os.path.splitext(t1w_name)[0]

        if save_nii:
            t1w_aff = t1w_nii.affine
            t1w_shape = t1w_nii.shape

            if isinstance(nii_outdir, NoneType):
                nii_outdir = t1w_dir

            if not os.path.exists(nii_outdir):
                os.mkdir(nii_outdir)
            out_path = os.path.join(nii_outdir,
                                    t1w_name + "_" + suffix + ".nii.gz")
            write_nifti(np.array(pr_bmsk_final, dtype=np.float32), t1w_aff,
                        t1w_shape, out_path)

        if save_dice:
            dice_dict[t1w_name] = dice

    if save_dice:
        return dice_dict