예제 #1
0
    def compute_q(self, f_df, q_df, return_f_nbow=False):
        logger.info('Computing question wmds')
        f_nbow = {
            row.Index: self.nbowify(row.Index, row.original)
            for row in f_df.itertuples()
        }
        nb_facts = len(f_nbow)
        q_nbow = {
            row.Index + nb_facts: self.nbowify(row.Index + nb_facts,
                                               row.original)
            for row in q_df.itertuples()
        }

        merged_fnbow = copy.copy(f_nbow)
        merged_fnbow.update(q_nbow)
        q_calc = WMD(SpacyEmbeddings(self.nlp),
                     merged_fnbow,
                     vocabulary_min=1,
                     verbosity=logging.WARNING)
        q_calc.cache_centroids()
        q_closest = pd.Series(
            np.array([
                i for i, _ in q_calc.nearest_neighbors(
                    idx, k=self.config.nearest_k_visible) if i < nb_facts
            ]) for idx in tqdm(q_nbow.keys(), desc='Question wmd...'))
        return (q_closest, f_nbow) if return_f_nbow else q_closest
예제 #2
0
    def compute_f(self, f_df, f_nbow=None):
        logger.info('Computing fact wmds')
        f_nbow = {
            row.Index: self.nbowify(row.Index, row.original)
            for row in f_df.itertuples()
        } if f_nbow is None else f_nbow

        f_calc = WMD(SpacyEmbeddings(self.nlp),
                     f_nbow,
                     vocabulary_min=1,
                     verbosity=logging.WARNING)
        f_calc.cache_centroids()
        f_closest = pd.Series(
            np.array([
                i for i, _ in f_calc.nearest_neighbors(
                    idx, k=self.config.nearest_k_visible)
            ]) for idx in tqdm(f_nbow.keys(), desc='Fact wmd...'))
        return f_closest
예제 #3
0
    def retrieve(self, top_id: str, k=None, only=None):
        assert only, 'not searching anything'
        index = self.db.mapping
        delta = common.timer()

        def to_nbow(doc_id):
            # transform to the nbow model used by wmd.WMD:
            # ('human readable name', 'item identifiers', 'weights')
            doc = index[doc_id]
            return (doc_id, doc.idx, doc.freq)

        docs = {d: to_nbow(d) for d in only + [top_id]}
        calc = WMDR(self.emb, docs, vocabulary_min=2)

        calc.cache_centroids()
        nn = calc.nearest_neighbors(top_id, k=k)

        self._times.append(delta())

        assert len(nn) == k, f'{len(nn)} not {k}'
        return [Result(*n) for n in nn]
예제 #4
0
class SimilarRepositories:
    GITHUB_URL_RE = re.compile(
        r"(https://|ssh://git@|git://)(github.com/[^/]+/[^/]+)(|.git|/)")
    _log = logging.getLogger("SimilarRepositories")

    def __init__(self,
                 id2vec=None,
                 df=None,
                 nbow=None,
                 prune_df_threshold=1,
                 wmd_cache_centroids=True,
                 wmd_kwargs: Dict[str, Any] = None,
                 languages: Tuple[List, bool] = (None, False),
                 engine_kwargs: Dict[str, Any] = None):
        backend = create_backend()
        if id2vec is None:
            self._id2vec = Id2Vec().load(backend=backend)
        else:
            assert isinstance(id2vec, Id2Vec)
            self._id2vec = id2vec
        self._log.info("Loaded id2vec model: %s", self._id2vec)
        if df is None:
            if df is not False:
                self._df = DocumentFrequencies().load(backend=backend)
            else:
                self._df = None
                self._log.warning("Disabled document frequencies - you will "
                                  "not be able to query custom repositories.")
        else:
            assert isinstance(df, DocumentFrequencies)
            self._df = df
        if self._df is not None:
            self._df = self._df.prune(prune_df_threshold)
        self._log.info("Loaded document frequencies: %s", self._df)
        if nbow is None:
            self._bow = BOW().load(backend=backend)
        else:
            assert isinstance(nbow, BOW)
            self._bow = nbow
        self._log.info("Loaded BOW model: %s", self._bow)
        assert self._bow.get_dep("id2vec")["uuid"] == self._id2vec.meta["uuid"]
        if len(self._id2vec) != self._bow.matrix.shape[1]:
            raise ValueError(
                "Models do not match: id2vec has %s tokens while nbow has %s" %
                (len(self._id2vec), self._bow.matrix.shape[1]))
        self._log.info("Creating the WMD engine...")
        self._wmd = WMD(self._id2vec.embeddings, self._bow, **(wmd_kwargs
                                                               or {}))
        if wmd_cache_centroids:
            self._wmd.cache_centroids()
        self._languages = languages
        self._engine_kwargs = engine_kwargs

    def query(self, url_or_path_or_name: str,
              **kwargs) -> List[Tuple[str, float]]:
        try:
            repo_index = self._bow.documents.index(url_or_path_or_name)
        except ValueError:
            repo_index = -1
        if repo_index == -1:
            match = self.GITHUB_URL_RE.match(url_or_path_or_name)
            if match is not None:
                name = match.group(2)
                try:
                    repo_index = self._bow.documents.index(name)
                except ValueError:
                    pass
        if repo_index >= 0:
            neighbours = self._query_domestic(repo_index, **kwargs)
        else:
            neighbours = self._query_foreign(url_or_path_or_name, **kwargs)
        neighbours = [(self._bow[n[0]][0], n[1]) for n in neighbours]
        return neighbours

    def _query_domestic(self, repo_index, **kwargs):
        return self._wmd.nearest_neighbors(repo_index, **kwargs)

    def _query_foreign(self, url_or_path: str, **kwargs):
        df = self._df
        if df is None:
            raise ValueError("Cannot query custom repositories if the "
                             "document frequencies are disabled.")

        with tempfile.TemporaryDirectory(prefix="vecino-") as tempdir:
            target = os.path.join(tempdir, "repo")
            if os.path.isdir(url_or_path):
                url_or_path = os.path.abspath(url_or_path)
                os.symlink(url_or_path, target, target_is_directory=True)
                repo_format = "standard"
            else:
                self._log.info("Cloning %s to %s", url_or_path, target)
                porcelain.clone(url_or_path,
                                target,
                                bare=True,
                                outstream=sys.stderr)
                repo_format = "bare"
            bow = repo2bow(tempdir,
                           repo_format,
                           1,
                           df,
                           *self._languages,
                           engine_kwargs=self._engine_kwargs)
        ibow = {}
        for key, val in bow.items():
            try:
                ibow[self._id2vec[key]] = val
            except KeyError:
                continue
        words, weights = zip(*sorted(ibow.items()))
        return self._wmd.nearest_neighbors((words, weights), **kwargs)
예제 #5
0
class SimilarRepositories:
    GITHUB_URL_RE = re.compile(
        r"(https://|ssh://git@|git://)(github.com/[^/]+/[^/]+)(|.git|/)")

    def __init__(self,
                 id2vec=None,
                 df=None,
                 nbow=None,
                 prune_df_threshold=1,
                 verbosity=logging.DEBUG,
                 wmd_cache_centroids=True,
                 wmd_kwargs=None,
                 gcs_bucket=None,
                 repo2nbow_kwargs=None,
                 initialize_environment=True):
        if initialize_environment:
            initialize()
        self._log = logging.getLogger("similar_repos")
        self._log.setLevel(verbosity)
        if gcs_bucket:
            backend = create_backend(args="bucket=" + gcs_bucket)
        else:
            backend = create_backend()
        if id2vec is None:
            self._id2vec = Id2Vec(log_level=verbosity).load(backend=backend)
        else:
            assert isinstance(id2vec, Id2Vec)
            self._id2vec = id2vec
        self._log.info("Loaded id2vec model: %s", self._id2vec)
        if df is None:
            if df is not False:
                self._df = DocumentFrequencies(log_level=verbosity).load(
                    backend=backend)
            else:
                self._df = None
                self._log.warning("Disabled document frequencies - you will "
                                  "not be able to query custom repositories.")
        else:
            assert isinstance(df, DocumentFrequencies)
            self._df = df
        if self._df is not None:
            self._df = self._df.prune(prune_df_threshold)
        self._log.info("Loaded document frequencies: %s", self._df)
        if nbow is None:
            self._nbow = NBOW(log_level=verbosity).load(backend=backend)
        else:
            assert isinstance(nbow, NBOW)
            self._nbow = nbow
        self._log.info("Loaded nBOW model: %s", self._nbow)
        self._repo2nbow = Repo2nBOW(self._id2vec,
                                    self._df,
                                    log_level=verbosity,
                                    **(repo2nbow_kwargs or {}))
        assert self._nbow.get_dependency(
            "id2vec")["uuid"] == self._id2vec.meta["uuid"]
        if len(self._id2vec) != self._nbow.matrix.shape[1]:
            raise ValueError(
                "Models do not match: id2vec has %s tokens while nbow has %s" %
                (len(self._id2vec), self._nbow.matrix.shape[1]))
        self._log.info("Creating the WMD engine...")
        self._wmd = WMD(self._id2vec.embeddings,
                        self._nbow,
                        verbosity=verbosity,
                        **(wmd_kwargs or {}))
        if wmd_cache_centroids:
            self._wmd.cache_centroids()

    def query(self, url_or_path_or_name, **kwargs):
        try:
            repo_index = self._nbow.repository_index_by_name(
                url_or_path_or_name)
        except KeyError:
            repo_index = -1
        if repo_index == -1:
            match = self.GITHUB_URL_RE.match(url_or_path_or_name)
            if match is not None:
                name = match.group(2)
                try:
                    repo_index = self._nbow.repository_index_by_name(name)
                except KeyError:
                    pass
        if repo_index >= 0:
            neighbours = self._query_domestic(repo_index, **kwargs)
        else:
            neighbours = self._query_foreign(url_or_path_or_name, **kwargs)
        neighbours = [(self._nbow[n[0]][0], n[1]) for n in neighbours]
        return neighbours

    @staticmethod
    def unicorn_query(repo_name,
                      id2vec=None,
                      nbow=None,
                      wmd_kwargs=None,
                      query_wmd_kwargs=None):
        sr = SimilarRepositories(id2vec=id2vec,
                                 df=False,
                                 nbow=nbow,
                                 wmd_kwargs=wmd_kwargs or {
                                     "vocabulary_min": 50,
                                     "vocabulary_max": 500
                                 })
        return sr.query(
            repo_name,
            **(query_wmd_kwargs or {
                "early_stop": 0.1,
                "max_time": 180,
                "skipped_stop": 0.95
            }))

    def _query_domestic(self, repo_index, **kwargs):
        return self._wmd.nearest_neighbors(repo_index, **kwargs)

    def _query_foreign(self, url_or_path, **kwargs):
        if self._df is None:
            raise ValueError("Cannot query custom repositories if the "
                             "document frequencies are disabled.")
        nbow_dict = self._repo2nbow.convert_repository(url_or_path)
        words = sorted(nbow_dict.keys())
        weights = [nbow_dict[k] for k in words]
        return self._wmd.nearest_neighbors((words, weights), **kwargs)