Exemplo n.º 1
0
def read_as_tfdataset(path,
                      output_types,
                      config=None,
                      output_shapes=None,
                      *args,
                      **kwargs):
    """
    return a orca.data.tf.data.Dataset
    :param path:
    :return:
    """
    path, _ = pa_fs(path)
    import tensorflow as tf

    schema_path = os.path.join(path, "_orca_metadata")
    j_str = open_text(schema_path)[0]
    schema = decode_schema(j_str)

    row_group = []

    for root, dirs, files in os.walk(path):
        for name in dirs:
            if name.startswith("chunk="):
                chunk_path = os.path.join(path, name)
                row_group.append(chunk_path)

    dataset = ParquetIterable(row_group=row_group,
                              schema=schema,
                              num_shards=config.get("num_shards"),
                              rank=config.get("rank"))

    return tf.data.Dataset.from_generator(dataset,
                                          output_types=output_types,
                                          output_shapes=output_shapes)
Exemplo n.º 2
0
 def test_open_s3_text(self):
     access_key_id = os.getenv("AWS_ACCESS_KEY_ID")
     secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")
     if access_key_id and secret_access_key:
         file_path = "s3://analytics-zoo-data/hyperseg/trainingData/train_tiled.txt"
         lines = open_text(file_path)
         assert lines[0] == "CONTENTAI_000001"
Exemplo n.º 3
0
 def test_write_text_local_2(self):
     temp = tempfile.mkdtemp()
     path = os.path.join(temp, "test.txt")
     write_text("file://" + path, "abc\n")
     text = open_text("file://" + path)
     shutil.rmtree(temp)
     assert text == ['abc']
Exemplo n.º 4
0
    def _read_as_dict_rdd(path):
        sc = SparkContext.getOrCreate()
        spark = SparkSession(sc)

        df = spark.read.parquet(path)
        schema_path = os.path.join(path, "_orca_metadata")

        j_str = open_text(schema_path)[0]

        schema = decode_schema(j_str)

        rdd = df.rdd.map(lambda r: row_to_dict(schema, r))
        return rdd, schema
Exemplo n.º 5
0
 def test_write_text_s3(self):
     access_key_id = os.getenv("AWS_ACCESS_KEY_ID")
     secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")
     if access_key_id and secret_access_key:
         file_path = "s3://analytics-zoo-data/test.txt"
         text = 'abc\ndef\n'
         write_text(file_path, text)
         lines = open_text(file_path)
         assert lines == ['abc', 'def']
         import boto3
         s3_client = boto3.Session(
             aws_access_key_id=access_key_id,
             aws_secret_access_key=secret_access_key).client('s3',
                                                             verify=False)
         s3_client.delete_object(Bucket='analytics-zoo-data',
                                 Key='test.txt')
Exemplo n.º 6
0
def read_as_dataloader(path,
                       config=None,
                       transforms=None,
                       batch_size=1,
                       *args,
                       **kwargs):
    path, _ = pa_fs(path)
    import tensorflow as tf
    import torch

    schema_path = os.path.join(path, "_orca_metadata")
    j_str = open_text(schema_path)[0]
    schema = decode_schema(j_str)

    row_group = []

    for root, dirs, files in os.walk(path):
        for name in dirs:
            if name.startswith("chunk="):
                chunk_path = os.path.join(path, name)
                row_group.append(chunk_path)

    class ParquetIterableDataset(torch.utils.data.IterableDataset):
        def __init__(self,
                     row_group,
                     schema,
                     num_shards=None,
                     rank=None,
                     transforms=None):
            super().__init__()
            self.iterator = ParquetIterable(row_group, schema, num_shards,
                                            rank, transforms)
            self.cur = self.iterator.cur
            self.cur_tail = self.iterator.cur_tail

        def __iter__(self):
            return self.iterator.__iter__()

        def __next__(self):
            self.iterator.__next__()

    def worker_init_fn(w_id):
        worker_info = torch.utils.data.get_worker_info()
        dataset = worker_info.dataset
        iter_start = dataset.cur
        iter_end = dataset.cur_tail
        per_worker = int(
            math.ceil(iter_end - iter_start / float(worker_info.num_workers)))
        w_id = worker_info.id
        dataset.cur = iter_start + w_id * per_worker
        dataset.cur_tail = min(dataset.cur + per_worker, iter_end)

    dataset = ParquetIterableDataset(row_group=row_group,
                                     schema=schema,
                                     num_shards=config.get("num_shards"),
                                     rank=config.get("rank"),
                                     transforms=transforms)

    return torch.utils.data.DataLoader(dataset,
                                       num_workers=config.get(
                                           "num_workers", 0),
                                       batch_size=batch_size,
                                       worker_init_fn=worker_init_fn)
Exemplo n.º 7
0
 def test_open_local_text_2(self):
     file_path = os.path.join(self.resource_path, "qa/relations.txt")
     lines = open_text("file://" + file_path)
     assert lines == ["Q1,Q1,1", "Q1,Q2,0", "Q2,Q1,0", "Q2,Q2,1"]