示例#1
0
 def __call__(self, head: RDD):
     c = Uast2BagFeatures.Columns
     self._log.info("Estimating the average bag size...")
     avglen = head \
         .map(lambda x: (x[c.document], 1)) \
         .reduceByKey(operator.add) \
         .map(lambda x: x[1])
     if self.explained:
         self._log.info("toDebugString():\n%s", avglen.toDebugString().decode())
     avglen = avglen.mean()
     self._log.info("Result: %.0f", avglen)
     avgdocnamelen = numpy.mean([len(v) for v in self.document_indexer.value_to_index])
     nparts = int(numpy.ceil(
         len(self.document_indexer) * (avglen * (4 + 4) + avgdocnamelen) / self.chunk_size))
     self._log.info("Estimated number of partitions: %d", nparts)
     doc_index_to_name = [None] * len(self.document_indexer)
     for k, i in self.document_indexer.value_to_index.items():
         doc_index_to_name[i] = k
     tokens = self.df.tokens()
     it = head \
         .map(lambda x: (x[c.document], (x[c.token], x[c.value]))) \
         .groupByKey() \
         .repartition(nparts) \
         .glom() \
         .toLocalIterator()
     if self.explained:
         self._log.info("toDebugString():\n%s", it.toDebugString().decode())
     ndocs = 0
     self._log.info("Writing files to %s", self.filename)
     for i, part in enumerate(it):
         docs = [doc_index_to_name[p[0]] for p in part]
         if not len(docs):
             self._log.info("Batch %d is empty, skipping.", i + 1)
             continue
         size = sum(len(p[1]) for p in part)
         data = numpy.zeros(size, dtype=numpy.float32)
         indices = numpy.zeros(size, dtype=numpy.int32)
         indptr = numpy.zeros(len(docs) + 1, dtype=numpy.int32)
         pos = 0
         for pi, (_, bag) in enumerate(part):
             for tok, val in sorted(bag):
                 indices[pos] = tok
                 data[pos] = val
                 pos += 1
             indptr[pi + 1] = indptr[pi] + len(bag)
         assert pos == size
         matrix = csr_matrix((data, indices, indptr), shape=(len(docs), len(tokens)))
         filename = self.get_bow_file_name(self.filename, i)
         BOW() \
             .construct(docs, tokens, matrix) \
             .save(filename, deps=(self.df,))
         self._log.info("%d -> %s with %d documents, %d nnz (%s)",
                        i + 1, filename, len(docs), size,
                        humanize.naturalsize(os.path.getsize(filename)))
         ndocs += len(docs)
     self._log.info("Final number of documents: %d", ndocs)
示例#2
0
 def test_convert_bow_to_vw(self):
     bow = BOW().load(source=paths.BOW)
     vocabulary = [
         "i.", "i.*", "i.Activity", "i.AdapterView", "i.ArrayAdapter",
         "i.Arrays"
     ]
     with tempfile.NamedTemporaryFile(prefix="sourced.ml-vw-") as fout:
         logging.getLogger().level = logging.ERROR
         try:
             bow.convert_bow_to_vw(fout.name)
         finally:
             logging.getLogger().level = logging.INFO
         fout.seek(0)
         contents = fout.read().decode()
     hits = 0
     for word in vocabulary:
         if " %s:" % word in contents:
             hits += 1
     self.assertEqual(hits, 4)
示例#3
0
 def test_finalize_reduce(self):
     self.merge_bow.convert_model(self.model1)
     self.merge_bow.features_namespaces = "f."
     with tempfile.TemporaryDirectory(prefix="merge-bow-") as tmpdir:
         dest = os.path.join(tmpdir, "bow.asdf")
         self.merge_bow.finalize(0, dest)
         bow = BOW().load(dest)
         self.assertListEqual(bow.documents, ["doc_1", "doc_2", "doc_3"])
         self.assertListEqual(bow.tokens, ["f.tok_1", "f.tok_3"])
         for i, row in enumerate(bow.matrix.toarray()):
             self.assertListEqual(list(row), self.merge_results[i][::2])
         self.assertEqual(bow.meta["dependencies"], [{'uuid': 'uuid', 'model': 'docfreq'}])
示例#4
0
 def test_finalize_base(self):
     self.merge_bow.convert_model(self.model1)
     self.merge_bow.convert_model(self.model2)
     with tempfile.TemporaryDirectory(prefix="merge-bow-") as tmpdir:
         dest = os.path.join(tmpdir, "bow.asdf")
         self.merge_bow.finalize(0, dest)
         bow = BOW().load(dest)
         self.assertListEqual(
             bow.documents,
             ["doc_1", "doc_2", "doc_3", "doc_4", "doc_5", "doc_6"])
         self.assertListEqual(bow.tokens, ["f.tok_1", "k.tok_2", "f.tok_3"])
         for i, row in enumerate(bow.matrix.toarray()):
             self.assertListEqual(list(row), self.merge_results[i])
         self.assertEqual(bow.meta["dependencies"], [{
             "uuid": "uuid",
             "model": "docfreq"
         }])
示例#5
0
 def __iter__(self):
     return (BOW().load(path) for path in self.files)
示例#6
0
class SimilarRepositories:
    GITHUB_URL_RE = re.compile(
        r"(https://|ssh://git@|git://)(github.com/[^/]+/[^/]+)(|.git|/)")
    _log = logging.getLogger("SimilarRepositories")

    def __init__(self,
                 id2vec=None,
                 df=None,
                 nbow=None,
                 prune_df_threshold=1,
                 wmd_cache_centroids=True,
                 wmd_kwargs: Dict[str, Any] = None,
                 languages: Tuple[List, bool] = (None, False),
                 engine_kwargs: Dict[str, Any] = None):
        backend = create_backend()
        if id2vec is None:
            self._id2vec = Id2Vec().load(backend=backend)
        else:
            assert isinstance(id2vec, Id2Vec)
            self._id2vec = id2vec
        self._log.info("Loaded id2vec model: %s", self._id2vec)
        if df is None:
            if df is not False:
                self._df = DocumentFrequencies().load(backend=backend)
            else:
                self._df = None
                self._log.warning("Disabled document frequencies - you will "
                                  "not be able to query custom repositories.")
        else:
            assert isinstance(df, DocumentFrequencies)
            self._df = df
        if self._df is not None:
            self._df = self._df.prune(prune_df_threshold)
        self._log.info("Loaded document frequencies: %s", self._df)
        if nbow is None:
            self._bow = BOW().load(backend=backend)
        else:
            assert isinstance(nbow, BOW)
            self._bow = nbow
        self._log.info("Loaded BOW model: %s", self._bow)
        assert self._bow.get_dep("id2vec")["uuid"] == self._id2vec.meta["uuid"]
        if len(self._id2vec) != self._bow.matrix.shape[1]:
            raise ValueError(
                "Models do not match: id2vec has %s tokens while nbow has %s" %
                (len(self._id2vec), self._bow.matrix.shape[1]))
        self._log.info("Creating the WMD engine...")
        self._wmd = WMD(self._id2vec.embeddings, self._bow, **(wmd_kwargs
                                                               or {}))
        if wmd_cache_centroids:
            self._wmd.cache_centroids()
        self._languages = languages
        self._engine_kwargs = engine_kwargs

    def query(self, url_or_path_or_name: str,
              **kwargs) -> List[Tuple[str, float]]:
        try:
            repo_index = self._bow.documents.index(url_or_path_or_name)
        except ValueError:
            repo_index = -1
        if repo_index == -1:
            match = self.GITHUB_URL_RE.match(url_or_path_or_name)
            if match is not None:
                name = match.group(2)
                try:
                    repo_index = self._bow.documents.index(name)
                except ValueError:
                    pass
        if repo_index >= 0:
            neighbours = self._query_domestic(repo_index, **kwargs)
        else:
            neighbours = self._query_foreign(url_or_path_or_name, **kwargs)
        neighbours = [(self._bow[n[0]][0], n[1]) for n in neighbours]
        return neighbours

    def _query_domestic(self, repo_index, **kwargs):
        return self._wmd.nearest_neighbors(repo_index, **kwargs)

    def _query_foreign(self, url_or_path: str, **kwargs):
        df = self._df
        if df is None:
            raise ValueError("Cannot query custom repositories if the "
                             "document frequencies are disabled.")

        with tempfile.TemporaryDirectory(prefix="vecino-") as tempdir:
            target = os.path.join(tempdir, "repo")
            if os.path.isdir(url_or_path):
                url_or_path = os.path.abspath(url_or_path)
                os.symlink(url_or_path, target, target_is_directory=True)
                repo_format = "standard"
            else:
                self._log.info("Cloning %s to %s", url_or_path, target)
                porcelain.clone(url_or_path,
                                target,
                                bare=True,
                                outstream=sys.stderr)
                repo_format = "bare"
            bow = repo2bow(tempdir,
                           repo_format,
                           1,
                           df,
                           *self._languages,
                           engine_kwargs=self._engine_kwargs)
        ibow = {}
        for key, val in bow.items():
            try:
                ibow[self._id2vec[key]] = val
            except KeyError:
                continue
        words, weights = zip(*sorted(ibow.items()))
        return self._wmd.nearest_neighbors((words, weights), **kwargs)
示例#7
0
def bow2vw_entry(args: argparse.Namespace):
    bow = BOW().load(source=args.bow)
    bow.convert_bow_to_vw(args.output)
示例#8
0
 def setUp(self):
     self.model = BOW().load(source=paths.BOW)
示例#9
0
文件: __main__.py 项目: zurk/vecino
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("input", help="Repository URL or path or name.")
    parser.add_argument("--log-level", default="INFO",
                        choices=logging._nameToLevel,
                        help="Logging verbosity.")
    parser.add_argument("--id2vec", default=None,
                        help="id2vec model URL or path.")
    parser.add_argument("--df", default=None,
                        help="Document frequencies URL or path.")
    parser.add_argument("--bow", default=None,
                        help="BOW model URL or path.")
    parser.add_argument("--prune-df", default=20, type=int,
                        help="Minimum number of times an identifier must occur in the dataset "
                             "to be taken into account.")
    parser.add_argument("--vocabulary-min", default=50, type=int,
                        help="Minimum number of words in a bag.")
    parser.add_argument("--vocabulary-max", default=500, type=int,
                        help="Maximum number of words in a bag.")
    parser.add_argument("-n", "--nnn", default=10, type=int,
                        help="Number of nearest neighbours.")
    parser.add_argument("--early-stop", default=0.1, type=float,
                        help="Maximum fraction of the nBOW dataset to scan.")
    parser.add_argument("--max-time", default=300, type=int,
                        help="Maximum time to spend scanning in seconds.")
    parser.add_argument("--skipped-stop", default=0.95, type=float,
                        help="Minimum fraction of skipped samples to stop.")
    languages = ["Java", "Python", "Go", "JavaScript", "TypeScript", "Ruby", "Bash", "Php"]
    parser.add_argument(
        "-l", "--languages", nargs="+", choices=languages,
        default=None,  # Default value for --languages arg should be None.
        # Otherwise if you process parquet files without 'lang' column, you will
        # fail to process it with any --languages argument.
        help="The programming languages to analyse.")
    parser.add_argument("--blacklist-languages", action="store_true",
                        help="Exclude the languages in --languages from the analysis "
                             "instead of filtering by default.")
    parser.add_argument(
        "-s", "--spark", default=SparkDefault.MASTER_ADDRESS,
        help="Spark's master address.")
    parser.add_argument("--bblfsh", default=EngineDefault.BBLFSH,
                        help="Babelfish server's address.")
    parser.add_argument("--engine", default=EngineDefault.VERSION,
                        help="source{d} jgit-spark-connector version.")
    args = parser.parse_args()
    setup_logging(args.log_level)
    backend = create_backend()
    if args.id2vec is not None:
        args.id2vec = Id2Vec().load(source=args.id2vec, backend=backend)
    if args.df is not None:
        args.df = DocumentFrequencies().load(source=args.df, backend=backend)
    if args.bow is not None:
        args.bow = BOW().load(source=args.bow, backend=backend)
    sr = SimilarRepositories(
        id2vec=args.id2vec, df=args.df, nbow=args.bow,
        prune_df_threshold=args.prune_df,
        wmd_cache_centroids=False,  # useless for a single query
        wmd_kwargs={"vocabulary_min": args.vocabulary_min,
                    "vocabulary_max": args.vocabulary_max},
        languages=(args.languages, args.blacklist_languages),
        engine_kwargs={"spark": args.spark,
                       "bblfsh": args.bblfsh,
                       "engine": args.engine},
    )
    neighbours = sr.query(
        args.input, k=args.nnn, early_stop=args.early_stop,
        max_time=args.max_time, skipped_stop=args.skipped_stop)
    for index, rate in neighbours:
        print("%48s\t%.2f" % (index, rate))