Beispiel #1
0
  def test_text_dataset_from_directory_binary(self):
    directory = self._prepare_directory(num_classes=2)
    dataset = text_dataset.text_dataset_from_directory(
        directory, batch_size=8, label_mode='int', max_length=10)
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8,))
    self.assertEqual(batch[0].dtype.name, 'string')
    self.assertEqual(len(batch[0].numpy()[0]), 10)  # Test max_length
    self.assertEqual(batch[1].shape, (8,))
    self.assertEqual(batch[1].dtype.name, 'int32')

    dataset = text_dataset.text_dataset_from_directory(
        directory, batch_size=8, label_mode='binary')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8,))
    self.assertEqual(batch[0].dtype.name, 'string')
    self.assertEqual(batch[1].shape, (8, 1))
    self.assertEqual(batch[1].dtype.name, 'float32')

    dataset = text_dataset.text_dataset_from_directory(
        directory, batch_size=8, label_mode='categorical')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8,))
    self.assertEqual(batch[0].dtype.name, 'string')
    self.assertEqual(batch[1].shape, (8, 2))
    self.assertEqual(batch[1].dtype.name, 'float32')
Beispiel #2
0
  def test_text_dataset_from_directory_multiclass(self):
    directory = self._prepare_directory(num_classes=4, count=15)

    dataset = text_dataset.text_dataset_from_directory(
        directory, batch_size=8, label_mode=None)
    batch = next(iter(dataset))
    self.assertEqual(batch.shape, (8,))

    dataset = text_dataset.text_dataset_from_directory(
        directory, batch_size=8, label_mode=None)
    sample_count = 0
    iterator = iter(dataset)
    for batch in dataset:
      sample_count += next(iterator).shape[0]
    self.assertEqual(sample_count, 15)

    dataset = text_dataset.text_dataset_from_directory(
        directory, batch_size=8, label_mode='int')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8,))
    self.assertEqual(batch[0].dtype.name, 'string')
    self.assertEqual(batch[1].shape, (8,))
    self.assertEqual(batch[1].dtype.name, 'int32')

    dataset = text_dataset.text_dataset_from_directory(
        directory, batch_size=8, label_mode='categorical')
    batch = next(iter(dataset))
    self.assertLen(batch, 2)
    self.assertEqual(batch[0].shape, (8,))
    self.assertEqual(batch[0].dtype.name, 'string')
    self.assertEqual(batch[1].shape, (8, 4))
    self.assertEqual(batch[1].dtype.name, 'float32')
Beispiel #3
0
    def test_text_dataset_from_directory_errors(self):
        directory = self._prepare_directory(num_classes=3, count=5)

        with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         labels=None)

        with self.assertRaisesRegex(ValueError,
                                    '`label_mode` argument must be'):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         label_mode='other')

        with self.assertRaisesRegex(
                ValueError,
                'only pass `class_names` if the labels are inferred'):
            _ = text_dataset.text_dataset_from_directory(
                directory,
                labels=[0, 0, 1, 1, 1],
                class_names=['class_0', 'class_1', 'class_2'])

        with self.assertRaisesRegex(
                ValueError,
                'Expected the lengths of `labels` to match the number of files'
        ):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         labels=[0, 0, 1, 1])

        with self.assertRaisesRegex(ValueError,
                                    '`class_names` passed did not match'):
            _ = text_dataset.text_dataset_from_directory(
                directory, class_names=['class_0', 'class_2'])

        with self.assertRaisesRegex(ValueError,
                                    'there must exactly 2 classes'):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         label_mode='binary')

        with self.assertRaisesRegex(
                ValueError, '`validation_split` must be between 0 and 1'):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         validation_split=2)

        with self.assertRaisesRegex(ValueError,
                                    '`subset` must be either "training" or'):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         validation_split=0.2,
                                                         subset='other')

        with self.assertRaisesRegex(ValueError,
                                    '`validation_split` must be set'):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         validation_split=0,
                                                         subset='training')

        with self.assertRaisesRegex(ValueError, 'must provide a `seed`'):
            _ = text_dataset.text_dataset_from_directory(directory,
                                                         validation_split=0.2,
                                                         subset='training')
Beispiel #4
0
 def test_text_dataset_from_directory_validation_split(self):
   directory = self._prepare_directory(num_classes=2, count=10)
   dataset = text_dataset.text_dataset_from_directory(
       directory, batch_size=10, validation_split=0.2, subset='training',
       seed=1337)
   batch = next(iter(dataset))
   self.assertLen(batch, 2)
   self.assertEqual(batch[0].shape, (8,))
   dataset = text_dataset.text_dataset_from_directory(
       directory, batch_size=10, validation_split=0.2, subset='validation',
       seed=1337)
   batch = next(iter(dataset))
   self.assertLen(batch, 2)
   self.assertEqual(batch[0].shape, (2,))
Beispiel #5
0
  def test_text_dataset_from_directory_standalone(self):
    # Test retrieving txt files without labels from a directory and its subdirs.
    # Save a few extra files in the parent directory.
    directory = self._prepare_directory(count=7, num_classes=2)
    for i in range(3):
      filename = 'text_%s.txt' % (i,)
      f = open(os.path.join(directory, filename), 'w')
      text = ''.join([random.choice(string.printable) for _ in range(20)])
      f.write(text)
      f.close()

    dataset = text_dataset.text_dataset_from_directory(
        directory, batch_size=5, label_mode=None, max_length=10)
    batch = next(iter(dataset))
    # We just return the texts, no labels
    self.assertEqual(batch.shape, (5,))
    self.assertEqual(batch.dtype.name, 'string')
    # Count samples
    batch_count = 0
    sample_count = 0
    for batch in dataset:
      batch_count += 1
      sample_count += batch.shape[0]
    self.assertEqual(batch_count, 2)
    self.assertEqual(sample_count, 10)
Beispiel #6
0
 def test_text_dataset_from_directory_manual_labels(self):
   directory = self._prepare_directory(num_classes=2, count=2)
   dataset = text_dataset.text_dataset_from_directory(
       directory, batch_size=8, labels=[0, 1], shuffle=False)
   batch = next(iter(dataset))
   self.assertLen(batch, 2)
   self.assertAllClose(batch[1], [0, 1])
Beispiel #7
0
 def test_sample_count(self):
   directory = self._prepare_directory(num_classes=4, count=15)
   dataset = text_dataset.text_dataset_from_directory(
       directory, batch_size=8, label_mode=None)
   sample_count = 0
   for batch in dataset:
     sample_count += batch.shape[0]
   self.assertEqual(sample_count, 15)
Beispiel #8
0
 def test_text_dataset_from_directory_follow_links(self):
   directory = self._prepare_directory(num_classes=2, count=25,
                                       nested_dirs=True)
   dataset = text_dataset.text_dataset_from_directory(
       directory, batch_size=8, label_mode=None, follow_links=True)
   sample_count = 0
   for batch in dataset:
     sample_count += batch.shape[0]
   self.assertEqual(sample_count, 25)
Beispiel #9
0
 def test_text_dataset_from_directory_no_files(self):
   directory = self._prepare_directory(num_classes=2, count=0)
   with self.assertRaisesRegex(ValueError, 'No text files found.'):
     _ = text_dataset.text_dataset_from_directory(directory)
Beispiel #10
0
from tensorflow.keras.layers import Input, Layer, MultiHeadAttention, Dense
from tensorflow.keras.layers import LayerNormalization, Dropout, Embedding, GlobalMaxPooling1D
from tensorflow.keras.models import Model
from tensorflow.python.data.experimental import cardinality
from tensorflow.python.keras.preprocessing.text_dataset import text_dataset_from_directory
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
from tensorflow.python.keras.utils.vis_utils import plot_model

# descargar el dataset desde: https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
# una vez descomprimido, borrar la carpeta 'unsup' (sólo útil para aprendizaje no supervisado)
data_path = os.path.join('data', 'nlp', 'aclImdb')

batch_size = 32
raw_train_ds = text_dataset_from_directory(os.path.join(data_path, 'train'),
                                           batch_size=batch_size,
                                           validation_split=0.2,
                                           subset='training',
                                           seed=42)
raw_val_ds = text_dataset_from_directory(os.path.join(data_path, 'train'),
                                         batch_size=batch_size,
                                         validation_split=0.2,
                                         subset='validation',
                                         seed=42)
raw_test_ds = text_dataset_from_directory(os.path.join(data_path, 'test'),
                                          batch_size=batch_size)

print(f'Number of batches in raw_train_ds: {cardinality(raw_train_ds)}')
print(f'Number of batches in raw_val_ds: {cardinality(raw_val_ds)}')
print(f'Number of batches in raw_test_ds: {cardinality(raw_test_ds)}')

# Imprimimos 5 instancias para ver que aspecto tienen los datos