def txest_basic_functionality(self): num_batches = 13 batch_labels = 75 # note: these settings imply a few iterations through the chunks # basic operation, should not crash bg = BucketedReadaheadBatchIterator( chunked_dataset_iterator(self.chunk_file_paths, self.read_chunk, shuffle=True, buffer_size=1000, seed=1), read_ahead=100, seed=1, key=lambda line: len(line), batch_size=lambda line: batch_labels // (1+len(line))) batches1 = list(itertools.islice(bg, num_batches)) # verify determinism bg = BucketedReadaheadBatchIterator( chunked_dataset_iterator(self.chunk_file_paths, self.read_chunk, shuffle=True, buffer_size=1000, seed=1), read_ahead=100, seed=1, key=lambda line: len(line), batch_size=lambda line: batch_labels // (1+len(line))) batches2 = list(itertools.islice(bg, num_batches)) print([(len(batch[0]), len(batch)) for batch in batches1]) self.assertListEqual(batches1, batches2)
def test_no_shuffle(self): items = list( itertools.islice( chunked_dataset_iterator(self.chunk_file_paths, self.read_chunk, shuffle=False, buffer_size=1000), len(self.flattened_test_data), )) self.assertListEqual(items, self.flattened_test_data)
def test_two_instances(self): dataset0 = chunked_dataset_iterator(self.chunk_file_paths, self.read_chunk, shuffle=False, buffer_size=1000, num_instances=2, instance_rank=0) dataset1 = chunked_dataset_iterator(self.chunk_file_paths, self.read_chunk, shuffle=False, buffer_size=1000, num_instances=2, instance_rank=1) items0 = list( itertools.islice(dataset0, len(self.test_data[0]) + len(self.test_data[2]))) items1 = list( itertools.islice(dataset1, len(self.test_data[1]) + len(self.test_data[3]))) self.assertMultisetEqual(set(items0 + items1), self.flattened_test_data)
def test_transform(self): transform = lambda s: s + "!" modified_test_data = [transform(s) for s in self.flattened_test_data] items = list( itertools.islice( chunked_dataset_iterator(self.chunk_file_paths, self.read_chunk, shuffle=False, buffer_size=1000, transform=transform), len(self.flattened_test_data), )) self.assertListEqual(items, modified_test_data)
def test_other_files_present(self): with open(os.path.join(self.data_dir, "i_do_not_belong_here.txt"), "w") as f: f.write("really ...") items = list( itertools.islice( chunked_dataset_iterator(self.chunk_file_paths, self.read_chunk, shuffle=False, buffer_size=1000), len(self.flattened_test_data), )) self.assertListEqual(items, self.flattened_test_data)
def test_checkpointing(self): random = Random(1) for use_windowed in (True, False): for i in range(2): first_length = random.randrange(11,21) extra_length = random.randrange(11,21) dataset = chunked_dataset_iterator(self.chunk_file_paths, self.read_chunk, shuffle=(i % 2 == 0), buffer_size=1000, seed=i, num_instances=2, instance_rank=0, use_windowed=use_windowed) for _ in range(first_length): next(dataset) checkpoint = dataset.getstate() items1 = list(itertools.islice(dataset, extra_length)) dataset.setstate(checkpoint) items2 = list(itertools.islice(dataset, extra_length)) self.assertListEqual(items1, items2)
def test_checkpointing(self): first_batches = 12 extra_batches = 7 batch_labels = 123 bg = BucketedReadaheadBatchIterator( chunked_dataset_iterator(self.chunk_file_paths, self.read_chunk, shuffle=True, buffer_size=1000, seed=1), read_ahead=100, seed=1, key=lambda line: len(line), batch_size=lambda line: batch_labels // (1+len(line))) _ = list(itertools.islice(bg, first_batches)) checkpoint = bg.getstate() batches1 = list(itertools.islice(bg, extra_batches)) bg.setstate(checkpoint) batches2 = list(itertools.islice(bg, extra_batches)) self.assertListEqual(batches1, batches2)
def __init__(self, paths: Union[str, Iterable[str]], shuffle: bool = True, buffer_size: int = 2**20, transform=None, seed: int = None, world_size: int = 1, rank: int = 0, num_workers_per_rank: int = 1): super().__init__() self.rank = rank self.num_workers_per_rank = num_workers_per_rank # instance_rank is set assuming that num_workers_per_rank = 1 and adapted dynamically in __iter__ self.dataset = chunked_dataset_iterator(paths, shuffle=shuffle, buffer_size=buffer_size, transform=transform, seed=seed, num_instances=world_size * num_workers_per_rank, instance_rank=rank)
container_uri = "https://" + account + ".blob.core.windows.net/" + container container_client = ContainerClient.from_container_url(container_uri, credential=_get_azure_key(account, credentials)) if not blob_path.endswith("/"): blob_path += "/" blob_uris = [container_uri + "/" + blob["name"] for blob in container_client.walk_blobs(blob_path, delimiter="") if (ext is None or blob["name"].endswith(ext))] print("enumerate_files:", len(blob_uris), "blobs found", file=sys.stderr, flush=True) for blob_name in blob_uris[:10]: print(blob_name, file=sys.stderr, flush=True) return blob_uris if sys.argv[1] == "--azure-storage-key": credential = sys.argv[2] paths = sys.argv[3:] else: credential = None paths = sys.argv[1:] chunk_file_paths = [ # enumerate all .gz files in the given paths subpath for path in paths for subpath in enumerate_files(path, '.gz', credential) ] chunk_file_paths.sort() # make sure file order is always the same, independent of OS print("block_randomize: reading from", len(chunk_file_paths), "chunk files", file=sys.stderr) ds = chunked_dataset_iterator(chunk_refs=chunk_file_paths, read_chunk_fn=lambda path: read_utf8_file(path, credential), shuffle=True, buffer_size=1000000, seed=1, use_windowed=True) for line in ds: print(line)
#!/usr/bin/python3.6 # simple command-line wrapper around BucketedReadaheadBatchIterator on a IterableChunkedDataset # Example: # block_randomize_and_batch my_chunked_data import os, sys, inspect from infinibatch.datasets import chunked_dataset_iterator from infinibatch.iterators import BucketedReadaheadBatchIterator sets = sys.argv[1:] ds = chunked_dataset_iterator(sets, shuffle=True, buffer_size=10000000, seed=1) batch_labels = 500 bg = BucketedReadaheadBatchIterator(ds, read_ahead=100, key=lambda line: len(line), batch_size=lambda line: batch_labels // (1+len(line)), seed=1) for batch in bg: print(f"\n---- size {len(batch)} ---\n") print("\n".join(batch))