Exemplo n.º 1
0
    def test_load_dataset(self, include_sentence_id):
        seq_length = 16
        batch_size = 10
        train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
        _create_fake_dataset(train_data_path, seq_length, include_sentence_id)
        data_config = tagging_dataloader.TaggingDataConfig(
            input_path=train_data_path,
            seq_length=seq_length,
            global_batch_size=batch_size,
            include_sentence_id=include_sentence_id)

        dataset = tagging_dataloader.TaggingDataLoader(data_config).load()
        features, labels = next(iter(dataset))

        expected_keys = ['input_word_ids', 'input_mask', 'input_type_ids']
        if include_sentence_id:
            expected_keys.extend(['sentence_id', 'sub_sentence_id'])
        self.assertCountEqual(expected_keys, features.keys())

        self.assertEqual(features['input_word_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['input_mask'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['input_type_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(labels.shape, (batch_size, seq_length))
        if include_sentence_id:
            self.assertEqual(features['sentence_id'].shape, (batch_size, ))
            self.assertEqual(features['sub_sentence_id'].shape, (batch_size, ))
Exemplo n.º 2
0
    def test_load_dataset(self):
        seq_length = 16
        batch_size = 10
        train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
        _create_fake_dataset(train_data_path, seq_length)
        data_config = tagging_dataloader.TaggingDataConfig(
            input_path=train_data_path,
            seq_length=seq_length,
            global_batch_size=batch_size)

        dataset = tagging_dataloader.TaggingDataLoader(data_config).load()
        features, labels = next(iter(dataset))
        self.assertCountEqual(
            ['input_word_ids', 'input_mask', 'input_type_ids'],
            features.keys())
        self.assertEqual(features['input_word_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['input_mask'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['input_type_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(labels.shape, (batch_size, seq_length))