def __init__(self, id2vec=None, df=None, nbow=None, verbosity=logging.DEBUG, wmd_cache_centroids=True, wmd_kwargs=None, gcs_bucket=None, repo2nbow_kwargs=None, initialize_environment=True): if initialize_environment: initialize() self._log = logging.getLogger("similar_repos") self._log.setLevel(verbosity) if gcs_bucket: backend = create_backend(args="bucket=" + gcs_bucket) else: backend = create_backend() if id2vec is None: self._id2vec = Id2Vec(log_level=verbosity, backend=backend) else: assert isinstance(id2vec, Id2Vec) self._id2vec = id2vec self._log.info("Loaded id2vec model: %s", self._id2vec) if df is None: if df is not False: self._df = DocumentFrequencies(log_level=verbosity, backend=backend) else: self._df = None self._log.warning("Disabled document frequencies - you will " "not be able to query custom repositories.") else: assert isinstance(df, DocumentFrequencies) self._df = df self._log.info("Loaded document frequencies: %s", self._df) if nbow is None: self._nbow = NBOW(log_level=verbosity, backend=backend) else: assert isinstance(nbow, NBOW) self._nbow = nbow self._log.info("Loaded nBOW model: %s", self._nbow) self._repo2nbow = Repo2nBOW(self._id2vec, self._df, log_level=verbosity, **(repo2nbow_kwargs or {})) self._log.info("Creating the WMD engine...") self._wmd = WMD(self._id2vec.embeddings, self._nbow, verbosity=verbosity, **(wmd_kwargs or {})) if wmd_cache_centroids: self._wmd.cache_centroids()
def __init__(self, id2vec=None, df=None, nbow=None, prune_df_threshold=1, verbosity=logging.DEBUG, wmd_cache_centroids=True, wmd_kwargs=None, gcs_bucket=None, repo2nbow_kwargs=None, initialize_environment=True): if initialize_environment: initialize() self._log = logging.getLogger("similar_repos") self._log.setLevel(verbosity) if gcs_bucket: backend = create_backend(args="bucket=" + gcs_bucket) else: backend = create_backend() if id2vec is None: self._id2vec = Id2Vec(log_level=verbosity).load(backend=backend) else: assert isinstance(id2vec, Id2Vec) self._id2vec = id2vec self._log.info("Loaded id2vec model: %s", self._id2vec) if df is None: if df is not False: self._df = DocumentFrequencies(log_level=verbosity).load(backend=backend) else: self._df = None self._log.warning("Disabled document frequencies - you will " "not be able to query custom repositories.") else: assert isinstance(df, DocumentFrequencies) self._df = df if self._df is not None: self._df = self._df.prune(prune_df_threshold) self._log.info("Loaded document frequencies: %s", self._df) if nbow is None: self._nbow = NBOW(log_level=verbosity).load(backend=backend) else: assert isinstance(nbow, NBOW) self._nbow = nbow self._log.info("Loaded nBOW model: %s", self._nbow) self._repo2nbow = Repo2nBOW( self._id2vec, self._df, log_level=verbosity, **(repo2nbow_kwargs or {})) assert self._nbow.dep("id2vec")["uuid"] == self._id2vec.meta["uuid"] if len(self._id2vec) != self._nbow.matrix.shape[1]: raise ValueError("Models do not match: id2vec has %s tokens while nbow has %s" % (len(self._id2vec), self._nbow.matrix.shape[1])) self._log.info("Creating the WMD engine...") self._wmd = WMD(self._id2vec.embeddings, self._nbow, verbosity=verbosity, **(wmd_kwargs or {})) if wmd_cache_centroids: self._wmd.cache_centroids()
class NBOWTests(unittest.TestCase): def setUp(self): self.model = NBOW().load( source=os.path.join(os.path.dirname(__file__), paths.NBOW)) def test_getitem(self): repo_name, indices, weights = self.model[0] self.assertEqual(repo_name, "ikizir/HohhaDynamicXOR") self.assertIsInstance(indices, numpy.ndarray) self.assertIsInstance(weights, numpy.ndarray) self.assertEqual(indices.shape, weights.shape) self.assertEqual(indices.shape, (85, )) def test_iter(self): pumped = list(self.model) self.assertEqual(len(pumped), 1000) self.assertEqual(pumped, list(range(1000))) def test_len(self): self.assertEqual(len(self.model), 1000) def test_repository_index_by_name(self): self.assertEqual( self.model.repository_index_by_name("ikizir/HohhaDynamicXOR"), 0)
def validate_asdf_file(obj, filename): model = NBOW().load(filename) obj.assertIsInstance(model.repos, list) obj.assertGreater(len(model.repos[0]), 1) obj.assertEqual(model._matrix.shape, (1, 1000)) obj.assertEqual(2, len(model.meta["dependencies"]))
class SimilarRepositories: GITHUB_URL_RE = re.compile( r"(https://|ssh://git@|git://)(github.com/[^/]+/[^/]+)(|.git|/)") def __init__(self, id2vec=None, df=None, nbow=None, prune_df_threshold=1, verbosity=logging.DEBUG, wmd_cache_centroids=True, wmd_kwargs=None, gcs_bucket=None, repo2nbow_kwargs=None, initialize_environment=True): if initialize_environment: initialize() self._log = logging.getLogger("similar_repos") self._log.setLevel(verbosity) if gcs_bucket: backend = create_backend(args="bucket=" + gcs_bucket) else: backend = create_backend() if id2vec is None: self._id2vec = Id2Vec(log_level=verbosity).load(backend=backend) else: assert isinstance(id2vec, Id2Vec) self._id2vec = id2vec self._log.info("Loaded id2vec model: %s", self._id2vec) if df is None: if df is not False: self._df = DocumentFrequencies(log_level=verbosity).load( backend=backend) else: self._df = None self._log.warning("Disabled document frequencies - you will " "not be able to query custom repositories.") else: assert isinstance(df, DocumentFrequencies) self._df = df if self._df is not None: self._df = self._df.prune(prune_df_threshold) self._log.info("Loaded document frequencies: %s", self._df) if nbow is None: self._nbow = NBOW(log_level=verbosity).load(backend=backend) else: assert isinstance(nbow, NBOW) self._nbow = nbow self._log.info("Loaded nBOW model: %s", self._nbow) self._repo2nbow = Repo2nBOW(self._id2vec, self._df, log_level=verbosity, **(repo2nbow_kwargs or {})) assert self._nbow.get_dependency( "id2vec")["uuid"] == self._id2vec.meta["uuid"] if len(self._id2vec) != self._nbow.matrix.shape[1]: raise ValueError( "Models do not match: id2vec has %s tokens while nbow has %s" % (len(self._id2vec), self._nbow.matrix.shape[1])) self._log.info("Creating the WMD engine...") self._wmd = WMD(self._id2vec.embeddings, self._nbow, verbosity=verbosity, **(wmd_kwargs or {})) if wmd_cache_centroids: self._wmd.cache_centroids() def query(self, url_or_path_or_name, **kwargs): try: repo_index = self._nbow.repository_index_by_name( url_or_path_or_name) except KeyError: repo_index = -1 if repo_index == -1: match = self.GITHUB_URL_RE.match(url_or_path_or_name) if match is not None: name = match.group(2) try: repo_index = self._nbow.repository_index_by_name(name) except KeyError: pass if repo_index >= 0: neighbours = self._query_domestic(repo_index, **kwargs) else: neighbours = self._query_foreign(url_or_path_or_name, **kwargs) neighbours = [(self._nbow[n[0]][0], n[1]) for n in neighbours] return neighbours @staticmethod def unicorn_query(repo_name, id2vec=None, nbow=None, wmd_kwargs=None, query_wmd_kwargs=None): sr = SimilarRepositories(id2vec=id2vec, df=False, nbow=nbow, wmd_kwargs=wmd_kwargs or { "vocabulary_min": 50, "vocabulary_max": 500 }) return sr.query( repo_name, **(query_wmd_kwargs or { "early_stop": 0.1, "max_time": 180, "skipped_stop": 0.95 })) def _query_domestic(self, repo_index, **kwargs): return self._wmd.nearest_neighbors(repo_index, **kwargs) def _query_foreign(self, url_or_path, **kwargs): if self._df is None: raise ValueError("Cannot query custom repositories if the " "document frequencies are disabled.") nbow_dict = self._repo2nbow.convert_repository(url_or_path) words = sorted(nbow_dict.keys()) weights = [nbow_dict[k] for k in words] return self._wmd.nearest_neighbors((words, weights), **kwargs)
def setUp(self): self.model = NBOW().load( source=os.path.join(os.path.dirname(__file__), paths.NBOW))
def main(): parser = argparse.ArgumentParser() parser.add_argument("input", help="Repository URL or path or name.") parser.add_argument("--log-level", default="INFO", choices=logging._nameToLevel, help="Logging verbosity.") parser.add_argument("--id2vec", default=None, help="id2vec model URL or path.") parser.add_argument("--df", default=None, help="Document frequencies URL or path.") parser.add_argument("--nbow", default=None, help="nBOW model URL or path.") parser.add_argument("--no-cache-centroids", action="store_true", help="Do not cache WMD centroids.") parser.add_argument("--bblfsh", default=None, help="babelfish server address.") parser.add_argument( "--timeout", type=int, default=Repo2Base.DEFAULT_BBLFSH_TIMEOUT, help="Babelfish timeout - longer requests are dropped.") parser.add_argument("--gcs", default=None, help="GCS bucket to use.") parser.add_argument("--linguist", default=None, help="Path to github/linguist or src-d/enry.") parser.add_argument("--vocabulary-min", default=50, type=int, help="Minimum number of words in a bag.") parser.add_argument("--vocabulary-max", default=500, type=int, help="Maximum number of words in a bag.") parser.add_argument("-n", "--nnn", default=10, type=int, help="Number of nearest neighbours.") parser.add_argument("--early-stop", default=0.1, type=float, help="Maximum fraction of the nBOW dataset to scan.") parser.add_argument("--max-time", default=300, type=int, help="Maximum time to spend scanning in seconds.") parser.add_argument("--skipped-stop", default=0.95, type=float, help="Minimum fraction of skipped samples to stop.") args = parser.parse_args() if args.linguist is None: args.linguist = "./enry" initialize(args.log_level, enry=args.linguist) if args.gcs: backend = create_backend(args="bucket=" + args.gcs) else: backend = create_backend() if args.id2vec is not None: args.id2vec = Id2Vec(source=args.id2vec, backend=backend) if args.df is not None: args.df = DocumentFrequencies(source=args.df, backend=backend) if args.nbow is not None: args.nbow = NBOW(source=args.nbow, backend=backend) sr = SimilarRepositories(id2vec=args.id2vec, df=args.df, nbow=args.nbow, verbosity=args.log_level, wmd_cache_centroids=not args.no_cache_centroids, gcs_bucket=args.gcs, repo2nbow_kwargs={ "linguist": args.linguist, "bblfsh_endpoint": args.bblfsh, "timeout": args.timeout }, wmd_kwargs={ "vocabulary_min": args.vocabulary_min, "vocabulary_max": args.vocabulary_max }) neighbours = sr.query(args.input, k=args.nnn, early_stop=args.early_stop, max_time=args.max_time, skipped_stop=args.skipped_stop) for index, rate in neighbours: print("%48s\t%.2f" % (index, rate))