def test_python_sentencepiece_preprocessing(self, use_tfds): batch_size = 10 seq_length = 256 # Non-default value. lower_case = True tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record') text_fields = ['sentence1', 'sentence2'] if not use_tfds: _create_fake_raw_dataset(tf_record_path, text_fields, label_type='int') sp_model_file_path = _create_fake_sentencepiece_model( self.get_temp_dir()) data_config = loader.SentencePredictionTextDataConfig( input_path='' if use_tfds else tf_record_path, tfds_name='glue/mrpc' if use_tfds else '', tfds_split='train' if use_tfds else '', text_fields=text_fields, global_batch_size=batch_size, seq_length=seq_length, is_training=True, lower_case=lower_case, tokenization='SentencePiece', vocab_file=sp_model_file_path, ) dataset = loader.SentencePredictionTextDataLoader(data_config).load() features = next(iter(dataset)) label_field = data_config.label_field expected_keys = [ 'input_word_ids', 'input_type_ids', 'input_mask', label_field ] if use_tfds: expected_keys += ['idx'] self.assertCountEqual(expected_keys, features.keys()) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features[label_field].shape, (batch_size, ))
def test_saved_model_preprocessing(self, use_tfds): batch_size = 10 seq_length = 256 # Non-default value. tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record') text_fields = ['sentence1', 'sentence2'] if not use_tfds: _create_fake_raw_dataset(tf_record_path, text_fields, label_type='float') vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt') _create_fake_vocab_file(vocab_file_path) data_config = loader.SentencePredictionTextDataConfig( input_path='' if use_tfds else tf_record_path, tfds_name='glue/mrpc' if use_tfds else '', tfds_split='train' if use_tfds else '', text_fields=text_fields, global_batch_size=batch_size, seq_length=seq_length, is_training=True, preprocessing_hub_module_url=( 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'), label_type='int' if use_tfds else 'float', ) dataset = loader.SentencePredictionTextDataLoader(data_config).load() features = next(iter(dataset)) label_field = data_config.label_field expected_keys = [ 'input_word_ids', 'input_type_ids', 'input_mask', label_field ] if use_tfds: expected_keys += ['idx'] self.assertCountEqual(expected_keys, features.keys()) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features[label_field].shape, (batch_size, ))
def test_python_wordpiece_preprocessing(self, use_tfds): batch_size = 10 seq_length = 256 # Non-default value. lower_case = True tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record') text_fields = ['sentence1', 'sentence2'] if not use_tfds: _create_fake_raw_dataset(tf_record_path, text_fields, label_type='int') vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt') _create_fake_vocab_file(vocab_file_path) data_config = loader.SentencePredictionTextDataConfig( input_path='' if use_tfds else tf_record_path, tfds_name='glue/mrpc' if use_tfds else '', tfds_split='train' if use_tfds else '', tfds_download=True, text_fields=text_fields, global_batch_size=batch_size, seq_length=seq_length, is_training=True, lower_case=lower_case, vocab_file=vocab_file_path) dataset = loader.SentencePredictionTextDataLoader(data_config).load() features, labels = next(iter(dataset)) self.assertCountEqual( ['input_word_ids', 'input_type_ids', 'input_mask'], features.keys()) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(labels.shape, (batch_size, ))