def __init__(
        self,
        dataset_loader: str,
        dataset_path: str,
        postures_generator: Optional[Generator] = None,
        video_name: str = None,
        **kwargs
    ):
        resize_options = ResizeOptions(**kwargs)
        dataset = load_dataset(dataset_loader, dataset_path, resize_options=resize_options, **kwargs)

        if postures_generator is None:
            postures_generator = PosturesModel().generate()
        if video_name is None:
            video_name = dataset.video_names[0]

        features = dataset.features_dataset[video_name]
        self.skeletons = features.skeletons
        self.measurements = features.measurements

        self.output_image_shape = dataset.image_shape

        self.synthetic_dataset = SyntheticDataset(
            frame_preprocessing=dataset.frame_preprocessing,
            output_image_shape=self.output_image_shape,
            enable_random_augmentations=False,
        )
        skel_is_not_nan = ~np.any(np.isnan(self.skeletons), axis=(1, 2))
        self.labelled_indexes = np.where(skel_is_not_nan)[0]
        if len(self.labelled_indexes) == 0:
            raise ValueError("No template frames found in the dataset, can't generate synthetic images.")
        self.frames_dataset = dataset.frames_dataset
        self.video_name = video_name
        self.postures_generator = postures_generator
Esempio n. 2
0
def _write_to_file(
    out_filename: str,
    num_samples: int,
    template_filename: str,
    postures_generation_fn: Callable[[], Generator],
    progress_counter,
    writer: Type[GenericFileWriter],
    synthetic_dataset_args,
    random_seed: Optional[int],
):
    from wormpose.images.synthetic import SyntheticDataset

    seed = int.from_bytes(os.urandom(4), byteorder="little") if random_seed is None else random_seed
    np.random.seed(seed)

    with open(template_filename, "rb") as templates_f:
        templates = pickle.load(templates_f)
    os.remove(template_filename)

    worm_measurements = {}
    for video_name in np.unique(templates.video_names):
        indexes = np.where(templates.video_names == video_name)[0]
        measurements = templates.measurements[indexes]
        average_measurements = np.empty((measurements.shape[1:]), dtype=measurements.dtype)
        for name in measurements.dtype.names:
            average_measurements[0][name] = np.nanmean(measurements[name])
        worm_measurements[video_name] = average_measurements

    synthetic_dataset = SyntheticDataset(**synthetic_dataset_args)

    # preallocate all the template indexes to create the synthetic images (random)
    template_indexes = np.random.randint(len(templates.frames), size=num_samples)

    # preallocate all the choices for the position of the head (random)
    headtail_choice = np.random.choice([0, 1], size=num_samples)

    # preallocate image buffer to draw the synthetic image on
    image_data = np.empty(synthetic_dataset.output_image_shape, dtype=np.uint8)

    postures_gen = postures_generation_fn()
    with writer(out_filename) as synth_data_writer:

        for index, (cur_template_index, cur_headtail_choice) in enumerate(zip(template_indexes, headtail_choice)):
            theta = next(postures_gen)
            label_data = [theta, flip_theta(theta)]
            video_name = templates.video_names[cur_template_index]
            template_frame = templates.frames[cur_template_index]
            template_skeleton = templates.skeletons[cur_template_index]
            template_measurements = worm_measurements[video_name]
            synthetic_dataset.generate(
                theta=label_data[cur_headtail_choice],
                template_frame=template_frame,
                template_skeleton=template_skeleton,
                template_measurements=template_measurements,
                out_image=image_data,
            )
            synth_data_writer.write(locals())
            progress_counter.value = index + 1
    def __init__(self, frame_preprocessing: BaseFramePreprocessing,
                 image_shape):

        self.real_dataset = RealDataset(
            frame_preprocessing=frame_preprocessing,
            output_image_shape=image_shape)

        self.synthetic_dataset = SyntheticDataset(
            frame_preprocessing=frame_preprocessing,
            output_image_shape=image_shape,
            enable_random_augmentations=False,
        )
        self.last_synth_image = np.empty(image_shape, np.uint8)
        self.last_real_image = None