Exemple #1
0
 def test_4fold(self):
     self._write_tfrecord('train', 5, 'abcdefghijkl')
     instructions = [
         tfrecords_reader.ReadInstruction('train',
                                          from_=k,
                                          to=k + 25,
                                          unit='%')
         for k in range(0, 100, 25)
     ]
     tests = self.reader.read('mnist', instructions, self.SPLIT_INFOS)
     instructions = [
         (tfrecords_reader.ReadInstruction('train', to=k, unit='%') +
          tfrecords_reader.ReadInstruction('train', from_=k + 25, unit='%'))
         for k in range(0, 100, 25)
     ]
     trains = self.reader.read('mnist', instructions, self.SPLIT_INFOS)
     read_tests = [list(r) for r in tfds.as_numpy(tests)]
     read_trains = [list(r) for r in tfds.as_numpy(trains)]
     self.assertEqual(read_tests, [[b'a', b'b', b'c'], [b'd', b'e', b'f'],
                                   [b'g', b'h', b'i'], [b'j', b'k', b'l']])
     self.assertEqual(
         read_trains,
         [[b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l'],
          [b'a', b'b', b'c', b'g', b'h', b'i', b'j', b'k', b'l'],
          [b'a', b'b', b'c', b'd', b'e', b'f', b'j', b'k', b'l'],
          [b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i']])
Exemple #2
0
 def test_valid(self):
     # Simple split:
     ri = self.check_from_spec('train', [('train', None, None)])
     self.assertEqual(
         str(ri),
         ("ReadInstruction(["
          "_RelativeInstruction(splitname='train', from_=None, to=None, "
          "unit='abs', rounding='closest')])"))
     self.check_from_spec('test', [('test', None, None)])
     # Addition of splits:
     self.check_from_spec('train+test', [
         ('train', None, None),
         ('test', None, None),
     ])
     # Absolute slicing:
     self.check_from_spec('train[0:0]', [('train', None, 0)])
     self.check_from_spec('train[:10]', [('train', None, 10)])
     self.check_from_spec('train[0:10]', [('train', None, 10)])
     self.check_from_spec('train[-10:]', [('train', 190, None)])
     self.check_from_spec('train[-100:-50]', [('train', 100, 150)])
     self.check_from_spec('train[-10:200]', [('train', 190, None)])
     self.check_from_spec('train[10:-10]', [('train', 10, 190)])
     self.check_from_spec('train[42:99]', [('train', 42, 99)])
     # Percent slicing, closest rounding:
     self.check_from_spec('train[:10%]', [('train', None, 20)])
     self.check_from_spec('train[90%:]', [('train', 180, None)])
     self.check_from_spec('train[-1%:]', [('train', 198, None)])
     ri = self.check_from_spec('test[:99%]', [('test', None, 100)])
     self.assertEqual(str(ri), (
         "ReadInstruction([_RelativeInstruction(splitname='test', from_=None,"
         " to=99, unit='%', rounding='closest')])"))
     # No overlap:
     self.check_from_spec('test[100%:]', [('test', 101, None)])
     # Percent slicing, pct1_dropremainder rounding:
     ri = tfrecords_reader.ReadInstruction('train',
                                           to=20,
                                           unit='%',
                                           rounding='pct1_dropremainder')
     self.check_from_ri(ri, [('train', None, 40)])
     # test split has 101 examples.
     ri = tfrecords_reader.ReadInstruction('test',
                                           to=100,
                                           unit='%',
                                           rounding='pct1_dropremainder')
     self.check_from_ri(ri, [('test', None, 100)])
     # No overlap using 'pct1_dropremainder' rounding:
     ri1 = tfrecords_reader.ReadInstruction('test',
                                            to=99,
                                            unit='%',
                                            rounding='pct1_dropremainder')
     ri2 = tfrecords_reader.ReadInstruction('test',
                                            from_=100,
                                            unit='%',
                                            rounding='pct1_dropremainder')
     self.check_from_ri(ri1, [('test', None, 99)])
     self.check_from_ri(ri2, [('test', 100, None)])
     # Empty:
     # Slices resulting in empty datasets are valid with 'closest' rounding:
     self.check_from_spec('validation[:1%]', [('validation', None, 0)])
 def test_add_invalid(self):
   # Mixed rounding:
   ri1 = tfrecords_reader.ReadInstruction('test', unit='%', to=10,
                                          rounding='pct1_dropremainder')
   ri2 = tfrecords_reader.ReadInstruction('test', unit='%', from_=90,
                                          rounding='closest')
   with self.assertRaisesWithPredicateMatch(AssertionError,
                                            'different rounding'):
     unused_ = ri1 + ri2
 def test_invalid_spec(self):
   # Invalid format:
   self.assertRaises('validation[:250%:2]',
                     'Unrecognized split format: \'validation[:250%:2]\'')
   # Unexisting split:
   self.assertRaises('imaginary', "Unknown split 'imaginary'")
   # Invalid boundaries abs:
   self.assertRaises('validation[:31]', 'incompatible with 30 examples')
   # Invalid boundaries %:
   self.assertRaises('validation[:250%]',
                     'percent slice boundaries should be in [-100, 100]')
   self.assertRaises('validation[-101%:]',
                     'percent slice boundaries should be in [-100, 100]')
   # pct1_dropremainder with < 100 examples
   with self.assertRaisesWithPredicateMatch(
       ValueError, 'with less than 100 elements is forbidden'):
     ri = tfrecords_reader.ReadInstruction(
         'validation', to=99, unit='%', rounding='pct1_dropremainder')
     ri.to_absolute(self.splits)
Exemple #5
0
 def test_invalid_spec(self):
     # Invalid format:
     self.assertRaises(
         'validation[:250%:2]',
         'Unrecognized instruction format: validation[:250%:2]')
     # Unexisting split:
     self.assertRaises('imaginary',
                       'Requested split "imaginary" does not exist')
     # Invalid boundaries abs:
     self.assertRaises('validation[:31]', 'incompatible with 30 examples')
     # Invalid boundaries %:
     self.assertRaises('validation[:250%]',
                       'Percent slice boundaries must be > -100 and < 100')
     self.assertRaises('validation[-101%:]',
                       'Percent slice boundaries must be > -100 and < 100')
     # pct1_dropremainder with < 100 examples
     with self.assertRaisesWithPredicateMatch(
             AssertionError, 'with less than 100 elements is forbidden'):
         ri = tfrecords_reader.ReadInstruction(
             'validation', to=99, unit='%', rounding='pct1_dropremainder')
         ri.to_absolute(self.splits)
Exemple #6
0
    def _absolute_to_read_instruction_for_index(
        self,
        abs_inst,
        split_infos: splits_lib.SplitDict,
    ) -> tfrecords_reader.ReadInstruction:
        start = abs_inst.from_ or 0
        if abs_inst.to is None:  # Note: `abs_inst.to == 0` is valid
            end = split_infos[abs_inst.splitname].num_examples
        else:
            end = abs_inst.to

        assert end >= start, f'start={start}, end={end}'
        num_examples = end - start

        examples_per_host = num_examples // self.count
        shard_start = start + examples_per_host * self.index
        shard_end = start + examples_per_host * (self.index + 1)

        # Handle remaining examples.
        num_unused_examples = num_examples % self.count
        assert num_unused_examples >= 0, num_unused_examples
        assert num_unused_examples < self.count, num_unused_examples
        if num_unused_examples > 0:
            if self.drop_remainder:
                logging.warning(
                    'Dropping %d examples of %d examples (host count: %d).',
                    num_unused_examples, num_examples, self.count)
            else:
                shard_start += min(self.index, num_unused_examples)
                shard_end += min(self.index + 1, num_unused_examples)

        return tfrecords_reader.ReadInstruction(
            abs_inst.splitname,
            from_=shard_start,
            to=shard_end,
            unit='abs',
        )
Exemple #7
0
 def test_invalid_unit(self):
     with self.assertRaisesWithPredicateMatch(ValueError, 'unit'):
         tfrecords_reader.ReadInstruction('test',
                                          unit='kg',
                                          rounding='closest')
Exemple #8
0
 def test_invalid_rounding(self):
     with self.assertRaisesWithPredicateMatch(ValueError, 'rounding'):
         tfrecords_reader.ReadInstruction('test',
                                          unit='%',
                                          rounding='unexisting')