def _add_candidate_tokens(self, example: WikiHopExample, begin_sentence_id: int) -> int: """Adds candidate tokens and returns the end_sentence_id. Every candidate is treated as a separate sentence. Args: example: The `WikiHopExample` to add the canddiate tokens to. begin_sentence_id: Begin sentence id to assign to candidates. Returns: end_sentence_id = begin_sentence_id + num_candidates """ sentence_id = begin_sentence_id candidates = example.candidate_answers for (i, candidate) in enumerate(candidates): self.global_paragraph_breakpoints.append(1) self.global_token_ids.append(CANDIDATE_GLOBAL_TOKEN_ID) self.global_token_type_ids.append(CANDIDATE_GLOBAL_TOKEN_TYPE_ID) candidate = tokenization.convert_to_unicode(candidate) candidate = self.tokenizer.tokenize(candidate) candidate_token_ids = self.tokenizer.convert_tokens_to_ids( candidate) if i not in self.cand_to_span_positions: self.cand_to_span_positions[i] = [] # Trivial span addition. Every candidate is present by default in the # long input. self.cand_to_span_positions[i].append( (len(self.long_token_ids), len(self.long_token_ids) + len(candidate_token_ids) - 1)) for token_id in candidate_token_ids: self.long_token_ids.append(token_id) self.long_token_type_ids.append(CANDIDATE_TOKEN_TYPE_ID) self.long_sentence_ids.append(sentence_id) self.long_paragraph_ids.append(-1) self.long_paragraph_breakpoints.append(0) self.long_paragraph_breakpoints[-1] = 1 sentence_id += 1 return sentence_id
def _add_query_tokens(self, example: WikiHopExample, begin_sentence_id: int) -> int: """Adds query tokens to long / global input. We mirror query tokens in global as well, i.e, we will have one global token per query WordPiece. Every WordPiece of the query is treated as a separate sentence. Args: example: The `WikiHopExample` to add the query tokens to. begin_sentence_id: The begin sentence id to be used to start assiging sentence ids to query tokens. Returns: end_sentence_id = begin_sentence_id + num_query_word_pieces """ sentence_id = begin_sentence_id query = example.query query = tokenization.convert_to_unicode(query) query_tokens = self.tokenizer.tokenize(query) query_token_ids = self.tokenizer.convert_tokens_to_ids(query_tokens) for token_id in query_token_ids: self.long_token_ids.append(token_id) self.global_token_ids.append(QUESTION_GLOBAL_TOKEN_ID) self.long_token_type_ids.append(QUESTION_TOKEN_TYPE_ID) self.global_token_type_ids.append(QUESTION_GLOBAL_TOKEN_TYPE_ID) self.long_sentence_ids.append(sentence_id) self.long_paragraph_ids.append(-1) self.long_paragraph_breakpoints.append(0) self.global_paragraph_breakpoints.append(0) sentence_id += 1 self.long_paragraph_breakpoints[-1] = 1 self.global_paragraph_breakpoints[-1] = 1 return sentence_id