def get_ptune_query( self, content: Dict, prompt_token_id: int, max_seq_len: int, templates: List[int], tokenizer: TokenizerSpec, ): text_a = content['sentence'] text_b = content['question'] sentence_a = f" Paragraph: {text_a}" sentence_b = f" Question: {text_b}?" a_input_token_ids = tokenizer.text_to_ids(sentence_a) b_input_token_ids = tokenizer.text_to_ids(sentence_b) c_input_token_ids = tokenizer.text_to_ids(" Answer:") cut = 0 total_num_ids = len(a_input_token_ids) + len(b_input_token_ids) + len( c_input_token_ids) + sum(templates) if total_num_ids > max_seq_len: logging.warning( "Input sequence is longer than the LM model max seq, will cut it off to fit" ) cut = total_num_ids - max_seq_len return ([prompt_token_id] * templates[0] + a_input_token_ids[cut:] + [prompt_token_id] * templates[1] + b_input_token_ids + [prompt_token_id] * templates[2] + c_input_token_ids)
def get_ptune_query( self, content: Dict, prompt_token_id: int, max_seq_len: int, templates: List[int], tokenizer: TokenizerSpec, ): all_ids = [] limits = [] for piece in self.pieces: if isinstance(piece, str): # replace variables if any variables = re.findall(r'{\w*}', piece) variable_text = {} limit_length = False for var in variables: varname = var[1:-1] variable_text[varname] = content[varname] if varname == self.limit_length_field: limit_length = True text = piece.format(**variable_text) text_ids = tokenizer.text_to_ids(text) all_ids.append(text_ids) limits.append(limit_length) else: # this is virtual token all_ids.append([prompt_token_id] * templates[piece]) limits.append(False) total_num_of_ids = sum([len(i) for i in all_ids]) if total_num_of_ids > max_seq_len: logging.warning( "Input sequence is longer than the LM model max seq, will cut it off to fit" ) cut = total_num_of_ids - max_seq_len new_ids = [] for i in range(len(limits)): if limits[i]: if len(all_ids[i]) < cut: raise ValueError( f"Some other field length is too long, cutting {self.limit_length_field} is not enough" ) new_ids.append(all_ids[i][cut:]) else: new_ids.append(all_ids[i]) return functools.reduce(lambda x, y: x + y, new_ids) else: return functools.reduce(lambda x, y: x + y, all_ids)