def _gen_cmy_data(self):
   ds = tf.data.Dataset.from_generator(
       self._gen, (tf.uint8, tf.int64), (tf.TensorShape(
           [self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([])))
   data = image_dataloader.ImageClassifierDataLoader(
       ds, self.IMAGES_PER_CLASS * 3, 3, ['cyan', 'magenta', 'yellow'])
   return data
  def test_split(self):
    ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
    data = image_dataloader.ImageClassifierDataLoader(ds, 4, 2, ['pos', 'neg'])
    train_data, test_data = data.split(0.5, shuffle=False)

    self.assertEqual(train_data.size, 2)
    for i, elem in enumerate(train_data.dataset):
      self.assertTrue((elem.numpy() == np.array([i, 1])).all())
    self.assertEqual(train_data.num_classes, 2)
    self.assertEqual(train_data.index_to_label, ['pos', 'neg'])

    self.assertEqual(test_data.size, 2)
    for i, elem in enumerate(test_data.dataset):
      self.assertTrue((elem.numpy() == np.array([i, 0])).all())
    self.assertEqual(test_data.num_classes, 2)
    self.assertEqual(test_data.index_to_label, ['pos', 'neg'])