Exemplo n.º 1
0
    def test_text_dataset_from_directory_binary(self):
        directory = self._prepare_directory(num_classes=2)
        dataset = text_dataset.text_dataset_from_directory(directory,
                                                           batch_size=8,
                                                           label_mode='int',
                                                           max_length=10)
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, ))
        self.assertEqual(batch[0].dtype.name, 'string')
        self.assertEqual(len(batch[0].numpy()[0]), 10)  # Test max_length
        self.assertEqual(batch[1].shape, (8, ))
        self.assertEqual(batch[1].dtype.name, 'int32')

        dataset = text_dataset.text_dataset_from_directory(directory,
                                                           batch_size=8,
                                                           label_mode='binary')
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, ))
        self.assertEqual(batch[0].dtype.name, 'string')
        self.assertEqual(batch[1].shape, (8, 1))
        self.assertEqual(batch[1].dtype.name, 'float32')

        dataset = text_dataset.text_dataset_from_directory(
            directory, batch_size=8, label_mode='categorical')
        batch = next(iter(dataset))
        self.assertLen(batch, 2)
        self.assertEqual(batch[0].shape, (8, ))
        self.assertEqual(batch[0].dtype.name, 'string')
        self.assertEqual(batch[1].shape, (8, 2))
        self.assertEqual(batch[1].dtype.name, 'float32')
Exemplo n.º 2
0
  def test_text_dataset_from_directory_multiclass(self):
    directory = self._prepare_directory(num_classes=4, count=15)

    dataset = text_dataset.text_dataset_from_directory(
        directory, batch_size=8, label_mode=None)
    batch = next(iter(dataset))
    self.assertEqual(batch.shape, (8,))

    dataset = text_dataset.text_dataset_from_directory(
        directory, batch_size=8, label_mode=None)
    sample_count = 0
    iterator = iter(dataset)
    for batch in dataset:
      sample_count += next(iterator).shape[0]
    self.assertEqual(sample_count, 15)

    dataset = text_dataset.text_dataset_from_directory(
        directory, batch_size=8, label_mode='int')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8,))
    self.assertEqual(batch[0].dtype.name, 'string')
    self.assertEqual(batch[1].shape, (8,))
    self.assertEqual(batch[1].dtype.name, 'int32')

    dataset = text_dataset.text_dataset_from_directory(
        directory, batch_size=8, label_mode='categorical')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8,))
    self.assertEqual(batch[0].dtype.name, 'string')
    self.assertEqual(batch[1].shape, (8, 4))
    self.assertEqual(batch[1].dtype.name, 'float32')
Exemplo n.º 3
0
    def test_text_dataset_from_directory_errors(self):
        directory = self._prepare_directory(num_classes=3, count=5)

        with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         labels='other')

        with self.assertRaisesRegex(ValueError,
                                    '`label_mode` argument must be'):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         label_mode='other')

        with self.assertRaisesRegex(
                ValueError,
                'only pass `class_names` if the labels are inferred'):
            _ = text_dataset.text_dataset_from_directory(
                directory,
                labels=[0, 0, 1, 1, 1],
                class_names=['class_0', 'class_1', 'class_2'])

        with self.assertRaisesRegex(
                ValueError,
                'Expected the lengths of `labels` to match the number of files'
        ):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         labels=[0, 0, 1, 1])

        with self.assertRaisesRegex(ValueError,
                                    '`class_names` passed did not match'):
            _ = text_dataset.text_dataset_from_directory(
                directory, class_names=['class_0', 'class_2'])

        with self.assertRaisesRegex(ValueError,
                                    'there must exactly 2 classes'):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         label_mode='binary')

        with self.assertRaisesRegex(
                ValueError, '`validation_split` must be between 0 and 1'):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         validation_split=2)

        with self.assertRaisesRegex(ValueError,
                                    '`subset` must be either "training" or'):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         validation_split=0.2,
                                                         subset='other')

        with self.assertRaisesRegex(ValueError,
                                    '`validation_split` must be set'):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         validation_split=0,
                                                         subset='training')

        with self.assertRaisesRegex(ValueError, 'must provide a `seed`'):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         validation_split=0.2,
                                                         subset='training')
Exemplo n.º 4
0
 def test_text_dataset_from_directory_validation_split(self):
   directory = self._prepare_directory(num_classes=2, count=10)
   dataset = text_dataset.text_dataset_from_directory(
       directory, batch_size=10, validation_split=0.2, subset='training',
       seed=1337)
   batch = next(iter(dataset))
   self.assertLen(batch, 2)
   self.assertEqual(batch[0].shape, (8,))
   dataset = text_dataset.text_dataset_from_directory(
       directory, batch_size=10, validation_split=0.2, subset='validation',
       seed=1337)
   batch = next(iter(dataset))
   self.assertLen(batch, 2)
   self.assertEqual(batch[0].shape, (2,))
Exemplo n.º 5
0
    def test_text_dataset_from_directory_standalone(self):
        # Test retrieving txt files without labels from a directory and its subdirs.
        # Save a few extra files in the parent directory.
        directory = self._prepare_directory(count=7, num_classes=2)
        for i in range(3):
            filename = 'text_%s.txt' % (i, )
            f = open(os.path.join(directory, filename), 'w')
            text = ''.join(
                [random.choice(string.printable) for _ in range(20)])
            f.write(text)
            f.close()

        dataset = text_dataset.text_dataset_from_directory(directory,
                                                           batch_size=5,
                                                           label_mode=None,
                                                           max_length=10)
        batch = next(iter(dataset))
        # We just return the texts, no labels
        self.assertEqual(batch.shape, (5, ))
        self.assertEqual(batch.dtype.name, 'string')
        # Count samples
        batch_count = 0
        sample_count = 0
        for batch in dataset:
            batch_count += 1
            sample_count += batch.shape[0]
        self.assertEqual(batch_count, 2)
        self.assertEqual(sample_count, 10)
Exemplo n.º 6
0
 def test_text_dataset_from_directory_manual_labels(self):
   directory = self._prepare_directory(num_classes=2, count=2)
   dataset = text_dataset.text_dataset_from_directory(
       directory, batch_size=8, labels=[0, 1], shuffle=False)
   batch = next(iter(dataset))
   self.assertLen(batch, 2)
   self.assertAllClose(batch[1], [0, 1])
Exemplo n.º 7
0
 def test_sample_count(self):
   directory = self._prepare_directory(num_classes=4, count=15)
   dataset = text_dataset.text_dataset_from_directory(
       directory, batch_size=8, label_mode=None)
   sample_count = 0
   for batch in dataset:
     sample_count += batch.shape[0]
   self.assertEqual(sample_count, 15)
Exemplo n.º 8
0
    def test_text_dataset_from_directory_not_batched(self):
        directory = self._prepare_directory()
        dataset = text_dataset.text_dataset_from_directory(directory,
                                                           batch_size=None,
                                                           label_mode=None,
                                                           follow_links=True)

        sample = next(iter(dataset))
        self.assertEqual(len(sample.shape), 0)
Exemplo n.º 9
0
 def test_text_dataset_from_directory_follow_links(self):
   directory = self._prepare_directory(num_classes=2, count=25,
                                       nested_dirs=True)
   dataset = text_dataset.text_dataset_from_directory(
       directory, batch_size=8, label_mode=None, follow_links=True)
   sample_count = 0
   for batch in dataset:
     sample_count += batch.shape[0]
   self.assertEqual(sample_count, 25)
Exemplo n.º 10
0
 def test_text_dataset_from_directory_no_files(self):
     directory = self._prepare_directory(num_classes=2, count=0)
     with self.assertRaisesRegex(ValueError, 'No text files found.'):
         _ = text_dataset.text_dataset_from_directory(directory)