def test_obj(self): basedir = os.path.dirname(__file__) id2vec = Id2Vec().load(os.path.join(basedir, ID2VEC)) df = DocumentFrequencies().load(os.path.join(basedir, DOCFREQ)) df._df["xxyyzz"] = 10 id2vec._token2index[id2vec.tokens[0]] = 1 id2vec.tokens[0] = "xxyyzz" id2vec._token2index["xxyyzz"] = 0 xxyyzz = Repo2nBOW(id2vec=id2vec, docfreq=df, linguist=tests.ENRY) nbow = xxyyzz.convert_repository(os.path.join(basedir, "..", "..")) self.assertIsInstance(nbow, dict) self.assertAlmostEqual(nbow[0], 3.192060730416365)
def __init__(self, id2vec=None, df=None, nbow=None, verbosity=logging.DEBUG, wmd_cache_centroids=True, wmd_kwargs=None, gcs_bucket=None, repo2nbow_kwargs=None, initialize_environment=True): if initialize_environment: initialize() self._log = logging.getLogger("similar_repos") self._log.setLevel(verbosity) if gcs_bucket: backend = create_backend(args="bucket=" + gcs_bucket) else: backend = create_backend() if id2vec is None: self._id2vec = Id2Vec(log_level=verbosity, backend=backend) else: assert isinstance(id2vec, Id2Vec) self._id2vec = id2vec self._log.info("Loaded id2vec model: %s", self._id2vec) if df is None: if df is not False: self._df = DocumentFrequencies(log_level=verbosity, backend=backend) else: self._df = None self._log.warning("Disabled document frequencies - you will " "not be able to query custom repositories.") else: assert isinstance(df, DocumentFrequencies) self._df = df self._log.info("Loaded document frequencies: %s", self._df) if nbow is None: self._nbow = NBOW(log_level=verbosity, backend=backend) else: assert isinstance(nbow, NBOW) self._nbow = nbow self._log.info("Loaded nBOW model: %s", self._nbow) self._repo2nbow = Repo2nBOW(self._id2vec, self._df, log_level=verbosity, **(repo2nbow_kwargs or {})) self._log.info("Creating the WMD engine...") self._wmd = WMD(self._id2vec.embeddings, self._nbow, verbosity=verbosity, **(wmd_kwargs or {})) if wmd_cache_centroids: self._wmd.cache_centroids()
def test_asdf(self): basedir = os.path.dirname(__file__) id2vec = Id2Vec().load(os.path.join(basedir, ID2VEC)) del id2vec._token2index[id2vec.tokens[0]] id2vec.tokens[0] = "test" id2vec._token2index["test"] = 0 df = DocumentFrequencies().load(os.path.join(basedir, DOCFREQ)) df._df["test"] = 10 with tempfile.NamedTemporaryFile() as file: args = argparse.Namespace( id2vec=id2vec, docfreq=df, linguist=tests.ENRY, gcs_bucket=None, output=file.name, bblfsh_endpoint=None, timeout=None, repository=os.path.join(basedir, "..", ".."), prune_df=1) repo2nbow_entry(args) self.assertTrue(os.path.isfile(file.name)) validate_asdf_file(self, file.name)
def __init__(self, id2vec=None, df=None, nbow=None, prune_df_threshold=1, verbosity=logging.DEBUG, wmd_cache_centroids=True, wmd_kwargs=None, gcs_bucket=None, repo2nbow_kwargs=None, initialize_environment=True): if initialize_environment: initialize() self._log = logging.getLogger("similar_repos") self._log.setLevel(verbosity) if gcs_bucket: backend = create_backend(args="bucket=" + gcs_bucket) else: backend = create_backend() if id2vec is None: self._id2vec = Id2Vec(log_level=verbosity).load(backend=backend) else: assert isinstance(id2vec, Id2Vec) self._id2vec = id2vec self._log.info("Loaded id2vec model: %s", self._id2vec) if df is None: if df is not False: self._df = DocumentFrequencies(log_level=verbosity).load(backend=backend) else: self._df = None self._log.warning("Disabled document frequencies - you will " "not be able to query custom repositories.") else: assert isinstance(df, DocumentFrequencies) self._df = df if self._df is not None: self._df = self._df.prune(prune_df_threshold) self._log.info("Loaded document frequencies: %s", self._df) if nbow is None: self._nbow = NBOW(log_level=verbosity).load(backend=backend) else: assert isinstance(nbow, NBOW) self._nbow = nbow self._log.info("Loaded nBOW model: %s", self._nbow) self._repo2nbow = Repo2nBOW( self._id2vec, self._df, log_level=verbosity, **(repo2nbow_kwargs or {})) assert self._nbow.dep("id2vec")["uuid"] == self._id2vec.meta["uuid"] if len(self._id2vec) != self._nbow.matrix.shape[1]: raise ValueError("Models do not match: id2vec has %s tokens while nbow has %s" % (len(self._id2vec), self._nbow.matrix.shape[1])) self._log.info("Creating the WMD engine...") self._wmd = WMD(self._id2vec.embeddings, self._nbow, verbosity=verbosity, **(wmd_kwargs or {})) if wmd_cache_centroids: self._wmd.cache_centroids()
def test_transform(self): basedir = os.path.dirname(__file__) id2vec = Id2Vec().load(os.path.join(basedir, ID2VEC)) del id2vec._token2index[id2vec.tokens[0]] id2vec.tokens[0] = "test" id2vec._token2index["test"] = 0 df = DocumentFrequencies().load(os.path.join(basedir, DOCFREQ)) df._df["test"] = 10 with tempfile.TemporaryDirectory() as tmpdir: repo2nbow = Repo2nBOWTransformer( id2vec=id2vec, docfreq=df, linguist=tests.ENRY, gcs_bucket=None, ) outfile = repo2nbow.prepare_filename(basedir, tmpdir) status = repo2nbow.transform(repos=basedir, output=tmpdir) self.assertTrue(os.path.isfile(outfile)) validate_asdf_file(self, outfile) self.assertEqual(status, 1)
def setUp(self): self.model = Id2Vec().load( source=os.path.join(os.path.dirname(__file__), paths.ID2VEC))
def check_postproc_results(obj, id2vec_loc): id2vec = Id2Vec().load(source=id2vec_loc) obj.assertEqual(len(id2vec.tokens), obj.VOCAB) obj.assertEqual(id2vec.embeddings.shape, (obj.VOCAB, 50))
def main(): parser = argparse.ArgumentParser() parser.add_argument("input", help="Repository URL or path or name.") parser.add_argument("--log-level", default="INFO", choices=logging._nameToLevel, help="Logging verbosity.") parser.add_argument("--id2vec", default=None, help="id2vec model URL or path.") parser.add_argument("--df", default=None, help="Document frequencies URL or path.") parser.add_argument("--nbow", default=None, help="nBOW model URL or path.") parser.add_argument("--no-cache-centroids", action="store_true", help="Do not cache WMD centroids.") parser.add_argument("--bblfsh", default=None, help="babelfish server address.") parser.add_argument( "--timeout", type=int, default=Repo2Base.DEFAULT_BBLFSH_TIMEOUT, help="Babelfish timeout - longer requests are dropped.") parser.add_argument("--gcs", default=None, help="GCS bucket to use.") parser.add_argument("--linguist", default=None, help="Path to github/linguist or src-d/enry.") parser.add_argument("--vocabulary-min", default=50, type=int, help="Minimum number of words in a bag.") parser.add_argument("--vocabulary-max", default=500, type=int, help="Maximum number of words in a bag.") parser.add_argument("-n", "--nnn", default=10, type=int, help="Number of nearest neighbours.") parser.add_argument("--early-stop", default=0.1, type=float, help="Maximum fraction of the nBOW dataset to scan.") parser.add_argument("--max-time", default=300, type=int, help="Maximum time to spend scanning in seconds.") parser.add_argument("--skipped-stop", default=0.95, type=float, help="Minimum fraction of skipped samples to stop.") args = parser.parse_args() if args.linguist is None: args.linguist = "./enry" initialize(args.log_level, enry=args.linguist) if args.gcs: backend = create_backend(args="bucket=" + args.gcs) else: backend = create_backend() if args.id2vec is not None: args.id2vec = Id2Vec(source=args.id2vec, backend=backend) if args.df is not None: args.df = DocumentFrequencies(source=args.df, backend=backend) if args.nbow is not None: args.nbow = NBOW(source=args.nbow, backend=backend) sr = SimilarRepositories(id2vec=args.id2vec, df=args.df, nbow=args.nbow, verbosity=args.log_level, wmd_cache_centroids=not args.no_cache_centroids, gcs_bucket=args.gcs, repo2nbow_kwargs={ "linguist": args.linguist, "bblfsh_endpoint": args.bblfsh, "timeout": args.timeout }, wmd_kwargs={ "vocabulary_min": args.vocabulary_min, "vocabulary_max": args.vocabulary_max }) neighbours = sr.query(args.input, k=args.nnn, early_stop=args.early_stop, max_time=args.max_time, skipped_stop=args.skipped_stop) for index, rate in neighbours: print("%48s\t%.2f" % (index, rate))