def compute_q(self, f_df, q_df, return_f_nbow=False): logger.info('Computing question wmds') f_nbow = { row.Index: self.nbowify(row.Index, row.original) for row in f_df.itertuples() } nb_facts = len(f_nbow) q_nbow = { row.Index + nb_facts: self.nbowify(row.Index + nb_facts, row.original) for row in q_df.itertuples() } merged_fnbow = copy.copy(f_nbow) merged_fnbow.update(q_nbow) q_calc = WMD(SpacyEmbeddings(self.nlp), merged_fnbow, vocabulary_min=1, verbosity=logging.WARNING) q_calc.cache_centroids() q_closest = pd.Series( np.array([ i for i, _ in q_calc.nearest_neighbors( idx, k=self.config.nearest_k_visible) if i < nb_facts ]) for idx in tqdm(q_nbow.keys(), desc='Question wmd...')) return (q_closest, f_nbow) if return_f_nbow else q_closest
def compute_f(self, f_df, f_nbow=None): logger.info('Computing fact wmds') f_nbow = { row.Index: self.nbowify(row.Index, row.original) for row in f_df.itertuples() } if f_nbow is None else f_nbow f_calc = WMD(SpacyEmbeddings(self.nlp), f_nbow, vocabulary_min=1, verbosity=logging.WARNING) f_calc.cache_centroids() f_closest = pd.Series( np.array([ i for i, _ in f_calc.nearest_neighbors( idx, k=self.config.nearest_k_visible) ]) for idx in tqdm(f_nbow.keys(), desc='Fact wmd...')) return f_closest
def retrieve(self, top_id: str, k=None, only=None): assert only, 'not searching anything' index = self.db.mapping delta = common.timer() def to_nbow(doc_id): # transform to the nbow model used by wmd.WMD: # ('human readable name', 'item identifiers', 'weights') doc = index[doc_id] return (doc_id, doc.idx, doc.freq) docs = {d: to_nbow(d) for d in only + [top_id]} calc = WMDR(self.emb, docs, vocabulary_min=2) calc.cache_centroids() nn = calc.nearest_neighbors(top_id, k=k) self._times.append(delta()) assert len(nn) == k, f'{len(nn)} not {k}' return [Result(*n) for n in nn]
class SimilarRepositories: GITHUB_URL_RE = re.compile( r"(https://|ssh://git@|git://)(github.com/[^/]+/[^/]+)(|.git|/)") _log = logging.getLogger("SimilarRepositories") def __init__(self, id2vec=None, df=None, nbow=None, prune_df_threshold=1, wmd_cache_centroids=True, wmd_kwargs: Dict[str, Any] = None, languages: Tuple[List, bool] = (None, False), engine_kwargs: Dict[str, Any] = None): backend = create_backend() if id2vec is None: self._id2vec = Id2Vec().load(backend=backend) else: assert isinstance(id2vec, Id2Vec) self._id2vec = id2vec self._log.info("Loaded id2vec model: %s", self._id2vec) if df is None: if df is not False: self._df = DocumentFrequencies().load(backend=backend) else: self._df = None self._log.warning("Disabled document frequencies - you will " "not be able to query custom repositories.") else: assert isinstance(df, DocumentFrequencies) self._df = df if self._df is not None: self._df = self._df.prune(prune_df_threshold) self._log.info("Loaded document frequencies: %s", self._df) if nbow is None: self._bow = BOW().load(backend=backend) else: assert isinstance(nbow, BOW) self._bow = nbow self._log.info("Loaded BOW model: %s", self._bow) assert self._bow.get_dep("id2vec")["uuid"] == self._id2vec.meta["uuid"] if len(self._id2vec) != self._bow.matrix.shape[1]: raise ValueError( "Models do not match: id2vec has %s tokens while nbow has %s" % (len(self._id2vec), self._bow.matrix.shape[1])) self._log.info("Creating the WMD engine...") self._wmd = WMD(self._id2vec.embeddings, self._bow, **(wmd_kwargs or {})) if wmd_cache_centroids: self._wmd.cache_centroids() self._languages = languages self._engine_kwargs = engine_kwargs def query(self, url_or_path_or_name: str, **kwargs) -> List[Tuple[str, float]]: try: repo_index = self._bow.documents.index(url_or_path_or_name) except ValueError: repo_index = -1 if repo_index == -1: match = self.GITHUB_URL_RE.match(url_or_path_or_name) if match is not None: name = match.group(2) try: repo_index = self._bow.documents.index(name) except ValueError: pass if repo_index >= 0: neighbours = self._query_domestic(repo_index, **kwargs) else: neighbours = self._query_foreign(url_or_path_or_name, **kwargs) neighbours = [(self._bow[n[0]][0], n[1]) for n in neighbours] return neighbours def _query_domestic(self, repo_index, **kwargs): return self._wmd.nearest_neighbors(repo_index, **kwargs) def _query_foreign(self, url_or_path: str, **kwargs): df = self._df if df is None: raise ValueError("Cannot query custom repositories if the " "document frequencies are disabled.") with tempfile.TemporaryDirectory(prefix="vecino-") as tempdir: target = os.path.join(tempdir, "repo") if os.path.isdir(url_or_path): url_or_path = os.path.abspath(url_or_path) os.symlink(url_or_path, target, target_is_directory=True) repo_format = "standard" else: self._log.info("Cloning %s to %s", url_or_path, target) porcelain.clone(url_or_path, target, bare=True, outstream=sys.stderr) repo_format = "bare" bow = repo2bow(tempdir, repo_format, 1, df, *self._languages, engine_kwargs=self._engine_kwargs) ibow = {} for key, val in bow.items(): try: ibow[self._id2vec[key]] = val except KeyError: continue words, weights = zip(*sorted(ibow.items())) return self._wmd.nearest_neighbors((words, weights), **kwargs)
class SimilarRepositories: GITHUB_URL_RE = re.compile( r"(https://|ssh://git@|git://)(github.com/[^/]+/[^/]+)(|.git|/)") def __init__(self, id2vec=None, df=None, nbow=None, prune_df_threshold=1, verbosity=logging.DEBUG, wmd_cache_centroids=True, wmd_kwargs=None, gcs_bucket=None, repo2nbow_kwargs=None, initialize_environment=True): if initialize_environment: initialize() self._log = logging.getLogger("similar_repos") self._log.setLevel(verbosity) if gcs_bucket: backend = create_backend(args="bucket=" + gcs_bucket) else: backend = create_backend() if id2vec is None: self._id2vec = Id2Vec(log_level=verbosity).load(backend=backend) else: assert isinstance(id2vec, Id2Vec) self._id2vec = id2vec self._log.info("Loaded id2vec model: %s", self._id2vec) if df is None: if df is not False: self._df = DocumentFrequencies(log_level=verbosity).load( backend=backend) else: self._df = None self._log.warning("Disabled document frequencies - you will " "not be able to query custom repositories.") else: assert isinstance(df, DocumentFrequencies) self._df = df if self._df is not None: self._df = self._df.prune(prune_df_threshold) self._log.info("Loaded document frequencies: %s", self._df) if nbow is None: self._nbow = NBOW(log_level=verbosity).load(backend=backend) else: assert isinstance(nbow, NBOW) self._nbow = nbow self._log.info("Loaded nBOW model: %s", self._nbow) self._repo2nbow = Repo2nBOW(self._id2vec, self._df, log_level=verbosity, **(repo2nbow_kwargs or {})) assert self._nbow.get_dependency( "id2vec")["uuid"] == self._id2vec.meta["uuid"] if len(self._id2vec) != self._nbow.matrix.shape[1]: raise ValueError( "Models do not match: id2vec has %s tokens while nbow has %s" % (len(self._id2vec), self._nbow.matrix.shape[1])) self._log.info("Creating the WMD engine...") self._wmd = WMD(self._id2vec.embeddings, self._nbow, verbosity=verbosity, **(wmd_kwargs or {})) if wmd_cache_centroids: self._wmd.cache_centroids() def query(self, url_or_path_or_name, **kwargs): try: repo_index = self._nbow.repository_index_by_name( url_or_path_or_name) except KeyError: repo_index = -1 if repo_index == -1: match = self.GITHUB_URL_RE.match(url_or_path_or_name) if match is not None: name = match.group(2) try: repo_index = self._nbow.repository_index_by_name(name) except KeyError: pass if repo_index >= 0: neighbours = self._query_domestic(repo_index, **kwargs) else: neighbours = self._query_foreign(url_or_path_or_name, **kwargs) neighbours = [(self._nbow[n[0]][0], n[1]) for n in neighbours] return neighbours @staticmethod def unicorn_query(repo_name, id2vec=None, nbow=None, wmd_kwargs=None, query_wmd_kwargs=None): sr = SimilarRepositories(id2vec=id2vec, df=False, nbow=nbow, wmd_kwargs=wmd_kwargs or { "vocabulary_min": 50, "vocabulary_max": 500 }) return sr.query( repo_name, **(query_wmd_kwargs or { "early_stop": 0.1, "max_time": 180, "skipped_stop": 0.95 })) def _query_domestic(self, repo_index, **kwargs): return self._wmd.nearest_neighbors(repo_index, **kwargs) def _query_foreign(self, url_or_path, **kwargs): if self._df is None: raise ValueError("Cannot query custom repositories if the " "document frequencies are disabled.") nbow_dict = self._repo2nbow.convert_repository(url_or_path) words = sorted(nbow_dict.keys()) weights = [nbow_dict[k] for k in words] return self._wmd.nearest_neighbors((words, weights), **kwargs)