예제 #1
0
 def txest_basic_functionality(self):
     num_batches = 13
     batch_labels = 75  # note: these settings imply a few iterations through the chunks
     # basic operation, should not crash
     bg = BucketedReadaheadBatchIterator(
         chunked_dataset_iterator(self.chunk_file_paths, self.read_chunk, shuffle=True, buffer_size=1000, seed=1),
         read_ahead=100, seed=1,
         key=lambda line: len(line),
         batch_size=lambda line: batch_labels // (1+len(line)))
     batches1 = list(itertools.islice(bg, num_batches))
     # verify determinism
     bg = BucketedReadaheadBatchIterator(
         chunked_dataset_iterator(self.chunk_file_paths, self.read_chunk, shuffle=True, buffer_size=1000, seed=1),
         read_ahead=100, seed=1,
         key=lambda line: len(line),
         batch_size=lambda line: batch_labels // (1+len(line)))
     batches2 = list(itertools.islice(bg, num_batches))
     print([(len(batch[0]), len(batch)) for batch in batches1])
     self.assertListEqual(batches1, batches2)
예제 #2
0
 def test_no_shuffle(self):
     items = list(
         itertools.islice(
             chunked_dataset_iterator(self.chunk_file_paths,
                                      self.read_chunk,
                                      shuffle=False,
                                      buffer_size=1000),
             len(self.flattened_test_data),
         ))
     self.assertListEqual(items, self.flattened_test_data)
예제 #3
0
 def test_two_instances(self):
     dataset0 = chunked_dataset_iterator(self.chunk_file_paths,
                                         self.read_chunk,
                                         shuffle=False,
                                         buffer_size=1000,
                                         num_instances=2,
                                         instance_rank=0)
     dataset1 = chunked_dataset_iterator(self.chunk_file_paths,
                                         self.read_chunk,
                                         shuffle=False,
                                         buffer_size=1000,
                                         num_instances=2,
                                         instance_rank=1)
     items0 = list(
         itertools.islice(dataset0,
                          len(self.test_data[0]) + len(self.test_data[2])))
     items1 = list(
         itertools.islice(dataset1,
                          len(self.test_data[1]) + len(self.test_data[3])))
     self.assertMultisetEqual(set(items0 + items1),
                              self.flattened_test_data)
예제 #4
0
 def test_transform(self):
     transform = lambda s: s + "!"
     modified_test_data = [transform(s) for s in self.flattened_test_data]
     items = list(
         itertools.islice(
             chunked_dataset_iterator(self.chunk_file_paths,
                                      self.read_chunk,
                                      shuffle=False,
                                      buffer_size=1000,
                                      transform=transform),
             len(self.flattened_test_data),
         ))
     self.assertListEqual(items, modified_test_data)
예제 #5
0
 def test_other_files_present(self):
     with open(os.path.join(self.data_dir, "i_do_not_belong_here.txt"),
               "w") as f:
         f.write("really ...")
     items = list(
         itertools.islice(
             chunked_dataset_iterator(self.chunk_file_paths,
                                      self.read_chunk,
                                      shuffle=False,
                                      buffer_size=1000),
             len(self.flattened_test_data),
         ))
     self.assertListEqual(items, self.flattened_test_data)
예제 #6
0
 def test_checkpointing(self):
     random = Random(1)
     for use_windowed in (True, False):
         for i in range(2):
             first_length = random.randrange(11,21)
             extra_length = random.randrange(11,21)
             dataset = chunked_dataset_iterator(self.chunk_file_paths, self.read_chunk, shuffle=(i % 2 == 0), buffer_size=1000, seed=i, num_instances=2, instance_rank=0, use_windowed=use_windowed)
             for _ in range(first_length):
                 next(dataset)
             checkpoint = dataset.getstate()
             items1 = list(itertools.islice(dataset, extra_length))
             dataset.setstate(checkpoint)
             items2 = list(itertools.islice(dataset, extra_length))
             self.assertListEqual(items1, items2)
예제 #7
0
 def test_checkpointing(self):
     first_batches = 12
     extra_batches = 7
     batch_labels = 123
     bg = BucketedReadaheadBatchIterator(
         chunked_dataset_iterator(self.chunk_file_paths, self.read_chunk, shuffle=True, buffer_size=1000, seed=1),
         read_ahead=100, seed=1,
         key=lambda line: len(line),
         batch_size=lambda line: batch_labels // (1+len(line)))
     _ = list(itertools.islice(bg, first_batches))
     checkpoint = bg.getstate()
     batches1 = list(itertools.islice(bg, extra_batches))
     bg.setstate(checkpoint)
     batches2 = list(itertools.islice(bg, extra_batches))
     self.assertListEqual(batches1, batches2)
예제 #8
0
 def __init__(self,
              paths: Union[str, Iterable[str]],
              shuffle: bool = True,
              buffer_size: int = 2**20,
              transform=None,
              seed: int = None,
              world_size: int = 1,
              rank: int = 0,
              num_workers_per_rank: int = 1):
     super().__init__()
     self.rank = rank
     self.num_workers_per_rank = num_workers_per_rank
     # instance_rank is set assuming that num_workers_per_rank = 1 and adapted dynamically in __iter__
     self.dataset = chunked_dataset_iterator(paths,
                                             shuffle=shuffle,
                                             buffer_size=buffer_size,
                                             transform=transform,
                                             seed=seed,
                                             num_instances=world_size *
                                             num_workers_per_rank,
                                             instance_rank=rank)
예제 #9
0
        container_uri = "https://" + account + ".blob.core.windows.net/" + container
        container_client = ContainerClient.from_container_url(container_uri, credential=_get_azure_key(account, credentials))
        if not blob_path.endswith("/"):
            blob_path += "/"
        blob_uris = [container_uri + "/" + blob["name"] for blob in container_client.walk_blobs(blob_path, delimiter="") if (ext is None or blob["name"].endswith(ext))]
        print("enumerate_files:", len(blob_uris), "blobs found", file=sys.stderr, flush=True)
        for blob_name in blob_uris[:10]:
            print(blob_name, file=sys.stderr, flush=True)
        return blob_uris


if sys.argv[1] == "--azure-storage-key":
    credential = sys.argv[2]
    paths = sys.argv[3:]
else:
    credential = None
    paths = sys.argv[1:]

chunk_file_paths = [  # enumerate all .gz files in the given paths
    subpath
    for path in paths
    for subpath in enumerate_files(path, '.gz', credential)
]
chunk_file_paths.sort()  # make sure file order is always the same, independent of OS
print("block_randomize: reading from", len(chunk_file_paths), "chunk files", file=sys.stderr)

ds = chunked_dataset_iterator(chunk_refs=chunk_file_paths, read_chunk_fn=lambda path: read_utf8_file(path, credential),
                              shuffle=True, buffer_size=1000000, seed=1, use_windowed=True)
for line in ds:
    print(line)
예제 #10
0
#!/usr/bin/python3.6

# simple command-line wrapper around BucketedReadaheadBatchIterator on a IterableChunkedDataset
# Example:
#   block_randomize_and_batch my_chunked_data

import os, sys, inspect

from infinibatch.datasets import chunked_dataset_iterator
from infinibatch.iterators import BucketedReadaheadBatchIterator

sets = sys.argv[1:]

ds = chunked_dataset_iterator(sets, shuffle=True, buffer_size=10000000, seed=1)
batch_labels = 500
bg = BucketedReadaheadBatchIterator(ds, read_ahead=100, key=lambda line: len(line), batch_size=lambda line: batch_labels // (1+len(line)), seed=1)
for batch in bg:
    print(f"\n---- size {len(batch)} ---\n")
    print("\n".join(batch))