Example #1
0
 def get_ptune_query(
     self,
     content: Dict,
     prompt_token_id: int,
     max_seq_len: int,
     templates: List[int],
     tokenizer: TokenizerSpec,
 ):
     text_a = content['sentence']
     text_b = content['question']
     sentence_a = f" Paragraph: {text_a}"
     sentence_b = f" Question: {text_b}?"
     a_input_token_ids = tokenizer.text_to_ids(sentence_a)
     b_input_token_ids = tokenizer.text_to_ids(sentence_b)
     c_input_token_ids = tokenizer.text_to_ids(" Answer:")
     cut = 0
     total_num_ids = len(a_input_token_ids) + len(b_input_token_ids) + len(
         c_input_token_ids) + sum(templates)
     if total_num_ids > max_seq_len:
         logging.warning(
             "Input sequence is longer than the LM model max seq, will cut it off to fit"
         )
         cut = total_num_ids - max_seq_len
     return ([prompt_token_id] * templates[0] + a_input_token_ids[cut:] +
             [prompt_token_id] * templates[1] + b_input_token_ids +
             [prompt_token_id] * templates[2] + c_input_token_ids)
Example #2
0
 def get_ptune_query(
     self,
     content: Dict,
     prompt_token_id: int,
     max_seq_len: int,
     templates: List[int],
     tokenizer: TokenizerSpec,
 ):
     all_ids = []
     limits = []
     for piece in self.pieces:
         if isinstance(piece, str):
             # replace variables if any
             variables = re.findall(r'{\w*}', piece)
             variable_text = {}
             limit_length = False
             for var in variables:
                 varname = var[1:-1]
                 variable_text[varname] = content[varname]
                 if varname == self.limit_length_field:
                     limit_length = True
             text = piece.format(**variable_text)
             text_ids = tokenizer.text_to_ids(text)
             all_ids.append(text_ids)
             limits.append(limit_length)
         else:
             # this is virtual token
             all_ids.append([prompt_token_id] * templates[piece])
             limits.append(False)
     total_num_of_ids = sum([len(i) for i in all_ids])
     if total_num_of_ids > max_seq_len:
         logging.warning(
             "Input sequence is longer than the LM model max seq, will cut it off to fit"
         )
         cut = total_num_of_ids - max_seq_len
         new_ids = []
         for i in range(len(limits)):
             if limits[i]:
                 if len(all_ids[i]) < cut:
                     raise ValueError(
                         f"Some other field length is too long, cutting {self.limit_length_field} is not enough"
                     )
                 new_ids.append(all_ids[i][cut:])
             else:
                 new_ids.append(all_ids[i])
         return functools.reduce(lambda x, y: x + y, new_ids)
     else:
         return functools.reduce(lambda x, y: x + y, all_ids)