def __init__(self, dataset_loader: str, dataset_path: str, video_name=None, **kwargs):
        resize_options = ResizeOptions(**kwargs)
        dataset = load_dataset(dataset_loader, dataset_path, resize_options=resize_options)

        self.video_name = video_name if video_name is not None else dataset.video_names[0]
        self.real_dataset = RealDataset(
            frame_preprocessing=dataset.frame_preprocessing,
            output_image_shape=dataset.image_shape,
        )
        self.frames_dataset = dataset.frames_dataset
    def __init__(self, frame_preprocessing: BaseFramePreprocessing,
                 image_shape):

        self.real_dataset = RealDataset(
            frame_preprocessing=frame_preprocessing,
            output_image_shape=image_shape)

        self.synthetic_dataset = SyntheticDataset(
            frame_preprocessing=frame_preprocessing,
            output_image_shape=image_shape,
            enable_random_augmentations=False,
        )
        self.last_synth_image = np.empty(image_shape, np.uint8)
        self.last_real_image = None
class RealSimpleVisualizer(object):
    """
    Utility class to visualize the real images processed
    """
    def __init__(self,
                 dataset_loader: str,
                 dataset_path: str,
                 video_name=None,
                 **kwargs):
        resize_options = ResizeOptions(**kwargs)
        dataset = load_dataset(dataset_loader,
                               dataset_path,
                               resize_options=resize_options,
                               **kwargs)

        self.video_name = video_name if video_name is not None else dataset.video_names[
            0]
        self.real_dataset = RealDataset(
            frame_preprocessing=dataset.frame_preprocessing,
            output_image_shape=dataset.image_shape,
        )
        self.frames_dataset = dataset.frames_dataset

    def generate(self):
        with self.frames_dataset.open(self.video_name) as frames:
            for frame in frames:
                processed_frame, _ = self.real_dataset.process_frame(frame)
                yield frame, processed_frame
示例#4
0
def _assemble_images_batch(
    frame_preprocessing: BaseFramePreprocessing,
    data_reading_queue,
    results_queue,
    temp_dir: str,
    batch_size: int,
    image_shape,
):

    images_batch = _ImagesBatch(batch_size, image_shape)
    real_dataset = RealDataset(frame_preprocessing, image_shape)
    while True:
        queue_data = data_reading_queue.get()
        if queue_data is None:
            break

        data_filename, chunk_index = queue_data

        with open(data_filename, "rb") as f:
            images_batch.reset()
            while True:
                try:
                    raw_frame = pickle.load(f)
                except EOFError:
                    break
                cur_frame, _ = real_dataset.process_frame(raw_frame)
                images_batch.add(cur_frame)

        os.remove(data_filename)

        # normalize between  and 1
        images_data_batch = images_batch.data / 255.0

        image_filename = os.path.join(temp_dir,
                                      f"real_topredict_{chunk_index:09d}.npy")
        np.save(image_filename, images_data_batch)
        results_queue.put(image_filename)
class CenterlineAccuracyCheck(object):
    """
    Class that performs comparison (image similarity) to assess how a centerline accurately represents a real image.

    It first preprocesses the original real image with a BaseFramePreprocessing class
    (crop and set the background pixels to a uniform color).
    Then, it creates a synthetic image representing the centerline theta, using a provided template image.
    Typically, the template image was chosen to be the closest labelled image in time to the real image.
    The synthetic image is cropped to fit the worm, in order to apply a template matching function between
    the real image (full size) and the synthetic image (smaller)
    The result is an image similarity value and the synthetic image skeleton coordinates.
    """
    def __init__(self, frame_preprocessing: BaseFramePreprocessing,
                 image_shape):

        self.real_dataset = RealDataset(
            frame_preprocessing=frame_preprocessing,
            output_image_shape=image_shape)

        self.synthetic_dataset = SyntheticDataset(
            frame_preprocessing=frame_preprocessing,
            output_image_shape=image_shape,
            enable_random_augmentations=False,
        )
        self.last_synth_image = np.empty(image_shape, np.uint8)
        self.last_real_image = None

    def __call__(
        self,
        theta,
        template_skeleton,
        template_frame,
        template_measurements,
        real_frame_orig,
    ):
        if np.any(np.isnan(theta)):
            score = np.nan
            synth_skel = np.full_like(template_skeleton, np.nan)
            return score, synth_skel

        self.last_real_image, skel_offset = self.real_dataset.process_frame(
            real_frame_orig)
        cur_bg_color, synth_skel = self.synthetic_dataset.generate(
            theta,
            template_frame=template_frame,
            template_skeleton=template_skeleton,
            out_image=self.last_synth_image,
            template_measurements=template_measurements,
        )

        # Crop the synthetic image to the object of interest before doing the image comparison,
        # we don't need the full image with all the background, still keep a little padding around the worm.
        bounding_box = fit_bounding_box_to_worm(self.last_synth_image,
                                                cur_bg_color)
        np.subtract(synth_skel, (bounding_box[1].start, bounding_box[0].start),
                    out=synth_skel)
        cropped_synth_image = self.last_synth_image[bounding_box]

        # Perform the image comparison between the real image and the reconstructed synthetic image cropped.
        # This gives a heat map, we are interested in the maximum value of the heatmap and its location.
        # This maximum score gives an estimation of the confidence of the prediction.
        score, score_loc = image_scoring.calculate_similarity(
            source_image=self.last_real_image,
            template_image=cropped_synth_image)

        # Using the heatmap maximum coordinates, we can transform the coordinates of the reconstructed skeleton
        # to the original image coordinates system.
        synth_skel += score_loc
        synth_skel += np.array([skel_offset[0], skel_offset[1]])

        return score, synth_skel
示例#6
0
def generate(
    dataset: Dataset,
    num_samples: int,
    theta_dims: int,
    file_pattern: str,
) -> int:
    """
    Generates evaluation dataset composed of processed real images, saved to a .TFrecord

    :param dataset: WormPose Dataset
    :param num_samples: How many images to generate
    :param theta_dims: Dimensions of theta for the labels
    :param file_pattern: Path of the output files with "index" variable
        example: "path_to_out/eval_{index}.tfrecord"
    :return: How many samples where actually generated (if there is less data than the requested num_samples)
    """
    labelled_frames = {}
    for video_name in dataset.video_names:
        skel_is_not_nan = ~np.any(np.isnan(
            dataset.features_dataset[video_name].skeletons),
                                  axis=(1, 2))
        labelled_indexes = np.where(skel_is_not_nan)[0]
        if len(labelled_indexes) > 0:
            labelled_frames[video_name] = labelled_indexes

    if len(labelled_frames) == 0:
        raise RuntimeError(
            "Can't create evaluation data because couldn't find any labelled frame in the dataset."
        )

    len_labelled_frames = int(
        np.sum([len(x) for x in labelled_frames.values()]))
    if len_labelled_frames < num_samples:
        logging.warning(
            f"Not enough labelled frames in the dataset "
            f"to create an evaluation set of {num_samples} unique samples, "
            f"using all available {len_labelled_frames} samples instead.")
        num_samples = len_labelled_frames

    real_dataset = RealDataset(dataset.frame_preprocessing,
                               dataset.image_shape)

    tfrecord_filename = file_pattern.format(index=0)
    csv_infos_filename = os.path.splitext(tfrecord_filename)[0] + ".csv"

    # get num_samples total random labelled frames from the videos
    eval_frames = _populate_eval_frames(labelled_frames, num_samples)

    # write the eval.tfrecord file with the images and the labels, save also the source infos in a separate eval.csv
    # the frames are not shuffled by video, all the frames from one video are consecutive in the file
    with tfrecord_file.Writer(tfrecord_filename) as record_writer, open(
            csv_infos_filename, "w") as csv_file:
        csv_writer = csv.writer(csv_file,
                                delimiter=" ",
                                quotechar="|",
                                quoting=csv.QUOTE_MINIMAL)
        for video_name, cur_video_eval_indexes in eval_frames.items():

            with dataset.frames_dataset.open(video_name) as frames:

                for eval_frame_index in cur_video_eval_indexes:
                    image_data, _ = real_dataset.process_frame(
                        frames[eval_frame_index])
                    cur_skel = dataset.features_dataset[video_name].skeletons[
                        eval_frame_index]
                    cur_theta = skeleton_to_angle(cur_skel,
                                                  theta_dims=theta_dims)
                    cur_theta_flipped = flip_theta(cur_theta)

                    record_writer.write(image_data, cur_theta,
                                        cur_theta_flipped)
                    csv_writer.writerow([video_name, int(eval_frame_index)])

    return num_samples