示例#1
0
文件: test_df.py 项目: sniperkit/ml
class DocumentFrequenciesTests(unittest.TestCase):
    def setUp(self):
        self.model = DocumentFrequencies().load(source=paths.DOCFREQ)

    def test_docs(self):
        docs = self.model.docs
        self.assertIsInstance(docs, int)
        self.assertEqual(docs, 1000)

    def test_get(self):
        self.assertEqual(self.model["aaaaaaa"], 341)
        with self.assertRaises(KeyError):
            print(self.model["xaaaaaa"])
        self.assertEqual(self.model.get("aaaaaaa", 0), 341)
        self.assertEqual(self.model.get("xaaaaaa", 100500), 100500)

    def test_tokens(self):
        self.assertEqual(list(self.model._df), self.model.tokens())

    def test_len(self):
        # the remaining 18 are not unique - the model was generated badly
        self.assertEqual(len(self.model), 982)

    def test_iter(self):
        aaa = False
        for tok, freq in self.model:
            if "aaaaaaa" in tok:
                aaa = True
                int(freq)
                break
        self.assertTrue(aaa)

    def test_prune(self):
        pruned = self.model.prune(4)
        for tok, freq in pruned:
            self.assertGreaterEqual(freq, 4)
        self.assertEqual(len(pruned), 346)

    def test_prune_self(self):
        pruned = self.model.prune(1)
        self.assertIs(self.model, pruned)

    def test_greatest(self):
        pruned = self.model.greatest(100)
        freqs = [v for v in self.model._df.values()]
        freqs.sort(reverse=True)
        border = freqs[100]
        for v in pruned._df.values():
            self.assertGreaterEqual(v, border)
        df1 = pruned._df
        df2 = self.model.greatest(100)._df
        self.assertEqual(df1, df2)

    def test_write(self):
        buffer = BytesIO()
        self.model.save(buffer)
        buffer.seek(0)
        new_model = DocumentFrequencies().load(buffer)
        self.assertEqual(self.model._df, new_model._df)
        self.assertEqual(self.model.docs, new_model.docs)
示例#2
0
class SimilarRepositories:
    GITHUB_URL_RE = re.compile(
        r"(https://|ssh://git@|git://)(github.com/[^/]+/[^/]+)(|.git|/)")
    _log = logging.getLogger("SimilarRepositories")

    def __init__(self,
                 id2vec=None,
                 df=None,
                 nbow=None,
                 prune_df_threshold=1,
                 wmd_cache_centroids=True,
                 wmd_kwargs: Dict[str, Any] = None,
                 languages: Tuple[List, bool] = (None, False),
                 engine_kwargs: Dict[str, Any] = None):
        backend = create_backend()
        if id2vec is None:
            self._id2vec = Id2Vec().load(backend=backend)
        else:
            assert isinstance(id2vec, Id2Vec)
            self._id2vec = id2vec
        self._log.info("Loaded id2vec model: %s", self._id2vec)
        if df is None:
            if df is not False:
                self._df = DocumentFrequencies().load(backend=backend)
            else:
                self._df = None
                self._log.warning("Disabled document frequencies - you will "
                                  "not be able to query custom repositories.")
        else:
            assert isinstance(df, DocumentFrequencies)
            self._df = df
        if self._df is not None:
            self._df = self._df.prune(prune_df_threshold)
        self._log.info("Loaded document frequencies: %s", self._df)
        if nbow is None:
            self._bow = BOW().load(backend=backend)
        else:
            assert isinstance(nbow, BOW)
            self._bow = nbow
        self._log.info("Loaded BOW model: %s", self._bow)
        assert self._bow.get_dep("id2vec")["uuid"] == self._id2vec.meta["uuid"]
        if len(self._id2vec) != self._bow.matrix.shape[1]:
            raise ValueError(
                "Models do not match: id2vec has %s tokens while nbow has %s" %
                (len(self._id2vec), self._bow.matrix.shape[1]))
        self._log.info("Creating the WMD engine...")
        self._wmd = WMD(self._id2vec.embeddings, self._bow, **(wmd_kwargs
                                                               or {}))
        if wmd_cache_centroids:
            self._wmd.cache_centroids()
        self._languages = languages
        self._engine_kwargs = engine_kwargs

    def query(self, url_or_path_or_name: str,
              **kwargs) -> List[Tuple[str, float]]:
        try:
            repo_index = self._bow.documents.index(url_or_path_or_name)
        except ValueError:
            repo_index = -1
        if repo_index == -1:
            match = self.GITHUB_URL_RE.match(url_or_path_or_name)
            if match is not None:
                name = match.group(2)
                try:
                    repo_index = self._bow.documents.index(name)
                except ValueError:
                    pass
        if repo_index >= 0:
            neighbours = self._query_domestic(repo_index, **kwargs)
        else:
            neighbours = self._query_foreign(url_or_path_or_name, **kwargs)
        neighbours = [(self._bow[n[0]][0], n[1]) for n in neighbours]
        return neighbours

    def _query_domestic(self, repo_index, **kwargs):
        return self._wmd.nearest_neighbors(repo_index, **kwargs)

    def _query_foreign(self, url_or_path: str, **kwargs):
        df = self._df
        if df is None:
            raise ValueError("Cannot query custom repositories if the "
                             "document frequencies are disabled.")

        with tempfile.TemporaryDirectory(prefix="vecino-") as tempdir:
            target = os.path.join(tempdir, "repo")
            if os.path.isdir(url_or_path):
                url_or_path = os.path.abspath(url_or_path)
                os.symlink(url_or_path, target, target_is_directory=True)
                repo_format = "standard"
            else:
                self._log.info("Cloning %s to %s", url_or_path, target)
                porcelain.clone(url_or_path,
                                target,
                                bare=True,
                                outstream=sys.stderr)
                repo_format = "bare"
            bow = repo2bow(tempdir,
                           repo_format,
                           1,
                           df,
                           *self._languages,
                           engine_kwargs=self._engine_kwargs)
        ibow = {}
        for key, val in bow.items():
            try:
                ibow[self._id2vec[key]] = val
            except KeyError:
                continue
        words, weights = zip(*sorted(ibow.items()))
        return self._wmd.nearest_neighbors((words, weights), **kwargs)