def execute(self, args):

        if not args.data_type.startswith('.'):
            args.data_type = '.' + args.data_type

        filenames = glob(args.raw_path + os.sep + '*' + args.data_type)
        filenames.sort()

        existing_files = glob(args.train_path + os.sep + 'img_*.ome.tif')
        print(len(existing_files))

        training_data_count = len(existing_files) // 3
        for _, fn in enumerate(filenames):

            training_data_count += 1

            # load raw
            reader = AICSImage(fn)
            struct_img = reader.get_image_data("CZYX",
                                               S=0,
                                               T=0,
                                               C=[args.input_channel
                                                  ]).astype(np.float32)
            struct_img = input_normalization(img, args)

            # load seg
            seg_fn = args.seg_path + os.sep + os.path.basename(
                fn)[:-1 * len(args.data_type)] + '_struct_segmentation.tiff'
            reader = AICSImage(seg_fn)
            seg = reader.get_image_data("ZYX", S=0, T=0, C=0) > 0.01
            seg = seg.astype(np.uint8)
            seg[seg > 0] = 1

            # excluding mask
            cmap = np.ones(seg.shape, dtype=np.float32)
            mask_fn = args.mask_path + os.sep + os.path.basename(
                fn)[:-1 * len(args.data_type)] + '_mask.tiff'
            if os.path.isfile(mask_fn):
                reader = AICSImage(mask_fn)
                mask = reader.get_image_data("ZYX", S=0, T=0, C=0)
                cmap[mask == 0] = 0

            with OmeTiffWriter(args.train_path + os.sep + 'img_' +
                               f'{training_data_count:03}' +
                               '.ome.tif') as writer:
                writer.save(struct_img)

            with OmeTiffWriter(args.train_path + os.sep + 'img_' +
                               f'{training_data_count:03}' +
                               '_GT.ome.tif') as writer:
                writer.save(seg)

            with OmeTiffWriter(args.train_path + os.sep + 'img_' +
                               f'{training_data_count:03}' +
                               '_CM.ome.tif') as writer:
                writer.save(cmap)
Example #2
0
    def execute(self, args):

        global draw_mask, ignore_img
        # part 1: do sorting
        df = pd.read_csv(args.csv_name, index_col=False)

        for index, row in df.iterrows():

            if not np.isnan(row['score']) and (row['score'] == 1
                                               or row['score'] == 0):
                continue

            reader = AICSImage(row['raw'])
            im_full = reader.data
            struct_img = im_full[0, args.input_channel, :, :, :]
            raw_img = (struct_img - struct_img.min() +
                       1e-8) / (struct_img.max() - struct_img.min() + 1e-8)
            raw_img = 255 * raw_img
            raw_img = raw_img.astype(np.uint8)

            reader_seg1 = AICSImage(row['seg1'])
            im_seg1_full = reader_seg1.data
            assert im_seg1_full.shape[0] == 1
            assert im_seg1_full.shape[1] == 1 or im_seg1_full.shape[2] == 1
            if im_seg1_full.shape[1] == 1:
                seg1 = im_seg1_full[0, 0, :, :, :] > 0.1
            else:
                seg1 = im_seg1_full[0, :, 0, :, :] > 0.1

            reader_seg2 = AICSImage(row['seg2'])
            im_seg2_full = reader_seg2.data
            assert im_seg2_full.shape[0] == 1
            assert im_seg2_full.shape[1] == 1 or im_seg2_full.shape[2] == 1
            if im_seg2_full.shape[1] == 1:
                seg2 = im_seg2_full[0, 0, :, :, :] > 0
            else:
                seg2 = im_seg2_full[0, :, 0, :, :] > 0

            create_merge_mask(raw_img, seg1.astype(np.uint8),
                              seg2.astype(np.uint8), 'merging_mask')

            if ignore_img:
                df['score'].iloc[index] = 0
            else:
                df['score'].iloc[index] = 1

                mask_fn = args.mask_path + os.sep + os.path.basename(
                    row['raw'])[:-5] + '_mask.tiff'
                crop_mask = np.zeros(seg1.shape, dtype=np.uint8)
                for zz in range(crop_mask.shape[0]):
                    crop_mask[zz, :, :] = draw_mask[:crop_mask.shape[1], :
                                                    crop_mask.shape[2]]

                crop_mask = crop_mask.astype(np.uint8)
                crop_mask[crop_mask > 0] = 255
                writer = omeTifWriter.OmeTifWriter(mask_fn)
                writer.save(crop_mask)
                df['merging_mask'].iloc[index] = mask_fn

                need_mask = input(
                    'Do you need to add an excluding mask for this image, enter y or n:  '
                )
                if need_mask == 'y':
                    create_merge_mask(raw_img, seg1.astype(np.uint8),
                                      seg2.astype(np.uint8), 'excluding mask')

                    mask_fn = args.ex_mask_path + os.sep + os.path.basename(
                        row['raw'])[:-5] + '_mask.tiff'
                    crop_mask = np.zeros(seg1.shape, dtype=np.uint8)
                    for zz in range(crop_mask.shape[0]):
                        crop_mask[zz, :, :] = draw_mask[:crop_mask.shape[1], :
                                                        crop_mask.shape[2]]

                    crop_mask = crop_mask.astype(np.uint8)
                    crop_mask[crop_mask > 0] = 255
                    writer = omeTifWriter.OmeTifWriter(mask_fn)
                    writer.save(crop_mask)
                    df['excluding_mask'].iloc[index] = mask_fn

            df.to_csv(args.csv_name, index=False)

        #########################################
        # generate training data:
        #  (we want to do this step after "sorting"
        #  (is mainly because we want to get the sorting
        #  step as smooth as possible, even though
        #  this may waster i/o time on reloading images)
        # #######################################
        print('finish merging, start building the training data ...')
        existing_files = glob(args.train_path + os.sep + 'img_*.ome.tif')
        print(len(existing_files))

        training_data_count = len(existing_files) // 3
        for index, row in df.iterrows():
            if row['score'] == 1:
                training_data_count += 1

                # load raw image
                reader = AICSImage(row['raw'])
                img = reader.data.astype(np.float32)
                struct_img = input_normalization(
                    img[0, [args.input_channel], :, :, :], args)
                struct_img = struct_img[0, :, :, :]

                reader_seg1 = AICSImage(row['seg1'])
                im_seg1_full = reader_seg1.data
                assert im_seg1_full.shape[0] == 1
                assert im_seg1_full.shape[1] == 1 or im_seg1_full.shape[2] == 1
                if im_seg1_full.shape[1] == 1:
                    seg1 = im_seg1_full[0, 0, :, :, :] > 0.1
                else:
                    seg1 = im_seg1_full[0, :, 0, :, :] > 0.1

                reader_seg2 = AICSImage(row['seg2'])
                im_seg2_full = reader_seg2.data
                assert im_seg2_full.shape[0] == 1
                assert im_seg2_full.shape[1] == 1 or im_seg2_full.shape[2] == 1
                if im_seg2_full.shape[1] == 1:
                    seg2 = im_seg2_full[0, 0, :, :, :] > 0
                else:
                    seg2 = im_seg2_full[0, :, 0, :, :] > 0

                if os.path.isfile(str(row['merging_mask'])):
                    reader = AICSImage(row['merging_mask'])
                    img = reader.data
                    assert img.shape[0] == 1 and img.shape[1] == 1
                    mask = img[0, 0, :, :, :] > 0
                    seg1[mask > 0] = 0
                    seg2[mask == 0] = 0
                    seg1 = np.logical_or(seg1, seg2)

                cmap = np.ones(seg1.shape, dtype=np.float32)
                if os.path.isfile(str(row['excluding_mask'])):
                    reader = AICSImage(row['excluding_mask'])
                    img = reader.data
                    assert img.shape[0] == 1 and img.shape[1] == 1
                    ex_mask = img[0, 0, :, :, :] > 0
                    cmap[ex_mask > 0] = 0

                writer = omeTifWriter.OmeTifWriter(
                    args.train_path + os.sep + 'img_' +
                    f'{training_data_count:03}' + '.ome.tif')
                writer.save(struct_img)

                seg1 = seg1.astype(np.uint8)
                seg1[seg1 > 0] = 1
                writer = omeTifWriter.OmeTifWriter(
                    args.train_path + os.sep + 'img_' +
                    f'{training_data_count:03}' + '_GT.ome.tif')
                writer.save(seg1)

                writer = omeTifWriter.OmeTifWriter(
                    args.train_path + os.sep + 'img_' +
                    f'{training_data_count:03}' + '_CM.ome.tif')
                writer.save(cmap)
        print('training data is ready')
Example #3
0
    def execute(self, args):

        global draw_mask
        # part 1: do sorting
        df = pd.read_csv(args.csv_name, index_col=False)

        for index, row in df.iterrows():

            if not np.isnan(row['score']) and (row['score']==1 or row['score']==0):
                continue

            reader = AICSImage(row['raw'])
            struct_img = reader.get_image_data("ZYX", S=0, T=0, C=args.input_channel)
            struct_img[struct_img>5000] = struct_img.min()  # adjust contrast
            raw_img = (struct_img- struct_img.min() + 1e-8)/(struct_img.max() - struct_img.min() + 1e-8)
            raw_img = 255 * raw_img
            raw_img = raw_img.astype(np.uint8)

            seg = np.squeeze(imread(row['seg']))

            score = gt_sorting(raw_img, seg)
            if score == 1:
                df['score'].iloc[index]=1
                need_mask = input('Do you need to add a mask for this image, enter y or n:  ')
                if need_mask == 'y':
                    create_mask(raw_img, seg.astype(np.uint8))
                    mask_fn = args.mask_path + os.sep + os.path.basename(row['raw'])[:-5] + '_mask.tiff'
                    crop_mask = np.zeros(seg.shape, dtype=np.uint8)
                    for zz in range(crop_mask.shape[0]):
                        crop_mask[zz,:,:] = draw_mask[:crop_mask.shape[1],:crop_mask.shape[2]]

                    crop_mask = crop_mask.astype(np.uint8)
                    crop_mask[crop_mask>0]=255
                    with OmeTiffWriter(mask_fn) as writer:
                        writer.save(crop_mask)
                    df['mask'].iloc[index]=mask_fn
            else:
                df['score'].iloc[index]=0

            df.to_csv(args.csv_name, index=False)

        #########################################
        # generate training data:
        #  (we want to do this step after "sorting"
        #  (is mainly because we want to get the sorting 
        #  step as smooth as possible, even though
        #  this may waster i/o time on reloading images)
        # #######################################
        print('finish merging, start building the training data ...')

        existing_files = glob(args.train_path+os.sep+'img_*.ome.tif')
        print(len(existing_files))

        training_data_count = len(existing_files)//3
        
        for index, row in df.iterrows():
            if row['score']==1:
                training_data_count += 1

                # load raw image
                reader = AICSImage(row['raw'])
                img = reader.get_image_data("CZYX", S=0, T=0, C=[args.input_channel]).astype(np.float32)
                struct_img = input_normalization(img, args)
                struct_img= struct_img[0,:,:,:]

                # load segmentation gt
                seg = np.squeeze(imread(row['seg'])) > 0.01
                seg = seg.astype(np.uint8)
                seg[seg>0]=1

                cmap = np.ones(seg.shape, dtype=np.float32)
                if os.path.isfile(str(row['mask'])):
                    # load segmentation gt
                    mask = np.squeeze(imread(row['mask']))
                    cmap[mask>0]=0

                with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '.ome.tif') as writer:
                    writer.save(struct_img)

                with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_GT.ome.tif') as writer:
                    writer.save(seg)
                
                with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_CM.ome.tif') as writer:
                    writer.save(cmap)

        print('training data is ready')
Example #4
0
    def execute(self, args):

        if not args.data_type.startswith('.'):
            args.data_type = '.' + args.data_type

        filenames = glob(args.raw_path + os.sep + '*' + args.data_type)
        filenames.sort()

        existing_files = glob(args.train_path + os.sep + 'img_*.ome.tif')
        print(len(existing_files))

        training_data_count = len(existing_files) // 3
        for _, fn in enumerate(filenames):

            training_data_count += 1

            # load raw
            reader = AICSImage(fn)
            img = reader.data.astype(np.float32)
            assert img.shape[0] == 1
            img = img[0, :, :, :, :]
            if img.shape[0] > img.shape[1]:
                img = np.transpose(img, (1, 0, 2, 3))
            struct_img = input_normalization(
                img[[args.input_channel], :, :, :], args)

            # load seg
            seg_fn = args.seg_path + os.sep + os.path.basename(
                fn)[:-1 * len(args.data_type)] + '_struct_segmentation.tiff'
            reader = AICSImage(seg_fn)
            img = reader.data
            assert img.shape[0] == 1 and img.shape[1] == 1
            seg = img[0, 0, :, :, :] > 0
            seg = seg.astype(np.uint8)
            seg[seg > 0] = 1

            # excluding mask
            cmap = np.ones(seg.shape, dtype=np.float32)
            mask_fn = args.mask_path + os.sep + os.path.basename(
                fn)[:-1 * len(args.data_type)] + '_mask.tiff'
            if os.path.isfile(mask_fn):
                reader = AICSImage(mask_fn)
                img = reader.data
                assert img.shape[0] == 1 and img.shape[1] == 1
                mask = img[0, 0, :, :, :]
                cmap[mask == 0] = 0

            writer = omeTifWriter.OmeTifWriter(args.train_path + os.sep +
                                               'img_' +
                                               f'{training_data_count:03}' +
                                               '.ome.tif')
            writer.save(struct_img)

            writer = omeTifWriter.OmeTifWriter(args.train_path + os.sep +
                                               'img_' +
                                               f'{training_data_count:03}' +
                                               '_GT.ome.tif')
            writer.save(seg)

            writer = omeTifWriter.OmeTifWriter(args.train_path + os.sep +
                                               'img_' +
                                               f'{training_data_count:03}' +
                                               '_CM.ome.tif')
            writer.save(cmap)
Example #5
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--config', required=True)
    args = parser.parse_args()

    config = load_config(args.config)

    # declare the model
    model = build_model(config)

    # load the trained model instance
    model_path = config['model_path']
    print(f'Loading model from {model_path}...')
    load_checkpoint(model_path, model)

    # extract the parameters for preparing the input image
    args_norm = lambda: None
    args_norm.Normalization = config['Normalization']

    # extract the parameters for running the model inference
    args_inference = lambda: None
    args_inference.size_in = config['size_in']
    args_inference.size_out = config['size_out']
    args_inference.OutputCh = config['OutputCh']
    args_inference.nclass = config['nclass']

    # run
    inf_config = config['mode']
    if inf_config['name'] == 'file':
        fn = inf_config['InputFile']
        data_reader = AICSImage(fn)
        img0 = data_reader.data

        if inf_config['timelapse']:
            assert img0.shape[0] > 1

            for tt in range(img0.shape[0]):
                # Assume:  dimensions = TCZYX
                img = img0[tt, config['InputCh'], :, :, :].astype(float)
                img = input_normalization(img, args_norm)

                if len(config['ResizeRatio']) > 0:
                    img = resize(
                        img,
                        (1, config['ResizeRatio'][0], config['ResizeRatio'][1],
                         config['ResizeRatio'][2]),
                        method='cubic')
                    for ch_idx in range(img.shape[0]):
                        struct_img = img[ch_idx, :, :, :]
                        struct_img = (struct_img - struct_img.min()) / (
                            struct_img.max() - struct_img.min())
                        img[ch_idx, :, :, :] = struct_img

                # apply the model
                output_img = model_inference(model, img,
                                             model.final_activation,
                                             args_inference)

                # extract the result and write the output
                if len(config['OutputCh']) == 2:
                    writer = omeTifWriter.OmeTifWriter(
                        config['OutputDir'] + os.sep +
                        pathlib.PurePosixPath(fn).stem + '_T_' + f'{tt:03}' +
                        '_struct_segmentation.tiff')
                    out = output_img[0]
                    out = (out - out.min()) / (out.max() - out.min())
                    if len(config['ResizeRatio']) > 0:
                        out = resize(out, (1.0, 1 / config['ResizeRatio'][0],
                                           1 / config['ResizeRatio'][1],
                                           1 / config['ResizeRatio'][2]),
                                     method='cubic')
                    out = out.astype(np.float32)
                    if config['Threshold'] > 0:
                        out = out > config['Threshold']
                        out = out.astype(np.uint8)
                        out[out > 0] = 255
                    writer.save(out)
                else:
                    for ch_idx in range(len(config['OutputCh']) // 2):
                        writer = omeTifWriter.OmeTifWriter(
                            config['OutputDir'] + os.sep +
                            pathlib.PurePosixPath(fn).stem + '_T_' +
                            f'{tt:03}' + '_seg_' +
                            str(config['OutputCh'][2 * ch_idx]) + '.tiff')
                        out = output_img[ch_idx]
                        out = (out - out.min()) / (out.max() - out.min())
                        if len(config['ResizeRatio']) > 0:
                            out = resize(out,
                                         (1.0, 1 / config['ResizeRatio'][0],
                                          1 / config['ResizeRatio'][1],
                                          1 / config['ResizeRatio'][2]),
                                         method='cubic')
                        out = out.astype(np.float32)
                        if config['Threshold'] > 0:
                            out = out > config['Threshold']
                            out = out.astype(np.uint8)
                            out[out > 0] = 255
                        writer.save(out)
        else:
            img = img0[0, :, :, :, :].astype(float)
            print(f'processing one image of size {img.shape}')
            if img.shape[1] < img.shape[0]:
                img = np.transpose(img, (1, 0, 2, 3))
            img = img[config['InputCh'], :, :, :]
            img = input_normalization(img, args_norm)

            if len(config['ResizeRatio']) > 0:
                img = resize(
                    img, (1, config['ResizeRatio'][0],
                          config['ResizeRatio'][1], config['ResizeRatio'][2]),
                    method='cubic')
                for ch_idx in range(img.shape[0]):
                    struct_img = img[
                        ch_idx, :, :, :]  # note that struct_img is only a view of img, so changes made on struct_img also affects img
                    struct_img = (struct_img - struct_img.min()) / (
                        struct_img.max() - struct_img.min())
                    img[ch_idx, :, :, :] = struct_img

            # apply the model
            output_img = model_inference(model, img, model.final_activation,
                                         args_inference)

            # extract the result and write the output
            if len(config['OutputCh']) == 2:
                out = output_img[0]
                out = (out - out.min()) / (out.max() - out.min())
                if len(config['ResizeRatio']) > 0:
                    out = resize(out, (1.0, 1 / config['ResizeRatio'][0],
                                       1 / config['ResizeRatio'][1],
                                       1 / config['ResizeRatio'][2]),
                                 method='cubic')
                out = out.astype(np.float32)
                print(out.shape)
                if config['Threshold'] > 0:
                    out = out > config['Threshold']
                    out = out.astype(np.uint8)
                    out[out > 0] = 255
                writer = omeTifWriter.OmeTifWriter(
                    config['OutputDir'] + os.sep +
                    pathlib.PurePosixPath(fn).stem +
                    '_struct_segmentation.tiff')
                writer.save(out)
            else:
                for ch_idx in range(len(config['OutputCh']) // 2):
                    out = output_img[ch_idx]
                    out = (out - out.min()) / (out.max() - out.min())
                    if len(config['ResizeRatio']) > 0:
                        out = resize(out, (1.0, 1 / config['ResizeRatio'][0],
                                           1 / config['ResizeRatio'][1],
                                           1 / config['ResizeRatio'][2]),
                                     method='cubic')
                    out = out.astype(np.float32)
                    if config['Threshold'] > 0:
                        out = out > config['Threshold']
                        out = out.astype(np.uint8)
                        out[out > 0] = 255
                    writer = omeTifWriter.OmeTifWriter(
                        config['OutputDir'] + os.sep +
                        pathlib.PurePosixPath(fn).stem + '_seg_' +
                        str(config['OutputCh'][2 * ch_idx]) + '.tiff')
                    writer.save(out)
            print(f'Image {fn} has been segmented')

    elif inf_config['name'] == 'folder':
        from glob import glob
        filenames = glob(inf_config['InputDir'] + '/*' +
                         inf_config['DataType'])
        filenames.sort()
        #print(filenames)

        for _, fn in enumerate(filenames):

            # load data
            data_reader = AICSImage(fn)
            img0 = data_reader.data
            img = img0[0, :, :, :, :].astype(float)
            if img.shape[1] < img.shape[0]:
                img = np.transpose(img, (1, 0, 2, 3))
            img = img[config['InputCh'], :, :, :]
            img = input_normalization(img, args_norm)
            #img = image_normalization(img, config['Normalization'])

            if len(config['ResizeRatio']) > 0:
                img = resize(
                    img, (1, config['ResizeRatio'][0],
                          config['ResizeRatio'][1], config['ResizeRatio'][2]),
                    method='cubic')
                for ch_idx in range(img.shape[0]):
                    struct_img = img[
                        ch_idx, :, :, :]  # note that struct_img is only a view of img, so changes made on struct_img also affects img
                    struct_img = (struct_img - struct_img.min()) / (
                        struct_img.max() - struct_img.min())
                    img[ch_idx, :, :, :] = struct_img

            # apply the model
            output_img = model_inference(model, img, model.final_activation,
                                         args_inference)

            # extract the result and write the output
            if len(config['OutputCh']) == 2:
                writer = omeTifWriter.OmeTifWriter(
                    config['OutputDir'] + os.sep +
                    pathlib.PurePosixPath(fn).stem +
                    '_struct_segmentation.tiff')
                if config['Threshold'] < 0:
                    out = output_img[0]
                    out = (out - out.min()) / (out.max() - out.min())
                    print(out.shape)
                    if len(config['ResizeRatio']) > 0:
                        out = resize(out, (1.0, 1 / config['ResizeRatio'][0],
                                           1 / config['ResizeRatio'][1],
                                           1 / config['ResizeRatio'][2]),
                                     method='cubic')
                    out = out.astype(np.float32)
                    out = (out - out.min()) / (out.max() - out.min())
                    writer.save(out)
                else:
                    out = remove_small_objects(
                        output_img[0] > config['Threshold'],
                        min_size=2,
                        connectivity=1)
                    out = out.astype(np.uint8)
                    out[out > 0] = 255
                    writer.save(out)
            else:
                for ch_idx in range(len(config['OutputCh']) // 2):
                    writer = omeTifWriter.OmeTifWriter(
                        config['OutputDir'] + os.sep +
                        pathlib.PurePosixPath(fn).stem + '_seg_' +
                        str(config['OutputCh'][2 * ch_idx]) + '.ome.tif')
                    if config['Threshold'] < 0:
                        out = output_img[ch_idx]
                        out = (out - out.min()) / (out.max() - out.min())
                        writer.save(out.astype(np.float32))
                    else:
                        out = output_img[ch_idx] > config['Threshold']
                        out = out.astype(np.uint8)
                        out[out > 0] = 255
                        writer.save(out)

            print(f'Image {fn} has been segmented')
Example #6
0
def evaluate(args, model):

    model.eval()
    softmax = nn.Softmax(dim=1)
    softmax.cuda()

    # check validity of parameters
    assert args.nchannel == len(
        args.InputCh
    ), f'number of input channel does not match input channel indices'

    if args.mode == 'eval':

        filenames = glob.glob(args.InputDir + '/*' + args.DataType)
        filenames.sort()

        for fi, fn in enumerate(filenames):
            print(fn)
            # load data
            struct_img = load_single_image(args, fn, time_flag=False)

            print(struct_img.shape)

            # apply the model
            output_img = apply_on_image(model, struct_img, softmax, args)
            #output_img = model_inference(model, struct_img, softmax, args)

            #print(len(output_img))

            for ch_idx in range(len(args.OutputCh) // 2):
                write = omeTifWriter.OmeTifWriter(
                    args.OutputDir + pathlib.PurePosixPath(fn).stem + '_seg_' +
                    str(args.OutputCh[2 * ch_idx]) + '.ome.tif')
                if args.Threshold < 0:
                    write.save(output_img[ch_idx].astype(float))
                else:
                    out = output_img[ch_idx] > args.Threshold
                    out = out.astype(np.uint8)
                    out[out > 0] = 255
                    write.save(out)

            print(f'Image {fn} has been segmented')

    elif args.mode == 'eval_file':

        fn = args.InputFile
        print(fn)
        data_reader = AICSImage(fn)
        img0 = data_reader.data
        if args.timelapse:
            assert data_reader.shape[0] > 1

            for tt in range(data_reader.shape[0]):
                # Assume:  TCZYX
                img = img0[tt, args.InputCh, :, :, :].astype(float)
                img = input_normalization(img, args)

                if len(args.ResizeRatio) > 0:
                    img = resize(img,
                                 (1, args.ResizeRatio[0], args.ResizeRatio[1],
                                  args.ResizeRatio[2]),
                                 method='cubic')
                    for ch_idx in range(img.shape[0]):
                        struct_img = img[
                            ch_idx, :, :, :]  # note that struct_img is only a view of img, so changes made on struct_img also affects img
                        struct_img = (struct_img - struct_img.min()) / (
                            struct_img.max() - struct_img.min())
                        img[ch_idx, :, :, :] = struct_img

                # apply the model
                output_img = model_inference(model, img, softmax, args)

                for ch_idx in range(len(args.OutputCh) // 2):
                    writer = omeTifWriter.OmeTifWriter(
                        args.OutputDir + pathlib.PurePosixPath(fn).stem +
                        '_T_' + f'{tt:03}' + '_seg_' +
                        str(args.OutputCh[2 * ch_idx]) + '.ome.tif')
                    if args.Threshold < 0:
                        out = output_img[ch_idx].astype(float)
                        out = resize(
                            out,
                            (1.0, 1 / args.ResizeRatio[0],
                             1 / args.ResizeRatio[1], 1 / args.ResizeRatio[2]),
                            method='cubic')
                        writer.save(out)
                    else:
                        out = output_img[ch_idx] > args.Threshold
                        out = resize(
                            out,
                            (1.0, 1 / args.ResizeRatio[0],
                             1 / args.ResizeRatio[1], 1 / args.ResizeRatio[2]),
                            method='nearest')
                        out = out.astype(np.uint8)
                        out[out > 0] = 255
                        writer.save(out)
        else:
            img = img0[0, :, :, :].astype(float)
            if img.shape[1] < img.shape[0]:
                img = np.transpose(img, (1, 0, 2, 3))
            img = img[args.InputCh, :, :, :]
            img = input_normalization(img, args)

            if len(args.ResizeRatio) > 0:
                img = resize(img, (1, args.ResizeRatio[0], args.ResizeRatio[1],
                                   args.ResizeRatio[2]),
                             method='cubic')
                for ch_idx in range(img.shape[0]):
                    struct_img = img[
                        ch_idx, :, :, :]  # note that struct_img is only a view of img, so changes made on struct_img also affects img
                    struct_img = (struct_img - struct_img.min()) / (
                        struct_img.max() - struct_img.min())
                    img[ch_idx, :, :, :] = struct_img

            # apply the model
            output_img = model_inference(model, img, softmax, args)

            for ch_idx in range(len(args.OutputCh) // 2):
                writer = omeTifWriter.OmeTifWriter(
                    args.OutputDir + pathlib.PurePosixPath(fn).stem + '_seg_' +
                    str(args.OutputCh[2 * ch_idx]) + '.ome.tif')
                if args.Threshold < 0:
                    writer.save(output_img[ch_idx].astype(float))
                else:
                    out = output_img[ch_idx] > args.Threshold
                    out = out.astype(np.uint8)
                    out[out > 0] = 255
                    writer.save(out)

        print(f'Image {fn} has been segmented')