Ejemplo n.º 1
0
 def __init__(self,
              id2vec=None,
              df=None,
              nbow=None,
              verbosity=logging.DEBUG,
              wmd_cache_centroids=True,
              wmd_kwargs=None,
              gcs_bucket=None,
              repo2nbow_kwargs=None,
              initialize_environment=True):
     if initialize_environment:
         initialize()
     self._log = logging.getLogger("similar_repos")
     self._log.setLevel(verbosity)
     if gcs_bucket:
         backend = create_backend(args="bucket=" + gcs_bucket)
     else:
         backend = create_backend()
     if id2vec is None:
         self._id2vec = Id2Vec(log_level=verbosity, backend=backend)
     else:
         assert isinstance(id2vec, Id2Vec)
         self._id2vec = id2vec
     self._log.info("Loaded id2vec model: %s", self._id2vec)
     if df is None:
         if df is not False:
             self._df = DocumentFrequencies(log_level=verbosity,
                                            backend=backend)
         else:
             self._df = None
             self._log.warning("Disabled document frequencies - you will "
                               "not be able to query custom repositories.")
     else:
         assert isinstance(df, DocumentFrequencies)
         self._df = df
     self._log.info("Loaded document frequencies: %s", self._df)
     if nbow is None:
         self._nbow = NBOW(log_level=verbosity, backend=backend)
     else:
         assert isinstance(nbow, NBOW)
         self._nbow = nbow
     self._log.info("Loaded nBOW model: %s", self._nbow)
     self._repo2nbow = Repo2nBOW(self._id2vec,
                                 self._df,
                                 log_level=verbosity,
                                 **(repo2nbow_kwargs or {}))
     self._log.info("Creating the WMD engine...")
     self._wmd = WMD(self._id2vec.embeddings,
                     self._nbow,
                     verbosity=verbosity,
                     **(wmd_kwargs or {}))
     if wmd_cache_centroids:
         self._wmd.cache_centroids()
Ejemplo n.º 2
0
 def __init__(self, id2vec=None, df=None, nbow=None, prune_df_threshold=1,
              verbosity=logging.DEBUG, wmd_cache_centroids=True, wmd_kwargs=None,
              gcs_bucket=None, repo2nbow_kwargs=None, initialize_environment=True):
     if initialize_environment:
         initialize()
     self._log = logging.getLogger("similar_repos")
     self._log.setLevel(verbosity)
     if gcs_bucket:
         backend = create_backend(args="bucket=" + gcs_bucket)
     else:
         backend = create_backend()
     if id2vec is None:
         self._id2vec = Id2Vec(log_level=verbosity).load(backend=backend)
     else:
         assert isinstance(id2vec, Id2Vec)
         self._id2vec = id2vec
     self._log.info("Loaded id2vec model: %s", self._id2vec)
     if df is None:
         if df is not False:
             self._df = DocumentFrequencies(log_level=verbosity).load(backend=backend)
         else:
             self._df = None
             self._log.warning("Disabled document frequencies - you will "
                               "not be able to query custom repositories.")
     else:
         assert isinstance(df, DocumentFrequencies)
         self._df = df
     if self._df is not None:
         self._df = self._df.prune(prune_df_threshold)
     self._log.info("Loaded document frequencies: %s", self._df)
     if nbow is None:
         self._nbow = NBOW(log_level=verbosity).load(backend=backend)
     else:
         assert isinstance(nbow, NBOW)
         self._nbow = nbow
     self._log.info("Loaded nBOW model: %s", self._nbow)
     self._repo2nbow = Repo2nBOW(
         self._id2vec, self._df, log_level=verbosity, **(repo2nbow_kwargs or {}))
     assert self._nbow.dep("id2vec")["uuid"] == self._id2vec.meta["uuid"]
     if len(self._id2vec) != self._nbow.matrix.shape[1]:
         raise ValueError("Models do not match: id2vec has %s tokens while nbow has %s" %
                          (len(self._id2vec), self._nbow.matrix.shape[1]))
     self._log.info("Creating the WMD engine...")
     self._wmd = WMD(self._id2vec.embeddings, self._nbow,
                     verbosity=verbosity, **(wmd_kwargs or {}))
     if wmd_cache_centroids:
         self._wmd.cache_centroids()
Ejemplo n.º 3
0
class NBOWTests(unittest.TestCase):
    def setUp(self):
        self.model = NBOW().load(
            source=os.path.join(os.path.dirname(__file__), paths.NBOW))

    def test_getitem(self):
        repo_name, indices, weights = self.model[0]
        self.assertEqual(repo_name, "ikizir/HohhaDynamicXOR")
        self.assertIsInstance(indices, numpy.ndarray)
        self.assertIsInstance(weights, numpy.ndarray)
        self.assertEqual(indices.shape, weights.shape)
        self.assertEqual(indices.shape, (85, ))

    def test_iter(self):
        pumped = list(self.model)
        self.assertEqual(len(pumped), 1000)
        self.assertEqual(pumped, list(range(1000)))

    def test_len(self):
        self.assertEqual(len(self.model), 1000)

    def test_repository_index_by_name(self):
        self.assertEqual(
            self.model.repository_index_by_name("ikizir/HohhaDynamicXOR"), 0)
Ejemplo n.º 4
0
def validate_asdf_file(obj, filename):
    model = NBOW().load(filename)
    obj.assertIsInstance(model.repos, list)
    obj.assertGreater(len(model.repos[0]), 1)
    obj.assertEqual(model._matrix.shape, (1, 1000))
    obj.assertEqual(2, len(model.meta["dependencies"]))
Ejemplo n.º 5
0
class SimilarRepositories:
    GITHUB_URL_RE = re.compile(
        r"(https://|ssh://git@|git://)(github.com/[^/]+/[^/]+)(|.git|/)")

    def __init__(self,
                 id2vec=None,
                 df=None,
                 nbow=None,
                 prune_df_threshold=1,
                 verbosity=logging.DEBUG,
                 wmd_cache_centroids=True,
                 wmd_kwargs=None,
                 gcs_bucket=None,
                 repo2nbow_kwargs=None,
                 initialize_environment=True):
        if initialize_environment:
            initialize()
        self._log = logging.getLogger("similar_repos")
        self._log.setLevel(verbosity)
        if gcs_bucket:
            backend = create_backend(args="bucket=" + gcs_bucket)
        else:
            backend = create_backend()
        if id2vec is None:
            self._id2vec = Id2Vec(log_level=verbosity).load(backend=backend)
        else:
            assert isinstance(id2vec, Id2Vec)
            self._id2vec = id2vec
        self._log.info("Loaded id2vec model: %s", self._id2vec)
        if df is None:
            if df is not False:
                self._df = DocumentFrequencies(log_level=verbosity).load(
                    backend=backend)
            else:
                self._df = None
                self._log.warning("Disabled document frequencies - you will "
                                  "not be able to query custom repositories.")
        else:
            assert isinstance(df, DocumentFrequencies)
            self._df = df
        if self._df is not None:
            self._df = self._df.prune(prune_df_threshold)
        self._log.info("Loaded document frequencies: %s", self._df)
        if nbow is None:
            self._nbow = NBOW(log_level=verbosity).load(backend=backend)
        else:
            assert isinstance(nbow, NBOW)
            self._nbow = nbow
        self._log.info("Loaded nBOW model: %s", self._nbow)
        self._repo2nbow = Repo2nBOW(self._id2vec,
                                    self._df,
                                    log_level=verbosity,
                                    **(repo2nbow_kwargs or {}))
        assert self._nbow.get_dependency(
            "id2vec")["uuid"] == self._id2vec.meta["uuid"]
        if len(self._id2vec) != self._nbow.matrix.shape[1]:
            raise ValueError(
                "Models do not match: id2vec has %s tokens while nbow has %s" %
                (len(self._id2vec), self._nbow.matrix.shape[1]))
        self._log.info("Creating the WMD engine...")
        self._wmd = WMD(self._id2vec.embeddings,
                        self._nbow,
                        verbosity=verbosity,
                        **(wmd_kwargs or {}))
        if wmd_cache_centroids:
            self._wmd.cache_centroids()

    def query(self, url_or_path_or_name, **kwargs):
        try:
            repo_index = self._nbow.repository_index_by_name(
                url_or_path_or_name)
        except KeyError:
            repo_index = -1
        if repo_index == -1:
            match = self.GITHUB_URL_RE.match(url_or_path_or_name)
            if match is not None:
                name = match.group(2)
                try:
                    repo_index = self._nbow.repository_index_by_name(name)
                except KeyError:
                    pass
        if repo_index >= 0:
            neighbours = self._query_domestic(repo_index, **kwargs)
        else:
            neighbours = self._query_foreign(url_or_path_or_name, **kwargs)
        neighbours = [(self._nbow[n[0]][0], n[1]) for n in neighbours]
        return neighbours

    @staticmethod
    def unicorn_query(repo_name,
                      id2vec=None,
                      nbow=None,
                      wmd_kwargs=None,
                      query_wmd_kwargs=None):
        sr = SimilarRepositories(id2vec=id2vec,
                                 df=False,
                                 nbow=nbow,
                                 wmd_kwargs=wmd_kwargs or {
                                     "vocabulary_min": 50,
                                     "vocabulary_max": 500
                                 })
        return sr.query(
            repo_name,
            **(query_wmd_kwargs or {
                "early_stop": 0.1,
                "max_time": 180,
                "skipped_stop": 0.95
            }))

    def _query_domestic(self, repo_index, **kwargs):
        return self._wmd.nearest_neighbors(repo_index, **kwargs)

    def _query_foreign(self, url_or_path, **kwargs):
        if self._df is None:
            raise ValueError("Cannot query custom repositories if the "
                             "document frequencies are disabled.")
        nbow_dict = self._repo2nbow.convert_repository(url_or_path)
        words = sorted(nbow_dict.keys())
        weights = [nbow_dict[k] for k in words]
        return self._wmd.nearest_neighbors((words, weights), **kwargs)
Ejemplo n.º 6
0
 def setUp(self):
     self.model = NBOW().load(
         source=os.path.join(os.path.dirname(__file__), paths.NBOW))
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("input", help="Repository URL or path or name.")
    parser.add_argument("--log-level",
                        default="INFO",
                        choices=logging._nameToLevel,
                        help="Logging verbosity.")
    parser.add_argument("--id2vec",
                        default=None,
                        help="id2vec model URL or path.")
    parser.add_argument("--df",
                        default=None,
                        help="Document frequencies URL or path.")
    parser.add_argument("--nbow", default=None, help="nBOW model URL or path.")
    parser.add_argument("--no-cache-centroids",
                        action="store_true",
                        help="Do not cache WMD centroids.")
    parser.add_argument("--bblfsh",
                        default=None,
                        help="babelfish server address.")
    parser.add_argument(
        "--timeout",
        type=int,
        default=Repo2Base.DEFAULT_BBLFSH_TIMEOUT,
        help="Babelfish timeout - longer requests are dropped.")
    parser.add_argument("--gcs", default=None, help="GCS bucket to use.")
    parser.add_argument("--linguist",
                        default=None,
                        help="Path to github/linguist or src-d/enry.")
    parser.add_argument("--vocabulary-min",
                        default=50,
                        type=int,
                        help="Minimum number of words in a bag.")
    parser.add_argument("--vocabulary-max",
                        default=500,
                        type=int,
                        help="Maximum number of words in a bag.")
    parser.add_argument("-n",
                        "--nnn",
                        default=10,
                        type=int,
                        help="Number of nearest neighbours.")
    parser.add_argument("--early-stop",
                        default=0.1,
                        type=float,
                        help="Maximum fraction of the nBOW dataset to scan.")
    parser.add_argument("--max-time",
                        default=300,
                        type=int,
                        help="Maximum time to spend scanning in seconds.")
    parser.add_argument("--skipped-stop",
                        default=0.95,
                        type=float,
                        help="Minimum fraction of skipped samples to stop.")
    args = parser.parse_args()
    if args.linguist is None:
        args.linguist = "./enry"
    initialize(args.log_level, enry=args.linguist)
    if args.gcs:
        backend = create_backend(args="bucket=" + args.gcs)
    else:
        backend = create_backend()
    if args.id2vec is not None:
        args.id2vec = Id2Vec(source=args.id2vec, backend=backend)
    if args.df is not None:
        args.df = DocumentFrequencies(source=args.df, backend=backend)
    if args.nbow is not None:
        args.nbow = NBOW(source=args.nbow, backend=backend)
    sr = SimilarRepositories(id2vec=args.id2vec,
                             df=args.df,
                             nbow=args.nbow,
                             verbosity=args.log_level,
                             wmd_cache_centroids=not args.no_cache_centroids,
                             gcs_bucket=args.gcs,
                             repo2nbow_kwargs={
                                 "linguist": args.linguist,
                                 "bblfsh_endpoint": args.bblfsh,
                                 "timeout": args.timeout
                             },
                             wmd_kwargs={
                                 "vocabulary_min": args.vocabulary_min,
                                 "vocabulary_max": args.vocabulary_max
                             })
    neighbours = sr.query(args.input,
                          k=args.nnn,
                          early_stop=args.early_stop,
                          max_time=args.max_time,
                          skipped_stop=args.skipped_stop)
    for index, rate in neighbours:
        print("%48s\t%.2f" % (index, rate))