def test_file_path(): """ Test S3IterableDataset for existing and nonexistent path """ # existing path s3_path = 's3://pt-s3plugin-test-data-west2/images/test' s3_dataset = S3IterableDataset(s3_path) assert s3_dataset # non-existent path s3_path_none = 's3://pt-s3plugin-test-data-west2/non_existent_path/test' with pytest.raises(AssertionError) as excinfo: s3_dataset = S3IterableDataset(s3_path_none) assert 'does not contain any objects' in str(excinfo.value)
def test_urls_list(): """ Test whether urls_list input for S3IterableDataset works properly """ os.environ['AWS_REGION'] = 'us-west-2' # provide url prefix (path within bucket) prefix_to_directory = 'images/test' prefix_to_file = 'test_1.JPEG' prefix_list = [prefix_to_directory, prefix_to_file] # set up boto3 s3 = boto3.resource('s3') bucket_name = 'pt-s3plugin-test-data-west2' test_bucket = s3.Bucket(bucket_name) # try individual valid urls and collect url_list and all_boto3_files to test url list input urls_list = list() all_boto3_files = list() for prefix in prefix_list: # collect list of all file names using S3IterableDataset url = os.path.join('s3://', bucket_name, prefix) urls_list.append(url) s3_dataset = S3IterableDataset(url) s3_files = [item[0] for item in s3_dataset] # collect list of all file names using boto3 boto3_files = [os.path.join('s3://', url.bucket_name, url.key) \ for url in test_bucket.objects.filter(Prefix=prefix)] all_boto3_files.extend(boto3_files) assert s3_files == boto3_files # test list of two valid urls as input s3_dataset = S3IterableDataset(urls_list) s3_files = [item[0] for item in s3_dataset] assert s3_files == all_boto3_files # add an non-existent url to list of urls url_to_non_existent = 's3://pt-s3plugin-test-data-west2/non_existent_directory' urls_list.append(url_to_non_existent) with pytest.raises(AssertionError) as excinfo: s3_dataset = S3IterableDataset(urls_list) assert 'does not contain any objects' in str(excinfo.value) del os.environ['AWS_REGION']
def test_zip_file_s3iterabledataset(): s3_dataset_path = 's3://pt-s3plugin-test-data-west2/tiny-imagenet-200.zip' dataset = S3IterableDataset(s3_dataset_path) list_of_files = [] for files in dataset: list_of_files.append(files[0][0]) result1 = len(list_of_files) result2 = get_zip(s3_dataset_path) assert result1 == len(result2)
def test_csv_file_s3iterabledataset(): os.environ['AWS_REGION'] = 'us-east-1' s3_dataset_path = 's3://pt-s3plugin-test-data-east1/genome-scores.csv' dataset = S3IterableDataset(s3_dataset_path) import pandas as pd for files in dataset: result1 = pd.read_csv(io.BytesIO(files[1])) s3 = boto3.client('s3') obj = s3.get_object(Bucket=s3_dataset_path.split('/')[2], Key=s3_dataset_path.split('/')[3]) result2 = pd.read_csv(io.BytesIO(obj['Body'].read())) assert result1.equals(result2) del os.environ['AWS_REGION']
def test_shuffle_true(): """ Tests shuffle_urls parameter, len and set_epoch functions """ os.environ['AWS_REGION'] = 'us-west-2' # create two datasets, one shuffled with self.epoch s3_dataset_path = 's3://pt-s3plugin-test-data-west2/images/test' s3_dataset0 = S3IterableDataset(s3_dataset_path) s3_dataset1 = S3IterableDataset(s3_dataset_path, shuffle_urls=True) s3_dataset1.set_epoch(5) # len is defined as the length of the urls_list created by the path assert len(s3_dataset0) == len(s3_dataset1) # check to make sure shuffling works filenames0 = [item[0] for item in s3_dataset0] filenames1 = [item[0] for item in s3_dataset1] assert len(filenames0) == len(filenames1) assert filenames0 != filenames1 del os.environ['AWS_REGION']
def test_disable_multi_download(): s3_dataset_path = 's3://pt-s3plugin-test-data-east1/genome-scores.csv' os.environ['S3_DISABLE_MULTI_PART_DOWNLOAD'] = "ON" os.environ['AWS_REGION'] = 'us-east-1' dataset = S3IterableDataset(s3_dataset_path) import pandas as pd for files in dataset: result1 = pd.read_csv(io.BytesIO(files[1])) s3 = boto3.client('s3') obj = s3.get_object(Bucket=s3_dataset_path.split('/')[2], Key=s3_dataset_path.split('/')[3]) result2 = pd.read_csv(io.BytesIO(obj['Body'].read())) assert result1.equals(result2) del os.environ['S3_DISABLE_MULTI_PART_DOWNLOAD'], os.environ['AWS_REGION']
def test_ShuffleDataset(): """ Args: bucket: name of the bucket tarfiles_list: list of all tarfiles with the prefix buffer_size: number of files the ShuffleDataset object caches Logic: Loop over the ShuffleDataset Dataloader twice For the runs, the corresponding batches returned should not be the same - ensures that shuffling is happening within tarfile constituents After both the runs, the overall dataloaded should be the same If either of these conditions fails, then test fails """ bucket = "pt-s3plugin-test-data-west2" tarfiles_list = [ "integration_tests/imagenet-train-000000.tar", "integration_tests/imagenet-train-000001.tar" ] url_list = ["s3://" + bucket + "/" + tarfile for tarfile in tarfiles_list] batch_size = 32 buffer_size = 300 for num_workers in [0, 16]: for buffer_size in [30, 300, 3000]: dataset = ShuffleDataset(S3IterableDataset(url_list), buffer_size=buffer_size) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) batch_list1 = get_batches(dataloader) batch_list2 = get_batches(dataloader) assert batches_shuffled( batch_list1, batch_list2), "ShuffleDataset Test fails: batches not shuffled" assert batches_congruent( batch_list1, batch_list2), "ShuffleDataset Test fails: data mismatch" print( "ShuffleDataset test passes for {} buffer_size & {} workers ". format(buffer_size, num_workers))
def test_shuffleurls(): """ Args: bucket : name of the bucket files_prefix : prefix of the location where files stored Logic: Loop over dataloader twice, once with shuffle_urls as True and once as False After both runs, the dataloaded should be the same, the loading order should be different Maintains a dictionary each of sets and lists. The keys of the dictionary is the state of shuffle_urls(True/False) Values are the set/list of the samples Test passes if the set of samples loaded in both cases is same and the list of samples is diffrent(loading order different - data being shuffled) """ bucket = "pt-s3plugin-test-data-west2" files_prefix = "integration_tests/files" assert files_prefix[ -1] != "/", "Enter Prefix without trailing \"/\" else error" prefix_list = get_file_list(bucket, files_prefix) url_list = ["s3://" + bucket + "/" + prefix for prefix in prefix_list] batch_size = 32 shuffled_sets = defaultdict(set) shuffled_lists = defaultdict(list) print("\nINITIATING SHUFFLE TEST") for shuffle_urls in [True, False]: dataset = S3IterableDataset(url_list, shuffle_urls=shuffle_urls) dataloader = DataLoader(dataset, batch_size=batch_size) for fname, fobj in dataloader: fname = [x.split("/")[-1] for x in fname] batch_set = set(map(tuple, zip(fname, fobj))) batch_list = list(map(tuple, zip(fname, fobj))) shuffled_sets[str(shuffle_urls)].update(batch_set) shuffled_lists[str(shuffle_urls)].append(batch_list) assert shuffled_sets['True'] == shuffled_sets['False'] and shuffled_lists['True'] != shuffled_lists['False'], \ "Shuffling not working correctly" print("Shuffle test passed for S3IterableDataset")
def __init__(self, url_list, shuffle_urls=False, transform=None): self.s3_iter_dataset = S3IterableDataset(url_list, shuffle_urls) self.transform = transform
def __init__(self, s3_directory): self.s3_directory = s3_directory self.dataset = S3IterableDataset(self.s3_directory, shuffle_urls=True)