def __init__(self,
              id2vec=None,
              df=None,
              nbow=None,
              verbosity=logging.DEBUG,
              wmd_cache_centroids=True,
              wmd_kwargs=None,
              gcs_bucket=None,
              repo2nbow_kwargs=None,
              initialize_environment=True):
     if initialize_environment:
         initialize()
     self._log = logging.getLogger("similar_repos")
     self._log.setLevel(verbosity)
     if gcs_bucket:
         backend = create_backend(args="bucket=" + gcs_bucket)
     else:
         backend = create_backend()
     if id2vec is None:
         self._id2vec = Id2Vec(log_level=verbosity, backend=backend)
     else:
         assert isinstance(id2vec, Id2Vec)
         self._id2vec = id2vec
     self._log.info("Loaded id2vec model: %s", self._id2vec)
     if df is None:
         if df is not False:
             self._df = DocumentFrequencies(log_level=verbosity,
                                            backend=backend)
         else:
             self._df = None
             self._log.warning("Disabled document frequencies - you will "
                               "not be able to query custom repositories.")
     else:
         assert isinstance(df, DocumentFrequencies)
         self._df = df
     self._log.info("Loaded document frequencies: %s", self._df)
     if nbow is None:
         self._nbow = NBOW(log_level=verbosity, backend=backend)
     else:
         assert isinstance(nbow, NBOW)
         self._nbow = nbow
     self._log.info("Loaded nBOW model: %s", self._nbow)
     self._repo2nbow = Repo2nBOW(self._id2vec,
                                 self._df,
                                 log_level=verbosity,
                                 **(repo2nbow_kwargs or {}))
     self._log.info("Creating the WMD engine...")
     self._wmd = WMD(self._id2vec.embeddings,
                     self._nbow,
                     verbosity=verbosity,
                     **(wmd_kwargs or {}))
     if wmd_cache_centroids:
         self._wmd.cache_centroids()
Beispiel #2
0
 def __init__(self, id2vec=None, df=None, nbow=None, prune_df_threshold=1,
              verbosity=logging.DEBUG, wmd_cache_centroids=True, wmd_kwargs=None,
              gcs_bucket=None, repo2nbow_kwargs=None, initialize_environment=True):
     if initialize_environment:
         initialize()
     self._log = logging.getLogger("similar_repos")
     self._log.setLevel(verbosity)
     if gcs_bucket:
         backend = create_backend(args="bucket=" + gcs_bucket)
     else:
         backend = create_backend()
     if id2vec is None:
         self._id2vec = Id2Vec(log_level=verbosity).load(backend=backend)
     else:
         assert isinstance(id2vec, Id2Vec)
         self._id2vec = id2vec
     self._log.info("Loaded id2vec model: %s", self._id2vec)
     if df is None:
         if df is not False:
             self._df = DocumentFrequencies(log_level=verbosity).load(backend=backend)
         else:
             self._df = None
             self._log.warning("Disabled document frequencies - you will "
                               "not be able to query custom repositories.")
     else:
         assert isinstance(df, DocumentFrequencies)
         self._df = df
     if self._df is not None:
         self._df = self._df.prune(prune_df_threshold)
     self._log.info("Loaded document frequencies: %s", self._df)
     if nbow is None:
         self._nbow = NBOW(log_level=verbosity).load(backend=backend)
     else:
         assert isinstance(nbow, NBOW)
         self._nbow = nbow
     self._log.info("Loaded nBOW model: %s", self._nbow)
     self._repo2nbow = Repo2nBOW(
         self._id2vec, self._df, log_level=verbosity, **(repo2nbow_kwargs or {}))
     assert self._nbow.dep("id2vec")["uuid"] == self._id2vec.meta["uuid"]
     if len(self._id2vec) != self._nbow.matrix.shape[1]:
         raise ValueError("Models do not match: id2vec has %s tokens while nbow has %s" %
                          (len(self._id2vec), self._nbow.matrix.shape[1]))
     self._log.info("Creating the WMD engine...")
     self._wmd = WMD(self._id2vec.embeddings, self._nbow,
                     verbosity=verbosity, **(wmd_kwargs or {}))
     if wmd_cache_centroids:
         self._wmd.cache_centroids()
Beispiel #3
0
def validate_asdf_file(obj, filename):
    model = NBOW().load(filename)
    obj.assertIsInstance(model.repos, list)
    obj.assertGreater(len(model.repos[0]), 1)
    obj.assertEqual(model._matrix.shape, (1, 1000))
    obj.assertEqual(2, len(model.meta["dependencies"]))
Beispiel #4
0
 def setUp(self):
     self.model = NBOW().load(
         source=os.path.join(os.path.dirname(__file__), paths.NBOW))
Beispiel #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("input", help="Repository URL or path or name.")
    parser.add_argument("--log-level",
                        default="INFO",
                        choices=logging._nameToLevel,
                        help="Logging verbosity.")
    parser.add_argument("--id2vec",
                        default=None,
                        help="id2vec model URL or path.")
    parser.add_argument("--df",
                        default=None,
                        help="Document frequencies URL or path.")
    parser.add_argument("--nbow", default=None, help="nBOW model URL or path.")
    parser.add_argument("--no-cache-centroids",
                        action="store_true",
                        help="Do not cache WMD centroids.")
    parser.add_argument("--bblfsh",
                        default=None,
                        help="babelfish server address.")
    parser.add_argument(
        "--timeout",
        type=int,
        default=Repo2Base.DEFAULT_BBLFSH_TIMEOUT,
        help="Babelfish timeout - longer requests are dropped.")
    parser.add_argument("--gcs", default=None, help="GCS bucket to use.")
    parser.add_argument("--linguist",
                        default=None,
                        help="Path to github/linguist or src-d/enry.")
    parser.add_argument("--vocabulary-min",
                        default=50,
                        type=int,
                        help="Minimum number of words in a bag.")
    parser.add_argument("--vocabulary-max",
                        default=500,
                        type=int,
                        help="Maximum number of words in a bag.")
    parser.add_argument("-n",
                        "--nnn",
                        default=10,
                        type=int,
                        help="Number of nearest neighbours.")
    parser.add_argument("--early-stop",
                        default=0.1,
                        type=float,
                        help="Maximum fraction of the nBOW dataset to scan.")
    parser.add_argument("--max-time",
                        default=300,
                        type=int,
                        help="Maximum time to spend scanning in seconds.")
    parser.add_argument("--skipped-stop",
                        default=0.95,
                        type=float,
                        help="Minimum fraction of skipped samples to stop.")
    args = parser.parse_args()
    if args.linguist is None:
        args.linguist = "./enry"
    initialize(args.log_level, enry=args.linguist)
    if args.gcs:
        backend = create_backend(args="bucket=" + args.gcs)
    else:
        backend = create_backend()
    if args.id2vec is not None:
        args.id2vec = Id2Vec(source=args.id2vec, backend=backend)
    if args.df is not None:
        args.df = DocumentFrequencies(source=args.df, backend=backend)
    if args.nbow is not None:
        args.nbow = NBOW(source=args.nbow, backend=backend)
    sr = SimilarRepositories(id2vec=args.id2vec,
                             df=args.df,
                             nbow=args.nbow,
                             verbosity=args.log_level,
                             wmd_cache_centroids=not args.no_cache_centroids,
                             gcs_bucket=args.gcs,
                             repo2nbow_kwargs={
                                 "linguist": args.linguist,
                                 "bblfsh_endpoint": args.bblfsh,
                                 "timeout": args.timeout
                             },
                             wmd_kwargs={
                                 "vocabulary_min": args.vocabulary_min,
                                 "vocabulary_max": args.vocabulary_max
                             })
    neighbours = sr.query(args.input,
                          k=args.nnn,
                          early_stop=args.early_stop,
                          max_time=args.max_time,
                          skipped_stop=args.skipped_stop)
    for index, rate in neighbours:
        print("%48s\t%.2f" % (index, rate))