class DocumentFrequenciesTests(unittest.TestCase): def setUp(self): self.model = DocumentFrequencies().load( source=os.path.join(os.path.dirname(__file__), paths.DOCFREQ)) def test_docs(self): docs = self.model.docs self.assertIsInstance(docs, int) self.assertEqual(docs, 1000) def test_get(self): self.assertEqual(self.model["aaaaaaa"], 341) with self.assertRaises(KeyError): print(self.model["xaaaaaa"]) self.assertEqual(self.model.get("aaaaaaa", 0), 341) self.assertEqual(self.model.get("xaaaaaa", 100500), 100500) def test_tokens(self): tokens = self.model.tokens() self.assertEqual(sorted(tokens), tokens) for t in tokens: self.assertGreater(self.model[t], 0) def test_len(self): # the remaining 18 are not unique - the model was generated badly self.assertEqual(len(self.model), 982) def test_iter(self): aaa = False for tok, freq in self.model: if "aaaaaaa" in tok: aaa = True int(freq) break self.assertTrue(aaa) def test_prune(self): pruned = self.model.prune(4) for tok, freq in pruned: self.assertGreaterEqual(freq, 4) self.assertEqual(len(pruned), 346)
class TopicDetector: GITHUB_URL_RE = re.compile( r"(https://|ssh://git@|git://)(github.com/[^/]+/[^/]+)(|.git|/)") def __init__(self, topics=None, docfreq=None, bow=None, verbosity=logging.DEBUG, prune_df_threshold=1, gcs_bucket=None, initialize_environment=True, repo2bow_kwargs=None): if initialize_environment: initialize() self._log = logging.getLogger("topic_detector") self._log.setLevel(verbosity) if gcs_bucket: backend = create_backend(args="bucket=" + gcs_bucket) else: backend = create_backend() if topics is None: self._topics = Topics(log_level=verbosity).load(backend=backend) else: assert isinstance(topics, Topics) self._topics = topics self._log.info("Loaded topics model: %s", self._topics) if docfreq is None: if docfreq is not False: self._docfreq = DocumentFrequencies(log_level=verbosity).load( source=self._topics.dep("docfreq")["uuid"], backend=backend) else: self._docfreq = None self._log.warning("Disabled document frequencies - you will " "not be able to query custom repositories.") else: assert isinstance(docfreq, DocumentFrequencies) self._docfreq = docfreq if self._docfreq is not None: self._docfreq = self._docfreq.prune(prune_df_threshold) self._log.info("Loaded docfreq model: %s", self._docfreq) if bow is not None: assert isinstance(bow, BOWBase) self._bow = bow if self._topics.matrix.shape[1] != self._bow.matrix.shape[1]: raise ValueError( "Models do not match: topics has %s tokens while bow has %s" % (self._topics.matrix.shape[1], self._bow.matrix.shape[1])) self._log.info("Attached BOW model: %s", self._bow) else: self._bow = None self._log.warning("No BOW cache was loaded.") if self._docfreq is not None: self._repo2bow = Repo2BOW( {t: i for i, t in enumerate(self._topics.tokens)}, self._docfreq, **(repo2bow_kwargs or {})) else: self._repo2bow = None def query(self, url_or_path_or_name, size=5): if size > len(self._topics): raise ValueError( "size may not be greater than the number of topics - %d" % len(self._topics)) if self._bow is not None: try: repo_index = self._bow.repository_index_by_name( url_or_path_or_name) except KeyError: repo_index = -1 if repo_index == -1: match = self.GITHUB_URL_RE.match(url_or_path_or_name) if match is not None: name = match.group(2) try: repo_index = self._bow.repository_index_by_name(name) except KeyError: pass else: repo_index = -1 if repo_index >= 0: token_vector = self._bow.matrix[repo_index] else: if self._docfreq is None: raise ValueError( "You need to specify document frequencies model to process " "custom repositories") bow_dict = self._repo2bow.convert_repository(url_or_path_or_name) token_vector = numpy.zeros(self._topics.matrix.shape[1], dtype=numpy.float32) for i, v in bow_dict.items(): token_vector[i] = v token_vector = csr_matrix(token_vector) topic_vector = -numpy.squeeze( self._topics.matrix.dot(token_vector.T).toarray()) order = numpy.argsort(topic_vector) result = [] i = 0 while len(result) < size and i < len(self._topics): topic = self._topics.topics[order[i]] if topic: result.append((topic, -topic_vector[order[i]])) i += 1 return result
class SimilarRepositories: GITHUB_URL_RE = re.compile( r"(https://|ssh://git@|git://)(github.com/[^/]+/[^/]+)(|.git|/)") def __init__(self, id2vec=None, df=None, nbow=None, prune_df_threshold=1, verbosity=logging.DEBUG, wmd_cache_centroids=True, wmd_kwargs=None, gcs_bucket=None, repo2nbow_kwargs=None, initialize_environment=True): if initialize_environment: initialize() self._log = logging.getLogger("similar_repos") self._log.setLevel(verbosity) if gcs_bucket: backend = create_backend(args="bucket=" + gcs_bucket) else: backend = create_backend() if id2vec is None: self._id2vec = Id2Vec(log_level=verbosity).load(backend=backend) else: assert isinstance(id2vec, Id2Vec) self._id2vec = id2vec self._log.info("Loaded id2vec model: %s", self._id2vec) if df is None: if df is not False: self._df = DocumentFrequencies(log_level=verbosity).load( backend=backend) else: self._df = None self._log.warning("Disabled document frequencies - you will " "not be able to query custom repositories.") else: assert isinstance(df, DocumentFrequencies) self._df = df if self._df is not None: self._df = self._df.prune(prune_df_threshold) self._log.info("Loaded document frequencies: %s", self._df) if nbow is None: self._nbow = NBOW(log_level=verbosity).load(backend=backend) else: assert isinstance(nbow, NBOW) self._nbow = nbow self._log.info("Loaded nBOW model: %s", self._nbow) self._repo2nbow = Repo2nBOW(self._id2vec, self._df, log_level=verbosity, **(repo2nbow_kwargs or {})) assert self._nbow.get_dependency( "id2vec")["uuid"] == self._id2vec.meta["uuid"] if len(self._id2vec) != self._nbow.matrix.shape[1]: raise ValueError( "Models do not match: id2vec has %s tokens while nbow has %s" % (len(self._id2vec), self._nbow.matrix.shape[1])) self._log.info("Creating the WMD engine...") self._wmd = WMD(self._id2vec.embeddings, self._nbow, verbosity=verbosity, **(wmd_kwargs or {})) if wmd_cache_centroids: self._wmd.cache_centroids() def query(self, url_or_path_or_name, **kwargs): try: repo_index = self._nbow.repository_index_by_name( url_or_path_or_name) except KeyError: repo_index = -1 if repo_index == -1: match = self.GITHUB_URL_RE.match(url_or_path_or_name) if match is not None: name = match.group(2) try: repo_index = self._nbow.repository_index_by_name(name) except KeyError: pass if repo_index >= 0: neighbours = self._query_domestic(repo_index, **kwargs) else: neighbours = self._query_foreign(url_or_path_or_name, **kwargs) neighbours = [(self._nbow[n[0]][0], n[1]) for n in neighbours] return neighbours @staticmethod def unicorn_query(repo_name, id2vec=None, nbow=None, wmd_kwargs=None, query_wmd_kwargs=None): sr = SimilarRepositories(id2vec=id2vec, df=False, nbow=nbow, wmd_kwargs=wmd_kwargs or { "vocabulary_min": 50, "vocabulary_max": 500 }) return sr.query( repo_name, **(query_wmd_kwargs or { "early_stop": 0.1, "max_time": 180, "skipped_stop": 0.95 })) def _query_domestic(self, repo_index, **kwargs): return self._wmd.nearest_neighbors(repo_index, **kwargs) def _query_foreign(self, url_or_path, **kwargs): if self._df is None: raise ValueError("Cannot query custom repositories if the " "document frequencies are disabled.") nbow_dict = self._repo2nbow.convert_repository(url_or_path) words = sorted(nbow_dict.keys()) weights = [nbow_dict[k] for k in words] return self._wmd.nearest_neighbors((words, weights), **kwargs)