Esempio n. 1
0
 def load_dataset(self):
     ds = TextLineDataset(str(pathlib.Path(self.log_dir, 'file_names.txt')))
     ds = ds.take(5)
     ds = ds.map(self.parse_svg_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
     ds = ds.padded_batch(2, drop_remainder=True)
     
     return ds
    def test_pipeline(self, num_threads):
        real_fname = os.path.join(self.dataset_path, 'test_real.txt')

        # extract directories
        real_dir, inst_dir = self.real_dir, self.inst_dir

        # count lines
        num_real = count_lines(real_fname)

        # dataset creation
        with tf.name_scope('dataset'):
            real = TextLineDataset(real_fname)

            # @see https://www.tensorflow.org/api_docs/python/tf/contrib/data/shuffle_and_repeat
            #synt.apply(shuffle_and_repeat(buffer_size = num_synt)) #, count = 1))
            #real.apply(shuffle_and_repeat(buffer_size = num_real)) #, count = ceil(ratio)))

            real = real.shuffle(num_real) # no repetition! .repeat()

            # real data only
            augment = 0 # self.params.get('augment', 0)
            def name2real(name):
                inst = read_instr(os.path.join(inst_dir, name.decode() + '.png'))
                if augment:
                    src_dir = self.params.get('augment_src', 'best')
                    # print('{}/{}/{}'.format(real_dir, str(src_dir), name.decode() + '.JPG'))
                    full = read_image(os.path.join(real_dir, str(src_dir), 'rgb', name.decode() + '.jpg'), False)
                    pnts = read_points(os.path.join(real_dir, str(src_dir), 'points', name.decode() + '.txt'))
                    if isinstance(src_dir, float):
                        pnts *= src_dir
                    self.params['augment_scale'] = 0.
                    real = random_crop(full, pnts, self.params)
                else:
                    real = read_image(os.path.join(real_dir, '160x160', 'gray', name.decode() + '.jpg'))
                return real, inst, name.decode()
            real = real.map(lambda name: tuple(tf.py_func(name2real, [name], [tf.float32, tf.int32, tf.string])), num_parallel_calls = num_threads)

            #dataset = Dataset.zip((rend, xfer, real, inst_synt, inst_real))
            dataset = Dataset.zip({ 'real': real })
            dataset = dataset.batch(self.batch_size, drop_remainder = True) # we need full batches!
            dataset = dataset.prefetch(self.batch_size * 2)
            return dataset