def segmentCell(image, segmentator):
    '''
    segment cell and nuclei from
    microtubules, endoplasmic reticulum,
    and nuclei (R, B, Y) filters.
    ------------------------------------
    Argument:
    image -- (R, B, Y) list of image arrays
    segmentator -- CellSegmentator class object
    
    Returns:
    cell_mask -- segmented cell mask
    '''

    nuc_segmentations = segmentator.pred_nuclei(image[2])
    cell_segmentations = segmentator.pred_cells(image)
    nuclei_mask, cell_mask = label_cell(nuc_segmentations[0],
                                        cell_segmentations[0])

    gc.collect()
    del nuc_segmentations
    del cell_segmentations
    del nuclei_mask

    return cell_mask
Example #2
0
def segmentCell(image, segmentator, seg_with="yellow"):
    '''
    segment cell and nuclei from
    microtubules, endoplasmic reticulum,
    and nuclei (R, B, Y) filters.
    ------------------------------------
    Argument:
    image -- (R, B, Y) list of image arrays
    segmentator -- CellSegmentator class object

    Returns:
    cell_mask -- segmented cell mask
    '''
    mapping = {"red": 0, "blue": 1, "yellow": 2}

    nuc_segmentations = segmentator.pred_nuclei(image[mapping[seg_with]])
    cell_segmentations = segmentator.pred_cells(image)
    nuclei_mask, cell_mask = label_cell(nuc_segmentations[0],
                                        cell_segmentations[0])

    gc.collect()
    del nuc_segmentations
    del cell_segmentations
    del nuclei_mask

    return cell_mask
Example #3
0
def get_segmentation_mask(ref_channels: List[str], segmentator: CellSegmentator):
    'Return cell segmentation mask for single image using paths to reference channels'
    # Ref channels must be in order red, yellow, blue
    input_ = [[i] for i in ref_channels]  # Segmentator only accepts list of lists input
    nuc_segmentation = segmentator.pred_nuclei(input_[2])[0]
    cell_segmentation = segmentator.pred_cells(input_)[0]
    mask = label_cell(nuc_segmentation, cell_segmentation)[1]
    return mask
Example #4
0
def get_masks(imgs):
    images = [[img[:, :, 0] for img in imgs], [img[:, :, 3] for img in imgs],
              [img[:, :, 2] for img in imgs]]

    nuc_segmentations, median_nuc_sizes = segmentator.pred_nuclei(images[2])
    cell_segmentations, init_sizes = segmentator.pred_cells(
        images, median_nuc_sizes=median_nuc_sizes)
    cell_masks = []
    for i in range(len(cell_segmentations)):
        cell_mask = label_cell(nuc_segmentations[i],
                               cell_segmentations[i],
                               median_nuc_sizes[i],
                               return_nuclei_label=False)
        cell_masks.append(cell_mask)

    return cell_masks
Example #5
0
def save_masks(from_dir, to_dir, save_cell_mask=True, save_nuc_mask=True):
    if not os.path.exists(to_dir):
        os.makedirs(to_dir)

    microtubule: List[str] = glob.glob(from_dir + '/' + '*_red.png')
    endo_ret: List[str] = [e.replace('red', 'yellow') for e in microtubule]
    nuclei: List[str] = [n.replace('red', 'blue') for n in microtubule]
    images: List[List[str]] = [microtubule, endo_ret, nuclei]

    nuc_segmentations = SEGMENTATOR.pred_nuclei(images[2])
    cell_segmetnations = SEGMENTATOR.pred_cells(images)

    for idx, predictions in enumerate(cell_segmetnations):
        nuc_mask, cell_mask = label_cell(nuc_segmentations[idx], cell_segmetnations[idx])
        if save_cell_mask:
            cell_mask_name = os.path.basename(microtubule[idx]).replace('red', 'cell_mask')
            imageio.imwrite(os.path.join(to_dir, cell_mask_name), cell_mask)
        if save_nuc_mask:
            nuc_mask_name = os.path.basename(microtubule[idx]).replace('red', 'nuc_mask')
            imageio.imwrite(os.path.join(to_dir, nuc_mask_name), nuc_mask)
Example #6
0
def main(images=None):
    """This part is for implementation, you can customize for your own usage."""
    segmentator = CellSegmentator(
        nuclei_model="./nuclei_model.pth",
        cell_model="./cell_3ch_model.pth",
        scale_factor=0.25,
        device="cuda:1",
        padding=True,
        multi_channel_model=False,
    )

    nuclei_preds = segmentator.pred_nuclei(images[2])
    cell_preds = segmentator.pred_cells(images)

    nuclei_pred = nuclei_preds[0]
    cell_preds = cell_preds[0]
    # this is the post-processing part
    # this will give you both cell_mask and nuclei _mask
    nuclei_mask, cell_mask = label_cell(nuclei_pred, cell_pred)
    # this is for nuclei mask generation
    nuclei_mask = label_nuclei(nuclei_pred)

    # get what ever you want to get
    return nuclei_preds, cell_preds
Example #7
0
        rpath = [f'{data_root}/{x}_red.png' for x in prefix_list]
        gpath = [f'{data_root}/{x}_green.png' for x in prefix_list]
        bpath = [f'{data_root}/{x}_blue.png' for x in prefix_list]
        ypath = [f'{data_root}/{x}_yellow.png' for x in prefix_list]
        imgs = [rpath, ypath, bpath]
        print(f'input size: {len(imgs[0])}, {len(imgs[1])}, {len(imgs[2])}')

        segs = segmentor.pred_cells(imgs)
        seg_nuc = segmentor.pred_nuclei(imgs[2])

        for i in range(len(seg_nuc)):
            prefix = prefix_list[i]
            label = label_list[i]
            g = cv2.imread(f'{data_root}/{prefix}_green.png', cv2.IMREAD_GRAYSCALE)
            w, h = g.shape
            nuclei_mask, cell_mask = label_cell(seg_nuc[i], segs[i])
            single_masks = [cell_mask == lb for lb in range(1, cell_mask.max() + 1)]
            for j in range(len(single_masks)):
                x, y = np.where(single_masks[j])
                index = np.meshgrid(np.arange(min(x), max(x) + 1), np.arange(min(y), max(y) + 1), indexing='xy')
                cropped = g[index]
                ws, hs = cropped.shape
                if not valid_size(cropped.shape):
                    continue
                line = [prefix, 0, 0, 0, j, label, ws, hs]
                save.append(line)
                print(len(save), line)
                cv2.imwrite(os.path.join(save_root, f'{prefix}_{j}.png'), cropped)
    except Exception as e:
        print(e)
    print
    for blue_images, ryb_images, sizes, _ids in dataloader_ims_seg:

        print(f"SEGMENT COUNT: {im_proc}")

        blue_batch = blue_images.numpy()
        ryb_batch = ryb_images.numpy()

        #print(blue_batch.shape)
        nuc_segmentations = segmentator_even_faster.pred_nuclei(blue_batch)
        cell_segmentations = segmentator_even_faster.pred_cells(
            ryb_batch, precombined=True)

        for data_id, nuc_seg, cell_seg, size in zip(_ids, nuc_segmentations,
                                                    cell_segmentations, sizes):
            _, cell = utils.label_cell(nuc_seg, cell_seg)
            even_faster_outputs.append(np.ubyte(cell))
            output_ids.append(data_id)
            sizes_list.append(size.numpy())
        im_proc += len(_ids)
        #if im_proc > 20:
        #break
    del dataloader_ims_seg
    print(time.time() - start_time)

cell_masks_df = pd.DataFrame(list(
    zip(output_ids, even_faster_outputs, sizes_list)),
                             columns=["ID", "mask", "ori_size"])

cell_masks_df = cell_masks_df.set_index('ID')
cell_masks_df = cell_masks_df.reindex(index=data_df['ID'])
Example #9
0
    def save_predictions(self, save_cams=True, save_masks=False, viz=False, infer=False, retmasks=False):
        image_rles = []
        tot_inf = []
        masklst = []
        with torch.no_grad():
            for step, (image, unnorm_image, image_id) in (enumerate(tqdm(self.dl))):
                # print(unnorm_image.shape)
                if not os.path.exists(self.pth):
                    os.makedirs(self.pth)

                with torch.cuda.amp.autocast(enabled=self.use_amp):
                    # unnorm_image = unnorm_image.to(self.dev)
                    image = image.to(self.dev)

                    # preds1, CAMs1, masks1 = get_cam(self.model, image, ttaflag=True, scale=1)

                    preds, CAMs, masks = get_cam(self.model, image, ttaflag=True, scale=1)
                    preds1, CAMs1, masks1 = get_cam(self.model, image, ttaflag=False, scale=1)
                    masks = masks.cpu().type(torch.FloatTensor)
                    masks1 = masks1.cpu().type(torch.FloatTensor)



                    # if retmasks:
                    #     masklst.append(masks)


                    # preds1, CAMs1, masks1 = get_cam(self.model, image, 1, flag=True)#TODO: few scales. flips. other tta; use mask
                    # CAMs1 = CAMs1.cpu()

                    CAMs = CAMs.cpu()
                    print(CAMs.size())
                    print(masks.shape)
                    # return preds, CAMs, masks, unnorm_image

                    # unnorm_image /= 255

                    # unnorm_image = unnorm_image.cpu()

                    cells = []

                    rgb_batch = [unnorm_image[..., [0, 3, 2]][i].numpy().astype(float) / 255 for i in
                          range(unnorm_image.size(0))]
                    blue_batch = [unnorm_image[..., 2][i].numpy().astype(float) / 255 for i in range(unnorm_image.size(0))]
                    nuc_segmentations = self.segmentator.pred_nuclei(blue_batch)
                    cell_segmentations = self.segmentator.pred_cells(rgb_batch, precombined=True)
                    for nuc_seg, cell_seg in zip(nuc_segmentations, cell_segmentations):
                        _, cell = label_cell(nuc_seg, cell_seg)
                        cells.append(cell)

                    if infer:
                        resized_cells = []

                        for i in range(len(image_id)):
                            cur_cell = cells[i]
                            og_size = self.df[self.df.ID == image_id[i]]
                            og_wh = (og_size.ImageWidth.values[0], og_size.ImageHeight.values[0])
                            resized = res_dict[og_wh[0]](image=np.random.randn(*og_wh), mask=cur_cell)['mask']
                            resized_cells.append(resized)
                        for i in range(len(image_id)):


                            # loaded = np.load(os.path.join(b, lds))['arr_0']
                            uniqs = np.unique(resized_cells[i]).tolist()
                            cur_rles = []
                            assert uniqs[0] == 0
                            for u in uniqs[1:]:
                                cur_rles.append(encode_binary_mask((resized_cells[i] == u)).decode("utf-8"))
                            image_rles.append(cur_rles)

                        for i in range(len(image_id)):
                            tot_inf.append(self.cust1(cells[i], CAMs[i]))
                    # for j in range(len(image_id)):
                    #     np.save(f'/common/danylokolinko/hpa_mask_semantic/nuc/{image_id[j]}', cv2.resize(nuc_segmentations[j], (1024, 1024)))


                    if save_cams:
                        for j in range(len(image_id)):
                            torch.save(CAMs[j], os.path.join(self.pth, f'{image_id[j]}.npy'))

                    if save_masks:
                        for j in range(len(image_id)):
                            np.save(self.cell_load(image_id[j]), cells[j])
                    if viz:
                        hrcams = nn.Sigmoid()(resize_for_tensors(CAMs.type(torch.FloatTensor), (512, 512)))
                        hrcams1 = nn.Sigmoid()(resize_for_tensors(CAMs1.type(torch.FloatTensor), (512, 512)))


                        for i, id in enumerate(image_id):

                            print_masked_img(self.config[self.train_str]['path'], self.train_str, id, hrcams[i], cell_mask=cells[i], cell_pred=nn.Sigmoid()(masks[i]))
                            print_masked_img(self.config[self.train_str]['path'], self.train_str, id, hrcams1[i],  cell_mask=cells[i], cell_pred=nn.Sigmoid()(masks1[i]))
                    if retmasks:
                        return masklst

        if infer:
            return [tot_inf, image_rles]