Beispiel #1
0
class DocumentFrequenciesTests(unittest.TestCase):
    def setUp(self):
        self.model = DocumentFrequencies().load(
            source=os.path.join(os.path.dirname(__file__), paths.DOCFREQ))

    def test_docs(self):
        docs = self.model.docs
        self.assertIsInstance(docs, int)
        self.assertEqual(docs, 1000)

    def test_get(self):
        self.assertEqual(self.model["aaaaaaa"], 341)
        with self.assertRaises(KeyError):
            print(self.model["xaaaaaa"])
        self.assertEqual(self.model.get("aaaaaaa", 0), 341)
        self.assertEqual(self.model.get("xaaaaaa", 100500), 100500)

    def test_tokens(self):
        tokens = self.model.tokens()
        self.assertEqual(sorted(tokens), tokens)
        for t in tokens:
            self.assertGreater(self.model[t], 0)

    def test_len(self):
        # the remaining 18 are not unique - the model was generated badly
        self.assertEqual(len(self.model), 982)

    def test_iter(self):
        aaa = False
        for tok, freq in self.model:
            if "aaaaaaa" in tok:
                aaa = True
                int(freq)
                break
        self.assertTrue(aaa)
Beispiel #2
0
 def __init__(self,
              topics=None,
              docfreq=None,
              bow=None,
              verbosity=logging.DEBUG,
              prune_df_threshold=1,
              gcs_bucket=None,
              initialize_environment=True,
              repo2bow_kwargs=None):
     if initialize_environment:
         initialize()
     self._log = logging.getLogger("topic_detector")
     self._log.setLevel(verbosity)
     if gcs_bucket:
         backend = create_backend(args="bucket=" + gcs_bucket)
     else:
         backend = create_backend()
     if topics is None:
         self._topics = Topics(log_level=verbosity).load(backend=backend)
     else:
         assert isinstance(topics, Topics)
         self._topics = topics
     self._log.info("Loaded topics model: %s", self._topics)
     if docfreq is None:
         if docfreq is not False:
             self._docfreq = DocumentFrequencies(log_level=verbosity).load(
                 source=self._topics.dep("docfreq")["uuid"],
                 backend=backend)
         else:
             self._docfreq = None
             self._log.warning("Disabled document frequencies - you will "
                               "not be able to query custom repositories.")
     else:
         assert isinstance(docfreq, DocumentFrequencies)
         self._docfreq = docfreq
     if self._docfreq is not None:
         self._docfreq = self._docfreq.prune(prune_df_threshold)
     self._log.info("Loaded docfreq model: %s", self._docfreq)
     if bow is not None:
         assert isinstance(bow, BOWBase)
         self._bow = bow
         if self._topics.matrix.shape[1] != self._bow.matrix.shape[1]:
             raise ValueError(
                 "Models do not match: topics has %s tokens while bow has %s"
                 %
                 (self._topics.matrix.shape[1], self._bow.matrix.shape[1]))
         self._log.info("Attached BOW model: %s", self._bow)
     else:
         self._bow = None
         self._log.warning("No BOW cache was loaded.")
     if self._docfreq is not None:
         self._repo2bow = Repo2BOW(
             {t: i
              for i, t in enumerate(self._topics.tokens)}, self._docfreq,
             **(repo2bow_kwargs or {}))
     else:
         self._repo2bow = None
Beispiel #3
0
 def test_obj(self):
     basedir = os.path.dirname(__file__)
     id2vec = Id2Vec().load(os.path.join(basedir, ID2VEC))
     df = DocumentFrequencies().load(os.path.join(basedir, DOCFREQ))
     df._df["xxyyzz"] = 10
     id2vec._token2index[id2vec.tokens[0]] = 1
     id2vec.tokens[0] = "xxyyzz"
     id2vec._token2index["xxyyzz"] = 0
     xxyyzz = Repo2nBOW(id2vec=id2vec, docfreq=df, linguist=tests.ENRY)
     nbow = xxyyzz.convert_repository(os.path.join(basedir, "..", ".."))
     self.assertIsInstance(nbow, dict)
     self.assertAlmostEqual(nbow[0], 3.192060730416365)
Beispiel #4
0
 def test_asdf(self):
     basedir = os.path.dirname(__file__)
     id2vec = Id2Vec().load(os.path.join(basedir, ID2VEC))
     del id2vec._token2index[id2vec.tokens[0]]
     id2vec.tokens[0] = "test"
     id2vec._token2index["test"] = 0
     df = DocumentFrequencies().load(os.path.join(basedir, DOCFREQ))
     df._df["test"] = 10
     with tempfile.NamedTemporaryFile() as file:
         args = argparse.Namespace(
             id2vec=id2vec, docfreq=df, linguist=tests.ENRY,
             gcs_bucket=None, output=file.name, bblfsh_endpoint=None, timeout=None,
             repository=os.path.join(basedir, "..", ".."), prune_df=1)
         repo2nbow_entry(args)
         self.assertTrue(os.path.isfile(file.name))
         validate_asdf_file(self, file.name)
Beispiel #5
0
 def __init__(self, id2vec=None, df=None, nbow=None, prune_df_threshold=1,
              verbosity=logging.DEBUG, wmd_cache_centroids=True, wmd_kwargs=None,
              gcs_bucket=None, repo2nbow_kwargs=None, initialize_environment=True):
     if initialize_environment:
         initialize()
     self._log = logging.getLogger("similar_repos")
     self._log.setLevel(verbosity)
     if gcs_bucket:
         backend = create_backend(args="bucket=" + gcs_bucket)
     else:
         backend = create_backend()
     if id2vec is None:
         self._id2vec = Id2Vec(log_level=verbosity).load(backend=backend)
     else:
         assert isinstance(id2vec, Id2Vec)
         self._id2vec = id2vec
     self._log.info("Loaded id2vec model: %s", self._id2vec)
     if df is None:
         if df is not False:
             self._df = DocumentFrequencies(log_level=verbosity).load(backend=backend)
         else:
             self._df = None
             self._log.warning("Disabled document frequencies - you will "
                               "not be able to query custom repositories.")
     else:
         assert isinstance(df, DocumentFrequencies)
         self._df = df
     if self._df is not None:
         self._df = self._df.prune(prune_df_threshold)
     self._log.info("Loaded document frequencies: %s", self._df)
     if nbow is None:
         self._nbow = NBOW(log_level=verbosity).load(backend=backend)
     else:
         assert isinstance(nbow, NBOW)
         self._nbow = nbow
     self._log.info("Loaded nBOW model: %s", self._nbow)
     self._repo2nbow = Repo2nBOW(
         self._id2vec, self._df, log_level=verbosity, **(repo2nbow_kwargs or {}))
     assert self._nbow.dep("id2vec")["uuid"] == self._id2vec.meta["uuid"]
     if len(self._id2vec) != self._nbow.matrix.shape[1]:
         raise ValueError("Models do not match: id2vec has %s tokens while nbow has %s" %
                          (len(self._id2vec), self._nbow.matrix.shape[1]))
     self._log.info("Creating the WMD engine...")
     self._wmd = WMD(self._id2vec.embeddings, self._nbow,
                     verbosity=verbosity, **(wmd_kwargs or {}))
     if wmd_cache_centroids:
         self._wmd.cache_centroids()
Beispiel #6
0
 def test_transform(self):
     basedir = os.path.dirname(__file__)
     id2vec = Id2Vec().load(os.path.join(basedir, ID2VEC))
     del id2vec._token2index[id2vec.tokens[0]]
     id2vec.tokens[0] = "test"
     id2vec._token2index["test"] = 0
     df = DocumentFrequencies().load(os.path.join(basedir, DOCFREQ))
     df._df["test"] = 10
     with tempfile.TemporaryDirectory() as tmpdir:
         repo2nbow = Repo2nBOWTransformer(
             id2vec=id2vec, docfreq=df, linguist=tests.ENRY, gcs_bucket=None,
         )
         outfile = repo2nbow.prepare_filename(basedir, tmpdir)
         status = repo2nbow.transform(repos=basedir, output=tmpdir)
         self.assertTrue(os.path.isfile(outfile))
         validate_asdf_file(self, outfile)
         self.assertEqual(status, 1)
 def __init__(self,
              id2vec=None,
              df=None,
              nbow=None,
              verbosity=logging.DEBUG,
              wmd_cache_centroids=True,
              wmd_kwargs=None,
              gcs_bucket=None,
              repo2nbow_kwargs=None,
              initialize_environment=True):
     if initialize_environment:
         initialize()
     self._log = logging.getLogger("similar_repos")
     self._log.setLevel(verbosity)
     if gcs_bucket:
         backend = create_backend(args="bucket=" + gcs_bucket)
     else:
         backend = create_backend()
     if id2vec is None:
         self._id2vec = Id2Vec(log_level=verbosity, backend=backend)
     else:
         assert isinstance(id2vec, Id2Vec)
         self._id2vec = id2vec
     self._log.info("Loaded id2vec model: %s", self._id2vec)
     if df is None:
         if df is not False:
             self._df = DocumentFrequencies(log_level=verbosity,
                                            backend=backend)
         else:
             self._df = None
             self._log.warning("Disabled document frequencies - you will "
                               "not be able to query custom repositories.")
     else:
         assert isinstance(df, DocumentFrequencies)
         self._df = df
     self._log.info("Loaded document frequencies: %s", self._df)
     if nbow is None:
         self._nbow = NBOW(log_level=verbosity, backend=backend)
     else:
         assert isinstance(nbow, NBOW)
         self._nbow = nbow
     self._log.info("Loaded nBOW model: %s", self._nbow)
     self._repo2nbow = Repo2nBOW(self._id2vec,
                                 self._df,
                                 log_level=verbosity,
                                 **(repo2nbow_kwargs or {}))
     self._log.info("Creating the WMD engine...")
     self._wmd = WMD(self._id2vec.embeddings,
                     self._nbow,
                     verbosity=verbosity,
                     **(wmd_kwargs or {}))
     if wmd_cache_centroids:
         self._wmd.cache_centroids()
Beispiel #8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("input", help="Repository URL or path or name.")
    parser.add_argument("--log-level",
                        default="INFO",
                        choices=logging._nameToLevel,
                        help="Logging verbosity.")
    parser.add_argument("--topics",
                        default=None,
                        help="Topic model URL or path.")
    parser.add_argument("--df",
                        default=None,
                        help="Document frequencies URL or path.")
    parser.add_argument("--bow", default=None, help="BOW model URL or path.")
    parser.add_argument("--bblfsh",
                        default=None,
                        help="babelfish server address.")
    parser.add_argument(
        "--timeout",
        type=int,
        default=None,
        help="Babelfish timeout - longer requests are dropped. Default is %s."
        % DEFAULT_BBLFSH_TIMEOUT)
    parser.add_argument("--gcs", default=None, help="GCS bucket to use.")
    parser.add_argument("--linguist",
                        default=None,
                        help="Path to src-d/enry or github/linguist.")
    parser.add_argument(
        "--prune-df",
        default=20,
        type=int,
        help="Minimum number of times an identifer must occur in different "
        "documents to be taken into account.")
    parser.add_argument("-n",
                        "--nnn",
                        default=10,
                        type=int,
                        help="Number of topics to print.")
    parser.add_argument("-f",
                        "--format",
                        default="human",
                        choices=["json", "human"],
                        help="Output format.")

    args = parser.parse_args()
    if args.linguist is None:
        args.linguist = "./enry"
    initialize(args.log_level, enry=args.linguist)
    if args.gcs:
        backend = create_backend(args="bucket=" + args.gcs)
    else:
        backend = create_backend()
    if args.topics is not None:
        args.topics = Topics(log_level=args.log_level).load(source=args.topics,
                                                            backend=backend)
    if args.df is not None:
        args.df = DocumentFrequencies(log_level=args.log_level).load(
            source=args.df, backend=backend)
    if args.bow is not None:
        args.bow = BOWBase(log_level=args.log_level).load(source=args.bow,
                                                          backend=backend)
    sr = TopicDetector(topics=args.topics,
                       docfreq=args.df,
                       bow=args.bow,
                       verbosity=args.log_level,
                       prune_df_threshold=args.prune_df,
                       gcs_bucket=args.gcs,
                       repo2bow_kwargs={
                           "linguist": args.linguist,
                           "bblfsh_endpoint": args.bblfsh,
                           "timeout": args.timeout
                       })
    topics = sr.query(args.input, size=args.nnn)
    if args.format == "json":
        json.dump({"repository": args.input, "topics": topics}, sys.stdout)
    elif args.format == "human":
        for t, r in topics:
            print("%64s" % t, "%.2f" % r, sep="\t")
Beispiel #9
0
 def setUp(self):
     self.model = DocumentFrequencies().load(
         source=os.path.join(os.path.dirname(__file__), paths.DOCFREQ))
Beispiel #10
0
 def test_preprocess(self):
     with tempfile.TemporaryDirectory() as tmpdir:
         args = default_preprocess_params(tmpdir, self.VOCAB)
         with captured_output() as (out, err, log):
             preprocess(args)
         self.assertFalse(out.getvalue())
         self.assertFalse(err.getvalue())
         self.assertIn("Skipped", log.getvalue())
         self.assertIn("error.asdf", log.getvalue())
         self.assertIn("empty_coocc.asdf", log.getvalue())
         self.assertEqual(sorted(os.listdir(tmpdir)), [
             "col_sums.txt", "col_vocab.txt", "docfreq.asdf",
             "row_sums.txt", "row_vocab.txt", "shard-000-000.pb"
         ])
         df = DocumentFrequencies().load(
             source=os.path.join(tmpdir, "docfreq.asdf"))
         self.assertEqual(len(df), self.VOCAB)
         self.assertEqual(df.docs, len(os.listdir(args.input[0])) - 1)
         with open(os.path.join(tmpdir, "col_sums.txt")) as fin:
             col_sums = fin.read()
         with open(os.path.join(tmpdir, "row_sums.txt")) as fin:
             row_sums = fin.read()
         self.assertEqual(col_sums, row_sums)
         with open(os.path.join(tmpdir, "col_vocab.txt")) as fin:
             col_vocab = fin.read()
         with open(os.path.join(tmpdir, "row_vocab.txt")) as fin:
             row_vocab = fin.read()
         self.assertEqual(col_vocab, row_vocab)
         self.assertEqual(row_vocab.split("\n"), df.tokens())
         for word in row_vocab.split("\n"):
             self.assertGreater(df[word], 0)
         with open(os.path.join(tmpdir, "shard-000-000.pb"), "rb") as fin:
             features = tf.parse_single_example(
                 fin.read(),
                 features={
                     "global_row":
                     tf.FixedLenFeature([self.VOCAB], dtype=tf.int64),
                     "global_col":
                     tf.FixedLenFeature([self.VOCAB], dtype=tf.int64),
                     "sparse_local_row":
                     tf.VarLenFeature(dtype=tf.int64),
                     "sparse_local_col":
                     tf.VarLenFeature(dtype=tf.int64),
                     "sparse_value":
                     tf.VarLenFeature(dtype=tf.float32)
                 })
         with tf.Session() as session:
             global_row, global_col, local_row, local_col, value = session.run(
                 [
                     features[n]
                     for n in ("global_row", "global_col",
                               "sparse_local_row", "sparse_local_col",
                               "sparse_value")
                 ])
         self.assertEqual(set(range(self.VOCAB)), set(global_row))
         self.assertEqual(set(range(self.VOCAB)), set(global_col))
         nnz = 1421193
         self.assertEqual(value.values.shape, (nnz, ))
         self.assertEqual(local_row.values.shape, (nnz, ))
         self.assertEqual(local_col.values.shape, (nnz, ))
         numpy.random.seed(0)
         all_tokens = row_vocab.split("\n")
         chosen_indices = numpy.random.choice(list(range(self.VOCAB)),
                                              128,
                                              replace=False)
         chosen = [all_tokens[i] for i in chosen_indices]
         freqs = numpy.zeros((len(chosen), ) * 2, dtype=int)
         index = {w: i for i, w in enumerate(chosen)}
         chosen = set(chosen)
         for path in os.listdir(args.input[0]):
             with asdf.open(os.path.join(args.input[0], path)) as model:
                 if model.tree["meta"]["model"] != "co-occurrences":
                     continue
                 matrix = assemble_sparse_matrix(
                     model.tree["matrix"]).tocsr()
                 tokens = split_strings(model.tree["tokens"])
                 interesting = {
                     i
                     for i, t in enumerate(tokens) if t in chosen
                 }
                 for y in interesting:
                     row = matrix[y]
                     yi = index[tokens[y]]
                     for x, v in zip(row.indices, row.data):
                         if x in interesting:
                             freqs[yi, index[tokens[x]]] += v
         matrix = coo_matrix(
             (value.values,
              ([global_row[row] for row in local_row.values
                ], [global_col[col] for col in local_col.values])),
             shape=(self.VOCAB, self.VOCAB))
         matrix = matrix.tocsr()[chosen_indices][:, chosen_indices].todense(
         ).astype(int)
         self.assertTrue((matrix == freqs).all())
Beispiel #11
0
class TopicDetector:
    GITHUB_URL_RE = re.compile(
        r"(https://|ssh://git@|git://)(github.com/[^/]+/[^/]+)(|.git|/)")

    def __init__(self,
                 topics=None,
                 docfreq=None,
                 bow=None,
                 verbosity=logging.DEBUG,
                 prune_df_threshold=1,
                 gcs_bucket=None,
                 initialize_environment=True,
                 repo2bow_kwargs=None):
        if initialize_environment:
            initialize()
        self._log = logging.getLogger("topic_detector")
        self._log.setLevel(verbosity)
        if gcs_bucket:
            backend = create_backend(args="bucket=" + gcs_bucket)
        else:
            backend = create_backend()
        if topics is None:
            self._topics = Topics(log_level=verbosity).load(backend=backend)
        else:
            assert isinstance(topics, Topics)
            self._topics = topics
        self._log.info("Loaded topics model: %s", self._topics)
        if docfreq is None:
            if docfreq is not False:
                self._docfreq = DocumentFrequencies(log_level=verbosity).load(
                    source=self._topics.dep("docfreq")["uuid"],
                    backend=backend)
            else:
                self._docfreq = None
                self._log.warning("Disabled document frequencies - you will "
                                  "not be able to query custom repositories.")
        else:
            assert isinstance(docfreq, DocumentFrequencies)
            self._docfreq = docfreq
        if self._docfreq is not None:
            self._docfreq = self._docfreq.prune(prune_df_threshold)
        self._log.info("Loaded docfreq model: %s", self._docfreq)
        if bow is not None:
            assert isinstance(bow, BOWBase)
            self._bow = bow
            if self._topics.matrix.shape[1] != self._bow.matrix.shape[1]:
                raise ValueError(
                    "Models do not match: topics has %s tokens while bow has %s"
                    %
                    (self._topics.matrix.shape[1], self._bow.matrix.shape[1]))
            self._log.info("Attached BOW model: %s", self._bow)
        else:
            self._bow = None
            self._log.warning("No BOW cache was loaded.")
        if self._docfreq is not None:
            self._repo2bow = Repo2BOW(
                {t: i
                 for i, t in enumerate(self._topics.tokens)}, self._docfreq,
                **(repo2bow_kwargs or {}))
        else:
            self._repo2bow = None

    def query(self, url_or_path_or_name, size=5):
        if size > len(self._topics):
            raise ValueError(
                "size may not be greater than the number of topics - %d" %
                len(self._topics))
        if self._bow is not None:
            try:
                repo_index = self._bow.repository_index_by_name(
                    url_or_path_or_name)
            except KeyError:
                repo_index = -1
            if repo_index == -1:
                match = self.GITHUB_URL_RE.match(url_or_path_or_name)
                if match is not None:
                    name = match.group(2)
                    try:
                        repo_index = self._bow.repository_index_by_name(name)
                    except KeyError:
                        pass
        else:
            repo_index = -1
        if repo_index >= 0:
            token_vector = self._bow.matrix[repo_index]
        else:
            if self._docfreq is None:
                raise ValueError(
                    "You need to specify document frequencies model to process "
                    "custom repositories")
            bow_dict = self._repo2bow.convert_repository(url_or_path_or_name)
            token_vector = numpy.zeros(self._topics.matrix.shape[1],
                                       dtype=numpy.float32)
            for i, v in bow_dict.items():
                token_vector[i] = v
            token_vector = csr_matrix(token_vector)
        topic_vector = -numpy.squeeze(
            self._topics.matrix.dot(token_vector.T).toarray())
        order = numpy.argsort(topic_vector)
        result = []
        i = 0
        while len(result) < size and i < len(self._topics):
            topic = self._topics.topics[order[i]]
            if topic:
                result.append((topic, -topic_vector[order[i]]))
            i += 1
        return result
Beispiel #12
0
class SimilarRepositories:
    GITHUB_URL_RE = re.compile(
        r"(https://|ssh://git@|git://)(github.com/[^/]+/[^/]+)(|.git|/)")

    def __init__(self,
                 id2vec=None,
                 df=None,
                 nbow=None,
                 prune_df_threshold=1,
                 verbosity=logging.DEBUG,
                 wmd_cache_centroids=True,
                 wmd_kwargs=None,
                 gcs_bucket=None,
                 repo2nbow_kwargs=None,
                 initialize_environment=True):
        if initialize_environment:
            initialize()
        self._log = logging.getLogger("similar_repos")
        self._log.setLevel(verbosity)
        if gcs_bucket:
            backend = create_backend(args="bucket=" + gcs_bucket)
        else:
            backend = create_backend()
        if id2vec is None:
            self._id2vec = Id2Vec(log_level=verbosity).load(backend=backend)
        else:
            assert isinstance(id2vec, Id2Vec)
            self._id2vec = id2vec
        self._log.info("Loaded id2vec model: %s", self._id2vec)
        if df is None:
            if df is not False:
                self._df = DocumentFrequencies(log_level=verbosity).load(
                    backend=backend)
            else:
                self._df = None
                self._log.warning("Disabled document frequencies - you will "
                                  "not be able to query custom repositories.")
        else:
            assert isinstance(df, DocumentFrequencies)
            self._df = df
        if self._df is not None:
            self._df = self._df.prune(prune_df_threshold)
        self._log.info("Loaded document frequencies: %s", self._df)
        if nbow is None:
            self._nbow = NBOW(log_level=verbosity).load(backend=backend)
        else:
            assert isinstance(nbow, NBOW)
            self._nbow = nbow
        self._log.info("Loaded nBOW model: %s", self._nbow)
        self._repo2nbow = Repo2nBOW(self._id2vec,
                                    self._df,
                                    log_level=verbosity,
                                    **(repo2nbow_kwargs or {}))
        assert self._nbow.get_dependency(
            "id2vec")["uuid"] == self._id2vec.meta["uuid"]
        if len(self._id2vec) != self._nbow.matrix.shape[1]:
            raise ValueError(
                "Models do not match: id2vec has %s tokens while nbow has %s" %
                (len(self._id2vec), self._nbow.matrix.shape[1]))
        self._log.info("Creating the WMD engine...")
        self._wmd = WMD(self._id2vec.embeddings,
                        self._nbow,
                        verbosity=verbosity,
                        **(wmd_kwargs or {}))
        if wmd_cache_centroids:
            self._wmd.cache_centroids()

    def query(self, url_or_path_or_name, **kwargs):
        try:
            repo_index = self._nbow.repository_index_by_name(
                url_or_path_or_name)
        except KeyError:
            repo_index = -1
        if repo_index == -1:
            match = self.GITHUB_URL_RE.match(url_or_path_or_name)
            if match is not None:
                name = match.group(2)
                try:
                    repo_index = self._nbow.repository_index_by_name(name)
                except KeyError:
                    pass
        if repo_index >= 0:
            neighbours = self._query_domestic(repo_index, **kwargs)
        else:
            neighbours = self._query_foreign(url_or_path_or_name, **kwargs)
        neighbours = [(self._nbow[n[0]][0], n[1]) for n in neighbours]
        return neighbours

    @staticmethod
    def unicorn_query(repo_name,
                      id2vec=None,
                      nbow=None,
                      wmd_kwargs=None,
                      query_wmd_kwargs=None):
        sr = SimilarRepositories(id2vec=id2vec,
                                 df=False,
                                 nbow=nbow,
                                 wmd_kwargs=wmd_kwargs or {
                                     "vocabulary_min": 50,
                                     "vocabulary_max": 500
                                 })
        return sr.query(
            repo_name,
            **(query_wmd_kwargs or {
                "early_stop": 0.1,
                "max_time": 180,
                "skipped_stop": 0.95
            }))

    def _query_domestic(self, repo_index, **kwargs):
        return self._wmd.nearest_neighbors(repo_index, **kwargs)

    def _query_foreign(self, url_or_path, **kwargs):
        if self._df is None:
            raise ValueError("Cannot query custom repositories if the "
                             "document frequencies are disabled.")
        nbow_dict = self._repo2nbow.convert_repository(url_or_path)
        words = sorted(nbow_dict.keys())
        weights = [nbow_dict[k] for k in words]
        return self._wmd.nearest_neighbors((words, weights), **kwargs)
Beispiel #13
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("input", help="Repository URL or path or name.")
    parser.add_argument("--log-level",
                        default="INFO",
                        choices=logging._nameToLevel,
                        help="Logging verbosity.")
    parser.add_argument("--id2vec",
                        default=None,
                        help="id2vec model URL or path.")
    parser.add_argument("--df",
                        default=None,
                        help="Document frequencies URL or path.")
    parser.add_argument("--nbow", default=None, help="nBOW model URL or path.")
    parser.add_argument("--no-cache-centroids",
                        action="store_true",
                        help="Do not cache WMD centroids.")
    parser.add_argument("--bblfsh",
                        default=None,
                        help="babelfish server address.")
    parser.add_argument(
        "--timeout",
        type=int,
        default=Repo2Base.DEFAULT_BBLFSH_TIMEOUT,
        help="Babelfish timeout - longer requests are dropped.")
    parser.add_argument("--gcs", default=None, help="GCS bucket to use.")
    parser.add_argument("--linguist",
                        default=None,
                        help="Path to github/linguist or src-d/enry.")
    parser.add_argument("--vocabulary-min",
                        default=50,
                        type=int,
                        help="Minimum number of words in a bag.")
    parser.add_argument("--vocabulary-max",
                        default=500,
                        type=int,
                        help="Maximum number of words in a bag.")
    parser.add_argument("-n",
                        "--nnn",
                        default=10,
                        type=int,
                        help="Number of nearest neighbours.")
    parser.add_argument("--early-stop",
                        default=0.1,
                        type=float,
                        help="Maximum fraction of the nBOW dataset to scan.")
    parser.add_argument("--max-time",
                        default=300,
                        type=int,
                        help="Maximum time to spend scanning in seconds.")
    parser.add_argument("--skipped-stop",
                        default=0.95,
                        type=float,
                        help="Minimum fraction of skipped samples to stop.")
    args = parser.parse_args()
    if args.linguist is None:
        args.linguist = "./enry"
    initialize(args.log_level, enry=args.linguist)
    if args.gcs:
        backend = create_backend(args="bucket=" + args.gcs)
    else:
        backend = create_backend()
    if args.id2vec is not None:
        args.id2vec = Id2Vec(source=args.id2vec, backend=backend)
    if args.df is not None:
        args.df = DocumentFrequencies(source=args.df, backend=backend)
    if args.nbow is not None:
        args.nbow = NBOW(source=args.nbow, backend=backend)
    sr = SimilarRepositories(id2vec=args.id2vec,
                             df=args.df,
                             nbow=args.nbow,
                             verbosity=args.log_level,
                             wmd_cache_centroids=not args.no_cache_centroids,
                             gcs_bucket=args.gcs,
                             repo2nbow_kwargs={
                                 "linguist": args.linguist,
                                 "bblfsh_endpoint": args.bblfsh,
                                 "timeout": args.timeout
                             },
                             wmd_kwargs={
                                 "vocabulary_min": args.vocabulary_min,
                                 "vocabulary_max": args.vocabulary_max
                             })
    neighbours = sr.query(args.input,
                          k=args.nnn,
                          early_stop=args.early_stop,
                          max_time=args.max_time,
                          skipped_stop=args.skipped_stop)
    for index, rate in neighbours:
        print("%48s\t%.2f" % (index, rate))