Ejemplo n.º 1
0
 def testDatasetSize(self, mode, split):
     batch_size = 9 if split == tfds.Split.TRAIN else 5
     dataset_builder = mnli.MnliDataset(split=split,
                                        mode=mode,
                                        shuffle_buffer_size=20)
     dataset = dataset_builder.load(batch_size=batch_size).take(1)
     element = next(iter(dataset))
     self.assertEqual(element['text_a'].shape[0], batch_size)
     self.assertEqual(element['text_b'].shape[0], batch_size)
     self.assertEqual(element['labels'].shape[0], batch_size)
Ejemplo n.º 2
0
    def testTrainSplitError(self):
        """Tests for ValueError when calling train split for mismatched data."""
        batch_size = 9
        eval_batch_size = 5
        dataset_builder = mnli.MnliDataset(mode='mismatched',
                                           batch_size=batch_size,
                                           eval_batch_size=eval_batch_size,
                                           shuffle_buffer_size=20)

        with self.assertRaises(Exception):
            dataset_builder.build(base.Split.TRAIN)
Ejemplo n.º 3
0
    def testDatasetSize(self, mode, split):
        batch_size = 9
        eval_batch_size = 5
        dataset_builder = mnli.MnliDataset(mode=mode,
                                           batch_size=batch_size,
                                           eval_batch_size=eval_batch_size,
                                           shuffle_buffer_size=20)
        dataset = dataset_builder.build(split).take(1)
        element = next(iter(dataset))
        text_a = element['text_a']
        text_b = element['text_b']
        labels = element['labels']

        expected_batch_size = (batch_size if split == base.Split.TRAIN else
                               eval_batch_size)
        feature_a_shape = text_a.shape[0]
        feature_b_shape = text_b.shape[0]
        labels_shape = labels.shape[0]

        self.assertEqual(feature_a_shape, expected_batch_size)
        self.assertEqual(feature_b_shape, expected_batch_size)
        self.assertEqual(labels_shape, expected_batch_size)
Ejemplo n.º 4
0
 def testTrainSplitError(self):
     """Tests for ValueError when calling train split for mismatched data."""
     with self.assertRaises(Exception):
         mnli.MnliDataset(mode='mismatched',
                          split=tfds.Split.TRAIN,
                          shuffle_buffer_size=20)