예제 #1
0
def preproc(ds_path, ratio):

    # get filenames
    filenames = glob(ds_path + "/*.tiff") + glob(ds_path + "/*.tif")
    random.shuffle(filenames)

    if ratio < 1:
        filenames = filenames[:int(round(ratio * len(filenames)))]

    set_a = []
    set_b = []
    for fn in tqdm(filenames):
        img = np.squeeze(imread(fn))
        a, b = suggest_normalization_param(img)
        set_a.append(a)
        set_b.append(b)

    print(set_a)
    print(set_b)

    print(statistics.mean(set_a))
    print(statistics.mean(set_b))

    print(statistics.stdev(set_a))
    print(statistics.stdev(set_b))

    print(statistics.mean(set_a) + 3 * statistics.stdev(set_a))
    print(statistics.mean(set_b) + 3 * statistics.stdev(set_b))

    print("please use")

    print(max(set_a))
    print(max(set_b))
예제 #2
0
def create_test_image(structure_name: str, output_type: str = "default"):
    # load structure wrapper for specified structure
    structure_name = structure_name.lower()
    module_name = DEFAULT_MODULE_PATH + structure_name
    try:
        seg_module = importlib.import_module(module_name)
        function_name = "Workflow_" + structure_name
        SegModuleFunction = getattr(seg_module, function_name)
    except Exception as e:
        print(
            f"raising failure while trying to get module/function for {module_name}"
        )
        raise e

    # load stock random image
    random_array = imread(Path(TEST_IMG_DIR +
                               "random_input.tiff")).reshape(*BASE_IMAGE_DIM)

    # conduct segmentation
    output_array = SegModuleFunction(
        struct_img=random_array,
        rescale_ratio=RESCALE_RATIO,
        output_type=output_type,
        output_path=TEST_IMG_DIR,
        fn="expected_" + structure_name,
    )
    return output_array
예제 #3
0
 def open_path(self, path):
     if os.path.isfile(path):
         viewer.layers.select_all()
         viewer.layers.remove_selected()
         if path.endswith((".tiff", ".tif", ".czi")):
             # Uses aicsimageio to open these files
             # (image.io but the channel in last dimension which doesn't)
             # work with napari
             image = imread(path)
         if path.endswith((".jpg", ".png")):
             image = io.imread(path)
         if path.endswith((".ome.tiff", "ome.tif")):
             # a little slow but dask_image imread doesn't work well
             javabridge.start_vm(class_path=bioformats.JARS)
             #image = load_bioformats(path)
             image = imread(path)
         viewer.add_image(image, name="image_1")
    def execute(self, args):

        if not args.data_type.startswith('.'):
            args.data_type = '.' + args.data_type

        filenames = glob(args.raw_path + os.sep +'*' + args.data_type)
        filenames.sort()

        existing_files = glob(args.train_path+os.sep+'img_*.ome.tif')
        print(len(existing_files))

        training_data_count = len(existing_files)//3
        for _, fn in enumerate(filenames):
            
            training_data_count += 1
            
            # load raw
            reader = AICSImage(fn)
            struct_img = reader.get_image_data("CZYX", S=0, T=0, C=[args.input_channel]).astype(np.float32)
            struct_img = input_normalization(img, args)

            # load seg
            seg_fn = args.seg_path + os.sep + os.path.basename(fn)[:-1*len(args.data_type)] + '_struct_segmentation.tiff'
            seg = np.squeeze(imread(seg_fn)) > 0.01
            seg = seg.astype(np.uint8)
            seg[seg>0]=1

            # excluding mask
            cmap = np.ones(seg.shape, dtype=np.float32)
            mask_fn = args.mask_path + os.sep + os.path.basename(fn)[:-1*len(args.data_type)] + '_mask.tiff'
            if os.path.isfile(mask_fn):
                mask = np.squeeze(imread(mask_fn))
                cmap[mask==0]=0

            with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '.ome.tif') as writer:
                writer.save(struct_img)

            with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_GT.ome.tif') as writer:
                writer.save(seg)
            
            with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_CM.ome.tif') as writer:
                writer.save(cmap)
예제 #5
0
def test_imread(resources_dir, filename, expected_shape):
    # Get filepath
    f = resources_dir / filename

    # Check that there are no open file pointers after init
    proc = Process()
    assert str(f) not in [f.path for f in proc.open_files()]

    # Check basics
    img = imread(f)
    assert img.shape == expected_shape

    # Check that there are no open file pointers after basics
    assert str(f) not in [f.path for f in proc.open_files()]
예제 #6
0
def unit_test(structure_name: str):
    structure_name = structure_name.lower()
    # segment stock random image with current semgentation versions
    output_array = create_test_image(structure_name,
                                     output_type="array").ravel()

    # get rid of STC dimensions from AICSImage format, resized to resize_ratio
    expected_output = imread(
        Path(TEST_IMG_DIR + "expected_" + structure_name +
             "_struct_segmentation.tiff")).ravel()

    assert np.allclose(
        output_array,
        expected_output), ("Tested and expected outputs differ for " +
                           structure_name)
예제 #7
0
def test_large_imread(resources_dir, filename, expected_shape,
                      expected_task_count):
    # Get filepath
    f = resources_dir / filename

    # Check that there are no open file pointers after init
    proc = Process()
    assert str(f) not in [f.path for f in proc.open_files()]

    # Check basics
    with Profiler() as prof:
        img = imread(f)
        assert img.shape == expected_shape
        assert len(prof.results) == expected_task_count

    # Check that there are no open file pointers after basics
    assert str(f) not in [f.path for f in proc.open_files()]
예제 #8
0
def test_imread(resources_dir, filename, expected_shape):
    # Get filepath
    f = resources_dir / filename

    # Check that there are no open file pointers after init
    proc = Process()
    assert str(f) not in [f.path for f in proc.open_files()]

    # Check basics
    with Profiler() as prof:
        img = imread(f)
        assert img.shape == expected_shape

        # Reshape and transpose are required so there should be two tasks in the graph
        assert len(prof.results) == 2

    # Check that there are no open file pointers after basics
    assert str(f) not in [f.path for f in proc.open_files()]
    def load_from_file(self, filenamesA, filenamesB=None, num_patch=-1):
        # assumption: transfer is from A to B
        self.imgA = []
        self.imgB = []
        self.imgA_path = []
        self.imgA_short_path = []
        self.imgB_path = []

        if filenamesB is not None:
            assert len(filenamesA) == len(
                filenamesB), "source/target num mismatch"

        num_data = len(filenamesA)
        assert num_data > 0, "no source type data found"

        # how many patches to take from each image
        self.num_patch_per_img = np.zeros((num_data, ), dtype=int)

        if num_patch == -1:
            # take all patches in a "stitch" way
            Stitch = True
        else:
            Stitch = False
            if num_data >= num_patch:
                print("suggest to use more patch in each buffer")
                self.num_patch_per_img[:num_patch] = 1
            else:
                basic_num = num_patch // num_data
                self.num_patch_per_img[:] = basic_num
                self.num_patch_per_img[:(num_patch - basic_num * num_data)] = (
                    self.num_patch_per_img[:(num_patch - basic_num * num_data)]
                    + 1)

        self.filenamesA = filenamesA
        self.filenamesB = filenamesB

        # for_calc_ave_offset is used to flag if a patch is to be used for
        # calculating mean offset for AutoAlign
        self.for_calc_ave_offset = []

        if self.opt.network["model"] == "stn" and self.opt.stn_adjust_fixed_z:
            print(f"read offsets from {self.opt.readoffsetfrom}")
            assert os.path.isfile(
                self.opt.readoffsetfrom
            ), f"opt.readoffsetfrom path: {self.opt.readoffsetfrom} is not found! \
                If you want to align images, set the correct path. If not, \
                set opt.stn_adjust_fixed_z=False"

            with open(self.opt.readoffsetfrom, "r") as fp:
                fixed_dict1 = {}
                for line in fp.readlines():
                    key, z, y, x = line.strip().split(",")
                    z = float(z)
                    y = float(y)
                    x = float(x)
                    if key in fixed_dict1:
                        fixed_dict1[key].append([z, y, x])
                    else:
                        fixed_dict1[key] = [
                            [z, y, x],
                        ]
                for key in fixed_dict1:
                    # only use the parameters that are between [10,90] percents.
                    clip_param = np.percentile(
                        np.array(fixed_dict1[key])[:, 0], [10, 90])
                    offset_raw = []
                    for i in range(len(fixed_dict1[key])):
                        if (fixed_dict1[key][i][0] > clip_param[0]
                                and fixed_dict1[key][i][0] < clip_param[1]):
                            offset_raw.append(fixed_dict1[key][i])
                    offset_raw = np.array(offset_raw)
                    z_std = np.std(offset_raw, axis=0)[0]

                    if z_std > 0.5:
                        print(
                            f"WARNING: The standard deviation of offsets estimation\
                             for {key} is {z_std}. Not accurate!")
                        with open(self.opt.resultroot / "WARNING", "w") as wp:
                            wp.write(
                                f"WARNING: The standard deviation of offsets \
                                estimation for {key} is {z_std}. Not accurate!"
                            )
                    fixed_dict1[key] = np.mean(offset_raw, axis=0)

        for idxA, fnA in tqdm(enumerate(filenamesA)):
            fnnA = fnA.split("/")[-1]

            # expected patch is met (when num_patch = -1, the loading will go thr all)
            if len(self.imgA) == num_patch:
                break

            # load source domain image
            # source_reader = AICSImage(fnA)  # STCZYX
            # src_img = source_reader.get_image_data("ZYX", S=0, T=0, C=0)
            src_img = np.squeeze(imread(fnA))

            # run intensity normalization
            src_img = self.source_norm(src_img,
                                       bulk_params=self.source_norm_param)

            if filenamesB is not None:
                idxB = idxA
                fnB = filenamesB[idxB]

                # load target domain image
                # target_reader = AICSImage(fnB)  # STCZYX
                # tar_img = target_reader.get_image_data("ZYX", S=0, T=0, C=0)
                tar_img = np.squeeze(imread(fnB))

                # run intensity normalization
                tar_img = self.target_norm(tar_img,
                                           bulk_params=self.target_norm_param)

                # determine new size for source
                new_size = (tar_img.shape[0], tar_img.shape[1],
                            tar_img.shape[2])

            else:
                r = self.opt.normalization["source"]["ratio_param"]
                nz = int(np.round(src_img.shape[0] * r[0]))
                ny = int(np.round(src_img.shape[1] * r[1]))
                nx = int(np.round(src_img.shape[2] * r[2]))
                new_size = (nz, ny, nx)

            src_img = resize_to(src_img, new_size, method="bilinear")

            if self.opt.network["model"] in ["stn"
                                             ] and self.opt.stn_adjust_image:
                if self.opt.isTrain:
                    shifted_stacks_dir = self.opt.resultroot + "/shift/"
                    if not os.path.isdir(shifted_stacks_dir):
                        os.makedirs(shifted_stacks_dir)
                    if fnnA in self.stn_adjust_dict:
                        imsave(shifted_stacks_dir + f"{fnnA}_rA.tiff",
                               src_img[0])
                        print(fnnA, self.stn_adjust_dict[fnnA])
                        offsets_zyx = self.stn_adjust_dict[fnnA]
                        offsets_zyx[
                            0] = offsets_zyx[0] * 1.0 / self.up_scale[0]
                        offsets_zyx[
                            1] = offsets_zyx[1] * 1.0 / self.up_scale[1]
                        offsets_zyx[
                            2] = offsets_zyx[2] * 1.0 / self.up_scale[2]
                        with open(shifted_stacks_dir + "shift.log", "a") as fp:
                            fp.write(
                                f"{fnnA},{offsets_zyx[0]},{offsets_zyx[1]},\
                                {offsets_zyx[2]}\n")
                        offsets_zyx = from_numpy(offsets_zyx)
                        tensor = from_numpy(
                            np.expand_dims(np.expand_dims(src_img, 0), 0))
                        label = self.apply_adjust(tensor, offsets_zyx)
                        label = np.squeeze(label.detach().cpu().numpy(),
                                           axis=0)
                        imsave(shifted_stacks_dir + f"{fnnA}_rA_new.tiff",
                               label[0])
            elif self.opt.network["model"] in [
                    "stn"
            ] and self.opt.stn_adjust_fixed_z:
                if self.opt.isTrain:
                    print("adjust fixed z")
                    if fnnA in fixed_dict1:
                        z, y, x = fixed_dict1[fnnA]
                    else:
                        z, y, x = 0, 0, 0
                        raise ValueError(f"****\n\nERROR: {fnnA} is not found \
                            in {self.opt.readoffsetfrom}, please train the AutoAlign \
                                first to get the offsets\n\n ***************\n"
                                         )

                    if self.opt.align_all_axis:
                        offsets_zyx = np.array((
                            z / self.up_scale[0],
                            y / self.up_scale[1],
                            x / self.up_scale[2],
                        ))
                    else:
                        offsets_zyx = np.array((z / self.up_scale[0], 0, 0))
                    offsets_zyx = from_numpy(offsets_zyx)
                    tensor = from_numpy(
                        np.expand_dims(np.expand_dims(src_img, 0), 0))
                    label = self.apply_adjust(tensor, offsets_zyx)
                    label = np.squeeze(label.detach().cpu().numpy(), axis=0)
                    label = label[:, 1:-1, :, :]
                    tar_img = tar_img[
                        int(self.up_scale[0]):-int(self.up_scale[0]), :, :]

            if Stitch:
                overlap_step = 0.5
                if filenamesB is not None:
                    self.positionB = [
                        new_size,
                    ]
                self.positionA = [
                    new_size,
                ]
                px_list, py_list, pz_list = [], [], []
                px, py, pz = 0, 0, 0
                while px < new_size[2] - self.size_in[2]:
                    px_list.append(px)
                    px += int(self.size_in[2] * overlap_step)
                px_list.append(new_size[2] - self.size_in[2])
                while py < new_size[1] - self.size_in[1]:
                    py_list.append(py)
                    py += int(self.size_in[1] * overlap_step)
                py_list.append(new_size[1] - self.size_in[1])
                while pz < new_size[0] - self.size_in[0]:
                    pz_list.append(pz)
                    pz += int(self.size_in[0] * overlap_step)
                pz_list.append(new_size[0] - self.size_in[0])
                for pz_in in pz_list:
                    for py_in in py_list:
                        for px_in in px_list:
                            (self.imgA).append(
                                np.expand_dims(
                                    src_img[pz_in:pz_in + self.size_in[0],
                                            py_in:py_in + self.size_in[1],
                                            px_in:px_in + self.size_in[2], ],
                                    axis=0,
                                ))
                            (self.imgA_path).append(fnA)
                            (self.imgA_short_path).append(fnnA)
                            pz_out = pz_in * self.up_scale[0]
                            py_out = py_in * self.up_scale[1]
                            px_out = px_in * self.up_scale[2]

                            if filenamesB is not None:
                                (self.imgB).append(
                                    np.expand_dims(
                                        tar_img[pz_out:pz_out +
                                                self.size_out[0],
                                                py_out:py_out +
                                                self.size_out[1],
                                                px_out:px_out +
                                                self.size_out[2], ],
                                        axis=0,
                                    ))
                                (self.imgB_path).append(fnB)
                                self.positionB.append((pz_out, py_out, px_out))

                            self.positionA.append((pz_in, py_in, px_in))
                            self.for_calc_ave_offset.append(False)
            else:
                # TODO: data augmentation, only cropping now
                new_patch_num = 0
                while new_patch_num < self.num_patch_per_img[idxA]:
                    pz = random.randint(0, tar_img.shape[0] - self.size_in[0])
                    py = random.randint(0, tar_img.shape[1] - self.size_in[1])
                    px = random.randint(0, tar_img.shape[2] - self.size_in[2])
                    (self.imgA).append(
                        np.expand_dims(
                            src_img[pz:pz + self.size_in[0],
                                    py:py + self.size_in[1],
                                    px:px + self.size_in[2], ],
                            axis=0,
                        ))
                    (self.imgA_path).append(fnA)
                    (self.imgA_short_path).append(fnnA)

                    # TODO: good crop?
                    if not self.aligned:
                        pz = random.randint(
                            0, tar_img.shape[1] - self.size_out[0])
                        py = random.randint(
                            0, tar_img.shape[2] - self.size_out[1])
                        px = random.randint(
                            0, tar_img.shape[3] - self.size_out[2])
                    else:
                        pz = pz * self.up_scale[0]
                        py = py * self.up_scale[1]
                        px = px * self.up_scale[2]
                    (self.imgB).append(
                        np.expand_dims(
                            tar_img[pz:pz + self.size_out[0],
                                    py:py + self.size_out[1],
                                    px:px + self.size_out[2], ],
                            axis=0,
                        ))
                    (self.imgB_path).append(fnB)
                    if new_patch_num > self.num_patch_per_img[idxA] * 1 // 3:
                        self.for_calc_ave_offset.append(True)
                    else:
                        self.for_calc_ave_offset.append(False)
                    new_patch_num += 1
    def __init__(self, filenames, num_patch, size_in, size_out):

        self.img = []
        self.gt = []
        self.cmap = []

        padding = [(x - y) // 2 for x, y in zip(size_in, size_out)]
        total_in_count = size_in[0] * size_in[1] * size_in[2]
        total_out_count = size_out[0] * size_out[1] * size_out[2]

        num_data = len(filenames)
        shuffle(filenames)
        num_patch_per_img = np.zeros((num_data, ), dtype=int)
        if num_data >= num_patch:
            # all one
            num_patch_per_img[:num_patch] = 1
        else:
            basic_num = num_patch // num_data
            # assign each image the same number of patches to extract
            num_patch_per_img[:] = basic_num

            # assign one more patch to the first few images to achieve the total patch number
            num_patch_per_img[:(num_patch -
                                basic_num * num_data)] = num_patch_per_img[:(
                                    num_patch - basic_num * num_data)] + 1

        for img_idx, fn in enumerate(filenames):

            if len(self.img) == num_patch:
                break

            label = np.squeeze(imread(fn + '_GT.ome.tif'))
            label = np.expand_dims(label, axis=0)

            input_img = np.squeeze(imread(fn + '.ome.tif'))
            if len(input_img.shape) == 3:
                # add channel dimension
                input_img = np.expand_dims(input_img, axis=0)
            elif len(input_img.shape) == 4:
                # assume number of channel < number of Z, make sure channel dim comes first
                if input_img.shape[0] > input_img.shape[1]:
                    input_img = np.transpose(input_img, (1, 0, 2, 3))

            costmap = np.squeeze(imread(fn + '_CM.ome.tif'))

            img_pad0 = np.pad(input_img,
                              ((0, 0), (0, 0), (padding[1], padding[1]),
                               (padding[2], padding[2])), 'constant')
            raw = np.pad(img_pad0,
                         ((0, 0), (padding[0], padding[0]), (0, 0), (0, 0)),
                         'constant')

            cost_scale = costmap.max()
            if cost_scale < 1:  ## this should not happen, but just in case
                cost_scale = 1

            deg = random.randrange(1, 180)
            flip_flag = random.random()

            for zz in range(label.shape[1]):

                for ci in range(label.shape[0]):
                    labi = label[ci, zz, :, :]
                    labi_pil = Image.fromarray(np.uint8(labi))
                    new_labi_pil = labi_pil.rotate(deg, resample=Image.NEAREST)
                    if flip_flag < 0.5:
                        new_labi_pil = new_labi_pil.transpose(
                            Image.FLIP_LEFT_RIGHT)
                    new_labi = np.array(new_labi_pil.convert('L'))
                    label[ci, zz, :, :] = new_labi.astype(int)

                cmap = costmap[zz, :, :]
                cmap_pil = Image.fromarray(np.uint8(255 * (cmap / cost_scale)))
                new_cmap_pil = cmap_pil.rotate(deg, resample=Image.NEAREST)
                if flip_flag < 0.5:
                    new_cmap_pil = new_cmap_pil.transpose(
                        Image.FLIP_LEFT_RIGHT)
                new_cmap = np.array(new_cmap_pil.convert('L'))
                costmap[zz, :, :] = cost_scale * (new_cmap / 255.0)

            for zz in range(raw.shape[1]):
                for ci in range(raw.shape[0]):
                    str_im = raw[ci, zz, :, :]
                    str_im_pil = Image.fromarray(np.uint8(str_im * 255))
                    new_str_im_pil = str_im_pil.rotate(deg,
                                                       resample=Image.BICUBIC)
                    if flip_flag < 0.5:
                        new_str_im_pil = new_str_im_pil.transpose(
                            Image.FLIP_LEFT_RIGHT)
                    new_str_image = np.array(new_str_im_pil.convert('L'))
                    raw[ci, zz, :, :] = (new_str_image.astype(float)) / 255.0
            new_patch_num = 0

            while new_patch_num < num_patch_per_img[img_idx]:

                pz = random.randint(0, label.shape[1] - size_out[0])
                py = random.randint(0, label.shape[2] - size_out[1])
                px = random.randint(0, label.shape[3] - size_out[2])

                # check if this is a good crop
                ref_patch_cmap = costmap[pz:pz + size_out[0],
                                         py:py + size_out[1],
                                         px:px + size_out[2]]

                # confirmed good crop
                (self.img).append(raw[:, pz:pz + size_in[0],
                                      py:py + size_in[1], px:px + size_in[2]])
                (self.gt).append(label[:, pz:pz + size_out[0],
                                       py:py + size_out[1],
                                       px:px + size_out[2]])
                (self.cmap).append(ref_patch_cmap)

                new_patch_num += 1
예제 #11
0
    def execute(self, args):

        global draw_mask
        # part 1: do sorting
        df = pd.read_csv(args.csv_name, index_col=False)

        for index, row in df.iterrows():

            if not np.isnan(row['score']) and (row['score']==1 or row['score']==0):
                continue

            reader = AICSImage(row['raw'])
            struct_img = reader.get_image_data("ZYX", S=0, T=0, C=args.input_channel)
            struct_img[struct_img>5000] = struct_img.min()  # adjust contrast
            raw_img = (struct_img- struct_img.min() + 1e-8)/(struct_img.max() - struct_img.min() + 1e-8)
            raw_img = 255 * raw_img
            raw_img = raw_img.astype(np.uint8)

            seg = np.squeeze(imread(row['seg']))

            score = gt_sorting(raw_img, seg)
            if score == 1:
                df['score'].iloc[index]=1
                need_mask = input('Do you need to add a mask for this image, enter y or n:  ')
                if need_mask == 'y':
                    create_mask(raw_img, seg.astype(np.uint8))
                    mask_fn = args.mask_path + os.sep + os.path.basename(row['raw'])[:-5] + '_mask.tiff'
                    crop_mask = np.zeros(seg.shape, dtype=np.uint8)
                    for zz in range(crop_mask.shape[0]):
                        crop_mask[zz,:,:] = draw_mask[:crop_mask.shape[1],:crop_mask.shape[2]]

                    crop_mask = crop_mask.astype(np.uint8)
                    crop_mask[crop_mask>0]=255
                    with OmeTiffWriter(mask_fn) as writer:
                        writer.save(crop_mask)
                    df['mask'].iloc[index]=mask_fn
            else:
                df['score'].iloc[index]=0

            df.to_csv(args.csv_name, index=False)

        #########################################
        # generate training data:
        #  (we want to do this step after "sorting"
        #  (is mainly because we want to get the sorting 
        #  step as smooth as possible, even though
        #  this may waster i/o time on reloading images)
        # #######################################
        print('finish merging, start building the training data ...')

        existing_files = glob(args.train_path+os.sep+'img_*.ome.tif')
        print(len(existing_files))

        training_data_count = len(existing_files)//3
        
        for index, row in df.iterrows():
            if row['score']==1:
                training_data_count += 1

                # load raw image
                reader = AICSImage(row['raw'])
                img = reader.get_image_data("CZYX", S=0, T=0, C=[args.input_channel]).astype(np.float32)
                struct_img = input_normalization(img, args)
                struct_img= struct_img[0,:,:,:]

                # load segmentation gt
                seg = np.squeeze(imread(row['seg'])) > 0.01
                seg = seg.astype(np.uint8)
                seg[seg>0]=1

                cmap = np.ones(seg.shape, dtype=np.float32)
                if os.path.isfile(str(row['mask'])):
                    # load segmentation gt
                    mask = np.squeeze(imread(row['mask']))
                    cmap[mask>0]=0

                with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '.ome.tif') as writer:
                    writer.save(struct_img)

                with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_GT.ome.tif') as writer:
                    writer.save(seg)
                
                with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_CM.ome.tif') as writer:
                    writer.save(cmap)

        print('training data is ready')
예제 #12
0
def test_image(data_dir):
    # read in image file
    file = data_dir / "cell_30827.tiff"
    im = np.squeeze(imread(file))
    return im
    def train(self):

        ### load settings ###
        config = self.config  #TODO, fix this
        model = self.model

        # define loss
        #TODO, add more loss
        loss_config = config['loss']
        if loss_config['name'] == 'Aux':
            criterion = MultiAuxillaryElementNLLLoss(
                3, loss_config['loss_weight'], config['nclass'])
        else:
            print('do not support other loss yet')
            quit()

        # dataloader
        validation_config = config['validation']
        loader_config = config['loader']
        args_inference = lambda: None
        if validation_config['metric'] is not None:
            print('prepare the data ... ...')
            filenames = glob(loader_config['datafolder'] + '/*_GT.ome.tif')
            filenames.sort()
            total_num = len(filenames)
            LeaveOut = validation_config['leaveout']
            if len(LeaveOut) == 1:
                if LeaveOut[0] > 0 and LeaveOut[0] < 1:
                    num_train = int(np.floor((1 - LeaveOut[0]) * total_num))
                    shuffled_idx = np.arange(total_num)
                    random.shuffle(shuffled_idx)
                    train_idx = shuffled_idx[:num_train]
                    valid_idx = shuffled_idx[num_train:]
                else:
                    valid_idx = [int(LeaveOut[0])]
                    train_idx = list(
                        set(range(total_num)) - set(map(int, LeaveOut)))
            elif LeaveOut:
                valid_idx = list(map(int, LeaveOut))
                train_idx = list(set(range(total_num)) - set(valid_idx))

            valid_filenames = []
            train_filenames = []
            for fi, fn in enumerate(valid_idx):
                valid_filenames.append(filenames[fn][:-11])
            for fi, fn in enumerate(train_idx):
                train_filenames.append(filenames[fn][:-11])

            args_inference.size_in = config['size_in']
            args_inference.size_out = config['size_out']
            args_inference.OutputCh = validation_config['OutputCh']
            args_inference.nclass = config['nclass']

        else:
            #TODO, update here
            print('need validation')
            quit()

        if loader_config['name'] == 'default':
            from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0 as train_loader
            train_set_loader = DataLoader(
                train_loader(train_filenames, loader_config['PatchPerBuffer'],
                             config['size_in'], config['size_out']),
                num_workers=loader_config['NumWorkers'],
                batch_size=loader_config['batch_size'],
                shuffle=True)
        elif loader_config['name'] == 'focus':
            from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0C as train_loader
            train_set_loader = DataLoader(
                train_loader(train_filenames, loader_config['PatchPerBuffer'],
                             config['size_in'], config['size_out']),
                num_workers=loader_config['NumWorkers'],
                batch_size=loader_config['batch_size'],
                shuffle=True)
        else:
            print('other loader not support yet')
            quit()

        num_iterations = 0
        num_epoch = 0  #TODO: load num_epoch from checkpoint

        start_epoch = num_epoch
        for _ in range(start_epoch, config['epochs'] + 1):

            # sets the model in training mode
            model.train()

            optimizer = None
            optimizer = optim.Adam(model.parameters(),
                                   lr=config['learning_rate'],
                                   weight_decay=config['weight_decay'])

            # check if re-load on training data in needed
            if num_epoch > 0 and num_epoch % loader_config[
                    'epoch_shuffle'] == 0:
                print('shuffling data')
                train_set_loader = None
                train_set_loader = DataLoader(
                    train_loader(train_filenames,
                                 loader_config['PatchPerBuffer'],
                                 config['size_in'], config['size_out']),
                    num_workers=loader_config['NumWorkers'],
                    batch_size=loader_config['batch_size'],
                    shuffle=True)

            # Training starts ...
            epoch_loss = []

            for i, current_batch in tqdm(enumerate(train_set_loader)):

                inputs = Variable(current_batch[0].cuda())
                targets = current_batch[1]
                outputs = model(inputs)

                if len(targets) > 1:
                    for zidx in range(len(targets)):
                        targets[zidx] = Variable(targets[zidx].cuda())
                else:
                    targets = Variable(targets[0].cuda())

                optimizer.zero_grad()
                if len(current_batch) == 3:  # input + target + cmap
                    cmap = Variable(current_batch[2].cuda())
                    loss = criterion(outputs, targets, cmap)
                else:  # input + target
                    loss = criterion(outputs, targets)

                loss.backward()
                optimizer.step()

                epoch_loss.append(loss.data.item())
                num_iterations += 1

            average_training_loss = sum(epoch_loss) / len(epoch_loss)

            # validation
            if num_epoch % validation_config['validate_every_n_epoch'] == 0:
                validation_loss = np.zeros(
                    (len(validation_config['OutputCh']) // 2, ))
                model.eval()

                for img_idx, fn in enumerate(valid_filenames):

                    # target
                    label = np.squeeze(imread(fn + '_GT.ome.tif'))
                    label = np.expand_dims(label, axis=0)

                    # input image
                    input_img = np.squeeze(imread(fn + '.ome.tif'))
                    if len(input_img.shape) == 3:
                        # add channel dimension
                        input_img = np.expand_dims(input_img, axis=0)
                    elif len(input_img.shape) == 4:
                        # assume number of channel < number of Z, make sure channel dim comes first
                        if input_img.shape[0] > input_img.shape[1]:
                            input_img = np.transpose(input_img, (1, 0, 2, 3))

                    # cmap tensor
                    costmap = np.squeeze(imread(fn + '_CM.ome.tif'))

                    # output
                    outputs = model_inference(model, input_img,
                                              model.final_activation,
                                              args_inference)

                    assert len(
                        validation_config['OutputCh']) // 2 == len(outputs)

                    for vi in range(len(outputs)):
                        if label.shape[
                                0] == 1:  # the same label for all output
                            validation_loss[vi] += compute_iou(
                                outputs[vi][0, :, :, :] > 0.5,
                                label[0, :, :, :] ==
                                validation_config['OutputCh'][2 * vi + 1],
                                costmap)
                        else:
                            validation_loss[vi] += compute_iou(
                                outputs[vi][0, :, :, :] > 0.5,
                                label[vi, :, :, :] ==
                                validation_config['OutputCh'][2 * vi + 1],
                                costmap)

                average_validation_loss = validation_loss / len(
                    valid_filenames)
                print(
                    f'Epoch: {num_epoch}, Training Loss: {average_training_loss}, Validation loss: {average_validation_loss}'
                )
            else:
                print(
                    f'Epoch: {num_epoch}, Training Loss: {average_training_loss}'
                )

            if num_epoch % config['save_every_n_epoch'] == 0:
                save_checkpoint(
                    {
                        'epoch': num_epoch,
                        'num_iterations': num_iterations,
                        'model_state_dict': model.state_dict(),
                        #'best_val_score': self.best_val_score,
                        'optimizer_state_dict': optimizer.state_dict(),
                        'device': str(self.device),
                    },
                    checkpoint_dir=config['checkpoint_dir'],
                    logger=self.logger)
            num_epoch += 1
    def __init__(self, filenames, num_patch, size_in, size_out):

        self.img = []
        self.gt = []
        self.cmap = []

        padding = [(x - y) // 2 for x, y in zip(size_in, size_out)]
        total_in_count = size_in[0] * size_in[1] * size_in[2]
        total_out_count = size_out[0] * size_out[1] * size_out[2]

        num_data = len(filenames)
        shuffle(filenames)
        num_patch_per_img = np.zeros((num_data, ), dtype=int)
        if num_data >= num_patch:
            # all one
            num_patch_per_img[:num_patch] = 1
        else:
            basic_num = num_patch // num_data
            # assign each image the same number of patches to extract
            num_patch_per_img[:] = basic_num

            # assign one more patch to the first few images to achieve the total patch number
            num_patch_per_img[:(num_patch -
                                basic_num * num_data)] = num_patch_per_img[:(
                                    num_patch - basic_num * num_data)] + 1

        for img_idx, fn in enumerate(filenames):

            label = np.squeeze(imread(fn + '_GT.ome.tif'))
            label = np.expand_dims(label, axis=0)

            input_img = np.squeeze(imread(fn + '.ome.tif'))
            if len(input_img.shape) == 3:
                # add channel dimension
                input_img = np.expand_dims(input_img, axis=0)
            elif len(input_img.shape) == 4:
                # assume number of channel < number of Z, make sure channel dim comes first
                if input_img.shape[0] > input_img.shape[1]:
                    input_img = np.transpose(input_img, (1, 0, 2, 3))

            costmap = np.squeeze(imread(fn + '_CM.ome.tif'))

            img_pad0 = np.pad(input_img,
                              ((0, 0), (0, 0), (padding[1], padding[1]),
                               (padding[2], padding[2])), 'symmetric')
            raw = np.pad(img_pad0,
                         ((0, 0), (padding[0], padding[0]), (0, 0), (0, 0)),
                         'constant')

            new_patch_num = 0

            while new_patch_num < num_patch_per_img[img_idx]:

                pz = random.randint(0, label.shape[1] - size_out[0])
                py = random.randint(0, label.shape[2] - size_out[1])
                px = random.randint(0, label.shape[3] - size_out[2])

                ## check if this is a good crop
                ref_patch_cmap = costmap[pz:pz + size_out[0],
                                         py:py + size_out[1],
                                         px:px + size_out[2]]

                # confirmed good crop
                (self.img).append(raw[:, pz:pz + size_in[0],
                                      py:py + size_in[1], px:px + size_in[2]])
                (self.gt).append(label[:, pz:pz + size_out[0],
                                       py:py + size_out[1],
                                       px:px + size_out[2]])
                (self.cmap).append(ref_patch_cmap)

                new_patch_num += 1
from fov_processing_pipeline import data
import aicsimageio
from aicsimageio import writers
import numpy as np

cell_data_parent, fov_data_parent = data.get_data()

save_dir = "./fov_processing_pipeline/tests/resources"

cell_data = cell_data_parent.iloc[0:2]

im_columns = [column for column in cell_data.columns if "ReadPath" in column]

im = aicsimageio.imread(cell_data["MembraneSegmentationReadPath"][0])

region = np.any(
    np.stack([im == cell_data["CellId"][0], im == cell_data["CellId"][1]], 0),
    0)

region_bounds = np.where(region)
region_bounds = [[np.min(region_bound),
                  np.max(region_bound)] for region_bound in region_bounds]

im_save_paths = [
    "{}/{}.tiff".format(save_dir, im_column) for im_column in im_columns
]

for im_column, im_save_path in zip(im_columns, im_save_paths):
    im = aicsimageio.imread(cell_data[im_column][0])

    region_slice = tuple(
예제 #16
0
def computeMetricsDict(selected_cell, cell_metrics_dir, AB_mode,
                       num_angular_compartments):
    """
    Method to compute and save cell metrics dictionary for 3D input cell image

    Parameters
    ----------
    selected_cell : Pandas Series object with information about the selected cell
    cell_metrics_dir : Path to folder where the computed cell metrics dictionary
        should be stored
    AB_mode : str
        "quadrants" if AB compartments should split the cell into quadrants,
        "hemispheres" if AB compartments should split the cell into halves.
    num_angular_compartments : int
        The number of equal-size angles the cell should be split into for the
        angular compartment analysis.

    Returns
    -------
    pfile: Path
        Path to the pickle file of the computed cell metrics dictionary
    """

    # image file
    file = selected_cell["Path"]
    cellid = selected_cell["CellId"]

    # read in image file
    try:
        im = np.squeeze(imread(file))  # provided absolute path
    except FileNotFoundError:
        try:
            rel_path = ROOT_DIR + file  # provided relative path
            im = np.squeeze(imread(rel_path))
        except FileNotFoundError:
            raise

    # additional image information
    pixelScaleX = selected_cell["PixelScaleX"]
    pixelScaleY = selected_cell["PixelScaleY"]
    pixelScaleZ = selected_cell["PixelScaleZ"]
    vol_scale_factor = pixelScaleX * pixelScaleY * pixelScaleZ
    # pixel scale factors stored in (z,y,x) order
    scale_factors = np.array([pixelScaleZ, pixelScaleY, pixelScaleX])

    # get the channel indices
    ch_dna = selected_cell["ch_dna"]
    ch_memb = selected_cell["ch_memb"]
    ch_struct = selected_cell["ch_struct"]
    ch_seg_nuc = selected_cell["ch_seg_nuc"]
    ch_seg_cell = selected_cell["ch_seg_cell"]

    channel_indices = {
        "ch_dna": ch_dna,
        "ch_memb": ch_memb,
        "ch_struct": ch_struct,
        "ch_seg_nuc": ch_seg_nuc,
        "ch_seg_cell": ch_seg_cell,
    }

    # Get the segmentation channels
    (seg_dna, seg_mem, seg_gfp, dna, mem,
     gfp) = applySegmentationMasks(im, channel_indices)

    masked_channels = {
        "seg_dna": seg_dna,
        "seg_mem": seg_mem,
        "seg_gfp": seg_gfp,
        "dna": dna,
        "mem": mem,
        "gfp": gfp,
    }

    # compute z metrics
    (
        bot_of_cell,
        bot_of_nucleus,
        centroid_of_nucleus,
        top_of_nucleus,
        top_of_cell,
    ) = findVerticalCutoffs(im, masked_channels)

    z_metrics = {
        "bot_of_cell": bot_of_cell,
        "bot_of_nucleus": bot_of_nucleus,
        "centroid_of_nucleus": centroid_of_nucleus,
        "top_of_nucleus": top_of_nucleus,
        "top_of_cell": top_of_cell,
    }

    # compute fold change metrics
    (AB_fold_changes, AB_cyto_vol,
     AB_gfp_intensities) = findFoldChange_AB(masked_channels,
                                             z_metrics,
                                             vol_scale_factor,
                                             mode=AB_mode,
                                             silent=True)
    (
        Ang_fold_changes,
        Ang_cyto_vol,
        Ang_gfp_intensities,
    ) = findFoldChange_Angular(
        masked_channels,
        z_metrics,
        vol_scale_factor,
        num_sections=num_angular_compartments,
        silent=True,
    )

    # compute (nx4) voxel matrix
    voxel_matrix = compute_voxel_matrix(scale_factors, centroid_of_nucleus,
                                        masked_channels)

    # store metrics
    metric = {
        "structure": selected_cell["Structure"],
        "vol_cell": np.sum(seg_mem > 0) * vol_scale_factor,
        "height_cell": (top_of_cell - bot_of_cell) * pixelScaleZ,
        "vol_nucleus": np.sum(seg_dna > 0) * vol_scale_factor,
        "height_nucleus": ((top_of_nucleus - bot_of_nucleus) * pixelScaleZ),
        "min_z_dna": bot_of_nucleus,
        "max_z_dna": top_of_nucleus,
        "min_z_cell": bot_of_cell,
        "max_z_cell": top_of_cell,
        "nuclear_centroid": centroid_of_nucleus,
        "total_dna_intensity": np.sum(dna),
        "total_mem_intensity": np.sum(mem),
        "total_gfp_intensity": np.sum(gfp),
        "AB_mode": AB_mode,
        "AB_fold_changes": AB_fold_changes,
        "AB_cyto_vol": AB_cyto_vol,
        "AB_gfp_intensities": AB_gfp_intensities,
        "num_angular_compartments": num_angular_compartments,
        "Ang_fold_changes": Ang_fold_changes,
        "Ang_cyto_vol": Ang_cyto_vol,
        "Ang_gfp_intensities": Ang_gfp_intensities,
        "x_dim": seg_dna.shape[2],
        "y_dim": seg_dna.shape[1],
        "z_dim": seg_dna.shape[0],
        "scale_factors": scale_factors,
        "voxel_matrix": voxel_matrix,
        "channels": masked_channels,
    }

    # save metric
    pfile = cell_metrics_dir / f"cell_{cellid}.pickle"
    if pfile.is_file():
        pfile.unlink()
    with open(pfile, "wb") as f:
        pickle.dump(metric, f)

    return pfile
예제 #17
0
def single_cell_gen_one_fov(
    row_index: int,
    row: pd.Series,
    single_cell_dir: Path,
    per_fov_dir: Path,
    overwrite: bool = False,
) -> List:
    ########################################
    # parameters
    ########################################
    # Don't use dask for image reading
    aicsimageio.use_dask(False)
    standard_res_qcb = 0.108

    print(f"ready to process FOV: {row.FOVId}")

    ########################################
    # check if results already exist
    ########################################
    this_fov_path = per_fov_dir / Path(str(row.FOVId))
    tag_file = this_fov_path / "done.txt"
    bad_tag_file = this_fov_path / "bad.txt"
    single_fov_csv = this_fov_path / "fov_meta.csv"
    cells_in_fov_csv = this_fov_path / "cell_meta.csv"
    if this_fov_path.exists():
        if overwrite:
            rmtree(this_fov_path)
            os.mkdir(this_fov_path)
        else:
            if tag_file.exists():
                # this fov has been fully processed, simply return
                # the path to fov csv and cells csv
                return [single_fov_csv, cells_in_fov_csv]
            else:
                if bad_tag_file.exists():
                    # this fov is known to be a bad one, no need to re-run
                    return [
                        row.FOVId, False, "bad FOV, check text file for detail"
                    ]
                else:
                    # this fov has only been partially processed, wipe out
                    rmtree(this_fov_path)
                    os.mkdir(this_fov_path)
    else:
        os.mkdir(this_fov_path)

    ########################################
    # load image and segmentation
    ########################################
    """
    if row.AlignedImageReadPath is None:
        raw_fn = row.SourceReadPath
    else:
        raw_fn = row.AlignedImageReadPath
    """
    # SourceReadPath should be always available
    raw_fn = row.SourceReadPath

    # verify filepaths
    if not (os.path.exists(raw_fn)
            and os.path.exists(row.MembraneSegmentationReadPath)
            and os.path.exists(row.StructureSegmentationReadPath)):
        # fail
        return [row.FOVId, True, "missing segmentation or raw files"]

    raw_reader = AICSImage(raw_fn)
    if raw_reader.shape[0] > 1:  # multi-scene
        return [row.FOVId, False, "multi scene image"]

    try:
        # get the raw image and split into different channels
        raw_data = np.squeeze(raw_reader.data)
        raw_mem0 = raw_data[int(row.ChannelNumber638), :, :, :]
        raw_nuc0 = raw_data[int(row.ChannelNumber405), :, :, :]
        raw_struct0 = raw_data[int(row.ChannelNumberStruct), :, :, :]

        assert row.MembraneSegmentationReadPath == row.NucleusSegmentationReadPath
        seg_reader = AICSImage(row.MembraneSegmentationReadPath)
        nuc_seg_whole = seg_reader.get_image_data("ZYX", S=0, T=0, C=0)
        mem_seg_whole = seg_reader.get_image_data("ZYX", S=0, T=0, C=1)

        assert (mem_seg_whole.shape[0] == raw_mem0.shape[0]
                and mem_seg_whole.shape[1] == raw_mem0.shape[1]
                and mem_seg_whole.shape[2]
                == raw_mem0.shape[2]), "raw and seg dim mismatch"

        assert ((not np.any(raw_mem0 < 1)) and (not np.any(raw_nuc0 < 1))
                and (not np.any(raw_struct0 < 1))
                ), "one z frame is blank, ignore this FOV"

        # get structure segmentation
        struct_seg_whole = np.squeeze(imread(
            row.StructureSegmentationReadPath))
        print(f"Segmentation load successfully: {row.FOVId}")
    except (Exception, AssertionError) as e:
        return [row.FOVId, True, e]

    # make a copy to be used for calculating true edge cell labels
    mem_seg_whole_copy = mem_seg_whole.copy()

    #########################################
    # run single cell qc in this fov
    #########################################
    ######################
    min_mem_size = 70000
    min_nuc_size = 10000
    ######################

    # flag for any segmented object in this FOV removed as bad cells
    full_fov_pass = 1

    # double check big failure, quick reject
    if mem_seg_whole.max() <= 3 or nuc_seg_whole.max() <= 3:
        # bad images, but not bug, use "False"
        with open(bad_tag_file, "w") as f:
            f.write("very few cells segmented")
        return [row.FOVId, False, "very few cells segmented"]

    # prune the results (remove cells touching image boundary)
    boundary_mask = np.zeros_like(mem_seg_whole)
    boundary_mask[:, :3, :] = 1
    boundary_mask[:, -3:, :] = 1
    boundary_mask[:, :, :3] = 1
    boundary_mask[:, :, -3:] = 1
    bd_idx = list(np.unique(mem_seg_whole[boundary_mask > 0]))

    # maintain a valid cell list, initialize with all cells minus
    # cells touching the image boundary, and minus cells with
    # no record in labkey (e.g., manually removed based on user's feedback)
    all_cell_index_list = list(np.unique(mem_seg_whole[mem_seg_whole > 0]))
    full_set_with_no_boundary = set(all_cell_index_list) - set(bd_idx)
    set_not_in_labkey = full_set_with_no_boundary - set(
        row.index_to_id_dict[0].keys())
    valid_cell = list(full_set_with_no_boundary - set_not_in_labkey)

    # single cell QC
    valid_cell_0 = valid_cell.copy()
    for list_idx, this_cell_index in enumerate(valid_cell):
        single_mem = mem_seg_whole == this_cell_index
        single_nuc = nuc_seg_whole == this_cell_index

        # remove too small cells from valid cell list
        if (np.count_nonzero(single_mem) < min_mem_size
                or np.count_nonzero(single_nuc) < min_nuc_size):
            valid_cell_0.remove(this_cell_index)
            full_fov_pass = 0

            # no need to go to next QC criteria
            continue

        # make sure the cell is not leaking to the bottom or top
        z_range_single = np.where(np.any(single_mem, axis=(1, 2)))
        single_min_z = z_range_single[0][0]
        single_max_z = z_range_single[0][-1]

        if single_min_z == 0 or single_max_z >= single_mem.shape[0] - 1:
            valid_cell_0.remove(this_cell_index)
            full_fov_pass = 0
    valid_cell = valid_cell_0.copy()

    # if only one cell left or no cell left, just throw it away
    if len(valid_cell_0) < 2:
        with open(bad_tag_file, "w") as f:
            f.write("very few cells left after single cell QC")
        return [row.FOVId, False, "very few cells left after single cell QC"]

    print(f"single cell QC done in FOV: {row.FOVId}")

    #################################################################
    # resize the image into isotropic dimension
    #################################################################
    raw_nuc = resize(
        raw_nuc0,
        (
            row.PixelScaleZ / standard_res_qcb,
            row.PixelScaleY / standard_res_qcb,
            row.PixelScaleX / standard_res_qcb,
        ),
        method="bilinear",
    ).astype(np.uint16)

    raw_mem = resize(
        raw_mem0,
        (
            row.PixelScaleZ / standard_res_qcb,
            row.PixelScaleY / standard_res_qcb,
            row.PixelScaleX / standard_res_qcb,
        ),
        method="bilinear",
    ).astype(np.uint16)

    raw_str = resize(
        raw_struct0,
        (
            row.PixelScaleZ / standard_res_qcb,
            row.PixelScaleY / standard_res_qcb,
            row.PixelScaleX / standard_res_qcb,
        ),
        method="bilinear",
    ).astype(np.uint16)

    mem_seg_whole = resize_to(mem_seg_whole, raw_mem.shape, method="nearest")
    nuc_seg_whole = resize_to(nuc_seg_whole, raw_nuc.shape, method="nearest")
    struct_seg_whole = resize_to(struct_seg_whole,
                                 raw_str.shape,
                                 method="nearest")

    #################################################################
    # calculate fov related info
    #################################################################
    index_to_cellid_map = dict()
    cellid_to_index_map = dict()
    for list_idx, this_cell_index in enumerate(valid_cell):
        # this is always valid since indices not in index_to_id_dict.keys()
        # have been removed
        index_to_cellid_map[this_cell_index] = row.index_to_id_dict[0][
            this_cell_index]
    for index_dict, cellid_dict in index_to_cellid_map.items():
        cellid_to_index_map[cellid_dict] = index_dict

    # compute center of mass
    index_to_centroid_map = dict()
    center_list = center_of_mass(nuc_seg_whole > 0, nuc_seg_whole, valid_cell)
    for list_idx, this_cell_index in enumerate(valid_cell):
        index_to_centroid_map[this_cell_index] = center_list[list_idx]

    # compute whole stack min/max z
    mem_seg_whole_valid = np.zeros_like(mem_seg_whole)
    for list_idx, this_cell_index in enumerate(valid_cell):
        mem_seg_whole_valid[mem_seg_whole == this_cell_index] = this_cell_index
    z_range_whole = np.where(np.any(mem_seg_whole_valid, axis=(1, 2)))
    stack_min_z = z_range_whole[0][0]
    stack_max_z = z_range_whole[0][-1]

    # find true edge cells, the cells in the outer layer of a colony
    true_edge_cells = []
    edge_fov_flag = False
    if row.ColonyPosition is None:
        # parse colony position from file name
        reg = re.compile("(-|_)((\d)?)(e)((\d)?)(-|_)")  # noqa: W605
        if reg.search(os.path.basename(raw_fn)):
            edge_fov_flag = True
    else:
        if row.ColonyPosition.lower() == "edge":
            edge_fov_flag = True

    if edge_fov_flag:
        true_edge_cells = find_true_edge_cells(mem_seg_whole_copy)

    #################################################################
    # calculate a dictionary to store FOV info
    #################################################################
    fov_meta = {
        "FOVId": row.FOVId,
        "structure_name": row.Gene,
        "position": row.ColonyPosition,
        "raw_fn": raw_fn,
        "str_filename": row.StructureSegmentationReadPath,
        "mem_seg_fn": row.MembraneSegmentationReadPath,
        "nuc_seg_fn": row.NucleusSegmentationReadPath,
        "index_to_id_dict": index_to_cellid_map,
        "id_to_index_dict": cellid_to_index_map,
        "xy_res": row.PixelScaleX,
        "z_res": row.PixelScaleZ,
        "stack_min_z": stack_min_z,
        "stack_max_z": stack_max_z,
        "scope_id": row.InstrumentId,
        "well_id": row.WellId,
        "well_name": row.WellName,
        "plateId": row.PlateId,
        "passage": row.Passage,
        "image_size": list(raw_mem.shape),
        "fov_seg_pass": full_fov_pass,
        "imaging_mode": row.ImagingMode,
    }

    df_fov_meta = pd.DataFrame([fov_meta])
    df_fov_meta.to_csv(single_fov_csv, header=True, index=False)
    print(f"FOV info is done: {row.FOVId}, ready to loop through cells")

    # loop through all valid cells in this fov
    cell_meta = []
    for list_idx, this_cell_index in enumerate(valid_cell):
        nuc_seg = nuc_seg_whole == this_cell_index
        mem_seg = mem_seg_whole == this_cell_index

        ###########################
        # implement nbr info
        ###########################
        single_mem_dilate = dilation(mem_seg, selem=ball(3))
        whole_template = mem_seg_whole.copy()
        whole_template[mem_seg] = 0
        this_cell_nbr_candiate_list = list(
            np.unique(whole_template[single_mem_dilate > 0]))
        this_cell_nbr_dist_3d = []
        this_cell_nbr_dist_2d = []
        this_cell_nbr_overlap_area = []
        this_cell_nbr_complete = 1

        for nbr_index, nbr_id in enumerate(this_cell_nbr_candiate_list):
            if nbr_id == 0 or nbr_id == this_cell_index:
                continue
            elif not (nbr_id in valid_cell):
                this_cell_nbr_complete = 0
                continue

            # only do calculation for valid neighbors
            nuc_dist_3d = euc_dist_3d(index_to_centroid_map[nbr_id],
                                      index_to_centroid_map[this_cell_index])
            nuc_dist_2d = euc_dist_2d(index_to_centroid_map[nbr_id],
                                      index_to_centroid_map[this_cell_index])
            overlap = overlap_area(mem_seg, mem_seg_whole == nbr_id)
            this_cell_nbr_dist_3d.append(
                (index_to_cellid_map[nbr_id], nuc_dist_3d))
            this_cell_nbr_dist_2d.append(
                (index_to_cellid_map[nbr_id], nuc_dist_2d))
            this_cell_nbr_overlap_area.append(
                (index_to_cellid_map[nbr_id], overlap))
        if len(this_cell_nbr_dist_3d) == 0:
            this_cell_nbr_complete = 0

        # get cell id
        cell_id = index_to_cellid_map[this_cell_index]

        # make the path for saving single cell crop result
        thiscell_path = single_cell_dir / Path(str(cell_id))
        if os.path.isdir(thiscell_path):
            rmtree(thiscell_path)
        os.mkdir(thiscell_path)

        ###############################
        # compute and  generate crop
        ###############################
        # determine crop roi
        z_range = np.where(np.any(mem_seg, axis=(1, 2)))
        y_range = np.where(np.any(mem_seg, axis=(0, 2)))
        x_range = np.where(np.any(mem_seg, axis=(0, 1)))
        z_range = z_range[0]
        y_range = y_range[0]
        x_range = x_range[0]

        # define a large ROI based on bounding box
        roi = [
            max(z_range[0] - 10, 0),
            min(z_range[-1] + 12, mem_seg.shape[0]),
            max(y_range[0] - 40, 0),
            min(y_range[-1] + 40, mem_seg.shape[1]),
            max(x_range[0] - 40, 0),
            min(x_range[-1] + 40, mem_seg.shape[2]),
        ]

        # roof augmentation
        mem_nearly_top_z = int(z_range[0] +
                               round(0.75 * (z_range[-1] - z_range[0] + 1)))
        mem_top_mask = np.zeros(mem_seg.shape, dtype=np.byte)
        mem_top_mask[
            mem_nearly_top_z:, :, :] = mem_seg[mem_nearly_top_z:, :, :] > 0
        mem_top_mask_dilate = dilation(mem_top_mask > 0,
                                       selem=np.ones((21, 1, 1),
                                                     dtype=np.byte))
        mem_top_mask_dilate[:mem_nearly_top_z, :, :] = (
            mem_seg[:mem_nearly_top_z, :, :] > 0)

        # crop mem/nuc seg
        mem_seg = mem_seg.astype(np.uint8)
        mem_seg = mem_seg[roi[0]:roi[1], roi[2]:roi[3], roi[4]:roi[5]]
        mem_seg[mem_seg > 0] = 255

        nuc_seg = nuc_seg.astype(np.uint8)
        nuc_seg = nuc_seg[roi[0]:roi[1], roi[2]:roi[3], roi[4]:roi[5]]
        nuc_seg[nuc_seg > 0] = 255

        mem_top_mask_dilate = mem_top_mask_dilate.astype(np.uint8)
        mem_top_mask_dilate = mem_top_mask_dilate[roi[0]:roi[1], roi[2]:roi[3],
                                                  roi[4]:roi[5]]
        mem_top_mask_dilate[mem_top_mask_dilate > 0] = 255

        # crop str seg (without roof augmentation)
        str_seg_crop = struct_seg_whole[roi[0]:roi[1], roi[2]:roi[3],
                                        roi[4]:roi[5]].astype(np.uint8)
        str_seg_crop[mem_seg < 1] = 0
        str_seg_crop[str_seg_crop > 0] = 255

        # crop str seg (with roof augmentation)
        str_seg_crop_roof = struct_seg_whole[roi[0]:roi[1], roi[2]:roi[3],
                                             roi[4]:roi[5]].astype(np.uint8)
        str_seg_crop_roof[mem_top_mask_dilate < 1] = 0
        str_seg_crop_roof[str_seg_crop_roof > 0] = 255

        # merge and save the cropped segmentation
        all_seg = np.stack(
            [
                nuc_seg, mem_seg, mem_top_mask_dilate, str_seg_crop,
                str_seg_crop_roof
            ],
            axis=0,
        )
        all_seg = np.expand_dims(np.transpose(all_seg, (1, 0, 2, 3)), axis=0)

        crop_seg_path = thiscell_path / "segmentation.ome.tif"
        writer = save_tif.OmeTiffWriter(crop_seg_path, overwrite_file=True)
        writer.save(all_seg)

        # crop raw image
        raw_nuc_thiscell = raw_nuc[roi[0]:roi[1], roi[2]:roi[3], roi[4]:roi[5]]
        raw_mem_thiscell = raw_mem[roi[0]:roi[1], roi[2]:roi[3], roi[4]:roi[5]]
        raw_str_thiscell = raw_str[roi[0]:roi[1], roi[2]:roi[3], roi[4]:roi[5]]
        crop_raw_merged = np.expand_dims(
            np.stack((raw_nuc_thiscell, raw_mem_thiscell, raw_str_thiscell),
                     axis=1),
            axis=0,
        )

        crop_raw_path = thiscell_path / "raw.ome.tif"
        writer = save_tif.OmeTiffWriter(crop_raw_path, overwrite_file=True)
        writer.save(crop_raw_merged)

        ############################
        # check for pair
        ############################
        dist_cutoff = 85
        dna_label, dna_num = label(nuc_seg > 0, return_num=True)

        if dna_num < 2:
            # certainly not pair if there is only one cc
            this_cell_is_pair = 0
        else:
            stats = regionprops(dna_label)
            region_size = [stats[i]["area"] for i in range(dna_num)]
            large_two = sorted(range(len(region_size)),
                               key=lambda sub: region_size[sub])[-2:]
            dis = euc_dist_3d(stats[large_two[0]]["centroid"],
                              stats[large_two[1]]["centroid"])
            if dis > dist_cutoff:
                sz1 = stats[large_two[0]]["area"]
                sz2 = stats[large_two[1]]["area"]
                if sz1 / sz2 > 1.5625 or sz1 / sz2 < 0.64:
                    # the two parts do not have comparable sizes
                    this_cell_is_pair = 0
                else:
                    this_cell_is_pair = 1
            else:
                # not far apart enough
                this_cell_is_pair = 0

        name_dict = {
            "crop_raw": ["dna", "membrane", "structure"],
            "crop_seg": [
                "dna_segmentation",
                "membrane_segmentation",
                "membrane_segmentation_roof",
                "struct_segmentation",
                "struct_segmentation_roof",
            ],
        }

        # out for mitotic classifier
        img_out = build_one_cell_for_classification(crop_raw_merged, mem_seg)
        out_fn = thiscell_path / "for_mito_prediction.npy"
        if out_fn.exists():
            os.remove(out_fn)
        np.save(out_fn, img_out)

        #########################################
        if len(true_edge_cells) > 0 and (this_cell_index in true_edge_cells):
            this_is_edge_cell = 1
        else:
            this_is_edge_cell = 0

        # write qcb cell meta
        cell_meta.append({
            "CellId": cell_id,
            "structure_name": row.Gene,
            "pair": this_cell_is_pair,
            "this_cell_nbr_complete": this_cell_nbr_complete,
            "this_cell_nbr_dist_3d": this_cell_nbr_dist_3d,
            "this_cell_nbr_dist_2d": this_cell_nbr_dist_2d,
            "this_cell_nbr_overlap_area": this_cell_nbr_overlap_area,
            "roi": roi,
            "crop_raw": crop_raw_path,
            "crop_seg": crop_seg_path,
            "name_dict": name_dict,
            "scale_micron": [0.108333, 0.108333, 0.108333],
            "edge_flag": this_is_edge_cell,
            "fov_id": row.FOVId,
            "fov_path": raw_fn,
            "fov_seg_path": row.MembraneSegmentationReadPath,
            "struct_seg_path": row.StructureSegmentationReadPath,
            "this_cell_index": this_cell_index,
            "stack_min_z": stack_min_z,
            "stack_max_z": stack_max_z,
            "image_size": list(raw_mem.shape),
            "plateId": row.PlateId,
            "position": row.ColonyPosition,
            "scope_id": row.InstrumentId,
            "well_id": row.WellId,
            "well_name": row.WellName,
            "passage": row.Passage,
            "imaging_mode": row.ImagingMode,
        })
        print(f"Cell {cell_id} is done")

    df_cell_meta = pd.DataFrame(cell_meta)
    df_cell_meta.to_csv(cells_in_fov_csv, header=True, index=False)

    #  single cell generation succeeds in this FOV
    print(f"FOV {row.FOVId} is done")
    with open(tag_file, "w") as f:
        f.write("all cells completed")
    return [single_fov_csv, cells_in_fov_csv]