Example #1
0
 def __init__(self,
              image_search_path,
              mask_search_path,
              output_dir,
              block_size=(100, 100),
              overlap=(0, 0),
              image_suffix=".tif",
              mask_suffix='_mask.tif',
              remove_zero_threshold=-1,
              ignore_padding=0,
              remove_fg_threshold=0.5,
              scale=1.0):
     super(PatchConverter, self).__init__(image_search_path,
                                          mask_search_path,
                                          output_dir,
                                          image_suffix=image_suffix,
                                          mask_suffix=mask_suffix)
     # NB. update block_size, overlap, central_size and padding_size together
     self.block_size = block_size
     self.overlap = overlap
     self.central_size = np.array(self.block_size) - np.array(self.overlap)
     assert np.all(
         item % 2 == 0
         for item in self.overlap), 'overlap must be even integers!'
     # padding_size used to compensate for prediction center cropping
     self.padding_size = (overlap[0] // 2, overlap[1] // 2)
     # Temp folder to store cropped patches. Use timestamp to avoid racing condition
     time_now = datetime.now()
     time_string = time_now.strftime("%Y%m%d-%Hh%M%p%S")
     self.prediction_patch_dir = '/tmp/tmp_patches_' + time_string
     self.remove_zero_threshold = remove_zero_threshold
     self.ignore_padding = ignore_padding
     fileio.maybe_make_new_dir(self.output_dir)
     self.remove_fg_threshold = remove_fg_threshold
     self.scale = scale
Example #2
0
def stratefied_sampling_neg_and_pos(positive_patch_search_path,
                                    negative_patch_search_path,
                                    strata_regex_pattern,
                                    positive_dir,
                                    negative_dir,
                                    output_dir,
                                    max_ratio=2,
                                    seed=42):
    """Sample from positive and negative

    Args:
        positive_patch_search_path:
        negative_patch_search_path:
        strata_regex_pattern:
        positive_dir:
        negative_dir:
        output_path:
        max_ratio:
        seed: random seed for shuffling

    Returns:
        None
    """
    positive_files = glob2.glob(positive_patch_search_path)
    positive_files = [
        file for file in positive_files
        if re.search(strata_regex_pattern, file)
    ]
    negative_files = glob2.glob(negative_patch_search_path)
    negative_files = [
        file for file in negative_files
        if re.search(strata_regex_pattern, file)
    ]
    n_pos = len(positive_files)
    n_neg = len(negative_files)
    # if too many negatives, truncate at max_ratio
    if n_neg > n_pos * max_ratio:
        print('Truncate from {} to {} files'.format(n_neg, n_pos * max_ratio))
        negative_files = sorted(negative_files)
        np.random.seed(seed)
        np.random.shuffle(negative_files)
        negative_files = negative_files[:(n_pos * max_ratio)]
    # copy files
    for source_file in tqdm(positive_files):
        new_file = source_file.replace(positive_dir, output_dir)
        # print('{} --> {}'.format(source_file, new_file))
        fileio.maybe_make_new_dir(os.path.dirname(new_file))
        shutil.copyfile(source_file, new_file)
        # copy images
        shutil.copyfile(source_file.replace('_mask', ''),
                        new_file.replace('_mask', ''))
    for source_file in tqdm(negative_files):
        new_file = source_file.replace(negative_dir, output_dir)
        # print('{} --> {}'.format(source_file, new_file))
        fileio.maybe_make_new_dir(os.path.dirname(new_file))
        shutil.copyfile(source_file, new_file)
        # copy images
        shutil.copyfile(source_file.replace('_mask', ''),
                        new_file.replace('_mask', ''))
Example #3
0
    def get_paired_image_and_mask(self, join='inner'):
        """Create paired image and mask files

        Find common files in two folders and create dictionary paired_dict. Each key has two keys:
            'image': path to the image file
            'mask': path to the mask file

        Args:
            join: {'inner', 'outer', 'left', 'right'}
                'inner': find intersection between two lists
                'outer': find union of two lists
                'left': first list (image list)
                'right': second list (mask list)
        Returns:
            paired_dict: paired_dict[key]['image'] and paired_dict[key]['mask'] is a list
        """

        # each key in image_dict and mask_dict corresponds to a list
        if not self.is_vector:
            self.image_dict = {self.get_image_key_fn(filepath): filepath for filepath in self.image_files}
            self.mask_dict = {self.get_mask_key_fn(filepath): filepath for filepath in self.mask_files}
        else:
            self.image_dict = {}
            for filepath in self.image_files:
                key = self.get_image_key_fn(filepath)
                if key not in self.image_dict:
                    self.image_dict[key] = []
                self.image_dict[key].append(filepath)
            self.mask_dict = {}
            for filepath in self.mask_files:
                key = self.get_mask_key_fn(filepath)
                if key not in self.mask_dict:
                    self.mask_dict[key] = []
                self.mask_dict[key].append(filepath)
        if join == 'inner':
            keys = set(self.image_dict.keys()) & set(self.mask_dict.keys())
        elif join == 'outer':
            keys = set(self.image_dict.keys()) | set(self.mask_dict.keys())
        elif join == 'left':
            keys = set(self.image_dict.keys())
        elif join == 'right':
            keys = set(self.mask_dict.keys())
        else:
            raise KeyError('Unsupported join method {}'.format(join))
        if self.is_vector:
            empty_val = []
        else:
            empty_val = ''
        paired_dict = {}
        for key in keys:
            paired_dict[key] = {}
            paired_dict[key]['image'] = self.image_dict.get(key, empty_val)
            paired_dict[key]['mask'] = self.mask_dict.get(key, empty_val)
        logging.debug('paired_dict with length {}'.format(len(paired_dict)))
        filepath = os.path.join(self.output_dir, 'paired_dict.json')
        fileio.maybe_make_new_dir(os.path.dirname(filepath))
        with open(filepath, 'w') as f_out:
            json.dump(paired_dict, f_out, indent=4, sort_keys=True)
        return paired_dict
Example #4
0
 def _move_files(files_list, subdir='train'):
     for file_path in files_list:
         new_dir = os.path.join(os.path.dirname(file_path), subdir)
         fileio.maybe_make_new_dir(new_dir)
         new_path = os.path.join(new_dir, os.path.basename(file_path))
         if dry_run:
             logging.debug('move {} to {}'.format(file_path, new_path))
         else:
             shutil.move(file_path, new_path)
Example #5
0
 def deploy(self, clipLimit=2.0, tileGridSize=(8, 8)):
     fileio.maybe_make_new_dir(self.output_dir)
     for image_file in tqdm(sorted(self.image_files)):
         image_array = plt.imread(image_file, -1)
         image_array = self.apply_clahe(image_array,
                                        clipLimit=clipLimit,
                                        tileGridSize=tileGridSize)
         output_file_path = os.path.join(self.output_dir,
                                         os.path.basename(image_file))
         cv2.imwrite(output_file_path, image_array)
Example #6
0
def log_flags(flags, logdir, log_name='config.json'):
    """Log tf FLAGS to json"""
    fileio.maybe_make_new_dir(logdir)
    config_log = os.path.join(logdir, log_name)
    if flags is None:
        config_dict = {}
    else:
        # for tensorflow 1.5 and above
        if StrictVersion(tf.__version__) >= StrictVersion('1.5.0'):
            flags = FlagsObjectView(flags)
        config_dict = flags.__dict__
    with open(config_log, 'w') as f:
        json.dump(config_dict, f, indent=1, sort_keys=True)
Example #7
0
def save_annotated_image(image, filename, label='', newdir=''):
    """ Save image to newdir with the same basename in filename """
    if newdir:
      basename = os.path.basename(filename)
      if label:
          file, ext = os.path.splitext(basename)
          basename = file + '_' + label + ext
      newfilename = os.path.join(newdir, basename)
    else:
      newfilename = filename
    dirname = os.path.dirname(newfilename)
    fileio.maybe_make_new_dir(dirname)
    # print(newfilename)
    cv2.imwrite(newfilename, image)
Example #8
0
def png2nii(png_filepath, nii_filepath):
    """Convert png to nii format

    Args:
        png_filepath:
        nii_filepath:

    Returns:
        None
    """
    image = sitk.ReadImage(png_filepath)
    # make parent directory otherwise sitk will not write files
    fileio.maybe_make_new_dir(os.path.dirname(nii_filepath))
    sitk.WriteImage(image, nii_filepath)
Example #9
0
    def get_paired_image_and_mask(self,
                                  join='inner',
                                  key_names=('image', 'mask')):
        """Create paired image and mask files

        Find common files in two folders and create dictionary paired_dict. Each key has two keys:
            'image': path to the image file
            'mask': path to the mask file

        Args:
            join: {'inner', 'outer', 'left', 'right'}
                'inner': find intersection between two lists
                'outer': find union of two lists
                'left': first list (image list)
                'right': second list (mask list)
        Returns:
            paired_dict: paired_dict[key]['image'] and paired_dict[key]['mask'] is a file
        """
        image_key, mask_key = key_names
        # print(self.image_files)
        self.image_dict = {
            self.get_image_key_fn(filepath): filepath
            for filepath in self.image_files
        }
        self.mask_dict = {
            self.get_mask_key_fn(filepath): filepath
            for filepath in self.mask_files
        }
        if join == 'inner':
            keys = set(self.image_dict.keys()) & set(self.mask_dict.keys())
        elif join == 'outer':
            keys = set(self.image_dict.keys()) | set(self.mask_dict.keys())
        elif join == 'left':
            keys = set(self.image_dict.keys())
        elif join == 'right':
            keys = set(self.mask_dict.keys())
        else:
            raise KeyError('Unsupported join method {}'.format(join))
        paired_dict = {}
        for key in keys:
            paired_dict[key] = {}
            paired_dict[key][image_key] = self.image_dict.get(key, '')
            paired_dict[key][mask_key] = self.mask_dict.get(key, '')
        logging.debug('paired_dict with length {}'.format(len(paired_dict)))
        filepath = os.path.join(self.output_dir, 'paired_dict.json')
        fileio.maybe_make_new_dir(os.path.dirname(filepath))
        with open(filepath, 'w') as f_out:
            json.dump(paired_dict, f_out, indent=4, sort_keys=True)
        return paired_dict
Example #10
0
    def visualize_multiple_gt(image_3ch,
                              boxes_dict,
                              masks_dict,
                              class_ids_dict,
                              class_names,
                              key,
                              fig_dir=None,
                              show_orig=False,
                              subplot_size=(16, 16)):
        assert set(boxes_dict.keys()) == set(masks_dict.keys()) == set(
            class_ids_dict.keys())
        n_annotation_series = len(boxes_dict.keys())
        if show_orig:
            # show original image without annotation
            axes = roc.get_ax(1, n_annotation_series + 1, size=subplot_size)
            ax, axes = axes[0], axes[1:]
            empty_array = np.array([])
            visualize.display_instances(image=image_3ch,
                                        boxes=empty_array,
                                        masks=empty_array,
                                        class_ids=empty_array,
                                        class_names=class_names,
                                        show_mask=False,
                                        show_bbox=False,
                                        ax=ax,
                                        title='orig image',
                                        verbose=False)
        else:
            axes = roc.get_ax(1, n_annotation_series)
        series_keys = boxes_dict.keys()
        assert len(axes) == len(series_keys)
        for idx, (ax, series_key) in enumerate(zip(axes, series_keys)):
            # Display GT bbox and mask
            visualize.display_instances(image=image_3ch,
                                        boxes=boxes_dict[series_key],
                                        masks=masks_dict[series_key],
                                        class_ids=class_ids_dict[series_key],
                                        class_names=class_names,
                                        ax=ax,
                                        title=series_key,
                                        verbose=False)

        # Save to model log dir
        fig_dir = fig_dir or '/tmp/tmp/'
        fileio.maybe_make_new_dir(fig_dir)
        fig_path = os.path.join(fig_dir, 'gt_{}.png'.format(key))
        plt.savefig(fig_path, bbox_inches='tight')
        plt.close('all')
Example #11
0
def dicom2png(dicom_filepath,
              png_filepath=None,
              normalize_functor=scale_to_255,
              dryrun=False):
    """Convert dicom image to png file

    Args:
        dicom_filepath:
        png_filepath:
        normalize_functor: normalizing function

    Returns:
        image_array: a numpy array containing the image
    """
    image_array = data.get_pixel_array_from_dicom_path(dicom_filepath,
                                                       to_bit=-1)
    if normalize_functor:
        image_array = normalize_functor(image_array)
    if image_array.max() <= 1:
        image_array = (image_array * 255).astype(np.uint8)
    if not dryrun:
        fileio.maybe_make_new_dir(os.path.dirname(png_filepath))
        cv2.imwrite(png_filepath, image_array)
    return image_array
Example #12
0
                                 train_txt_path=train_txt_path,
                                 scale=1.0,
                                 ignore_padding=0,
                                 remove_zero_threshold=25).deploy()

    # for a smaller size datasize.
    if FLAGS.task == 'calc_crop_synth':
        image_search_path = r'/data1/Image_data/Mammography_data/INbreast/Calc_synthesis/20190711/synthesis_image/*png'
        if FLAGS.ignore_single_point:
            mask_search_path = r'/data1/Image_data/Mammography_data/INbreast/Calc_synthesis/20190711/synthesis_mask/*png'
        else:
            mask_search_path = r'/data1/Image_data/Mammography_data/INbreast/calc_mask/*png'
        valid_txt_path = r'/data1/Image_data/Mammography_data/INbreast/Calc_synthesis/20190711/valid.txt'
        train_txt_path = r'/data1/Image_data/Mammography_data/INbreast/Calc_synthesis/20190711/train.txt'
        output_dir = '/data1/Image_data/Mammography_data/INbreast/Calc_synthesis/20190711/calc_patches/'
        fileio.maybe_make_new_dir(output_dir)
        MassOrCalcPatchConverter(image_search_path,
                                 mask_search_path,
                                 output_dir,
                                 valid_txt_path=valid_txt_path,
                                 train_txt_path=train_txt_path,
                                 scale=1.0,
                                 ignore_padding=0,
                                 remove_zero_threshold=25).deploy()

    if FLAGS.task == 'calc_cluster_crop':
        image_search_path = r'/media/Data/Data02/Datasets/Mammogram/Ziwei_WIP/calc_cluster/png/*png'
        mask_search_path = r'/media/Data/Data02/Datasets/Mammogram/Ziwei_WIP/calc_cluster/bootstrap_mask_cleaned/*png'
        valid_txt_path = r'/media/Data/Data02/Datasets/Mammogram/Ziwei_WIP/evaluation/valid.txt'
        train_txt_path = r'/media/Data/Data02/Datasets/Mammogram/Ziwei_WIP/evaluation/train.txt'
        output_dir = '/data/log/mammo/calc_cluster_patches'
Example #13
0
def show_bbox(bbox_dict, raw_image_path_dict, image_pred_path_dict=None, n_demo=0, output_dir='', gt_only=False):
    """Overlay bbox coordinate to original image and show GT and pred side by side.

    Prediction overlay uses prediction probability map if image_pred_path_dict is not None, otherwise use raw image

    Args:
        bbox_dict: a dict with image name as key. Each key corresponds to another dict with the following keys
            'pred_bbox_list': input to is_bbox_list_overlapped()
            'gt_bbox_list': input to is_bbox_list_overlapped()
            'pred_box_correct': output of is_bbox_list_overlapped()
            'gt_box_covered': output of is_bbox_list_overlapped()
        raw_image_path_dict: a dict with image name as key. The corresponding value is the path to the raw image to
            overlay
        image_pred_path_dict: optional, a dict with image name as key. The corresponding value is the path to the pred
            results to overlay. Default to None, and if specified, use it to showcase pred result on the RHS of the
            stack image
        n_demo: number of times to run demo
        gt_only: boolean, whether to show gt only (no prediction) <TODO> Not tested yet

    Returns:
        None
    """
    colors = {
        "tp": (0, 255, 0), # green
        "fp": (255, 0, 0), # blue in BGR
        "fn": (0, 0, 255), # red in BGR
    }
    for idx, key in enumerate(bbox_dict.keys()):
        if key.startswith('_'):
            continue
        pred_bbox_list = bbox_dict[key]['pred_bbox_list']
        is_pred_bbox_correct_list = bbox_dict[key]['pred_box_correct']
        gt_bbox_list = bbox_dict[key]['gt_bbox_list']
        is_gt_bbox_covered_list = bbox_dict[key]['gt_box_covered']
        bbox_list_tp = [bbox for bbox, bool in zip(pred_bbox_list, is_pred_bbox_correct_list) if bool]
        bbox_list_fp = [bbox for bbox, bool in zip(pred_bbox_list, is_pred_bbox_correct_list) if not bool]
        bbox_list_fn = [bbox for bbox, bool in zip(gt_bbox_list, is_gt_bbox_covered_list) if not bool]
        bbox_list_tp_gt = [bbox for bbox, bool in zip(gt_bbox_list, is_gt_bbox_covered_list) if bool]

        image_path = raw_image_path_dict[key]
        image = fileio.load_image_to_array(image_path, np.uint8)
        if image_pred_path_dict:
            image_pred_path = image_pred_path_dict[key]
            # this can be a list of up to 3 elements to populate BGR channels
            if isinstance(image_pred_path, (list, tuple)) and len(image_pred_path) > 1:
                image_overlay = np.dstack([image] * 3)
                for idx_ch, single_pred_path in enumerate(image_pred_path):
                    logging.debug('assembling channel {}'.format(idx_ch))
                    image_pred = fileio.load_image_to_array(single_pred_path, np.uint8)
                    # generate overlay in green channel (low prob in magenta color)
                    logging.debug('before crop_or_pad {} {}'.format(image_pred.shape, image.shape))
                    image_pred = augmentation.crop_or_pad(image_pred, image.shape)
                    logging.debug('after crop_or_pad {} {}'.format(image_pred.shape, image.shape))
                    image_proba = np.where(image_pred > 0, image_pred, image) # as a single channel
                    image_overlay[:, :, idx_ch] = image_proba
                image_pred = image_overlay
            else:
                image_pred = fileio.load_image_to_array(image_pred_path, np.uint8)
                # generate overlay in green channel (low prob in magenta color)
                logging.debug('before crop_or_pad {} {}'.format(image_pred.shape, image.shape))
                image_pred = augmentation.crop_or_pad(image_pred, image.shape)
                logging.debug('after crop_or_pad {} {}'.format(image_pred.shape, image.shape))
                image_proba = np.where(image_pred > 0, image_pred, image) # as a single channel
                image_overlay = np.dstack([image, image_proba, image])
                image_pred = image_overlay
        else:
            image_pred = image
        image_overlay_pred = overlay_bbox_list_on_image(image_pred, bbox_list_tp, color=colors['tp'])
        image_overlay_pred = overlay_bbox_list_on_image(image_overlay_pred, bbox_list_fp, color=colors['fp'])
        image_overlay_gt = overlay_bbox_list_on_image(image, bbox_list_tp_gt, color=colors['tp'])
        image_overlay_gt = overlay_bbox_list_on_image(image_overlay_gt, bbox_list_fn, color=colors['fn'])
        if idx < n_demo:
            fig, ax = plt.subplots(1, 2, figsize=(16, 10))
            ax = np.atleast_2d(ax)
            ax[0, 0].imshow(image_overlay_gt)
            ax[0, 1].imshow(image_overlay_pred)
            plt.show()
        # stack image and image_overlay side by side
        # image_rgb = np.dstack([image] * 3)
        logging.info('Processing key: {}'.format(key))
        if output_dir:
            if not gt_only:
                image_stack = np.hstack([image_overlay_gt, image_overlay_pred])
            else:
                image_stack = image_overlay_gt
            image_stack_path = os.path.join(output_dir, os.path.basename(image_path))
            fileio.maybe_make_new_dir(output_dir)
            cv2.imwrite(image_stack_path, image_stack)
        else:
            logging.warning('No output_dir specified. Skip key: {}'.format(key))
Example #14
0
def plot_froc_from_data_dict(data_dict,
                             output_fig_path=None,
                             fig_title=None,
                             label_filter='',
                             xlim=(0.1, 50),
                             key_sorter=None,
                             plot_recall=False,
                             highlight_idx=None,
                             **kwargs):
    """Plot froc curve from a data dict

    Args:
        data_dict: a dict of dict. Each sub-dict has keys
            label: used as legend
            data: list of list in the format [[recall, fpc, threshold], ...]
        output_fig_path:
        fig_title:
        label_filter:
        xlim:
        key_sorter: a function to sort the keys. Default to sorting by last mod time
        plot_recall: defaults to False, where a semilogx and a linear plot are plotted side by side.
            When plot_recall is True, replcae the linear plot with plot of recall in chronological order
        highlight_idx: the idx (counting from 0) of the threshold list to plot trend over time

    Returns:
        None
    """
    fig = plt.figure(figsize=(12, 6))
    labels = sorted(set(val['label'] for val in data_dict.values()))
    line_styles = ['-', ':', '-.', '--']

    if len(labels) > len(line_styles) or len(data_dict) == len(labels):
        ls_dict = {label: line_styles[0] for idx, label in enumerate(labels)}
    else:
        ls_dict = {label: line_styles[idx] for idx, label in enumerate(labels)}
    if plot_recall:
        plot_fns = [plt.semilogx]
    else:
        plot_fns = [plt.semilogx, plt.plot]
    for idx, plot_func in enumerate(plot_fns):
        plt.subplot(1, 2, idx + 1)
        keys = sorted(data_dict.keys())
        # sort by last mod time
        key_sorter = key_sorter or (lambda x: os.path.getmtime(x))
        try:
            keys.sort(key=key_sorter)
        except:
            # if cannot sort with key_sorter, sort alphabetically by key string
            keys.sort(key=str)
        mid_recall_list = []
        mid_fp_list = []
        for key in keys:
            label = data_dict[key]['label']
            line_style = ls_dict[label]
            if 'num_images' in data_dict[key]:
                label = '{} (count:{})'.format(label,
                                               data_dict[key]['num_images'])
            if label_filter in label:
                data = data_dict[key]['data']
                fpc = [item[1] for item in data]
                recall = [item[0] for item in data]
                p = plot_func(fpc,
                              recall,
                              marker='.',
                              ls=line_style,
                              label=label)
                color = p[0].get_color()
                if highlight_idx is not None:
                    if highlight_idx == 'mid':
                        highlight_idx = (len(fpc) - 1) // 2
                    mid_recall_list.append(recall[highlight_idx])
                    mid_fp_list.append(fpc[highlight_idx])
                    plt.scatter(fpc[highlight_idx],
                                recall[highlight_idx],
                                marker='o',
                                s=100,
                                facecolors='none',
                                edgecolors=color)
        plt.xlabel('FP per Image')
        plt.ylabel('Recall')
        plt.title(fig_title)
        plt.xlim(xlim)
        plt.ylim([0, 1])
        plt.grid()
        plt.yticks([i / 10.0 for i in range(0, 10)])
        plt.grid(b=True, which='major', color='gray', linestyle='-')
        plt.grid(b=True, which='minor', color='gray', linestyle='--')
    # only plot the legend of the last subplot
    plt.legend(loc='best', fancybox=True, framealpha=0.5)
    if plot_recall:
        # plot on the RHS
        ax1 = plt.subplot(122)
        ax2 = ax1.twinx()
        plot_sharex_series(ax1,
                           ax2,
                           mid_recall_list,
                           mid_fp_list,
                           ylim1=(0, 1),
                           ylim2=(0.1, 10),
                           xlabel='ckpt',
                           ylabels=('Recall', 'FPC'))
        plt.grid()
        plt.title('Recall and FPC')
    if output_fig_path is not None:
        fileio.maybe_make_new_dir(os.path.dirname(output_fig_path))
        plt.savefig(output_fig_path, dpi=300)  # default dpi is usually 100
        plt.close('all')
    else:
        plt.show()
    return fig
Example #15
0
def generate_negative_sample(image_path,
                             label_path,
                             patch_size,
                             neg_imagedir,
                             isrotate=False,
                             ignore_padding=0,
                             n_patches=20,
                             key='',
                             nonezero_threshold=0.5,
                             scale=1.0,
                             resize_jitter_list=[0.75, 1.25],
                             max_trial_per_patch=5):
    """
    Generate the negative sample, random choose 100 points, to see if the result meet the demand
    Args:
        image_path(str)
        label_path(str): if empty, then use an all zero mask
        patch_size(int)
        neg_imagedir(str)

    Returns:
        None
    """
    assert image_path
    image = cv2.imread(image_path, -1)
    if label_path:
        label = cv2.imread(label_path, -1)
    else:
        print('Use all zero mask!')
        label = np.zeros_like(image, dtype=np.uint8)
    target_size = np.array([patch_size * 3, patch_size * 2])
    max_trial = n_patches * max_trial_per_patch  # for each patch try up to max_trial_per_patch times
    i = 0
    trial = 0
    max_nonzero_ratio = 0

    while trial <= max_trial and i < n_patches:
        trial += 1
        resize_ratio_lower, resize_ratio_upper = resize_jitter_list
        resize_jitter = np.random.uniform(resize_ratio_lower,
                                          resize_ratio_upper)
        image_resize = augmentation.resize(image, scale=resize_jitter * scale)
        label_resize = augmentation.resize(label, scale=resize_jitter * scale)
        image_resize_shape = np.asarray(image_resize.shape)
        if np.any(image_resize_shape < target_size):
            target_size = np.maximum(target_size, image_resize_shape)
            image_pad = augmentation.center_pad(image_resize, target_size)
            label_pad = augmentation.center_pad(label_resize, target_size)
        # Generate rotation angle randomly
        if isrotate:
            degree = generate_rotate_list(rotations_per_axis=1, max_degree=180)
            M = cv2.getRotationMatrix2D(
                (image_pad.shape[0] / 2, image_pad.shape[1] / 2), degree[0],
                1)  # the rotation center must be tuple
            image_rotate = cv2.warpAffine(
                image_pad, M, (image_pad.shape[1], image_pad.shape[0]))
            label_rotate = cv2.warpAffine(label_pad, M, image_pad.shape)
            image_aug = image_rotate
            label_aug = label_rotate
        else:
            image_aug = image_pad
            label_aug = label_pad
        y = random.randint(patch_size / 2, image_aug.shape[0] - patch_size / 2)
        x = random.randint(patch_size / 2, image_aug.shape[1] - patch_size / 2)
        label_patch = label_aug[int(y - patch_size / 2):int(y +
                                                            patch_size / 2),
                                int(x - patch_size / 2):int(x +
                                                            patch_size / 2)]
        image_patch = image_aug[int(y - patch_size / 2):int(y +
                                                            patch_size / 2),
                                int(x - patch_size / 2):int(x +
                                                            patch_size / 2)]
        central_label_patch = label_patch[ignore_padding:-ignore_padding,
                                          ignore_padding:-ignore_padding]
        central_image_patch = image_patch[ignore_padding:-ignore_padding,
                                          ignore_padding:-ignore_padding]
        nonzero_ratio = np.count_nonzero(
            central_image_patch) / central_image_patch.size
        if not central_label_patch.any():
            max_nonzero_ratio = max(max_nonzero_ratio, nonzero_ratio)
            if nonzero_ratio >= nonezero_threshold:
                print('============', nonzero_ratio)
                i += 1
                neg_patch = image_patch
                neg_path = os.path.join(
                    neg_imagedir, key,
                    "{}_neg{:03d}_scale{:.2f}.png".format(key, i, scale))
                neg_label_path = os.path.join(
                    neg_imagedir, key,
                    "{}_neg{:03d}_scale{:.2f}_mask.png".format(key, i, scale))
                fileio.maybe_make_new_dir(os.path.dirname(neg_path))
                fileio.maybe_make_new_dir(os.path.dirname(neg_label_path))
                if neg_patch.shape == (patch_size,
                                       patch_size) and label_patch.shape == (
                                           patch_size, patch_size):
                    cv2.imwrite(neg_path, neg_patch)
                    cv2.imwrite(neg_label_path, label_patch)
                else:
                    continue
    print('max_nonzero_ratio', max_nonzero_ratio)
Example #16
0
 def write(self, output_path, patch_array, write_rgb=False):
     fileio.maybe_make_new_dir(os.path.dirname(output_path))
     if write_rgb:
         patch_array = np.dstack([patch_array] * 3)
     cv2.imwrite(output_path, patch_array)