예제 #1
0
    def test_read_parquet_images_tf_dataset(self):
        temp_dir = tempfile.mkdtemp()

        try:
            ParquetDataset.write("file://" + temp_dir,
                                 images_generator(),
                                 images_schema,
                                 block_size=4)
            path = "file://" + temp_dir
            output_types = {
                "id": tf.string,
                "image": tf.string,
                "label": tf.float32
            }
            dataset = read_parquet("tf_dataset",
                                   input_path=path,
                                   output_types=output_types)
            for dt in dataset.take(1):
                print(dt.keys())

            dataloader = read_parquet("dataloader", input_path=path)
            cur_dl = iter(dataloader)
            while True:
                try:
                    print(next(cur_dl)['label'])
                except StopIteration:
                    break

        finally:
            shutil.rmtree(temp_dir)
    def test_read_parquet_images_tf_dataset(self):
        temp_dir = tempfile.mkdtemp()

        try:
            ParquetDataset.write("file://" + temp_dir,
                                 images_generator(),
                                 images_schema,
                                 block_size=4)
            path = "file://" + temp_dir
            output_types = {
                "id": tf.string,
                "image": tf.string,
                "label": tf.float32
            }
            dataset = read_parquet("tf_dataset",
                                   path=path,
                                   output_types=output_types)
            for dt in dataset.take(1):
                print(dt.keys())

            num_shards, rank = 3, 1
            dataset_shard = read_parquet("tf_dataset",
                                         path=path,
                                         config={
                                             "num_shards": num_shards,
                                             "rank": rank
                                         },
                                         output_types=output_types)
            assert len(list(dataset_shard)) <= len(list(dataset)) // num_shards, \
                "len of dataset_shard should be 1/`num_shards` of the whole dataset."

            dataloader = read_parquet("dataloader", path=path)
            dataloader_shard = read_parquet("dataloader",
                                            path=path,
                                            config={
                                                "num_shards": num_shards,
                                                "rank": rank
                                            })
            cur_dl = iter(dataloader_shard)
            cur_count = 0
            while True:
                try:
                    print(next(cur_dl)['label'])
                    cur_count += 1
                except StopIteration:
                    break
            assert cur_count == len(list(dataset_shard))
        finally:
            shutil.rmtree(temp_dir)
예제 #3
0
 def data_creator(config, batch_size):
     dataset = read_parquet("tf_dataset",
                            input_path=path,
                            output_types=output_types,
                            output_shapes=output_shapes)
     dataset = dataset.shuffle(10)
     dataset = dataset.map(lambda data_dict:
                           (data_dict["image"], data_dict["label"]))
     dataset = dataset.map(parse_data_train)
     dataset = dataset.batch(batch_size)
     return dataset
예제 #4
0
 def val_data_creator(config, batch_size):
     val_dataset = read_parquet(format="tf_dataset", path=voc_val_path,
                                output_types=output_types,
                                output_shapes=output_shapes)
     val_dataset = val_dataset.map(
         lambda data_dict: (data_dict["image"], data_dict["label"]))
     val_dataset = val_dataset.map(parse_data_train)
     val_dataset = val_dataset.batch(batch_size)
     val_dataset = val_dataset.map(lambda x, y: (
         transform_images(x, DEFAULT_IMAGE_SIZE),
         transform_targets(y, anchors, anchor_masks, DEFAULT_IMAGE_SIZE)))
     return val_dataset
예제 #5
0
 def train_data_creator(config, batch_size):
     train_dataset = read_parquet(format="tf_dataset", path=voc_train_path,
                                  output_types=output_types,
                                  output_shapes=output_shapes)
     train_dataset = train_dataset.map(
         lambda data_dict: (data_dict["image"], data_dict["label"]))
     train_dataset = train_dataset.map(parse_data_train)
     train_dataset = train_dataset.shuffle(buffer_size=512)
     train_dataset = train_dataset.batch(batch_size)
     train_dataset = train_dataset.map(lambda x, y: (
         transform_images(x, DEFAULT_IMAGE_SIZE),
         transform_targets(y, anchors, anchor_masks, DEFAULT_IMAGE_SIZE)))
     train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
     return train_dataset
예제 #6
0
    def test_read_parquet_images_tf_dataset(self):
        temp_dir = tempfile.mkdtemp()

        try:
            ParquetDataset.write("file://" + temp_dir, images_generator(),
                                 images_schema)
            path = "file://" + temp_dir
            output_types = {
                "id": tf.string,
                "image": tf.string,
                "label": tf.float32
            }
            dataset = read_parquet("tf_dataset",
                                   input_path=path,
                                   output_types=output_types)
            for dt in dataset.take(1):
                print(dt.keys())

        finally:
            shutil.rmtree(temp_dir)