示例#1
0
    def __init__(self, metadata, eval, remove_white):

        self.base_path = Path(__file__).parent
        first_region_id = list(metadata.keys())[0]
        pth = metadata[first_region_id]['wsipath']
        pth = (self.base_path / pth).resolve().as_posix()
        self.scan = openslide.OpenSlide(pth)

        if remove_white:
            'get low res. nuclei image/foreground mask'
            scan = self.wsis[pth]
            x, y = scan.level_dimensions[-1]
            mask = scan.read_region((0, 0), scan.level_count - 1,
                                    (x, y)).convert('RGB')
            mask = mask.resize((x // 4, y // 4))
            mask = preprocessing.find_nuclei(mask)
            mask = Image.fromarray(mask.astype(np.uint8)).resize((x, y))
            mask = np.asarray(mask)

        params = {
            'iw': self.scan.level_dimensions[0][0],
            'ih': self.scan.level_dimensions[0][1],
            'tile_w': HR_PATCH_W,
            'tile_h': HR_PATCH_H,
            'scan_level': metadata[first_region_id]['scan_level']
        }
        params = preprocessing.DotDict(params)

        ' build the datalist '
        self.datalist = []
        for key in metadata:
            region_obj = metadata[key].copy()

            if remove_white:
                'given points, remove patches that are only white'
                region_obj[
                    'cnt_xy'], num_cnt_pts = regiontools.remove_white_region(
                        mask, region_obj['cnt_xy'], params)
                region_obj[
                    'perim_xy'], num_perim_pts = regiontools.remove_white_region(
                        mask, region_obj['perim_xy'], params)
            region_obj['cnt_xy'], num_cnt_pts = regiontools.map_points(
                region_obj['cnt_xy'], params)
            region_obj['perim_xy'], num_perim_pts = regiontools.map_points(
                region_obj['perim_xy'], params)

            if num_cnt_pts >= HR_NUM_CNT_SAMPLES and \
                    num_perim_pts >= HR_NUM_PERIM_SAMPLES:
                self.datalist.append(region_obj)

        self.eval = eval

        ' augmentation settings '
        self.image_aug = preprocessing.standard_augmentor(self.eval)
示例#2
0
    def __init__(self, impth, eval):

        self.eval = eval
        ' augmentation settings '
        self.image_aug = preprocessing.standard_augmentor(eval)

        ' build the dataset '
        self.datalist = []
        impths = glob.glob('{}/*_image.png'.format(impth))
        for pth in impths:
            item = {
                'image': pth,
                'label': pth.replace('_image.png', '_gt.png')
            }
            self.datalist.append(item)

        if not self.eval:
            from itertools import chain
            self.datalist = list(chain(*[[i] * 10 for i in self.datalist]))
示例#3
0
    def __init__(self, impth, eval, duplicate_dataset):

        self.eval = eval
        ' augmentation settings '
        self.image_aug = preprocessing.standard_augmentor(eval)

        ' build the dataset '
        self.datalist = []
        gt = np.load('{}/gt.npy'.format(impth), allow_pickle=True).flatten()[0]
        for key in gt:
            self.datalist.append([{
                'wsi': gt[key][tile_id]['wsi'],
                'label': gt[key][tile_id]['label'],
            } for tile_id in gt[key]])
        self.datalist = [item for sublist in self.datalist for item in sublist]

        if not self.eval:
            from itertools import chain
            self.datalist = list(
                chain(*[[i] * duplicate_dataset for i in self.datalist]))
示例#4
0
    def __init__(self, pth, eval, remove_white, duplicate_dataset):

        self.base_path = Path(__file__).parent
        metadata_pth = (self.base_path /
                        '../{}/gt.npy'.format(pth)).resolve().as_posix()
        metadata = ufs.fetch_metadata(metadata_pth)
        '''
        dataset structure:
        dataset is comprised of patches+wsi regions.
        patches:
        metadata['P'] indicates where all the patches are.
        wsi:
        0. metadata[filename/svs file name]
        1. m[f][connected component id]
        2. m[f][c][region within the connected component]
        @ level 1, we have the connected component
        as given in gt mask. at this level m[f][c][0] 
        always points to the large region
        if the region is large enough, we then split it
        to smaller sub-regions at m[f][c][>=1].
        '''

        ' build the datalist '
        self.datalist = []
        cls = np.zeros(args.num_classes, )

        ' build patch portion of ds '
        if 'P' in metadata:
            P = copy.deepcopy(metadata['P'][0])
            del metadata['P']

            P_dims = {}
            for key in P:
                d = P[key]['dimensions']

                if d not in P_dims:
                    params = {
                        'num_center_points': HR_NUM_CNT_SAMPLES,
                        'num_perim_points': HR_NUM_PERIM_SAMPLES,
                        'scan_level': HR_SCAN_LEVEL,
                        'tile_w': HR_PATCH_W,
                        'tile_h': HR_PATCH_H,
                        'dimensions': d
                    }
                    params = preprocessing.DotDict(params)
                    P_dims[d] = regiontools.get_key_points_for_patch(params)

                item = {**P[key], **P_dims[d]}

                self.datalist.append(item)
                cls[item['label']] += 1

        ' build wsi regions portion '
        self.wsis = {}

        for filename in metadata:
            first_region_id = list(metadata[filename].keys())[0]
            first_sub_region_id = list(
                metadata[filename][first_region_id].keys())[0]
            pth = metadata[filename][first_region_id][first_sub_region_id][
                'wsipath']
            pth = (self.base_path / pth).resolve().as_posix()
            self.wsis[pth] = openslide.OpenSlide(pth)

            if remove_white:
                'get low res. nuclei image/foreground mask'
                scan = self.wsis[pth]
                x, y = scan.level_dimensions[-1]
                mask = scan.read_region((0, 0), scan.level_count - 1,
                                        (x, y)).convert('RGB')
                mask = mask.resize((x // 4, y // 4))
                mask = preprocessing.find_nuclei(mask)
                mask = Image.fromarray(mask.astype(np.uint8)).resize((x, y))
                mask = np.asarray(mask)

            params = {
                'iw':
                self.wsis[pth].level_dimensions[0][0],
                'ih':
                self.wsis[pth].level_dimensions[0][1],
                'tile_w':
                HR_PATCH_W,
                'tile_h':
                HR_PATCH_H,
                'scan_level':
                metadata[filename][first_region_id][first_sub_region_id]
                ['scan_level']
            }
            params = preprocessing.DotDict(params)

            for conncomp in metadata[filename]:
                for id in metadata[filename][conncomp]:
                    region_obj = metadata[filename][conncomp][id].copy()

                    if remove_white:
                        'given points, remove patches that are only white'
                        region_obj[
                            'cnt_xy'], num_cnt_pts = regiontools.remove_white_region(
                                mask, region_obj['cnt_xy'], params)
                        region_obj[
                            'perim_xy'], num_perim_pts = regiontools.remove_white_region(
                                mask, region_obj['perim_xy'], params)

                    'which points valid for this patch size, scan level combo?'
                    region_obj['cnt_xy'], num_cnt_pts = regiontools.map_points(
                        region_obj['cnt_xy'], params)
                    region_obj[
                        'perim_xy'], num_perim_pts = regiontools.map_points(
                            region_obj['perim_xy'], params)

                    if num_cnt_pts >= HR_NUM_CNT_SAMPLES and \
                            num_perim_pts >= HR_NUM_PERIM_SAMPLES:
                        self.datalist.append(region_obj)
                        cls[region_obj['label']] += 1

        self.eval = eval

        cls = np.array(cls)
        '''cls[0] += cls[1]
        cls[1] = cls[2]
        cls[2] = cls[3]
        cls[3] = 0'''

        print(cls)
        cls = cls / cls.sum()
        print(cls)
        if not self.eval:
            args.cls_ratios = cls

        ' augmentation settings '
        self.image_aug = preprocessing.standard_augmentor(self.eval)

        if not self.eval:
            from itertools import chain
            self.datalist = list(
                chain(*[[i] * duplicate_dataset for i in self.datalist]))
示例#5
0
    def __init__(self, wsipth, params):

        self.params = params

        ' build the dataset '
        self.datalist = []

        'read the wsi scan'
        filename = os.path.basename(wsipth)
        self.scan = openslide.OpenSlide(wsipth)

        ' if a slide has less levels than our desired scan level, ignore the slide'
        if len(self.scan.level_dimensions) - 1 >= args.scan_level:

            self.params.iw, self.params.ih = self.scan.level_dimensions[
                args.scan_level]

            'gt mask'
            'find nuclei is slow, hence save masks from preprocessing' \
            'for later use'
            msk_pth = '{}/{}.png'.format(args.wsi_mask_pth, filename)
            if not os.path.exists(msk_pth):
                thmb = self.scan.read_region(
                    (0, 0), 2, self.scan.level_dimensions[2]).convert('RGB')
                mask = preprocessing.find_nuclei(thmb)
                Image.fromarray(mask.astype(np.uint8)).save(msk_pth)
            else:
                mask = Image.open(msk_pth).convert('L')
                mask = np.asarray(mask)

            ' augmentation settings '
            self.image_aug = preprocessing.standard_augmentor(True)

            'downsample multiplier'
            m = self.scan.level_downsamples[
                args.scan_level] / self.scan.level_downsamples[2]
            dx, dy = int(self.params.pw * m), int(self.params.ph * m)

            for ypos in range(1, self.params.ih - 1 - self.params.ph,
                              self.params.sh):
                for xpos in range(1, self.params.iw - 1 - self.params.pw,
                                  self.params.sw):
                    yp, xp = int(ypos * m), int(xpos * m)
                    if not preprocessing.isforeground(mask[yp:yp + dy,
                                                           xp:xp + dx]):
                        continue
                    self.datalist.append((xpos, ypos))

            xpos = self.params.iw - 1 - self.params.pw
            for ypos in range(1, self.params.ih - 1 - self.params.ph,
                              self.params.sh):
                yp, xp = int(ypos * m), int(xpos * m)
                if not preprocessing.isforeground(mask[yp:yp + dy,
                                                       xp:xp + dx]):
                    continue
                self.datalist.append((xpos, ypos))

            ypos = self.params.ih - 1 - self.params.ph
            for xpos in range(1, self.params.iw - 1 - self.params.pw,
                              self.params.sw):
                yp, xp = int(ypos * m), int(xpos * m)
                if not preprocessing.isforeground(mask[yp:yp + dy,
                                                       xp:xp + dx]):
                    continue
                self.datalist.append((xpos, ypos))
def predict_breastpathq(model, ep, dataset_path, label_csv_path):
    import csv

    image_aug = preprocessing.standard_augmentor(True)
    image_resize = torchvision.transforms.Resize((args.tile_h, args.tile_w))

    model.eval()
    model.regressor.eval()
    model.classifier.eval()
    model.decoder.eval()

    with torch.no_grad():

        with open('Ozan_Results_{}.csv'.format(ep), 'w',
                  newline='') as csv_write:

            fieldnames = ['slide', 'rid', 'p']
            writer = csv.DictWriter(csv_write, fieldnames=fieldnames)
            writer.writeheader()

            with open('{}'.format(label_csv_path)) as csv_file:

                csv_reader = csv.reader(csv_file, delimiter=',')
                next(csv_reader)

                for row in csv_reader:

                    image_id = int(row[0])
                    region_id = int(row[1])

                    pth = '{}/{}_{}.tif'.format(dataset_path, image_id,
                                                region_id)

                    image = Image.open(pth).convert('RGB')
                    image = image_resize(image)
                    image = image_aug(image)
                    image = image.cuda()
                    image = image.unsqueeze_(0)

                    pred_cls_ = None
                    augmented_set = [
                        image,
                        image.transpose(2, 3),
                        image.flip(2),
                        image.transpose(2, 3).flip(3)
                    ]
                    for image_ in augmented_set:
                        encoding = model.encoder(image_)
                        pred_cls = model.regressor(encoding[0])
                        if pred_cls_ is None:
                            pred_cls_ = pred_cls.view(-1)
                        else:
                            pred_cls_ += pred_cls.view(-1)

                    pred_cls = pred_cls_ / len(augmented_set)

                    pred_cls = pred_cls.cpu().numpy()[0]
                    pred_cls = np.maximum(pred_cls, 0.0)
                    pred_cls = np.minimum(pred_cls, 1.0)

                    writer.writerow({
                        fieldnames[0]: image_id,
                        fieldnames[1]: region_id,
                        fieldnames[2]: pred_cls
                    })