def __repr__(self):
     s = ""
     s += "guid: {}".format((self.guid))
     s += ", text: {}".format(tokenization.printable_text(self.text))
     if self.text_b:
         s += ", text_a: {}".format(tokenization.printable_text(
             self.text_a))
     if self.text_b:
         s += ", text_b: {}".format(self.text_b)
     if self.label:
         s += ", label: {}".format(self.label)
     if self.label_ratio:
         s += ", label_ratio: {}".format(self.label_ratio)
     return s
Exemplo n.º 2
0
	def __str__(self):
		s = ""
		s += "guid: {}\n".format(self.guid)
		s += "tokens: %s\n" % (" ".join(
			[tokenization.printable_text(x) for x in self.tokens]))
		s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
		s += "is_random_next: %s\n" % self.is_random_next
		s += "masked_lm_positions: %s\n" % (" ".join(
			[str(x) for x in self.masked_lm_positions]))
		s += "masked_lm_labels: %s\n" % (" ".join(
			[tokenization.printable_text(x) for x in self.masked_lm_labels]))
		s += "label: {}\n".format(self.label)
		s += "\n"
		return s
Exemplo n.º 3
0
 def __repr__(self):
     s = ""
     s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
     s += ", question_text: %s" % (tokenization.printable_text(
         self.question_text))
     s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
     if self.start_position:
         s += ", start_position: %d" % (self.start_position)
     if self.end_position:
         s += ", end_position: %d" % (self.end_position)
     if self.answer_choice:
         for index, answer in enumerate(self.answer_choice):
             s += ", answer_choice: {} {}".format(index, answer)
     if self.choice:
         s += ", correct choice: {}".format(self.choice)
     return s
Exemplo n.º 4
0
 def __repr__(self):
     s = ""
     s += "guid: {}".format((self.guid))
     s += ", text_a: {}".format(tokenization.printable_text(self.text_a))
     if self.text_b:
         s += ", text_b: {}".format(self.text_b)
     if self.label:
         s += ", label: {}".format(self.label)
     if self.label_ratio:
         s += ", label_ratio: {}".format(self.label_ratio)
     if self.label_probs:
         s += ", label_probs: {}".foramt(self.label_probs)
     if self.distillation_ratio:
         s += ", label_probs: {}".foramt(self.distillation_ratio)
     if self.adv_label:
         s += ", adv_label: {}".foramt(self.adv_label)
     return s
Exemplo n.º 5
0
def write_instance_to_example_files(examples,
                                    label_dict,
                                    tokenizer,
                                    max_seq_length,
                                    masked_lm_prob,
                                    max_predictions_per_seq,
                                    output_file,
                                    dupe,
                                    random_seed=2018,
                                    feature_type="pretrain_qa",
                                    log_cycle=100):
    """Create TF example files from `TrainingInstance`s."""
    rng = random.Random(random_seed)

    feature_writer = PairPreTrainingFeature(output_file, is_training=False)

    if feature_type == "pretrain_qa":
        instances = create_instances_qa(examples, dupe, max_seq_length,
                                        masked_lm_prob, tokenizer,
                                        max_predictions_per_seq, rng)
    elif feature_type == "pretrain_classification":
        instances = create_instances_classification(examples, dupe,
                                                    max_seq_length,
                                                    masked_lm_prob, tokenizer,
                                                    max_predictions_per_seq,
                                                    rng)

    rng.shuffle(instances)

    total_written = 0
    for (inst_index, instance) in enumerate(instances):
        input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
        input_mask = [1] * len(input_ids)
        segment_ids = list(instance.segment_ids)
        assert len(input_ids) <= max_seq_length

        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        masked_lm_positions = list(instance.masked_lm_positions)
        masked_lm_ids = tokenizer.convert_tokens_to_ids(
            instance.masked_lm_labels)
        masked_lm_weights = [1.0] * len(masked_lm_ids)

        while len(masked_lm_positions) < max_predictions_per_seq:
            masked_lm_positions.append(0)
            masked_lm_ids.append(0)
            masked_lm_weights.append(0.0)

        if len(instance.label) == 1:
            label_id = label_dict[instance.label[0]]

        features = pretrain_feature.PreTrainingFeature(
            guid=instance.guid,
            input_ids=input_ids,
            input_mask=input_mask,
            segment_ids=segment_ids,
            masked_lm_positions=masked_lm_positions,
            masked_lm_ids=masked_lm_ids,
            masked_lm_weights=masked_lm_weights,
            label_ids=label_id,
            is_random_next=instance.is_random_next)
        feature_writer.process_feature(features)

        if np.mod(inst_index, log_cycle) == 0:
            tf.logging.info("*** Example ***")
            tf.logging.info("guid: %s" % (features.guid))
            tf.logging.info("tokens: %s" % " ".join(
                [tokenization.printable_text(x) for x in instance.tokens]))
            tf.logging.info("input_ids: %s" %
                            " ".join([str(x) for x in input_ids]))
            tf.logging.info("input_mask: %s" %
                            " ".join([str(x) for x in input_mask]))
            tf.logging.info("segment_ids: %s" %
                            " ".join([str(x) for x in segment_ids]))
            tf.logging.info("masked_lm_positions: %s" %
                            " ".join([str(x) for x in masked_lm_positions]))
            tf.logging.info("masked_lm_ids: %s" %
                            " ".join([str(x) for x in masked_lm_ids]))
            tf.logging.info("masked_lm_weights: %s" %
                            " ".join([str(x) for x in masked_lm_weights]))
            tf.logging.info("label: {} (id = {})".format(
                instance.label, label_id))
            tf.logging.info("is_random_next: {} ".format(
                instance.is_random_next))
            tf.logging.info("length of tokens: {} ".format(len(
                instance.tokens)))

    feature_writer.close()

    tf.logging.info("Wrote %d total instances", total_written)
Exemplo n.º 6
0
def convert_classifier_examples_with_rule_to_features(examples, label_dict,
                                                      max_seq_length,
                                                      tokenizer, rule_detector,
                                                      output_file):

    feature_writer = ClassifierRuleFeatureWriter(output_file,
                                                 is_training=False)

    for (ex_index, example) in enumerate(examples):
        tokens_a = tokenizer.tokenize(example.text_a)
        if ex_index % 10000 == 0:
            tf.logging.info("Writing example %d of %d" %
                            (ex_index, len(examples)))

        tokens_b = None
        if example.text_b:
            try:
                tokens_b = tokenizer.tokenize(example.text_b)
            except:
                print("==token b error==", example.text_b, ex_index)
                break

        if tokens_b:
            tf_data_utils._truncate_seq_pair(tokens_a, tokens_b,
                                             max_seq_length - 3)

        else:
            if len(tokens_a) > max_seq_length - 2:
                tokens_a = tokens_a[0:(max_seq_length - 2)]

        rule_id_lst = rule_detector.infer(tokens_a)

        tokens = []
        segment_ids = []
        rule_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        rule_ids.append(0)
        for index, token in enumerate(tokens_a):
            tokens.append(token)
            segment_ids.append(0)
            rule_ids.append(rule_id_lst[index])
        tokens.append("[SEP]")
        segment_ids.append(0)
        rule_ids.append(0)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)
            rule_ids.append(0)

        try:
            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length
            assert len(rule_ids) == max_seq_length
        except:
            print(len(input_ids), max_seq_length, ex_index, "length error")
            break

        if len(example.label) == 1:
            label_id = label_dict[example.label[0]]
        else:
            label_id = [0] * len(label_dict)
            for item in example.label:
                label_id[label_dict[item]] = 1
        if ex_index < 5:
            print(tokens)
            tf.logging.info("*** Example ***")
            tf.logging.info("guid: %s" % (example.guid))
            tf.logging.info(
                "tokens: %s" %
                " ".join([tokenization.printable_text(x) for x in tokens]))
            tf.logging.info("input_ids: %s" %
                            " ".join([str(x) for x in input_ids]))
            tf.logging.info("input_mask: %s" %
                            " ".join([str(x) for x in input_mask]))
            tf.logging.info("segment_ids: %s" %
                            " ".join([str(x) for x in segment_ids]))
            tf.logging.info("rule_ids: %s" %
                            " ".join([str(x) for x in rule_ids]))
            tf.logging.info("label: {} (id = {})".format(
                example.label, label_id))

        feature = extra_mask_feature_classifier.InputFeatures(
            guid=example.guid,
            input_ids=input_ids,
            input_mask=input_mask,
            segment_ids=segment_ids,
            rule_ids=rule_ids,
            label_ids=label_id)
        feature_writer.process_feature(feature)
    feature_writer.close()
Exemplo n.º 7
0
def convert_pair_order_classifier_examples_to_features(examples, label_dict,
                                                       max_seq_length,
                                                       tokenizer, output_file):

    feature_writer = PairClassifierFeatureWriter(output_file,
                                                 is_training=False)

    for (ex_index, example) in enumerate(examples):
        tokens_a = tokenizer.tokenize(example.text_a)
        if ex_index % 10000 == 0:
            tf.logging.info("Writing example %d of %d" %
                            (ex_index, len(examples)))

        tokens_b = tokenizer.tokenize(example.text_b)

        tf_data_utils._truncate_seq_pair(tokens_a, tokens_b,
                                         max_seq_length - 3)

        def get_input(input_tokens_a, input_tokens_b):
            tokens = []
            segment_ids = []
            tokens.append("[CLS]")
            segment_ids.append(0)

            for token in input_tokens_a:
                tokens.append(token)
                segment_ids.append(0)
            tokens.append("[SEP]")
            segment_ids.append(0)

            for token in input_tokens_b:
                tokens.append(token)
                segment_ids.append(1)
            tokens.append("[SEP]")
            segment_ids.append(1)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)

            return [tokens, input_ids, input_mask, segment_ids]

        [tokens_a_, input_ids_a, input_mask_a,
         segment_ids_a] = get_input(tokens_a, tokens_b)

        [tokens_b_, input_ids_b, input_mask_b,
         segment_ids_b] = get_input(tokens_b, tokens_a)

        try:
            assert len(input_ids_a) == max_seq_length
            assert len(input_mask_a) == max_seq_length
            assert len(segment_ids_a) == max_seq_length

            assert len(input_ids_b) == max_seq_length
            assert len(input_mask_b) == max_seq_length
            assert len(segment_ids_b) == max_seq_length

        except:
            print(len(input_ids_a), input_ids_a, max_seq_length, ex_index,
                  "length error")
            break

        if len(example.label) == 1:
            label_id = label_dict[example.label[0]]
        else:
            label_id = [0] * len(label_dict)
            for item in example.label:
                label_id[label_dict[item]] = 1
        if ex_index < 5:
            tf.logging.info("*** Example ***")
            tf.logging.info("guid: %s" % (example.guid))
            tf.logging.info(
                "tokens_a: %s" %
                " ".join([tokenization.printable_text(x) for x in tokens_a_]))
            tf.logging.info("input_ids_a: %s" %
                            " ".join([str(x) for x in input_ids_a]))
            tf.logging.info("input_mask_a: %s" %
                            " ".join([str(x) for x in input_mask_a]))
            tf.logging.info("segment_ids_a: %s" %
                            " ".join([str(x) for x in segment_ids_a]))

            tf.logging.info(
                "tokens_b: %s" %
                " ".join([tokenization.printable_text(x) for x in tokens_b_]))
            tf.logging.info("input_ids_b: %s" %
                            " ".join([str(x) for x in input_ids_b]))
            tf.logging.info("input_mask_b: %s" %
                            " ".join([str(x) for x in input_mask_b]))
            tf.logging.info("segment_ids_b: %s" %
                            " ".join([str(x) for x in segment_ids_b]))

            tf.logging.info("label: {} (id = {})".format(
                example.label, label_id))

        feature = pair_data_feature_classifier.InputFeatures(
            guid=example.guid,
            input_ids_a=input_ids_a,
            input_mask_a=input_mask_a,
            segment_ids_a=segment_ids_a,
            input_ids_b=input_ids_b,
            input_mask_b=input_mask_b,
            segment_ids_b=segment_ids_b,
            label_ids=label_id)
        feature_writer.process_feature(feature)
    feature_writer.close()
Exemplo n.º 8
0
def convert_span_mrc_examples_to_features(examples, label_dict, max_seq_length,
                                          tokenizer, output_file):
    """Loads a data file into a list of `InputBatch`s."""

    unique_id = 1000000000
    feature_writer = SpanFeatureWriter(output_file, is_traiing=False)
    for (example_index, example) in enumerate(examples):
        query_tokens = tokenizer.tokenize(example.question_text)

        if len(query_tokens) > max_query_length:
            query_tokens = query_tokens[0:max_query_length]

        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, token) in enumerate(example.doc_tokens):
            orig_to_tok_index.append(len(all_doc_tokens))
            sub_tokens = tokenizer.tokenize(token)
            for sub_token in sub_tokens:
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)

        tok_start_position = None
        tok_end_position = None
        if is_training:
            tok_start_position = orig_to_tok_index[example.start_position]
            if example.end_position < len(example.doc_tokens) - 1:
                tok_end_position = orig_to_tok_index[example.end_position +
                                                     1] - 1
            else:
                tok_end_position = len(all_doc_tokens) - 1
            (tok_start_position,
             tok_end_position) = tf_data_utils._improve_answer_span(
                 all_doc_tokens, tok_start_position, tok_end_position,
                 tokenizer, example.orig_answer_text)

        # The -3 accounts for [CLS], [SEP] and [SEP]
        max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

        # We can have documents that are longer than the maximum sequence length.
        # To deal with this we do a sliding window approach, where we take chunks
        # of the up to our max length with a stride of `doc_stride`.
        _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
            "DocSpan", ["start", "length"])
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start=start_offset, length=length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, doc_stride)

        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_to_orig_map = {}
            token_is_max_context = {}
            segment_ids = []
            tokens.append("[CLS]")
            segment_ids.append(0)
            for token in query_tokens:
                tokens.append(token)
                segment_ids.append(0)
            tokens.append("[SEP]")
            segment_ids.append(0)

            for i in range(doc_span.length):
                split_token_index = doc_span.start + i
                token_to_orig_map[len(
                    tokens)] = tok_to_orig_index[split_token_index]

                is_max_context = tf_data_utils._check_is_max_context(
                    doc_spans, doc_span_index, split_token_index)
                token_is_max_context[len(tokens)] = is_max_context
                tokens.append(all_doc_tokens[split_token_index])
                segment_ids.append(1)
            tokens.append("[SEP]")
            segment_ids.append(1)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            start_position = None
            end_position = None
            if is_training:
                # For training, if our document chunk does not contain an annotation
                # we throw it out, since there is nothing to predict.
                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1
                if (example.start_position < doc_start
                        or example.end_position < doc_start
                        or example.start_position > doc_end
                        or example.end_position > doc_end):
                    continue

                doc_offset = len(query_tokens) + 2
                start_position = tok_start_position - doc_start + doc_offset
                end_position = tok_end_position - doc_start + doc_offset

            if example_index < 20:
                tf.logging.info("*** Example ***")
                tf.logging.info("unique_id: %s" % (unique_id))
                tf.logging.info("example_index: %s" % (example_index))
                tf.logging.info("doc_span_index: %s" % (doc_span_index))
                tf.logging.info(
                    "tokens: %s" %
                    " ".join([tokenization.printable_text(x) for x in tokens]))
                tf.logging.info("token_to_orig_map: %s" % " ".join([
                    "%d:%d" % (x, y)
                    for (x, y) in six.iteritems(token_to_orig_map)
                ]))
                tf.logging.info("token_is_max_context: %s" % " ".join([
                    "%d:%s" % (x, y)
                    for (x, y) in six.iteritems(token_is_max_context)
                ]))
                tf.logging.info("input_ids: %s" %
                                " ".join([str(x) for x in input_ids]))
                tf.logging.info("input_mask: %s" %
                                " ".join([str(x) for x in input_mask]))
                tf.logging.info("segment_ids: %s" %
                                " ".join([str(x) for x in segment_ids]))
                if is_training:
                    answer_text = " ".join(
                        tokens[start_position:(end_position + 1)])
                    tf.logging.info("start_position: %d" % (start_position))
                    tf.logging.info("end_position: %d" % (end_position))
                    tf.logging.info("answer: %s" %
                                    (tokenization.printable_text(answer_text)))

            feature = data_feature_mrc.InputFeatures(
                unique_id=unique_id,
                example_index=example_index,
                doc_span_index=doc_span_index,
                tokens=tokens,
                token_to_orig_map=token_to_orig_map,
                token_is_max_context=token_is_max_context,
                input_ids=input_ids,
                input_mask=input_mask,
                segment_ids=segment_ids,
                start_position=start_position,
                end_position=end_position)

            feature_writer.process_feature(feature)
            unique_id += 1
    feature_writer.close()
Exemplo n.º 9
0
def create_cls_problem_generator(task_type,
									examples,
									label_dict,
									multi_task_config,
									tokenizer,
									mode):
	max_seq_length = multi_task_config[task_type]["max_length"]
	lm_augumentation = multi_task_config[task_type]["lm_augumentation"]
	for (ex_index, example) in enumerate(examples):
		tokens_a = tokenizer.tokenize(example.text_a)
		if ex_index % 10000 == 0:
			tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))

		tokens_b = None
		if example.text_b:
			try:
				tokens_b = tokenizer.tokenize(example.text_b)
			except:
				print("==token b error==", example.text_b, ex_index)
				break

		if tokens_b:
			tf_data_utils._truncate_seq_pair(tokens_a, tokens_b, max_seq_length-3)
		else:
			if len(tokens_a) > max_seq_length - 2:
				tokens_a = tokens_a[0:(max_seq_length - 2)]

		tokens = []
		segment_ids = []
		tokens.append("[CLS]")
		segment_ids.append(0)

		for token in tokens_a:
			tokens.append(token)
			segment_ids.append(0)
		tokens.append("[SEP]")
		segment_ids.append(0)

		if tokens_b:
			for token in tokens_b:
				tokens.append(token)
				segment_ids.append(1)
			tokens.append("[SEP]")
			segment_ids.append(1)

		if lm_augumentation and mode == 'train':
			rng = random.Random()
			(mask_lm_tokens, masked_lm_positions,
				masked_lm_labels) = create_masked_lm_predictions(
					tokens,
					multi_task_config[task_type]["masked_lm_prob"],
					multi_task_config[task_type]["max_predictions_per_seq"],
					list(tokenizer.vocab.keys()), rng)

			_, mask_lm_tokens, _ = create_mask_and_padding(
				mask_lm_tokens, copy(segment_ids), max_seq_length)
			masked_lm_weights, masked_lm_labels, masked_lm_positions = create_mask_and_padding(
				masked_lm_labels, masked_lm_positions, 
				multi_task_config[task_type]["max_predictions_per_seq"])
			mask_lm_input_ids = tokenizer.convert_tokens_to_ids(
				mask_lm_tokens)
			masked_lm_ids = tokenizer.convert_tokens_to_ids(masked_lm_labels)

			assert len(mask_lm_tokens) == max_seq_length

		input_mask, tokens, segment_ids = create_mask_and_padding(
			tokens, segment_ids, max_seq_length)

		input_ids = tokenizer.convert_tokens_to_ids(tokens)
		if len(example.label) == 1:
			label_id = label_dict[example.label[0]]
		else:
			label_id = [0] * len(label_dict)
			for item in example.label:
				label_id[label_dict[item]] = 1

		assert len(input_ids) == max_seq_length
		assert len(input_mask) == max_seq_length
		assert len(segment_ids) == max_seq_length

		if ex_index < 5:
			tf.logging.debug("*** Example ***")
			tf.logging.debug("tokens: %s" % " ".join(
				[tokenization.printable_text(x) for x in tokens]))
			tf.logging.debug("input_ids: %s" %
							 " ".join([str(x) for x in input_ids]))
			tf.logging.debug("input_mask: %s" %
							 " ".join([str(x) for x in input_mask]))
			tf.logging.debug("segment_ids: %s" %
							 " ".join([str(x) for x in segment_ids]))
			tf.logging.debug("%s_label_ids: %s" %
							 (task_type, str(label_id)))
			tf.logging.debug("%s_label: %s" %
							 (task_type, str(example.label)))
			if lm_augumentation and mode == 'train':
				tf.logging.debug("mask lm tokens: %s" % " ".join(
					[tokenization.printable_text(x) for x in mask_lm_tokens]))
				tf.logging.debug("mask lm input_ids: %s" %
								 " ".join([str(x) for x in mask_lm_input_ids]))
				tf.logging.debug("mask lm label ids: %s" %
								 " ".join([str(x) for x in masked_lm_ids]))
				tf.logging.debug("mask lm position: %s" %
								 " ".join([str(x) for x in masked_lm_positions]))
			
		if not lm_augumentation:
			return_dict = {
				'input_ids': input_ids,
				'input_mask': input_mask,
				'segment_ids': segment_ids,
				'%s_label_ids' % task_type: label_id
			}

		else:
			if mode == 'train':
				return_dict = {
					'input_ids': mask_lm_input_ids,
					'input_mask': input_mask,
					'segment_ids': segment_ids,
					'%s_label_ids' % task_type: label_id,
					"masked_lm_positions": masked_lm_positions,
					"masked_lm_ids": masked_lm_ids,
					"masked_lm_weights": masked_lm_weights,
				}
			else:
				return_dict = {
                    'input_ids': input_ids,
                    'input_mask': input_mask,
                    'segment_ids': segment_ids,
                    '%s_label_ids' % task_type: label_id,
                    "masked_lm_positions": [0]*multi_task_config[task_type]["max_predictions_per_seq"],
                    "masked_lm_ids": [0]*multi_task_config[task_type]["max_predictions_per_seq"],
                    "masked_lm_weights": [0]*multi_task_config[task_type]["max_predictions_per_seq"],
                }
		
		yield return_dict
Exemplo n.º 10
0
def convert_classifier_examples_to_features(examples, label_dict, 
											max_seq_length,
											tokenizer, output_file, 
											rule_matcher, background_label):

	feature_writer = ClassifierFeatureWriter(output_file, is_training=False)

	for (ex_index, example) in enumerate(examples):
		tokens_a = tokenizer.tokenize(example.text_a)

		if ex_index % 10000 == 0:
			tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))

		if len(tokens_a) > max_seq_length - 2:
			tokens_a = tokens_a[0:(max_seq_length - 2)]

		tokens_a_rule = rule_matcher.parse(tokens_a, background_label)

		tokens = []
		tokens.append("[CLS]")
		rule_ids.append(label_dict[background_label])
		rule_ids = [label_dict[rule[0]] for rule in tokens_a_rule]

		for token in tokens_a:
			tokens.append(token)

		tokens.append("[SEP]")
		rule_ids.append(label_dict[background_label])

		input_ids = tokenizer.convert_tokens_to_ids(tokens)
		input_mask = [1] * len(input_ids)

		# Zero-pad up to the sequence length.
		while len(input_ids) < max_seq_length:
			input_ids.append(0)
			input_mask.append(0)
			rule_ids.append(label_dict[background_label])

		try:

			assert len(input_ids) == max_seq_length
			assert len(input_mask) == max_seq_length
			assert len(rule_ids) == max_seq_length
		except:
			print(len(input_ids), max_seq_length, ex_index, "length error")
			break

		if len(example.label) == 1:
			label_id = label_dict[example.label[0]]
		else:
			label_id = [0] * len(label_dict)
			for item in example.label:
				label_id[label_dict[item]] = 1
		if ex_index < 5:
			print(tokens)
			tf.logging.info("*** Example ***")
			tf.logging.info("guid: %s" % (example.guid))
			tf.logging.info("tokens: %s" % " ".join(
					[tokenization.printable_text(x) for x in tokens]))
			tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
			tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
			tf.logging.info(
					"rule_ids: %s" % " ".join([str(x) for x in rule_ids]))
			tf.logging.info("label: {} (id = {})".format(example.label, label_id))

		feature = data_feature_classifier.InputFeatures(
					guid=example.guid,
					input_ids=input_ids,
					input_mask=input_mask,
					segment_ids=rule_ids,
					label_ids=label_id)
		feature_writer.process_feature(feature)
	feature_writer.close()