示例#1
0
 def test_obj(self):
     basedir = os.path.dirname(__file__)
     id2vec = Id2Vec().load(os.path.join(basedir, ID2VEC))
     df = DocumentFrequencies().load(os.path.join(basedir, DOCFREQ))
     df._df["xxyyzz"] = 10
     id2vec._token2index[id2vec.tokens[0]] = 1
     id2vec.tokens[0] = "xxyyzz"
     id2vec._token2index["xxyyzz"] = 0
     xxyyzz = Repo2nBOW(id2vec=id2vec, docfreq=df, linguist=tests.ENRY)
     nbow = xxyyzz.convert_repository(os.path.join(basedir, "..", ".."))
     self.assertIsInstance(nbow, dict)
     self.assertAlmostEqual(nbow[0], 3.192060730416365)
示例#2
0
 def __init__(self,
              id2vec=None,
              df=None,
              nbow=None,
              verbosity=logging.DEBUG,
              wmd_cache_centroids=True,
              wmd_kwargs=None,
              gcs_bucket=None,
              repo2nbow_kwargs=None,
              initialize_environment=True):
     if initialize_environment:
         initialize()
     self._log = logging.getLogger("similar_repos")
     self._log.setLevel(verbosity)
     if gcs_bucket:
         backend = create_backend(args="bucket=" + gcs_bucket)
     else:
         backend = create_backend()
     if id2vec is None:
         self._id2vec = Id2Vec(log_level=verbosity, backend=backend)
     else:
         assert isinstance(id2vec, Id2Vec)
         self._id2vec = id2vec
     self._log.info("Loaded id2vec model: %s", self._id2vec)
     if df is None:
         if df is not False:
             self._df = DocumentFrequencies(log_level=verbosity,
                                            backend=backend)
         else:
             self._df = None
             self._log.warning("Disabled document frequencies - you will "
                               "not be able to query custom repositories.")
     else:
         assert isinstance(df, DocumentFrequencies)
         self._df = df
     self._log.info("Loaded document frequencies: %s", self._df)
     if nbow is None:
         self._nbow = NBOW(log_level=verbosity, backend=backend)
     else:
         assert isinstance(nbow, NBOW)
         self._nbow = nbow
     self._log.info("Loaded nBOW model: %s", self._nbow)
     self._repo2nbow = Repo2nBOW(self._id2vec,
                                 self._df,
                                 log_level=verbosity,
                                 **(repo2nbow_kwargs or {}))
     self._log.info("Creating the WMD engine...")
     self._wmd = WMD(self._id2vec.embeddings,
                     self._nbow,
                     verbosity=verbosity,
                     **(wmd_kwargs or {}))
     if wmd_cache_centroids:
         self._wmd.cache_centroids()
示例#3
0
 def test_asdf(self):
     basedir = os.path.dirname(__file__)
     id2vec = Id2Vec().load(os.path.join(basedir, ID2VEC))
     del id2vec._token2index[id2vec.tokens[0]]
     id2vec.tokens[0] = "test"
     id2vec._token2index["test"] = 0
     df = DocumentFrequencies().load(os.path.join(basedir, DOCFREQ))
     df._df["test"] = 10
     with tempfile.NamedTemporaryFile() as file:
         args = argparse.Namespace(
             id2vec=id2vec, docfreq=df, linguist=tests.ENRY,
             gcs_bucket=None, output=file.name, bblfsh_endpoint=None, timeout=None,
             repository=os.path.join(basedir, "..", ".."), prune_df=1)
         repo2nbow_entry(args)
         self.assertTrue(os.path.isfile(file.name))
         validate_asdf_file(self, file.name)
示例#4
0
 def __init__(self, id2vec=None, df=None, nbow=None, prune_df_threshold=1,
              verbosity=logging.DEBUG, wmd_cache_centroids=True, wmd_kwargs=None,
              gcs_bucket=None, repo2nbow_kwargs=None, initialize_environment=True):
     if initialize_environment:
         initialize()
     self._log = logging.getLogger("similar_repos")
     self._log.setLevel(verbosity)
     if gcs_bucket:
         backend = create_backend(args="bucket=" + gcs_bucket)
     else:
         backend = create_backend()
     if id2vec is None:
         self._id2vec = Id2Vec(log_level=verbosity).load(backend=backend)
     else:
         assert isinstance(id2vec, Id2Vec)
         self._id2vec = id2vec
     self._log.info("Loaded id2vec model: %s", self._id2vec)
     if df is None:
         if df is not False:
             self._df = DocumentFrequencies(log_level=verbosity).load(backend=backend)
         else:
             self._df = None
             self._log.warning("Disabled document frequencies - you will "
                               "not be able to query custom repositories.")
     else:
         assert isinstance(df, DocumentFrequencies)
         self._df = df
     if self._df is not None:
         self._df = self._df.prune(prune_df_threshold)
     self._log.info("Loaded document frequencies: %s", self._df)
     if nbow is None:
         self._nbow = NBOW(log_level=verbosity).load(backend=backend)
     else:
         assert isinstance(nbow, NBOW)
         self._nbow = nbow
     self._log.info("Loaded nBOW model: %s", self._nbow)
     self._repo2nbow = Repo2nBOW(
         self._id2vec, self._df, log_level=verbosity, **(repo2nbow_kwargs or {}))
     assert self._nbow.dep("id2vec")["uuid"] == self._id2vec.meta["uuid"]
     if len(self._id2vec) != self._nbow.matrix.shape[1]:
         raise ValueError("Models do not match: id2vec has %s tokens while nbow has %s" %
                          (len(self._id2vec), self._nbow.matrix.shape[1]))
     self._log.info("Creating the WMD engine...")
     self._wmd = WMD(self._id2vec.embeddings, self._nbow,
                     verbosity=verbosity, **(wmd_kwargs or {}))
     if wmd_cache_centroids:
         self._wmd.cache_centroids()
示例#5
0
 def test_transform(self):
     basedir = os.path.dirname(__file__)
     id2vec = Id2Vec().load(os.path.join(basedir, ID2VEC))
     del id2vec._token2index[id2vec.tokens[0]]
     id2vec.tokens[0] = "test"
     id2vec._token2index["test"] = 0
     df = DocumentFrequencies().load(os.path.join(basedir, DOCFREQ))
     df._df["test"] = 10
     with tempfile.TemporaryDirectory() as tmpdir:
         repo2nbow = Repo2nBOWTransformer(
             id2vec=id2vec, docfreq=df, linguist=tests.ENRY, gcs_bucket=None,
         )
         outfile = repo2nbow.prepare_filename(basedir, tmpdir)
         status = repo2nbow.transform(repos=basedir, output=tmpdir)
         self.assertTrue(os.path.isfile(outfile))
         validate_asdf_file(self, outfile)
         self.assertEqual(status, 1)
示例#6
0
 def setUp(self):
     self.model = Id2Vec().load(
         source=os.path.join(os.path.dirname(__file__), paths.ID2VEC))
示例#7
0
def check_postproc_results(obj, id2vec_loc):
    id2vec = Id2Vec().load(source=id2vec_loc)
    obj.assertEqual(len(id2vec.tokens), obj.VOCAB)
    obj.assertEqual(id2vec.embeddings.shape, (obj.VOCAB, 50))
示例#8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("input", help="Repository URL or path or name.")
    parser.add_argument("--log-level",
                        default="INFO",
                        choices=logging._nameToLevel,
                        help="Logging verbosity.")
    parser.add_argument("--id2vec",
                        default=None,
                        help="id2vec model URL or path.")
    parser.add_argument("--df",
                        default=None,
                        help="Document frequencies URL or path.")
    parser.add_argument("--nbow", default=None, help="nBOW model URL or path.")
    parser.add_argument("--no-cache-centroids",
                        action="store_true",
                        help="Do not cache WMD centroids.")
    parser.add_argument("--bblfsh",
                        default=None,
                        help="babelfish server address.")
    parser.add_argument(
        "--timeout",
        type=int,
        default=Repo2Base.DEFAULT_BBLFSH_TIMEOUT,
        help="Babelfish timeout - longer requests are dropped.")
    parser.add_argument("--gcs", default=None, help="GCS bucket to use.")
    parser.add_argument("--linguist",
                        default=None,
                        help="Path to github/linguist or src-d/enry.")
    parser.add_argument("--vocabulary-min",
                        default=50,
                        type=int,
                        help="Minimum number of words in a bag.")
    parser.add_argument("--vocabulary-max",
                        default=500,
                        type=int,
                        help="Maximum number of words in a bag.")
    parser.add_argument("-n",
                        "--nnn",
                        default=10,
                        type=int,
                        help="Number of nearest neighbours.")
    parser.add_argument("--early-stop",
                        default=0.1,
                        type=float,
                        help="Maximum fraction of the nBOW dataset to scan.")
    parser.add_argument("--max-time",
                        default=300,
                        type=int,
                        help="Maximum time to spend scanning in seconds.")
    parser.add_argument("--skipped-stop",
                        default=0.95,
                        type=float,
                        help="Minimum fraction of skipped samples to stop.")
    args = parser.parse_args()
    if args.linguist is None:
        args.linguist = "./enry"
    initialize(args.log_level, enry=args.linguist)
    if args.gcs:
        backend = create_backend(args="bucket=" + args.gcs)
    else:
        backend = create_backend()
    if args.id2vec is not None:
        args.id2vec = Id2Vec(source=args.id2vec, backend=backend)
    if args.df is not None:
        args.df = DocumentFrequencies(source=args.df, backend=backend)
    if args.nbow is not None:
        args.nbow = NBOW(source=args.nbow, backend=backend)
    sr = SimilarRepositories(id2vec=args.id2vec,
                             df=args.df,
                             nbow=args.nbow,
                             verbosity=args.log_level,
                             wmd_cache_centroids=not args.no_cache_centroids,
                             gcs_bucket=args.gcs,
                             repo2nbow_kwargs={
                                 "linguist": args.linguist,
                                 "bblfsh_endpoint": args.bblfsh,
                                 "timeout": args.timeout
                             },
                             wmd_kwargs={
                                 "vocabulary_min": args.vocabulary_min,
                                 "vocabulary_max": args.vocabulary_max
                             })
    neighbours = sr.query(args.input,
                          k=args.nnn,
                          early_stop=args.early_stop,
                          max_time=args.max_time,
                          skipped_stop=args.skipped_stop)
    for index, rate in neighbours:
        print("%48s\t%.2f" % (index, rate))