def test_4fold(self): self._write_tfrecord('train', 5, 'abcdefghijkl') instructions = [ tfrecords_reader.ReadInstruction('train', from_=k, to=k + 25, unit='%') for k in range(0, 100, 25) ] tests = self.reader.read('mnist', instructions, self.SPLIT_INFOS) instructions = [ (tfrecords_reader.ReadInstruction('train', to=k, unit='%') + tfrecords_reader.ReadInstruction('train', from_=k + 25, unit='%')) for k in range(0, 100, 25) ] trains = self.reader.read('mnist', instructions, self.SPLIT_INFOS) read_tests = [list(r) for r in tfds.as_numpy(tests)] read_trains = [list(r) for r in tfds.as_numpy(trains)] self.assertEqual(read_tests, [[b'a', b'b', b'c'], [b'd', b'e', b'f'], [b'g', b'h', b'i'], [b'j', b'k', b'l']]) self.assertEqual( read_trains, [[b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l'], [b'a', b'b', b'c', b'g', b'h', b'i', b'j', b'k', b'l'], [b'a', b'b', b'c', b'd', b'e', b'f', b'j', b'k', b'l'], [b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i']])
def test_valid(self): # Simple split: ri = self.check_from_spec('train', [('train', None, None)]) self.assertEqual( str(ri), ("ReadInstruction([" "_RelativeInstruction(splitname='train', from_=None, to=None, " "unit='abs', rounding='closest')])")) self.check_from_spec('test', [('test', None, None)]) # Addition of splits: self.check_from_spec('train+test', [ ('train', None, None), ('test', None, None), ]) # Absolute slicing: self.check_from_spec('train[0:0]', [('train', None, 0)]) self.check_from_spec('train[:10]', [('train', None, 10)]) self.check_from_spec('train[0:10]', [('train', None, 10)]) self.check_from_spec('train[-10:]', [('train', 190, None)]) self.check_from_spec('train[-100:-50]', [('train', 100, 150)]) self.check_from_spec('train[-10:200]', [('train', 190, None)]) self.check_from_spec('train[10:-10]', [('train', 10, 190)]) self.check_from_spec('train[42:99]', [('train', 42, 99)]) # Percent slicing, closest rounding: self.check_from_spec('train[:10%]', [('train', None, 20)]) self.check_from_spec('train[90%:]', [('train', 180, None)]) self.check_from_spec('train[-1%:]', [('train', 198, None)]) ri = self.check_from_spec('test[:99%]', [('test', None, 100)]) self.assertEqual(str(ri), ( "ReadInstruction([_RelativeInstruction(splitname='test', from_=None," " to=99, unit='%', rounding='closest')])")) # No overlap: self.check_from_spec('test[100%:]', [('test', 101, None)]) # Percent slicing, pct1_dropremainder rounding: ri = tfrecords_reader.ReadInstruction('train', to=20, unit='%', rounding='pct1_dropremainder') self.check_from_ri(ri, [('train', None, 40)]) # test split has 101 examples. ri = tfrecords_reader.ReadInstruction('test', to=100, unit='%', rounding='pct1_dropremainder') self.check_from_ri(ri, [('test', None, 100)]) # No overlap using 'pct1_dropremainder' rounding: ri1 = tfrecords_reader.ReadInstruction('test', to=99, unit='%', rounding='pct1_dropremainder') ri2 = tfrecords_reader.ReadInstruction('test', from_=100, unit='%', rounding='pct1_dropremainder') self.check_from_ri(ri1, [('test', None, 99)]) self.check_from_ri(ri2, [('test', 100, None)]) # Empty: # Slices resulting in empty datasets are valid with 'closest' rounding: self.check_from_spec('validation[:1%]', [('validation', None, 0)])
def test_add_invalid(self): # Mixed rounding: ri1 = tfrecords_reader.ReadInstruction('test', unit='%', to=10, rounding='pct1_dropremainder') ri2 = tfrecords_reader.ReadInstruction('test', unit='%', from_=90, rounding='closest') with self.assertRaisesWithPredicateMatch(AssertionError, 'different rounding'): unused_ = ri1 + ri2
def test_invalid_spec(self): # Invalid format: self.assertRaises('validation[:250%:2]', 'Unrecognized split format: \'validation[:250%:2]\'') # Unexisting split: self.assertRaises('imaginary', "Unknown split 'imaginary'") # Invalid boundaries abs: self.assertRaises('validation[:31]', 'incompatible with 30 examples') # Invalid boundaries %: self.assertRaises('validation[:250%]', 'percent slice boundaries should be in [-100, 100]') self.assertRaises('validation[-101%:]', 'percent slice boundaries should be in [-100, 100]') # pct1_dropremainder with < 100 examples with self.assertRaisesWithPredicateMatch( ValueError, 'with less than 100 elements is forbidden'): ri = tfrecords_reader.ReadInstruction( 'validation', to=99, unit='%', rounding='pct1_dropremainder') ri.to_absolute(self.splits)
def test_invalid_spec(self): # Invalid format: self.assertRaises( 'validation[:250%:2]', 'Unrecognized instruction format: validation[:250%:2]') # Unexisting split: self.assertRaises('imaginary', 'Requested split "imaginary" does not exist') # Invalid boundaries abs: self.assertRaises('validation[:31]', 'incompatible with 30 examples') # Invalid boundaries %: self.assertRaises('validation[:250%]', 'Percent slice boundaries must be > -100 and < 100') self.assertRaises('validation[-101%:]', 'Percent slice boundaries must be > -100 and < 100') # pct1_dropremainder with < 100 examples with self.assertRaisesWithPredicateMatch( AssertionError, 'with less than 100 elements is forbidden'): ri = tfrecords_reader.ReadInstruction( 'validation', to=99, unit='%', rounding='pct1_dropremainder') ri.to_absolute(self.splits)
def _absolute_to_read_instruction_for_index( self, abs_inst, split_infos: splits_lib.SplitDict, ) -> tfrecords_reader.ReadInstruction: start = abs_inst.from_ or 0 if abs_inst.to is None: # Note: `abs_inst.to == 0` is valid end = split_infos[abs_inst.splitname].num_examples else: end = abs_inst.to assert end >= start, f'start={start}, end={end}' num_examples = end - start examples_per_host = num_examples // self.count shard_start = start + examples_per_host * self.index shard_end = start + examples_per_host * (self.index + 1) # Handle remaining examples. num_unused_examples = num_examples % self.count assert num_unused_examples >= 0, num_unused_examples assert num_unused_examples < self.count, num_unused_examples if num_unused_examples > 0: if self.drop_remainder: logging.warning( 'Dropping %d examples of %d examples (host count: %d).', num_unused_examples, num_examples, self.count) else: shard_start += min(self.index, num_unused_examples) shard_end += min(self.index + 1, num_unused_examples) return tfrecords_reader.ReadInstruction( abs_inst.splitname, from_=shard_start, to=shard_end, unit='abs', )
def test_invalid_unit(self): with self.assertRaisesWithPredicateMatch(ValueError, 'unit'): tfrecords_reader.ReadInstruction('test', unit='kg', rounding='closest')
def test_invalid_rounding(self): with self.assertRaisesWithPredicateMatch(ValueError, 'rounding'): tfrecords_reader.ReadInstruction('test', unit='%', rounding='unexisting')