def __repr__(self): s = "" s += "qas_id: %s" % (printable_text(self.qas_id)) s += ", question_text: %s" % (printable_text(self.question_text)) s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text)) if self.start_position: s += ", start_position: %d" % (self.start_position) if self.start_position: s += ", is_impossible: %r" % (self.is_impossible) return s
def print_tokens(inputs: Inputs, inv_vocab, updates_mask=None): """Pretty-print model inputs.""" pos_to_tokid = {} for tokid, pos, weight in zip(inputs.masked_lm_ids[0], inputs.masked_lm_positions[0], inputs.masked_lm_weights[0]): if weight == 0: pass else: pos_to_tokid[pos] = tokid text = "" provided_update_mask = (updates_mask is not None) if not provided_update_mask: updates_mask = np.zeros_like(inputs.input_ids) for pos, (tokid, um) in enumerate(zip(inputs.input_ids[0], updates_mask[0])): token = inv_vocab[tokid] if token == "[PAD]": break if pos in pos_to_tokid: token = RED + token + " (" + inv_vocab[ pos_to_tokid[pos]] + ")" + ENDC if provided_update_mask: assert um == 1 else: if provided_update_mask: assert um == 0 text += token + " " utils.log(utils.printable_text(text))
def full_label_running(sess, model, dataset, output_dir, show_info=True, tokenizer=None): next_element = dataset.get_next() start_time = datetime.now() step = 0 output_file = os.path.join(output_dir, f"filtered_{FLAGS.pl_domain}_part_label") # output_file = os.path.join(output_dir, "filtered_com_part_label") keep_count = 0 batch_index = 0 f = codecs.open(output_file, 'w', encoding='utf-8') while True: ground_truth = [] predictions = [] try: example = sess.run(next_element) input_ids = example["input_ids"] input_dicts = example["input_dicts"] label_ids = example["label_ids"] seq_length = example["seq_length"] loss, length, prediction = sess.run( [model.total_loss, model.seq_length, model.prediction], feed_dict={model.input_ids: input_ids, model.input_dicts: input_dicts, model.label_ids: label_ids, model.seq_length: seq_length, model.dropout_keep_prob: 1} ) step += 1 # label_ids: [B, MaxLen, D] # prediction_one_hot: [B, MaxLen, D] # prediction: [B, MaxLen] true_batch_size = len(input_ids) ground_truth.extend([label_ids[i, :length[i]].tolist() for i in range(true_batch_size)]) predictions.extend([prediction[i, :length[i]].tolist() for i in range(true_batch_size)]) texts_ids = [input_ids[i, :length[i].tolist()] for i in range(true_batch_size)] tokens = list(map(lambda x: tokenizer.restore(x), texts_ids)) texts = [list(map(lambda t: utils.printable_text(t), token)) for token in tokens] for index, (gt, pred, txt, inputs) in enumerate(zip(ground_truth, predictions, texts, input_ids)): if keep(gt, pred): keep_count += 1 f.write(f"{batch_index * FLAGS.batch_size + index}\n") for g, p, t, i in zip(gt, pred, txt, inputs): f.write(f"{t} {tokenizer.inv_type_vocab[i[6]]} {label2str(g)} {label2str(p)}" f"{' ◁' if g[p] == 0 else ''}\n") f.write("\n") if step % 1000 == 0 and show_info: now_time = datetime.now() tf.logging.info( f"Step: {step} ({(now_time - start_time).total_seconds():.2f} sec)") start_time = now_time batch_index += 1 except tf.errors.OutOfRangeError: tf.logging.info( f"Finish Keep: {keep_count}") break
def build_single_example(self, ex_index, example): """Converts a single `InputExample` into a single `InputFeatures`.""" tokens_raw = example.text labels_raw = example.labels tokens = [] label_ids = [] assert len(tokens_raw) == len(labels_raw) for token, label in zip(tokens_raw, labels_raw): tokens.append(token) label_ids.append(self.label_map[label]) input_features = {} seq_length = len(tokens) assert seq_length == len(label_ids) for feature_name, feature_extractor in self.extractors.items(): feature = feature_extractor.extract(tokens) input_features[feature_name] = feature assert seq_length == len(feature) if ex_index < 1: tf.logging.info("*** Example ***") tf.logging.info("guid: %s" % example.guid) tf.logging.info("tokens: %s" % " ".join( [utils.printable_text(x) for x in tokens])) for feature_name, feature in input_features.items(): tf.logging.info("%s: %s" % (feature_name, " ".join([str(x) for x in feature]))) tf.logging.info("labels: %s" % " ".join([str(x) for x in example.labels])) tf.logging.info("labels_ids: %s" % " ".join([str(x) for x in label_ids])) feature = InputFeatures( input_features=input_features, label_ids=label_ids, seq_length=seq_length) return feature
def main(_): if FLAGS.do_train: tf.logging.set_verbosity(tf.logging.INFO) np.random.seed(31415926) random.seed(31415926) if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: raise ValueError( "At least one of `train`, `eval` or `predict' must be select.") model_class, config, dim_info, processor, extractors, data_augmenter = prepare_form_config( FLAGS) test_dataset_map = {} cxt_feature_extractor = extractors["input_ids"] feat_builder = feature_builder.FeatureBuilder( extractors=extractors, label_map=processor.get_labels()) train_features = [] part_label_dataset = None train_dataset = None dev_dataset = None if FLAGS.do_train: train_examples = processor.get_examples(data_dir=FLAGS.data_dir, example_type="train") train_features = feat_builder.build_features_from_examples( examples=train_examples) train_dataset = dataset.PaddingDataset( train_features, batch_size=FLAGS.train_batch_size, dim_info=dim_info) del train_examples if FLAGS.do_eval: dev_examples = processor.get_examples(data_dir=FLAGS.data_dir, example_type="dev") dev_features = feat_builder.build_features_from_examples( examples=dev_examples) dev_dataset = dataset.PaddingDataset(dev_features, batch_size=FLAGS.eval_batch_size, dim_info=dim_info) del dev_examples if FLAGS.pl_domain is not None and FLAGS.do_train: if not FLAGS.multitag: raise ValueError("part label train must use multi tag!") part_label_examples = processor.get_examples(data_dir=FLAGS.data_dir, example_type="pl", domain=FLAGS.pl_domain) part_label_features = feat_builder.build_features_from_examples( examples=part_label_examples) if FLAGS.mix_pl_data: if FLAGS.corpus_weighting: part_label_dataset = dataset.CorpusWeightingDataset( [train_features, part_label_features], [10000, 10000], batch_size=FLAGS.train_batch_size, dim_info=dim_info) else: part_label_dataset = dataset.BatchMixDataset( [train_features, part_label_features], [1, 5], batch_size=FLAGS.train_batch_size, dim_info=dim_info) else: part_label_dataset = dataset.PaddingDataset( part_label_features, batch_size=FLAGS.train_batch_size, dim_info=dim_info) del part_label_examples if FLAGS.do_predict: if FLAGS.test_domain is not None: domains = FLAGS.test_domain.split(",") test_dataset_map = { domain: dataset.PaddingDataset( feat_builder.build_features_from_examples( examples=processor.get_examples( data_dir=FLAGS.data_dir, example_type="test", domain=domain)), batch_size=FLAGS.predict_batch_size, dim_info=dim_info) for domain in domains } else: test_dataset_map = { "test": dataset.PaddingDataset( feat_builder.build_features_from_examples( examples=processor.get_examples( data_dir=FLAGS.data_dir, example_type="test")), batch_size=FLAGS.predict_batch_size, dim_info=dim_info) } sess_config = tf.ConfigProto() sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory with tf.Graph().as_default(), tf.Session(config=sess_config) as sess: tf.set_random_seed(31415926) # train & eval model = models.ModelAdapter.ModelAdapter( model_class, dim_info=dim_info, config=config, init_checkpoint=FLAGS.init_checkpoint, tokenizer=cxt_feature_extractor, init_embedding=FLAGS.init_embedding, learning_rate=FLAGS.learning_rate) sess.run(tf.global_variables_initializer()) # if FLAGS.pl_domain is not None: # model_path = os.path.join(FLAGS.output_dir, f"{FLAGS.pl_domain}_model.ckpt") # else: # model_path = os.path.join(FLAGS.output_dir, "model.ckpt") # saver = tf.train.Saver() if FLAGS.do_train: # saver = BestCheckpointSaver( # save_dir=FLAGS.output_dir, # num_to_keep=3, # maximize=True # ) best_valid_f1 = 0. best_epoch = 0 best_heap = [] very_start_time = datetime.now() for epoch in range(FLAGS.num_train_epochs): start_time = datetime.now() if FLAGS.pl_domain is not None: tf.logging.info( f"Epoch: {epoch} Domain: {FLAGS.pl_domain}") if epoch < 10: model.assign_lr(sess, FLAGS.learning_rate) if 10 <= epoch < 15: model.assign_lr(sess, FLAGS.learning_rate * config.lr_decay) if 15 <= epoch < 20: model.assign_lr(sess, FLAGS.learning_rate * config.lr_decay**2) if 20 <= epoch < 25: model.assign_lr(sess, FLAGS.learning_rate * config.lr_decay**3) if 25 <= epoch: model.assign_lr(sess, FLAGS.learning_rate * config.lr_decay**4) if part_label_dataset is None: _, _, _, total_loss, total_step = dataset_running( sess, model, train_dataset, dim_info, config, is_training=True, show_info=True) else: if FLAGS.mix_pl_data: total_loss, total_step = dataset_running( sess, model, part_label_dataset, dim_info, config, is_training=True, show_info=True) else: total_pl_loss = 0 total_pl_step = 0 if epoch % FLAGS.whole_pl_training_epoch == 0: total_pl_loss, total_pl_step = dataset_running( sess, model, part_label_dataset, dim_info, config, is_training=True, show_info=True) _, _, _, total_loss, total_step = dataset_running( sess, model, train_dataset, dim_info, config, is_training=True, show_info=True) total_loss += total_pl_loss total_step += total_pl_step avg_loss = total_loss / total_step now_time = datetime.now() tf.logging.info( f"Epoch: {epoch} Average Loss: {avg_loss} ({(now_time - start_time).total_seconds():.2f} sec)" ) if FLAGS.do_eval: dev_ground_true, dev_prediction, dev_texts, dev_loss, dev_step = dataset_running( sess, model, dev_dataset, dim_info, config, is_training=False) p, r, f = processor.evaluate(dev_prediction, dev_ground_true) # if saver.handle(f, sess, epoch, FLAGS.pl_domain if FLAGS.pl_domain else None): # heapq.heappush(best_heap, (f, epoch)) # if len(best_heap) > 3: # heapq.heappop(best_heap) # best_epoch = epoch # else: # if epoch - best_epoch >= FLAGS.early_stop_epochs and FLAGS.early_stop: # tf.logging.info(f"Early Stop Best F1: {best_valid_f1}") # break tf.logging.info( "Epoch: %d Dev Dataset Precision: %.5f Recall: %.5f F1: %.5f" % (epoch, p, r, f)) for rank, (top_f, top_epoch) in enumerate( sorted(best_heap, reverse=True)): tf.logging.info("Top %d: Epoch: %d F1: %.5f" % (rank + 1, top_epoch, top_f)) if FLAGS.debug_mode: for domain, test_dataset in test_dataset_map.items(): predict_ground_truth, predict_prediction, predict_texts, predict_loss, predict_step = dataset_running( sess, model, test_dataset, dim_info, config, is_training=False) p, r, f = processor.evaluate(predict_prediction, predict_ground_truth) tf.logging.info('%s Domain: %s Test: P:%f R:%f F1:%f' % (FLAGS.data_dir, domain, p, r, f)) tokens = list( map(lambda x: cxt_feature_extractor.restore(x), predict_texts)) texts = [ list(map(lambda t: utils.printable_text(t), token)) for token in tokens ] processor.segment(texts, predict_prediction, FLAGS.output_dir, f"{domain}_predict") processor.segment(texts, predict_ground_truth, FLAGS.output_dir, f"{domain}_predict_golden") now_time = datetime.now() tf.logging.info(f"Train Spent: {now_time - very_start_time} sec")
def convert_single_example(ex_index, example: InputExample, tokenizer, label_map, dict_builder=None): """Converts a single `InputExample` into a single `InputFeatures`.""" # label_map = {"B": 0, "M": 1, "E": 2, "S": 3} # tokens_raw = tokenizer.tokenize(example.text) tokens_raw = list(example.text) labels_raw = example.labels # Account for [CLS] and [SEP] with "- 2" # The convention in BERT is: # (b) For single sequences: # tokens: [CLS] the dog is hairy . [SEP] # type_ids: 0 0 0 0 0 0 0 # # Where "type_ids" are used to indicate whether this is the first # sequence or the second sequence. The embedding vectors for `type=0` and # `type=1` were learned during pre-training and are added to the wordpiece # embedding vector (and position vector). This is not *strictly* necessary # since the [SEP] token unambiguously separates the sequences, but it makes # it easier for the model to learn the concept of sequences. # # For classification tasks, the first vector (corresponding to [CLS]) is # used as as the "sentence vector". Note that this only makes sense because # the entire model is fine-tuned. tokens = [] label_ids = [] for token, label in zip(tokens_raw, labels_raw): tokens.append(token) label_ids.append(label_map[label]) input_ids = tokenizer.convert_tokens_to_ids(tokens) if dict_builder is None: input_dicts = np.zeros_like(tokens_raw, dtype=np.int64) else: input_dicts = dict_builder.extract(tokens) seq_length = len(tokens) assert seq_length == len(input_ids) assert seq_length == len(input_dicts) assert seq_length == len(label_ids) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. if ex_index < 1: tf.logging.info("*** Example ***") tf.logging.info("guid: %s" % example.guid) tf.logging.info("tokens: %s" % " ".join([utils.printable_text(x) for x in tokens])) tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_dicts])) tf.logging.info("labels: %s" % " ".join([str(x) for x in example.labels])) tf.logging.info("labels_ids: %s" % " ".join([str(x) for x in label_ids])) feature = InputFeatures(input_ids=input_ids, input_dicts=input_dicts, label_ids=label_ids, seq_length=seq_length) return feature
def main(_): if FLAGS.do_train: tf.logging.set_verbosity(tf.logging.INFO) if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: raise ValueError( "At least one of `train`, `eval` or `predict' must be select.") tf.gfile.MakeDirs(FLAGS.output_dir) if FLAGS.bigram_file is not None: tokenizer = tokenization.WindowBigramTokenizer( vocab_file=FLAGS.vocab_file, bigram_file=FLAGS.bigram_file, do_lower_case=FLAGS.do_lower_case, window_size=FLAGS.window_size) else: tokenizer = tokenization.WindowTokenizer( vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case, window_size=FLAGS.window_size) dict_builder = None if FLAGS.dict_file is not None: dict_builder = dictionary_builder.DefaultDictionaryBuilder(FLAGS.dict_file, min_word_len=FLAGS.min_word_len, max_word_len=FLAGS.max_word_len) augm = augmenter.DefaultAugmenter(FLAGS.dict_augment_rate) session_config = tf.ConfigProto() session_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory run_config = tf.estimator.RunConfig( model_dir=FLAGS.output_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps).replace(session_config=session_config) processor = getattr(process, FLAGS.processor)() train_examples = None num_early_steps = None num_train_steps = None num_warmup_steps = None if FLAGS.do_train: train_examples = processor.get_train_examples(FLAGS.data_dir) single_epoch_steps = int(len(train_examples) / FLAGS.train_batch_size) num_train_steps = int(single_epoch_steps * FLAGS.num_train_epochs) num_early_steps = int(single_epoch_steps * 5) num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) cls = None if FLAGS.model == "baseline": cls = models.BaselineModel elif FLAGS.model == "dict_concat": cls = models.DictConcatModel elif FLAGS.model == "dict_hyper": cls = models.DictHyperModel elif FLAGS.model == "attend_dict": cls = models.AttendedDictModel elif FLAGS.model == "attend_input": cls = models.AttendedInputModel elif FLAGS.model == "dual_dict": cls = models.DictConcatModel assert FLAGS.bigram_file is not None, "dual_dict must need bigram file" tokenizer = tokenization.WindowNgramTokenizer( vocab_file=FLAGS.vocab_file, ngram_file=FLAGS.bigram_file, do_lower_case=FLAGS.do_lower_case, window_size=FLAGS.window_size) if dict_builder is None: dict_builder = dictionary_builder.DefaultDictionaryBuilder(FLAGS.bigram_file, min_word_len=FLAGS.min_word_len, max_word_len=FLAGS.max_word_len) augm = augmenter.DualAugmenter(FLAGS.window_size) config = ModelConfig.from_json_file(FLAGS.config_file) model_fn = model_fn_builder( cls, config=config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, tokenizer=tokenizer, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, init_embedding=FLAGS.init_embedding) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. estimator = tf.estimator.Estimator( model_fn=model_fn, config=run_config) if FLAGS.do_train: train_file = os.path.join(FLAGS.data_dir, "train.tf_record") process.file_based_convert_examples_to_features( examples=train_examples, tokenizer=tokenizer, dict_builder=dict_builder, label_map=processor.get_labels(), output_file=train_file) tf.logging.info("***** Running training *****") tf.logging.info(" Num examples = %d", len(train_examples)) tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) tf.logging.info(" Num steps = %d", num_train_steps) train_input_fn = file_based_input_fn_builder( input_file=train_file, batch_size=FLAGS.train_batch_size, is_training=True, drop_remainder=True, input_dim=tokenizer.dim, dict_dim=dict_builder.dim if dict_builder is not None else 1, shuffle_buffer=len(train_examples), augmenter=augm ) eval_input_fn = None if FLAGS.do_eval: dev_file = os.path.join(FLAGS.data_dir, "dev.tf_record") dev_examples = processor.get_dev_examples(FLAGS.data_dir) process.file_based_convert_examples_to_features( examples=dev_examples, tokenizer=tokenizer, dict_builder=dict_builder, label_map=processor.get_labels(), output_file=dev_file) tf.logging.info("***** Running evaluation *****") tf.logging.info(" Num examples = %d", len(dev_examples)) tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) eval_input_fn = file_based_input_fn_builder( input_file=dev_file, batch_size=FLAGS.eval_batch_size, is_training=False, drop_remainder=False, input_dim=tokenizer.dim, dict_dim=dict_builder.dim if dict_builder is not None else 1) if FLAGS.early_stop: print("using early stop") assert eval_input_fn is not None, "early_stop request do_eval" early_stopping = tf.contrib.estimator.stop_if_no_increase_hook( estimator, metric_name='eval_accuracy', max_steps_without_increase=num_early_steps, min_steps=num_early_steps, run_every_secs=None, run_every_steps=single_epoch_steps) tf.estimator.train_and_evaluate(estimator, train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]), eval_spec=tf.estimator.EvalSpec(eval_input_fn, throttle_secs=60)) else: if FLAGS.do_eval: print("do not use early stop") tf.estimator.train_and_evaluate(estimator, train_spec=tf.estimator.TrainSpec(train_input_fn, max_steps=num_train_steps), eval_spec=tf.estimator.EvalSpec(eval_input_fn, throttle_secs=60)) else: estimator.train(train_input_fn, max_steps=num_train_steps) if FLAGS.do_predict: test_file = os.path.join(FLAGS.data_dir, "test.tf_record") test_examples = processor.get_test_examples(FLAGS.data_dir) process.file_based_convert_examples_to_features( examples=test_examples, tokenizer=tokenizer, dict_builder=dict_builder, label_map=processor.get_labels(), output_file=test_file) tf.logging.info("***** Running prediction*****") tf.logging.info(" Num examples = %d", len(test_examples)) tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) predict_input_fn = file_based_input_fn_builder( input_file=test_file, batch_size=FLAGS.predict_batch_size, is_training=False, drop_remainder=False, input_dim=tokenizer.dim, dict_dim=dict_builder.dim if dict_builder is not None else 1) predictions = [] ground_truths = [] texts = [] for result in estimator.predict(input_fn=predict_input_fn, yield_single_examples=True): input_ids = result["input_ids"].astype(int) prediction = result["prediction"].astype(int) ground_truth = result["ground_truths"].astype(int) length = int(result["length"]) if length == 0: continue tokens = tokenizer.convert_ids_to_tokens(input_ids[:length]) predictions.append(prediction[:length].tolist()) ground_truths.append(ground_truth[:length].tolist()) text = [utils.printable_text(x) for x in tokens] texts.append(text) P, R, F = processor.evaluate_word_PRF(predictions, ground_truths) print('%s Test: P:%f R:%f F:%f' % (FLAGS.data_dir, P, R, F)) processor.convert_word_segmentation(texts, predictions, FLAGS.output_dir, "predict") processor.convert_word_segmentation(texts, ground_truths, FLAGS.output_dir, "predict_golden")
def convert_examples_to_features(config, examples, sp_model, max_seq_length, doc_stride, max_query_length, is_training, output_fn): print('reading and save recored ....') cnt_pos, cnt_neg = 0, 0 unique_id = 1000000000 max_N, max_M = 1024, 1024 f = np.zeros((max_N, max_M), dtype=np.float32) for (example_index, example) in enumerate(examples): if example_index % 100 == 0: print('Converting {}/{} pos {} neg {}'.format( example_index, len(examples), cnt_pos, cnt_neg)) query_tokens = encode_ids( sp_model, preprocess_text(example.question_text, lower=config.uncased)) if len(query_tokens) > max_query_length: query_tokens = query_tokens[0:max_query_length] paragraph_text = example.paragraph_text para_tokens = encode_pieces( sp_model, preprocess_text(example.paragraph_text, lower=config.uncased)) chartok_to_tok_index = [] tok_start_to_chartok_index = [] tok_end_to_chartok_index = [] char_cnt = 0 for i, token in enumerate(para_tokens): chartok_to_tok_index.extend([i] * len(token)) tok_start_to_chartok_index.append(char_cnt) char_cnt += len(token) tok_end_to_chartok_index.append(char_cnt - 1) tok_cat_text = ''.join(para_tokens).replace(SPIECE_UNDERLINE, ' ') N, M = len(paragraph_text), len(tok_cat_text) if N > max_N or M > max_M: max_N = max(N, max_N) max_M = max(M, max_M) f = np.zeros((max_N, max_M), dtype=np.float32) gc.collect() g = {} def _lcs_match(max_dist): f.fill(0) g.clear() ### longest common sub sequence # f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j)) for i in range(N): # note(zhiliny): # unlike standard LCS, this is specifically optimized for the setting # because the mismatch between sentence pieces and original text will # be small for j in range(i - max_dist, i + max_dist): if j >= M or j < 0: continue if i > 0: g[(i, j)] = 0 f[i, j] = f[i - 1, j] if j > 0 and f[i, j - 1] > f[i, j]: g[(i, j)] = 1 f[i, j] = f[i, j - 1] f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0 if (preprocess_text(paragraph_text[i], lower=config.uncased, remove_space=False) == tok_cat_text[j] and f_prev + 1 > f[i, j]): g[(i, j)] = 2 f[i, j] = f_prev + 1 max_dist = abs(N - M) + 5 for _ in range(2): _lcs_match(max_dist) if f[N - 1, M - 1] > 0.8 * N: break max_dist *= 2 orig_to_chartok_index = [None] * N chartok_to_orig_index = [None] * M i, j = N - 1, M - 1 while i >= 0 and j >= 0: if (i, j) not in g: break if g[(i, j)] == 2: orig_to_chartok_index[i] = j chartok_to_orig_index[j] = i i, j = i - 1, j - 1 elif g[(i, j)] == 1: j = j - 1 else: i = i - 1 if all(v is None for v in orig_to_chartok_index) or f[N - 1, M - 1] < 0.8 * N: print('MISMATCH DETECTED!') continue tok_start_to_orig_index = [] tok_end_to_orig_index = [] for i in range(len(para_tokens)): start_chartok_pos = tok_start_to_chartok_index[i] end_chartok_pos = tok_end_to_chartok_index[i] start_orig_pos = _convert_index(chartok_to_orig_index, start_chartok_pos, N, is_start=True) end_orig_pos = _convert_index(chartok_to_orig_index, end_chartok_pos, N, is_start=False) tok_start_to_orig_index.append(start_orig_pos) tok_end_to_orig_index.append(end_orig_pos) if not is_training: tok_start_position = tok_end_position = None if is_training and example.is_impossible: tok_start_position = -1 tok_end_position = -1 if is_training and not example.is_impossible: start_position = example.start_position end_position = start_position + len(example.orig_answer_text) - 1 start_chartok_pos = _convert_index(orig_to_chartok_index, start_position, is_start=True) tok_start_position = chartok_to_tok_index[start_chartok_pos] end_chartok_pos = _convert_index(orig_to_chartok_index, end_position, is_start=False) tok_end_position = chartok_to_tok_index[end_chartok_pos] assert tok_start_position <= tok_end_position def _piece_to_id(x): if six.PY2 and isinstance(x, unicode): x = x.encode('utf-8') return sp_model.PieceToId(x) all_doc_tokens = list(map(_piece_to_id, para_tokens)) # The -3 accounts for [CLS], [SEP] and [SEP] max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 # We can have documents that are longer than the maximum sequence length. # To deal with this we do a sliding window approach, where we take chunks # of the up to our max length with a stride of `doc_stride`. _DocSpan = collections.namedtuple( # pylint: disable=invalid-name "DocSpan", ["start", "length"]) doc_spans = [] start_offset = 0 while start_offset < len(all_doc_tokens): length = len(all_doc_tokens) - start_offset if length > max_tokens_for_doc: length = max_tokens_for_doc doc_spans.append(_DocSpan(start=start_offset, length=length)) if start_offset + length == len(all_doc_tokens): break start_offset += min(length, doc_stride) for (doc_span_index, doc_span) in enumerate(doc_spans): tokens = [] token_is_max_context = {} segment_ids = [] p_mask = [] cur_tok_start_to_orig_index = [] cur_tok_end_to_orig_index = [] for i in range(doc_span.length): split_token_index = doc_span.start + i cur_tok_start_to_orig_index.append( tok_start_to_orig_index[split_token_index]) cur_tok_end_to_orig_index.append( tok_end_to_orig_index[split_token_index]) is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index) token_is_max_context[len(tokens)] = is_max_context tokens.append(all_doc_tokens[split_token_index]) segment_ids.append(SEG_ID_P) p_mask.append(0) paragraph_len = len(tokens) tokens.append(SEP_ID) segment_ids.append(SEG_ID_P) p_mask.append(1) # note(zhiliny): we put P before Q # because during pretraining, B is always shorter than A for token in query_tokens: tokens.append(token) segment_ids.append(SEG_ID_Q) p_mask.append(1) tokens.append(SEP_ID) segment_ids.append(SEG_ID_Q) p_mask.append(1) cls_index = len(segment_ids) tokens.append(CLS_ID) segment_ids.append(SEG_ID_CLS) p_mask.append(0) input_ids = tokens # The mask has 0 for real tokens and 1 for padding tokens. Only real # tokens are attended to. input_mask = [0] * len(input_ids) # Zero-pad up to the sequence length. while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(1) segment_ids.append(SEG_ID_PAD) p_mask.append(1) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length assert len(p_mask) == max_seq_length span_is_impossible = example.is_impossible start_position = None end_position = None if is_training and not span_is_impossible: # For training, if our document chunk does not contain an annotation # we throw it out, since there is nothing to predict. doc_start = doc_span.start doc_end = doc_span.start + doc_span.length - 1 out_of_span = False if not (tok_start_position >= doc_start and tok_end_position <= doc_end): out_of_span = True if out_of_span: # continue start_position = 0 end_position = 0 span_is_impossible = True else: # note(zhiliny): we put P before Q, so doc_offset should be zero. # doc_offset = len(query_tokens) + 2 doc_offset = 0 start_position = tok_start_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset if is_training and span_is_impossible: start_position = cls_index end_position = cls_index if example_index < 20: print("*** Example ***") print("unique_id: %s" % (unique_id)) print("example_index: %s" % (example_index)) print("doc_span_index: %s" % (doc_span_index)) print("tok_start_to_orig_index: %s" % " ".join([str(x) for x in cur_tok_start_to_orig_index])) print("tok_end_to_orig_index: %s" % " ".join([str(x) for x in cur_tok_end_to_orig_index])) print("token_is_max_context: %s" % " ".join([ "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context) ])) print("input_ids: %s" % " ".join([str(x) for x in input_ids])) print("input_mask: %s" % " ".join([str(x) for x in input_mask])) print("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) if is_training and span_is_impossible: print("impossible example span") if is_training and not span_is_impossible: pieces = [ sp_model.IdToPiece(token) for token in tokens[start_position:(end_position + 1)] ] answer_text = sp_model.DecodePieces(pieces) print("start_position: %d" % (start_position)) print("end_position: %d" % (end_position)) print("answer: %s" % (printable_text(answer_text))) # note(zhiliny): With multi processing, # the example_index is actually the index within the current process # therefore we use example_index=None to avoid being used in the future. # The current code does not use example_index of training data. feat_example_index = example_index feature = InputFeatures( unique_id=unique_id, example_index=feat_example_index, doc_span_index=doc_span_index, tok_start_to_orig_index=cur_tok_start_to_orig_index, tok_end_to_orig_index=cur_tok_end_to_orig_index, token_is_max_context=token_is_max_context, input_ids=input_ids, input_mask=input_mask, p_mask=p_mask, segment_ids=segment_ids, paragraph_len=paragraph_len, cls_index=cls_index, start_position=start_position, end_position=end_position, is_impossible=span_is_impossible) # Run callback output_fn(feature) unique_id += 1 if span_is_impossible: cnt_neg += 1 else: cnt_pos += 1