def test_text_dataset_from_directory_binary(self): directory = self._prepare_directory(num_classes=2) dataset = text_dataset.text_dataset_from_directory( directory, batch_size=8, label_mode='int', max_length=10) batch = next(iter(dataset)) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (8,)) self.assertEqual(batch[0].dtype.name, 'string') self.assertEqual(len(batch[0].numpy()[0]), 10) # Test max_length self.assertEqual(batch[1].shape, (8,)) self.assertEqual(batch[1].dtype.name, 'int32') dataset = text_dataset.text_dataset_from_directory( directory, batch_size=8, label_mode='binary') batch = next(iter(dataset)) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (8,)) self.assertEqual(batch[0].dtype.name, 'string') self.assertEqual(batch[1].shape, (8, 1)) self.assertEqual(batch[1].dtype.name, 'float32') dataset = text_dataset.text_dataset_from_directory( directory, batch_size=8, label_mode='categorical') batch = next(iter(dataset)) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (8,)) self.assertEqual(batch[0].dtype.name, 'string') self.assertEqual(batch[1].shape, (8, 2)) self.assertEqual(batch[1].dtype.name, 'float32')
def test_text_dataset_from_directory_multiclass(self): directory = self._prepare_directory(num_classes=4, count=15) dataset = text_dataset.text_dataset_from_directory( directory, batch_size=8, label_mode=None) batch = next(iter(dataset)) self.assertEqual(batch.shape, (8,)) dataset = text_dataset.text_dataset_from_directory( directory, batch_size=8, label_mode=None) sample_count = 0 iterator = iter(dataset) for batch in dataset: sample_count += next(iterator).shape[0] self.assertEqual(sample_count, 15) dataset = text_dataset.text_dataset_from_directory( directory, batch_size=8, label_mode='int') batch = next(iter(dataset)) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (8,)) self.assertEqual(batch[0].dtype.name, 'string') self.assertEqual(batch[1].shape, (8,)) self.assertEqual(batch[1].dtype.name, 'int32') dataset = text_dataset.text_dataset_from_directory( directory, batch_size=8, label_mode='categorical') batch = next(iter(dataset)) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (8,)) self.assertEqual(batch[0].dtype.name, 'string') self.assertEqual(batch[1].shape, (8, 4)) self.assertEqual(batch[1].dtype.name, 'float32')
def test_text_dataset_from_directory_errors(self): directory = self._prepare_directory(num_classes=3, count=5) with self.assertRaisesRegex(ValueError, '`labels` argument should be'): _ = text_dataset.text_dataset_from_directory(directory, labels=None) with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'): _ = text_dataset.text_dataset_from_directory(directory, label_mode='other') with self.assertRaisesRegex( ValueError, 'only pass `class_names` if the labels are inferred'): _ = text_dataset.text_dataset_from_directory( directory, labels=[0, 0, 1, 1, 1], class_names=['class_0', 'class_1', 'class_2']) with self.assertRaisesRegex( ValueError, 'Expected the lengths of `labels` to match the number of files' ): _ = text_dataset.text_dataset_from_directory(directory, labels=[0, 0, 1, 1]) with self.assertRaisesRegex(ValueError, '`class_names` passed did not match'): _ = text_dataset.text_dataset_from_directory( directory, class_names=['class_0', 'class_2']) with self.assertRaisesRegex(ValueError, 'there must exactly 2 classes'): _ = text_dataset.text_dataset_from_directory(directory, label_mode='binary') with self.assertRaisesRegex( ValueError, '`validation_split` must be between 0 and 1'): _ = text_dataset.text_dataset_from_directory(directory, validation_split=2) with self.assertRaisesRegex(ValueError, '`subset` must be either "training" or'): _ = text_dataset.text_dataset_from_directory(directory, validation_split=0.2, subset='other') with self.assertRaisesRegex(ValueError, '`validation_split` must be set'): _ = text_dataset.text_dataset_from_directory(directory, validation_split=0, subset='training') with self.assertRaisesRegex(ValueError, 'must provide a `seed`'): _ = text_dataset.text_dataset_from_directory(directory, validation_split=0.2, subset='training')
def test_text_dataset_from_directory_validation_split(self): directory = self._prepare_directory(num_classes=2, count=10) dataset = text_dataset.text_dataset_from_directory( directory, batch_size=10, validation_split=0.2, subset='training', seed=1337) batch = next(iter(dataset)) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (8,)) dataset = text_dataset.text_dataset_from_directory( directory, batch_size=10, validation_split=0.2, subset='validation', seed=1337) batch = next(iter(dataset)) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (2,))
def test_text_dataset_from_directory_standalone(self): # Test retrieving txt files without labels from a directory and its subdirs. # Save a few extra files in the parent directory. directory = self._prepare_directory(count=7, num_classes=2) for i in range(3): filename = 'text_%s.txt' % (i,) f = open(os.path.join(directory, filename), 'w') text = ''.join([random.choice(string.printable) for _ in range(20)]) f.write(text) f.close() dataset = text_dataset.text_dataset_from_directory( directory, batch_size=5, label_mode=None, max_length=10) batch = next(iter(dataset)) # We just return the texts, no labels self.assertEqual(batch.shape, (5,)) self.assertEqual(batch.dtype.name, 'string') # Count samples batch_count = 0 sample_count = 0 for batch in dataset: batch_count += 1 sample_count += batch.shape[0] self.assertEqual(batch_count, 2) self.assertEqual(sample_count, 10)
def test_text_dataset_from_directory_manual_labels(self): directory = self._prepare_directory(num_classes=2, count=2) dataset = text_dataset.text_dataset_from_directory( directory, batch_size=8, labels=[0, 1], shuffle=False) batch = next(iter(dataset)) self.assertLen(batch, 2) self.assertAllClose(batch[1], [0, 1])
def test_sample_count(self): directory = self._prepare_directory(num_classes=4, count=15) dataset = text_dataset.text_dataset_from_directory( directory, batch_size=8, label_mode=None) sample_count = 0 for batch in dataset: sample_count += batch.shape[0] self.assertEqual(sample_count, 15)
def test_text_dataset_from_directory_follow_links(self): directory = self._prepare_directory(num_classes=2, count=25, nested_dirs=True) dataset = text_dataset.text_dataset_from_directory( directory, batch_size=8, label_mode=None, follow_links=True) sample_count = 0 for batch in dataset: sample_count += batch.shape[0] self.assertEqual(sample_count, 25)
def test_text_dataset_from_directory_no_files(self): directory = self._prepare_directory(num_classes=2, count=0) with self.assertRaisesRegex(ValueError, 'No text files found.'): _ = text_dataset.text_dataset_from_directory(directory)
from tensorflow.keras.layers import Input, Layer, MultiHeadAttention, Dense from tensorflow.keras.layers import LayerNormalization, Dropout, Embedding, GlobalMaxPooling1D from tensorflow.keras.models import Model from tensorflow.python.data.experimental import cardinality from tensorflow.python.keras.preprocessing.text_dataset import text_dataset_from_directory from tensorflow.keras.layers.experimental.preprocessing import TextVectorization from tensorflow.python.keras.utils.vis_utils import plot_model # descargar el dataset desde: https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz # una vez descomprimido, borrar la carpeta 'unsup' (sólo útil para aprendizaje no supervisado) data_path = os.path.join('data', 'nlp', 'aclImdb') batch_size = 32 raw_train_ds = text_dataset_from_directory(os.path.join(data_path, 'train'), batch_size=batch_size, validation_split=0.2, subset='training', seed=42) raw_val_ds = text_dataset_from_directory(os.path.join(data_path, 'train'), batch_size=batch_size, validation_split=0.2, subset='validation', seed=42) raw_test_ds = text_dataset_from_directory(os.path.join(data_path, 'test'), batch_size=batch_size) print(f'Number of batches in raw_train_ds: {cardinality(raw_train_ds)}') print(f'Number of batches in raw_val_ds: {cardinality(raw_val_ds)}') print(f'Number of batches in raw_test_ds: {cardinality(raw_test_ds)}') # Imprimimos 5 instancias para ver que aspecto tienen los datos