def test_touching_boundaries(self): # Nothing to read. instruction = splits._AbsoluteInstruction('train', 0, 0) files = self._get_files(instruction) self.assertEqual(files, []) instruction = splits._AbsoluteInstruction('train', None, 0) files = self._get_files(instruction) self.assertEqual(files, []) instruction = splits._AbsoluteInstruction('train', 3, 3) files = self._get_files(instruction) self.assertEqual(files, []) instruction = splits._AbsoluteInstruction('train', 13, None) files = self._get_files(instruction) self.assertEqual(files, [])
def test_skip_take1(self): # A single shard with both skip and take. instruction = splits._AbsoluteInstruction('train', 1, 2) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % 0, skip=1, take=1, num_examples=1), ])
def test_no_skip_no_take(self): instruction = splits._AbsoluteInstruction('train', None, None) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % i, skip=0, take=-1, num_examples=n) for i, n in enumerate([3, 2, 3, 2, 3]) ])
def check_from_ri(self, ri, expected): res = ri.to_absolute(self.splits) expected_result = [] for split_name, from_, to_ in expected: expected_result.append( splits._AbsoluteInstruction(split_name, from_, to_)) self.assertEqual(res, expected_result) return ri
def test_skip_take2(self): # 2 elements in across two shards are taken in middle. instruction = splits._AbsoluteInstruction('train', 7, 9) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % 2, skip=2, take=-1, num_examples=1), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 3, skip=0, take=1, num_examples=1), ])
def test_take(self): # Two files are not taken, one file is partially taken. instruction = splits._AbsoluteInstruction('train', None, 6) files = self._get_files(instruction) self.assertEqual(files, [ shard_utils.FileInstruction( filename=self.PATH_PATTERN % 0, skip=0, take=-1, num_examples=3), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 1, skip=0, take=-1, num_examples=2), shard_utils.FileInstruction( filename=self.PATH_PATTERN % 2, skip=0, take=1, num_examples=1), ])