def test_missing_shard_lengths(self): instruction = tfrecords_reader._AbsoluteInstruction( 'train', None, None) with self.assertRaisesWithPredicateMatch( AssertionError, 'S3 tfrecords_reader cannot be used'): tfrecords_reader._get_dataset_files('mnist', '/foo/bar', instruction, {'train': None})
def test_skip(self): # One file is not taken, one file is partially taken. instruction = tfrecords_reader._AbsoluteInstruction('train', 4, None) files = self._get_files(instruction) self.assertEqual(files, [ { 'skip': 1, 'take': -1, 'filename': self.PATH_PATTERN % 1 }, { 'skip': 0, 'take': -1, 'filename': self.PATH_PATTERN % 2 }, { 'skip': 0, 'take': -1, 'filename': self.PATH_PATTERN % 3 }, { 'skip': 0, 'take': -1, 'filename': self.PATH_PATTERN % 4 }, ])
def check_from_ri(self, ri, expected): res = ri.to_absolute(self.splits) expected_result = [] for split_name, from_, to_ in expected: expected_result.append( tfrecords_reader._AbsoluteInstruction(split_name, from_, to_)) self.assertEqual(res, expected_result) return ri
def test_no_skip_no_take(self): instruction = tfrecords_reader._AbsoluteInstruction('train', None, None) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % i, skip=0, take=-1, num_examples=n) for i, n in enumerate([3, 2, 3, 2, 3]) ])
def test_touching_boundaries(self): # Nothing to read. instruction = tfrecords_reader._AbsoluteInstruction('train', 0, 0) files = self._get_files(instruction) self.assertEqual(files, []) instruction = tfrecords_reader._AbsoluteInstruction('train', None, 0) files = self._get_files(instruction) self.assertEqual(files, []) instruction = tfrecords_reader._AbsoluteInstruction('train', 3, 3) files = self._get_files(instruction) self.assertEqual(files, []) instruction = tfrecords_reader._AbsoluteInstruction('train', 13, None) files = self._get_files(instruction) self.assertEqual(files, [])
def test_skip_take1(self): # A single shard with both skip and take. instruction = tfrecords_reader._AbsoluteInstruction('train', 1, 2) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % 0, skip=1, take=1, num_examples=1), ])
def test_no_skip_no_take(self): instruction = tfrecords_reader._AbsoluteInstruction( 'train', None, None) files = self._get_files(instruction) self.assertEqual(files, [{ 'skip': 0, 'take': -1, 'filename': self.PATH_PATTERN % i } for i in range(5)])
def test_skip_take2(self): # 2 elements in across two shards are taken in middle. instruction = tfrecords_reader._AbsoluteInstruction('train', 7, 9) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % 2, skip=2, take=-1, num_examples=1), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 3, skip=0, take=1, num_examples=1), ])
def test_skip_take1(self): # A single shard with both skip and take. instruction = tfrecords_reader._AbsoluteInstruction('train', 1, 2) files = self._get_files(instruction) self.assertEqual(files, [ { 'skip': 1, 'take': 1, 'filename': self.PATH_PATTERN % 0 }, ])
def test_take(self): # Two files are not taken, one file is partially taken. instruction = tfrecords_reader._AbsoluteInstruction('train', None, 6) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % 0, skip=0, take=-1, num_examples=3), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 1, skip=0, take=-1, num_examples=2), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 2, skip=0, take=1, num_examples=1), ])
def test_skip_take2(self): # 2 elements in across two shards are taken in middle. instruction = tfrecords_reader._AbsoluteInstruction('train', 7, 9) files = self._get_files(instruction) self.assertEqual(files, [ { 'skip': 2, 'take': -1, 'filename': self.PATH_PATTERN % 2 }, { 'skip': 0, 'take': 1, 'filename': self.PATH_PATTERN % 3 }, ])