Exemplo n.º 1
0
    def dump_primitive_data(self, primitive, tar_path, config):
        temp_dir = Path(os.environ['TMPDIR'], primitive)

        tf.logging.info('Generating tarfile for primitive {}.'.format(primitive))
        synthetic_dataset.set_random_state(np.random.RandomState(
                config['generation']['random_seed']))
        for split, size in self.config['generation']['split_sizes'].items():
            im_dir, pts_dir = [Path(temp_dir, i, split) for i in ['images', 'points']]
            im_dir.mkdir(parents=True, exist_ok=True)
            pts_dir.mkdir(parents=True, exist_ok=True)

            for i in tqdm(range(size), desc=split, leave=False):
                image = synthetic_dataset.generate_background(
                        config['generation']['image_size'],
                        **config['generation']['params']['generate_background'])
                points = np.array(getattr(synthetic_dataset, primitive)(
                        image, **config['generation']['params'].get(primitive, {})))
                points = np.flip(points, 1)  # reverse convention with opencv

                b = config['preprocessing']['blur_size']
                image = cv2.GaussianBlur(image, (b, b), 0)
                points = (points * np.array(config['preprocessing']['resize'], np.float)
                          / np.array(config['generation']['image_size'], np.float))
                image = cv2.resize(image, tuple(config['preprocessing']['resize'][::-1]),
                                   interpolation=cv2.INTER_LINEAR)

                cv2.imwrite(str(Path(im_dir, '{}.png'.format(i))), image)
                np.save(Path(pts_dir, '{}.npy'.format(i)), points)

        # Pack into a tar file
        tar = tarfile.open(tar_path, mode='w:gz')
        tar.add(temp_dir, arcname=primitive)
        tar.close()
        shutil.rmtree(temp_dir)
        tf.logging.info('Tarfile dumped to {}.'.format(tar_path))
Exemplo n.º 2
0
 def _gen_shape():
     primitives = parse_primitives(config['primitives'], self.drawing_primitives)
     while True:
         primitive = np.random.choice(primitives)
         image = synthetic_dataset.generate_background(
                 config['generation']['image_size'],
                 **config['generation']['params']['generate_background'])
         points = np.array(getattr(synthetic_dataset, primitive)(
                 image, **config['generation']['params'].get(primitive, {})))
         yield (np.expand_dims(image, axis=-1).astype(np.float32),
                np.flip(points.astype(np.float32), 1))