def transform(self, df): queries = [] docs = [] idx_by_query = {} query_idxs = [] # We do not want to redo the calculation of query representations, but due to logging # in the ance package, doing a groupby or pt.apply.by_query here will result in # excessive log messages. So we instead calculate each query rep once and keep track of # the correspeonding index so we can project back out the original sequence for q in df["query"].to_list(): if q in idx_by_query: query_idxs.append(idx_by_query[q]) else: passage = self.tokenizer.encode( q, add_special_tokens=True, max_length=self.args.max_seq_length, ) passage_len = min(len(passage), self.args.max_query_length) input_id_b = pad_input_ids(passage, self.args.max_query_length) queries.append([passage_len, input_id_b]) qidx = len(idx_by_query) idx_by_query[q] = qidx query_idxs.append(qidx) for d in df[self.text_field].to_list(): passage = self.tokenizer.encode( d, add_special_tokens=True, max_length=self.args.max_seq_length, ) passage_len = min(len(passage), self.args.max_seq_length) input_id_b = pad_input_ids(passage, self.args.max_seq_length) docs.append([passage_len, input_id_b]) query_embeddings, _ = StreamInferenceDoc(self.args, self.model, GetProcessingFn(self.args, query=True), "transform", queries, is_query_inference=True) passage_embeddings, _ = StreamInferenceDoc(self.args, self.model, GetProcessingFn( self.args, query=False), "transform", docs, is_query_inference=False) # project out the query representations (see comment above) query_embeddings = query_embeddings[query_idxs] scores = (query_embeddings * passage_embeddings).sum(axis=1) return df.assign(score=scores)
def PassagePreprocessingFn(args, line, tokenizer): if args.data_type == 0: line_arr = line.split('\t') p_id = int(line_arr[0][1:]) # remove "D" url = line_arr[1].rstrip() title = line_arr[2].rstrip() p_text = line_arr[3].rstrip() full_text = url + "<sep>" + title + "<sep>" + p_text # keep only first 10000 characters, should be sufficient for any # experiment that uses less than 500 - 1k tokens full_text = full_text[:args.max_doc_character] else: line = line.strip() line_arr = line.split('\t') p_id = int(line_arr[0]) p_text = line_arr[1].rstrip() # keep only first 10000 characters, should be sufficient for any # experiment that uses less than 500 - 1k tokens full_text = p_text[:args.max_doc_character] passage = tokenizer.encode( full_text, add_special_tokens=True, max_length=args.max_seq_length, ) passage_len = min(len(passage), args.max_seq_length) input_id_b = pad_input_ids(passage, args.max_seq_length) return p_id.to_bytes(8, 'big') + passage_len.to_bytes(4, 'big') + np.array( input_id_b, np.int32).tobytes()
def QueryPreprocessingFn(args, line, tokenizer): line_arr = line.split('\t') q_id = int(line_arr[0]) passage = tokenizer.encode(line_arr[1].rstrip(), add_special_tokens=True, max_length=args.max_query_length) passage_len = min(len(passage), args.max_query_length) input_id_b = pad_input_ids(passage, args.max_query_length) return q_id.to_bytes(8, 'big') + passage_len.to_bytes(4, 'big') + np.array( input_id_b, np.int32).tobytes()
def transform(self, topics): from pyterrier import tqdm queries = [] qid2q = {} for q, qid in zip(topics["query"].to_list(), topics["qid"].to_list()): passage = self.tokenizer.encode( q, add_special_tokens=True, max_length=self.args.max_seq_length, ) passage_len = min(len(passage), self.args.max_query_length) input_id_b = pad_input_ids(passage, self.args.max_query_length) queries.append([passage_len, input_id_b]) qid2q[qid] = q print("***** inference of %d queries *****" % len(queries)) dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc( self.args, self.model, GetProcessingFn(self.args, query=True), "transform", queries, is_query_inference=True) print("***** faiss search for %d queries on %d shards *****" % (len(queries), self.segments)) rtr = [] for i, offset in enumerate(tqdm(self.shard_offsets, unit="shard")): scores, neighbours = self.cpu_index[i].search( dev_query_embedding, self.num_results) res = self._calc_scores(topics["qid"].values, self.passage_embedding2id[i], neighbours, scores, num_results=self.num_results, offset=offset, qid2q=qid2q) rtr.append(res) rtr = pd.concat(rtr) rtr = add_ranks(rtr) rtr = rtr[rtr["rank"] < self.num_results] rtr = rtr.sort_values(by=["qid", "score", "docno"], ascending=[True, False, True]) return rtr
def gen_tokenize(): text_attr = self.text_attr kwargs = {} if self.num_docs is not None: kwargs['total'] = self.num_docs for doc in pt.tqdm(generator, desc="Indexing", unit="d", ** kwargs) if self.verbose else generator: contents = doc[text_attr] docid2docno.append(doc["docno"]) passage = tokenizer.encode( contents, add_special_tokens=True, max_length=self.args.max_seq_length, ) passage_len = min(len(passage), self.args.max_seq_length) input_id_b = pad_input_ids(passage, self.args.max_seq_length) yield passage_len, input_id_b