Ejemplo n.º 1
0
    def test_python_sentencepiece_preprocessing(self, use_tfds):
        batch_size = 10
        seq_length = 256  # Non-default value.
        lower_case = True

        tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
        text_fields = ['sentence1', 'sentence2']
        if not use_tfds:
            _create_fake_raw_dataset(tf_record_path,
                                     text_fields,
                                     label_type='int')

        sp_model_file_path = _create_fake_sentencepiece_model(
            self.get_temp_dir())
        data_config = loader.SentencePredictionTextDataConfig(
            input_path='' if use_tfds else tf_record_path,
            tfds_name='glue/mrpc' if use_tfds else '',
            tfds_split='train' if use_tfds else '',
            text_fields=text_fields,
            global_batch_size=batch_size,
            seq_length=seq_length,
            is_training=True,
            lower_case=lower_case,
            tokenization='SentencePiece',
            vocab_file=sp_model_file_path,
        )
        dataset = loader.SentencePredictionTextDataLoader(data_config).load()
        features = next(iter(dataset))
        label_field = data_config.label_field
        expected_keys = [
            'input_word_ids', 'input_type_ids', 'input_mask', label_field
        ]
        if use_tfds:
            expected_keys += ['idx']
        self.assertCountEqual(expected_keys, features.keys())
        self.assertEqual(features['input_word_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['input_mask'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['input_type_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features[label_field].shape, (batch_size, ))
Ejemplo n.º 2
0
    def test_saved_model_preprocessing(self, use_tfds):
        batch_size = 10
        seq_length = 256  # Non-default value.

        tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
        text_fields = ['sentence1', 'sentence2']
        if not use_tfds:
            _create_fake_raw_dataset(tf_record_path,
                                     text_fields,
                                     label_type='float')

        vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
        _create_fake_vocab_file(vocab_file_path)
        data_config = loader.SentencePredictionTextDataConfig(
            input_path='' if use_tfds else tf_record_path,
            tfds_name='glue/mrpc' if use_tfds else '',
            tfds_split='train' if use_tfds else '',
            text_fields=text_fields,
            global_batch_size=batch_size,
            seq_length=seq_length,
            is_training=True,
            preprocessing_hub_module_url=(
                'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'),
            label_type='int' if use_tfds else 'float',
        )
        dataset = loader.SentencePredictionTextDataLoader(data_config).load()
        features = next(iter(dataset))
        label_field = data_config.label_field
        expected_keys = [
            'input_word_ids', 'input_type_ids', 'input_mask', label_field
        ]
        if use_tfds:
            expected_keys += ['idx']
        self.assertCountEqual(expected_keys, features.keys())
        self.assertEqual(features['input_word_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['input_mask'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['input_type_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features[label_field].shape, (batch_size, ))
Ejemplo n.º 3
0
    def test_python_wordpiece_preprocessing(self, use_tfds):
        batch_size = 10
        seq_length = 256  # Non-default value.
        lower_case = True

        tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
        text_fields = ['sentence1', 'sentence2']
        if not use_tfds:
            _create_fake_raw_dataset(tf_record_path,
                                     text_fields,
                                     label_type='int')

        vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
        _create_fake_vocab_file(vocab_file_path)

        data_config = loader.SentencePredictionTextDataConfig(
            input_path='' if use_tfds else tf_record_path,
            tfds_name='glue/mrpc' if use_tfds else '',
            tfds_split='train' if use_tfds else '',
            tfds_download=True,
            text_fields=text_fields,
            global_batch_size=batch_size,
            seq_length=seq_length,
            is_training=True,
            lower_case=lower_case,
            vocab_file=vocab_file_path)
        dataset = loader.SentencePredictionTextDataLoader(data_config).load()
        features, labels = next(iter(dataset))
        self.assertCountEqual(
            ['input_word_ids', 'input_type_ids', 'input_mask'],
            features.keys())
        self.assertEqual(features['input_word_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['input_mask'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['input_type_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(labels.shape, (batch_size, ))