Пример #1
0
class DocumentFrequenciesTests(unittest.TestCase):
    def setUp(self):
        self.model = DocumentFrequencies().load(
            source=os.path.join(os.path.dirname(__file__), paths.DOCFREQ))

    def test_docs(self):
        docs = self.model.docs
        self.assertIsInstance(docs, int)
        self.assertEqual(docs, 1000)

    def test_get(self):
        self.assertEqual(self.model["aaaaaaa"], 341)
        with self.assertRaises(KeyError):
            print(self.model["xaaaaaa"])
        self.assertEqual(self.model.get("aaaaaaa", 0), 341)
        self.assertEqual(self.model.get("xaaaaaa", 100500), 100500)

    def test_tokens(self):
        tokens = self.model.tokens()
        self.assertEqual(sorted(tokens), tokens)
        for t in tokens:
            self.assertGreater(self.model[t], 0)

    def test_len(self):
        # the remaining 18 are not unique - the model was generated badly
        self.assertEqual(len(self.model), 982)

    def test_iter(self):
        aaa = False
        for tok, freq in self.model:
            if "aaaaaaa" in tok:
                aaa = True
                int(freq)
                break
        self.assertTrue(aaa)

    def test_prune(self):
        pruned = self.model.prune(4)
        for tok, freq in pruned:
            self.assertGreaterEqual(freq, 4)
        self.assertEqual(len(pruned), 346)
Пример #2
0
class TopicDetector:
    GITHUB_URL_RE = re.compile(
        r"(https://|ssh://git@|git://)(github.com/[^/]+/[^/]+)(|.git|/)")

    def __init__(self,
                 topics=None,
                 docfreq=None,
                 bow=None,
                 verbosity=logging.DEBUG,
                 prune_df_threshold=1,
                 gcs_bucket=None,
                 initialize_environment=True,
                 repo2bow_kwargs=None):
        if initialize_environment:
            initialize()
        self._log = logging.getLogger("topic_detector")
        self._log.setLevel(verbosity)
        if gcs_bucket:
            backend = create_backend(args="bucket=" + gcs_bucket)
        else:
            backend = create_backend()
        if topics is None:
            self._topics = Topics(log_level=verbosity).load(backend=backend)
        else:
            assert isinstance(topics, Topics)
            self._topics = topics
        self._log.info("Loaded topics model: %s", self._topics)
        if docfreq is None:
            if docfreq is not False:
                self._docfreq = DocumentFrequencies(log_level=verbosity).load(
                    source=self._topics.dep("docfreq")["uuid"],
                    backend=backend)
            else:
                self._docfreq = None
                self._log.warning("Disabled document frequencies - you will "
                                  "not be able to query custom repositories.")
        else:
            assert isinstance(docfreq, DocumentFrequencies)
            self._docfreq = docfreq
        if self._docfreq is not None:
            self._docfreq = self._docfreq.prune(prune_df_threshold)
        self._log.info("Loaded docfreq model: %s", self._docfreq)
        if bow is not None:
            assert isinstance(bow, BOWBase)
            self._bow = bow
            if self._topics.matrix.shape[1] != self._bow.matrix.shape[1]:
                raise ValueError(
                    "Models do not match: topics has %s tokens while bow has %s"
                    %
                    (self._topics.matrix.shape[1], self._bow.matrix.shape[1]))
            self._log.info("Attached BOW model: %s", self._bow)
        else:
            self._bow = None
            self._log.warning("No BOW cache was loaded.")
        if self._docfreq is not None:
            self._repo2bow = Repo2BOW(
                {t: i
                 for i, t in enumerate(self._topics.tokens)}, self._docfreq,
                **(repo2bow_kwargs or {}))
        else:
            self._repo2bow = None

    def query(self, url_or_path_or_name, size=5):
        if size > len(self._topics):
            raise ValueError(
                "size may not be greater than the number of topics - %d" %
                len(self._topics))
        if self._bow is not None:
            try:
                repo_index = self._bow.repository_index_by_name(
                    url_or_path_or_name)
            except KeyError:
                repo_index = -1
            if repo_index == -1:
                match = self.GITHUB_URL_RE.match(url_or_path_or_name)
                if match is not None:
                    name = match.group(2)
                    try:
                        repo_index = self._bow.repository_index_by_name(name)
                    except KeyError:
                        pass
        else:
            repo_index = -1
        if repo_index >= 0:
            token_vector = self._bow.matrix[repo_index]
        else:
            if self._docfreq is None:
                raise ValueError(
                    "You need to specify document frequencies model to process "
                    "custom repositories")
            bow_dict = self._repo2bow.convert_repository(url_or_path_or_name)
            token_vector = numpy.zeros(self._topics.matrix.shape[1],
                                       dtype=numpy.float32)
            for i, v in bow_dict.items():
                token_vector[i] = v
            token_vector = csr_matrix(token_vector)
        topic_vector = -numpy.squeeze(
            self._topics.matrix.dot(token_vector.T).toarray())
        order = numpy.argsort(topic_vector)
        result = []
        i = 0
        while len(result) < size and i < len(self._topics):
            topic = self._topics.topics[order[i]]
            if topic:
                result.append((topic, -topic_vector[order[i]]))
            i += 1
        return result
Пример #3
0
class SimilarRepositories:
    GITHUB_URL_RE = re.compile(
        r"(https://|ssh://git@|git://)(github.com/[^/]+/[^/]+)(|.git|/)")

    def __init__(self,
                 id2vec=None,
                 df=None,
                 nbow=None,
                 prune_df_threshold=1,
                 verbosity=logging.DEBUG,
                 wmd_cache_centroids=True,
                 wmd_kwargs=None,
                 gcs_bucket=None,
                 repo2nbow_kwargs=None,
                 initialize_environment=True):
        if initialize_environment:
            initialize()
        self._log = logging.getLogger("similar_repos")
        self._log.setLevel(verbosity)
        if gcs_bucket:
            backend = create_backend(args="bucket=" + gcs_bucket)
        else:
            backend = create_backend()
        if id2vec is None:
            self._id2vec = Id2Vec(log_level=verbosity).load(backend=backend)
        else:
            assert isinstance(id2vec, Id2Vec)
            self._id2vec = id2vec
        self._log.info("Loaded id2vec model: %s", self._id2vec)
        if df is None:
            if df is not False:
                self._df = DocumentFrequencies(log_level=verbosity).load(
                    backend=backend)
            else:
                self._df = None
                self._log.warning("Disabled document frequencies - you will "
                                  "not be able to query custom repositories.")
        else:
            assert isinstance(df, DocumentFrequencies)
            self._df = df
        if self._df is not None:
            self._df = self._df.prune(prune_df_threshold)
        self._log.info("Loaded document frequencies: %s", self._df)
        if nbow is None:
            self._nbow = NBOW(log_level=verbosity).load(backend=backend)
        else:
            assert isinstance(nbow, NBOW)
            self._nbow = nbow
        self._log.info("Loaded nBOW model: %s", self._nbow)
        self._repo2nbow = Repo2nBOW(self._id2vec,
                                    self._df,
                                    log_level=verbosity,
                                    **(repo2nbow_kwargs or {}))
        assert self._nbow.get_dependency(
            "id2vec")["uuid"] == self._id2vec.meta["uuid"]
        if len(self._id2vec) != self._nbow.matrix.shape[1]:
            raise ValueError(
                "Models do not match: id2vec has %s tokens while nbow has %s" %
                (len(self._id2vec), self._nbow.matrix.shape[1]))
        self._log.info("Creating the WMD engine...")
        self._wmd = WMD(self._id2vec.embeddings,
                        self._nbow,
                        verbosity=verbosity,
                        **(wmd_kwargs or {}))
        if wmd_cache_centroids:
            self._wmd.cache_centroids()

    def query(self, url_or_path_or_name, **kwargs):
        try:
            repo_index = self._nbow.repository_index_by_name(
                url_or_path_or_name)
        except KeyError:
            repo_index = -1
        if repo_index == -1:
            match = self.GITHUB_URL_RE.match(url_or_path_or_name)
            if match is not None:
                name = match.group(2)
                try:
                    repo_index = self._nbow.repository_index_by_name(name)
                except KeyError:
                    pass
        if repo_index >= 0:
            neighbours = self._query_domestic(repo_index, **kwargs)
        else:
            neighbours = self._query_foreign(url_or_path_or_name, **kwargs)
        neighbours = [(self._nbow[n[0]][0], n[1]) for n in neighbours]
        return neighbours

    @staticmethod
    def unicorn_query(repo_name,
                      id2vec=None,
                      nbow=None,
                      wmd_kwargs=None,
                      query_wmd_kwargs=None):
        sr = SimilarRepositories(id2vec=id2vec,
                                 df=False,
                                 nbow=nbow,
                                 wmd_kwargs=wmd_kwargs or {
                                     "vocabulary_min": 50,
                                     "vocabulary_max": 500
                                 })
        return sr.query(
            repo_name,
            **(query_wmd_kwargs or {
                "early_stop": 0.1,
                "max_time": 180,
                "skipped_stop": 0.95
            }))

    def _query_domestic(self, repo_index, **kwargs):
        return self._wmd.nearest_neighbors(repo_index, **kwargs)

    def _query_foreign(self, url_or_path, **kwargs):
        if self._df is None:
            raise ValueError("Cannot query custom repositories if the "
                             "document frequencies are disabled.")
        nbow_dict = self._repo2nbow.convert_repository(url_or_path)
        words = sorted(nbow_dict.keys())
        weights = [nbow_dict[k] for k in words]
        return self._wmd.nearest_neighbors((words, weights), **kwargs)