def sample_data(dump_paths, para=False, doc_sample_ratio=0.2, vec_sample_ratio=0.2, seed=29, max_norm=None, max_norm_cf=1.3, num_dummy_zeros=0, norm_th=999): vecs = [] random.seed(seed) np.random.seed(seed) print('sampling from:') for dump_path in dump_paths: print(dump_path) dumps = [h5py.File(dump_path, 'r') for dump_path in dump_paths] for i, f in enumerate(tqdm(dumps)): doc_ids = list(f.keys()) sampled_doc_ids = random.sample(doc_ids, int(doc_sample_ratio * len(doc_ids))) for doc_id in tqdm(sampled_doc_ids, desc='sampling from %d' % i): doc_group = f[doc_id] if para: groups = doc_group.values() else: groups = [doc_group] for group in groups: num_vecs, d = group['start'].shape sampled_vec_idxs = np.random.choice( num_vecs, int(vec_sample_ratio * num_vecs)) cur_vecs = int8_to_float( group['start'][:], group.attrs['offset'], group.attrs['scale'])[sampled_vec_idxs] cur_vecs = cur_vecs[ np.linalg.norm(cur_vecs, axis=1) <= norm_th] vecs.append(cur_vecs) out = np.concatenate(vecs, 0) for dump in dumps: dump.close() norms = np.linalg.norm(out, axis=1, keepdims=True) if max_norm is None: max_norm = max_norm_cf * np.max(norms) consts = np.sqrt(np.maximum(0.0, max_norm**2 - norms**2)) out = np.concatenate([consts, out], axis=1) if num_dummy_zeros > 0: out = np.concatenate( [out, np.zeros([out.shape[0], num_dummy_zeros], dtype=out.dtype)], axis=1) return out, max_norm
def dequant(group, input_): if 'offset' in group.attrs: return int8_to_float(input_, group.attrs['offset'], group.attrs['scale']) return input_
def search_start(self, query_start, doc_idxs=None, para_idxs=None, start_top_k=100, mid_top_k=20, out_top_k=5, nprobe=16, q_texts=None, doc_top_k=5, search_strategy='dense_first'): # doc_idxs = [Q], para_idxs = [Q] assert self.start_index is not None query_start = query_start.astype(np.float32) # Open if doc_idxs is None: if not len(self.sparse_type) == 0: q_spvecs = vstack([self.ranker.text2spvec(q) for q in q_texts]) doc_scores = np.squeeze((q_spvecs * self.ranker.doc_mat).toarray()) if search_strategy == 'dense_first': (doc_idxs, para_idxs, start_idxs), start_scores = self.search_dense(query_start, start_top_k, nprobe, doc_scores) elif search_strategy == 'sparse_first': (doc_idxs, start_idxs), start_scores = self.search_sparse(query_start, doc_scores, doc_top_k) elif search_strategy == 'hybrid': (doc_idxs, para_idxs, start_idxs), start_scores = self.search_dense(query_start, start_top_k, nprobe, doc_scores) (doc_idxs_, start_idxs_), start_scores_ = self.search_sparse(query_start, doc_scores, doc_top_k) doc_idxs = np.concatenate([doc_idxs, doc_idxs_], -1) start_idxs = np.concatenate([start_idxs, start_idxs_], -1) start_scores = np.concatenate([start_scores, start_scores_], -1) else: raise ValueError(search_strategy) mid_top_k = min(mid_top_k, start_scores.shape[-1]) out_top_k = min(out_top_k, start_scores.shape[-1]) # rerank and reduce rerank_idxs = start_scores.argsort()[-mid_top_k:][::-1] doc_idxs = doc_idxs[rerank_idxs] start_idxs = start_idxs[rerank_idxs] start_scores = start_scores[rerank_idxs] # Para and rerank and reduce doc_idxs = np.reshape(doc_idxs, [-1]) if self.para: para_idxs = np.reshape(para_idxs, [-1]) start_idxs = np.reshape(start_idxs, [-1]) start_scores = np.reshape(start_scores, [-1]) start_scores += self.sparse_weight * self.get_para_scores(q_spvecs, doc_idxs, start_idxs=start_idxs) rerank_scores = np.reshape(start_scores, [-1, mid_top_k]) rerank_idxs = np.array([scores.argsort()[-out_top_k:][::-1] for scores in rerank_scores]) doc_idxs = doc_idxs[rerank_idxs] if para_idxs is not None: para_idxs = para_idxs[rerank_idxs] start_idxs = start_idxs[rerank_idxs] start_scores = start_scores[rerank_idxs] # Closed else: groups = [self.get_doc_group(doc_idx)[str(para_idx)] for doc_idx, para_idx in zip(doc_idxs, para_idxs)] starts = [group['start'][:, :] for group in groups] starts = [int8_to_float(start, groups[0].attrs['offset'], groups[0].attrs['scale']) for start in starts] all_scores = [np.squeeze(np.matmul(start, query_start[i:i + 1, :].transpose()), -1) for i, start in enumerate(starts)] start_idxs = np.array([scores.argsort()[-out_top_k:][::-1] for scores in all_scores]) start_scores = np.array([scores[idxs] for scores, idxs in zip(all_scores, start_idxs)]) doc_idxs = np.tile(np.expand_dims(doc_idxs, -1), [1, out_top_k]) para_idxs = np.tile(np.expand_dims(para_idxs, -1), [1, out_top_k]) return start_scores, doc_idxs, para_idxs, start_idxs
def add_to_index(dump_paths, trained_index_path, target_index_path, idx2id_path, max_norm, para=False, num_docs_per_add=1000, num_dummy_zeros=0, cuda=False, fine_quant='SQ8', offset=0, norm_th=999, ignore_ids=None): idx2doc_id = [] idx2para_id = [] idx2word_id = [] dumps = [h5py.File(dump_path, 'r') for dump_path in dump_paths] print('reading %s' % trained_index_path) start_index = faiss.read_index(trained_index_path) if cuda: if fine_quant.startswith('PQ'): print('PQ not supported on GPU; keeping CPU.') else: res = faiss.StandardGpuResources() start_index = faiss.index_cpu_to_gpu(res, 0, start_index) print('adding following dumps:') for dump_path in dump_paths: print(dump_path) if para: for di, phrase_dump in enumerate(tqdm(dumps, desc='dumps')): starts = [] for i, (doc_idx, doc_group) in enumerate( tqdm(phrase_dump.items(), desc='faiss indexing')): for para_idx, group in doc_group.items(): num_vecs = group['start'].shape[0] start = int8_to_float(group['start'][:], group.attrs['offset'], group.attrs['scale']) norms = np.linalg.norm(start, axis=1, keepdims=True) consts = np.sqrt(np.maximum(0.0, max_norm**2 - norms**2)) start = np.concatenate([consts, start], axis=1) if num_dummy_zeros > 0: start = np.concatenate([ start, np.zeros([start.shape[0], num_dummy_zeros], dtype=start.dtype) ], axis=1) starts.append(start) idx2doc_id.extend([int(doc_idx)] * num_vecs) idx2para_id.extend([int(para_idx)] * num_vecs) idx2word_id.extend(list(range(num_vecs))) if len(starts) > 0 and i % num_docs_per_add == 0: print('concatenating') concat = np.concatenate(starts, axis=0) print('adding') add_with_offset(start_index, concat, offset) # start_index.add(concat) print('done') starts = [] if i % 100 == 0: print('%d/%d' % (i + 1, len(phrase_dump.keys()))) print('adding leftover') add_with_offset(start_index, np.concatenate(starts, axis=0), offset) # start_index.add(np.concatenate(starts, axis=0)) # leftover print('done') else: for di, phrase_dump in enumerate(tqdm(dumps, desc='dumps')): starts = [] valids = [] for i, (doc_idx, doc_group) in enumerate( tqdm(phrase_dump.items(), desc='adding %d' % di)): if ignore_ids is not None and doc_idx in ignore_ids: continue num_vecs = doc_group['start'].shape[0] start = int8_to_float(doc_group['start'][:], doc_group.attrs['offset'], doc_group.attrs['scale']) valid = np.linalg.norm(start, axis=1) <= norm_th norms = np.linalg.norm(start, axis=1, keepdims=True) consts = np.sqrt(np.maximum(0.0, max_norm**2 - norms**2)) start = np.concatenate([consts, start], axis=1) if num_dummy_zeros > 0: start = np.concatenate([ start, np.zeros([start.shape[0], num_dummy_zeros], dtype=start.dtype) ], axis=1) starts.append(start) valids.append(valid) idx2doc_id.extend([int(doc_idx)] * num_vecs) idx2word_id.extend(range(num_vecs)) if len(starts) > 0 and i % num_docs_per_add == 0: add_with_offset(start_index, np.concatenate(starts, axis=0), offset, np.concatenate(valids)) # start_index.add(np.concatenate(starts, axis=0)) starts = [] valids = [] if i % 100 == 0: print('%d/%d' % (i + 1, len(phrase_dump.keys()))) add_with_offset(start_index, np.concatenate(starts, axis=0), offset, np.concatenate(valids)) # start_index.add(np.concatenate(starts, axis=0)) # leftover for dump in dumps: dump.close() if cuda and not fine_quant.startswith('PQ'): print('moving back to cpu') start_index = faiss.index_gpu_to_cpu(start_index) print('index ntotal: %d' % start_index.ntotal) idx2doc_id = np.array(idx2doc_id, dtype=np.int32) idx2para_id = np.array(idx2para_id, dtype=np.int32) idx2word_id = np.array(idx2word_id, dtype=np.int32) print('writing index and metadata') with h5py.File(idx2id_path, 'w') as f: g = f.create_group(str(offset)) g.create_dataset('doc', data=idx2doc_id) g.create_dataset('para', data=idx2para_id) g.create_dataset('word', data=idx2word_id) g.attrs['offset'] = offset faiss.write_index(start_index, target_index_path) print('done')