def __init__( self, dataset_loader: str, dataset_path: str, postures_generator: Optional[Generator] = None, video_name: str = None, **kwargs ): resize_options = ResizeOptions(**kwargs) dataset = load_dataset(dataset_loader, dataset_path, resize_options=resize_options, **kwargs) if postures_generator is None: postures_generator = PosturesModel().generate() if video_name is None: video_name = dataset.video_names[0] features = dataset.features_dataset[video_name] self.skeletons = features.skeletons self.measurements = features.measurements self.output_image_shape = dataset.image_shape self.synthetic_dataset = SyntheticDataset( frame_preprocessing=dataset.frame_preprocessing, output_image_shape=self.output_image_shape, enable_random_augmentations=False, ) skel_is_not_nan = ~np.any(np.isnan(self.skeletons), axis=(1, 2)) self.labelled_indexes = np.where(skel_is_not_nan)[0] if len(self.labelled_indexes) == 0: raise ValueError("No template frames found in the dataset, can't generate synthetic images.") self.frames_dataset = dataset.frames_dataset self.video_name = video_name self.postures_generator = postures_generator
def _write_to_file( out_filename: str, num_samples: int, template_filename: str, postures_generation_fn: Callable[[], Generator], progress_counter, writer: Type[GenericFileWriter], synthetic_dataset_args, random_seed: Optional[int], ): from wormpose.images.synthetic import SyntheticDataset seed = int.from_bytes(os.urandom(4), byteorder="little") if random_seed is None else random_seed np.random.seed(seed) with open(template_filename, "rb") as templates_f: templates = pickle.load(templates_f) os.remove(template_filename) worm_measurements = {} for video_name in np.unique(templates.video_names): indexes = np.where(templates.video_names == video_name)[0] measurements = templates.measurements[indexes] average_measurements = np.empty((measurements.shape[1:]), dtype=measurements.dtype) for name in measurements.dtype.names: average_measurements[0][name] = np.nanmean(measurements[name]) worm_measurements[video_name] = average_measurements synthetic_dataset = SyntheticDataset(**synthetic_dataset_args) # preallocate all the template indexes to create the synthetic images (random) template_indexes = np.random.randint(len(templates.frames), size=num_samples) # preallocate all the choices for the position of the head (random) headtail_choice = np.random.choice([0, 1], size=num_samples) # preallocate image buffer to draw the synthetic image on image_data = np.empty(synthetic_dataset.output_image_shape, dtype=np.uint8) postures_gen = postures_generation_fn() with writer(out_filename) as synth_data_writer: for index, (cur_template_index, cur_headtail_choice) in enumerate(zip(template_indexes, headtail_choice)): theta = next(postures_gen) label_data = [theta, flip_theta(theta)] video_name = templates.video_names[cur_template_index] template_frame = templates.frames[cur_template_index] template_skeleton = templates.skeletons[cur_template_index] template_measurements = worm_measurements[video_name] synthetic_dataset.generate( theta=label_data[cur_headtail_choice], template_frame=template_frame, template_skeleton=template_skeleton, template_measurements=template_measurements, out_image=image_data, ) synth_data_writer.write(locals()) progress_counter.value = index + 1
class SyntheticSimpleVisualizer(object): """ Utility class to visualize the synthetic images """ def __init__(self, dataset_loader: str, dataset_path: str, postures_generator: Optional[Generator] = None, video_name: str = None, **kwargs): resize_options = ResizeOptions(**kwargs) dataset = load_dataset(dataset_loader, dataset_path, resize_options=resize_options) if postures_generator is None: postures_generator = PosturesModel().generate() if video_name is None: video_name = dataset.video_names[0] features = dataset.features_dataset[video_name] self.skeletons = features.skeletons self.measurements = features.measurements self.output_image_shape = dataset.image_shape self.synthetic_dataset = SyntheticDataset( frame_preprocessing=dataset.frame_preprocessing, output_image_shape=self.output_image_shape, enable_random_augmentations=False, ) skel_is_not_nan = ~np.any(np.isnan(self.skeletons), axis=(1, 2)) self.labelled_indexes = np.where(skel_is_not_nan)[0] if len(self.labelled_indexes) == 0: raise ValueError( "No template frames found in the dataset, can't generate synthetic images." ) self.frames_dataset = dataset.frames_dataset self.video_name = video_name self.postures_generator = postures_generator def generate(self): out_image = np.empty(self.output_image_shape, dtype=np.uint8) with self.frames_dataset.open(self.video_name) as frames: while True: theta = next(self.postures_generator) random_label_index = np.random.choice(self.labelled_indexes) self.synthetic_dataset.generate( theta=theta, template_skeleton=self.skeletons[random_label_index], template_frame=frames[random_label_index], out_image=out_image, template_measurements=self.measurements, ) yield out_image, theta
def __init__(self, frame_preprocessing: BaseFramePreprocessing, image_shape): self.real_dataset = RealDataset( frame_preprocessing=frame_preprocessing, output_image_shape=image_shape) self.synthetic_dataset = SyntheticDataset( frame_preprocessing=frame_preprocessing, output_image_shape=image_shape, enable_random_augmentations=False, ) self.last_synth_image = np.empty(image_shape, np.uint8) self.last_real_image = None
class CenterlineAccuracyCheck(object): """ Class that performs comparison (image similarity) to assess how a centerline accurately represents a real image. It first preprocesses the original real image with a BaseFramePreprocessing class (crop and set the background pixels to a uniform color). Then, it creates a synthetic image representing the centerline theta, using a provided template image. Typically, the template image was chosen to be the closest labelled image in time to the real image. The synthetic image is cropped to fit the worm, in order to apply a template matching function between the real image (full size) and the synthetic image (smaller) The result is an image similarity value and the synthetic image skeleton coordinates. """ def __init__(self, frame_preprocessing: BaseFramePreprocessing, image_shape): self.real_dataset = RealDataset( frame_preprocessing=frame_preprocessing, output_image_shape=image_shape) self.synthetic_dataset = SyntheticDataset( frame_preprocessing=frame_preprocessing, output_image_shape=image_shape, enable_random_augmentations=False, ) self.last_synth_image = np.empty(image_shape, np.uint8) self.last_real_image = None def __call__( self, theta, template_skeleton, template_frame, template_measurements, real_frame_orig, ): if np.any(np.isnan(theta)): score = np.nan synth_skel = np.full_like(template_skeleton, np.nan) return score, synth_skel self.last_real_image, skel_offset = self.real_dataset.process_frame( real_frame_orig) cur_bg_color, synth_skel = self.synthetic_dataset.generate( theta, template_frame=template_frame, template_skeleton=template_skeleton, out_image=self.last_synth_image, template_measurements=template_measurements, ) # Crop the synthetic image to the object of interest before doing the image comparison, # we don't need the full image with all the background, still keep a little padding around the worm. bounding_box = fit_bounding_box_to_worm(self.last_synth_image, cur_bg_color) np.subtract(synth_skel, (bounding_box[1].start, bounding_box[0].start), out=synth_skel) cropped_synth_image = self.last_synth_image[bounding_box] # Perform the image comparison between the real image and the reconstructed synthetic image cropped. # This gives a heat map, we are interested in the maximum value of the heatmap and its location. # This maximum score gives an estimation of the confidence of the prediction. score, score_loc = image_scoring.calculate_similarity( source_image=self.last_real_image, template_image=cropped_synth_image) # Using the heatmap maximum coordinates, we can transform the coordinates of the reconstructed skeleton # to the original image coordinates system. synth_skel += score_loc synth_skel += np.array([skel_offset[0], skel_offset[1]]) return score, synth_skel