Exemple #1
0
 def test_create_spark(self):
     spark = create_spark("test_1")
     self.assertEqual(spark.read.parquet(PARQUET_DIR).count(), 6)
     spark.stop()
     spark = create_spark("test_2",
                          dep_zip=False,
                          config=["spark.rpc.retry.wait=10s"])
     self.assertEqual(spark.conf.get("spark.rpc.retry.wait"), "10s")
     spark.stop()
Exemple #2
0
def merge_coocc_spark(df, filepaths, log, args):
    session_name = "merge_coocc-%s" % uuid4()
    session = create_spark(session_name, **filter_kwargs(args.__dict__, create_spark))
    spark_context = session.sparkContext
    global_index = spark_context.broadcast(df.order)

    coocc_rdds = []

    def local_to_global(local_index):
        """
        Converts token index of co-occurrence matrix to the common index.
        For example index, 5 correspond to `get` token for a current model.
        And `get` have index 7 in the result.
        So we convert 5 to `get` via tokens list and `get` to 7 via global_index mapping.
        If global_index do not have `get` token, it returns -1.
        """
        return global_index.value.get(tokens.value[local_index], -1)

    for path, coocc in load_and_check(filepaths, log):
        rdd = coocc.matrix_to_rdd(spark_context)  # rdd structure: ((row, col), weight)
        log.info("Broadcasting tokens order for %s model...", path)
        tokens = spark_context.broadcast(coocc.tokens)
        coocc_rdds.append(
            rdd.map(lambda row: ((local_to_global(row[0][0]),
                                  local_to_global(row[0][1])),
                                 np.uint32(row[1])))
               .filter(lambda row: row[0][0] >= 0))

    log.info("Calculating the union of cooccurrence matrices...")
    rdd = spark_context \
        .union(coocc_rdds) \
        .reduceByKey(lambda x, y: min(MAX_INT32, x + y))
    CooccModelSaver(args.output, df)(rdd)
Exemple #3
0
def create_engine(session_name,
                  repositories,
                  repository_format=EngineDefault.REPOSITORY_FORMAT,
                  bblfsh=EngineDefault.BBLFSH,
                  engine=EngineDefault.VERSION,
                  config=SparkDefault.CONFIG,
                  packages=SparkDefault.JAR_PACKAGES,
                  spark=SparkDefault.MASTER_ADDRESS,
                  spark_local_dir=SparkDefault.LOCAL_DIR,
                  spark_log_level=SparkDefault.LOG_LEVEL,
                  dep_zip=SparkDefault.DEP_ZIP,
                  memory=SparkDefault.MEMORY):

    config += (get_bblfsh_dependency(bblfsh), )
    packages += (get_engine_package(engine), )
    session = create_spark(session_name,
                           spark=spark,
                           spark_local_dir=spark_local_dir,
                           config=config,
                           packages=packages,
                           spark_log_level=spark_log_level,
                           dep_zip=dep_zip,
                           memory=memory)
    logging.getLogger("engine").info("Initializing engine on %s", repositories)
    return Engine(session, repositories, repository_format)
Exemple #4
0
def create_engine(session_name,
                  repositories,
                  repository_format="siva",
                  bblfsh=None,
                  engine=None,
                  config=SparkDefault.CONFIG,
                  packages=SparkDefault.PACKAGES,
                  spark=SparkDefault.MASTER_ADDRESS,
                  spark_local_dir=SparkDefault.LOCAL_DIR,
                  spark_log_level=SparkDefault.LOG_LEVEL,
                  memory=SparkDefault.MEMORY,
                  dep_zip=False):
    if not bblfsh:
        bblfsh = "localhost"
    if not engine:
        engine = get_engine_version()
    config = assemble_spark_config(config=config, memory=memory)
    add_engine_dependencies(engine=engine, config=config, packages=packages)
    add_bblfsh_dependencies(bblfsh=bblfsh, config=config)
    session = create_spark(session_name,
                           spark=spark,
                           spark_local_dir=spark_local_dir,
                           config=config,
                           packages=packages,
                           spark_log_level=spark_log_level,
                           dep_zip=dep_zip)
    log = logging.getLogger("engine")
    log.info("Initializing on %s", repositories)
    engine = Engine(session, repositories, repository_format)
    return engine
Exemple #5
0
    def test_error(self):
        with self.assertRaises(ValueError):
            create_or_load_ordered_df(argparse.Namespace(docfreq_in=None), 10, None)

        with self.assertRaises(ValueError):
            session = create_spark("test_df_util")
            uast_extractor = ParquetLoader(session, paths.PARQUET_DIR) \
                .link(Moder("file")) \
                .link(UastRow2Document()) \
                .link(UastDeserializer()) \
                .link(Uast2BagFeatures(IdentifiersBagExtractor()))
            create_or_load_ordered_df(argparse.Namespace(docfreq_in=None), None, uast_extractor)
Exemple #6
0
 def test_create(self):
     session = create_spark("test_df_util")
     uast_extractor = ParquetLoader(session, paths.PARQUET_DIR) \
         .link(Moder("file")) \
         .link(UastRow2Document())
     ndocs = uast_extractor.link(Counter()).execute()
     uast_extractor = uast_extractor.link(UastDeserializer()) \
         .link(Uast2BagFeatures(IdentifiersBagExtractor()))
     with tempfile.TemporaryDirectory() as tmpdir:
         tmp_path = os.path.join(tmpdir, "df.asdf")
         args = argparse.Namespace(docfreq_in=None, docfreq_out=tmp_path, min_docfreq=1,
                                   vocabulary_size=1000)
         df_model = create_or_load_ordered_df(args, ndocs, uast_extractor)
         self.assertEqual(df_model.docs, ndocs)
         self.assertTrue(os.path.exists(tmp_path))
Exemple #7
0
 def test_create(self):
     session = create_spark("test_quant_util")
     extractor = ChildrenBagExtractor()
     with tempfile.NamedTemporaryFile(mode="r+b", suffix="-quant.asdf") as tmp:
         path = tmp.name
         uast_extractor = ParquetLoader(session, paths.PARQUET_DIR) \
             .link(Moder("file")) \
             .link(UastRow2Document()) \
             .link(UastDeserializer())
         create_or_apply_quant(path, [extractor], uast_extractor)
         self.assertIsNotNone(extractor.levels)
         self.assertTrue(os.path.exists(path))
         model_levels = QuantizationLevels().load(source=path)._levels["children"]
         for key in model_levels:
             self.assertListEqual(list(model_levels[key]), list(extractor.levels[key]))
Exemple #8
0
def create_parquet_loader(session_name, repositories,
                          config=SparkDefault.CONFIG,
                          packages=SparkDefault.JAR_PACKAGES,
                          spark=SparkDefault.MASTER_ADDRESS,
                          spark_local_dir=SparkDefault.LOCAL_DIR,
                          spark_log_level=SparkDefault.LOG_LEVEL,
                          memory=SparkDefault.MEMORY,
                          dep_zip=SparkDefault.DEP_ZIP):
    config += get_spark_memory_config(memory)
    session = create_spark(session_name, spark=spark, spark_local_dir=spark_local_dir,
                           config=config, packages=packages, spark_log_level=spark_log_level,
                           dep_zip=dep_zip)
    log = logging.getLogger("parquet")
    log.info("Initializing on %s", repositories)
    parquet = ParquetLoader(session, repositories)
    return parquet
Exemple #9
0
def create_engine(session_name,
                  repositories,
                  repository_format="siva",
                  bblfsh=None,
                  engine=None,
                  config=SparkDefault.CONFIG,
                  packages=SparkDefault.PACKAGES,
                  spark=SparkDefault.MASTER_ADDRESS,
                  spark_local_dir=SparkDefault.LOCAL_DIR,
                  spark_log_level=SparkDefault.LOG_LEVEL,
                  memory=SparkDefault.MEMORY,
                  dep_zip=False):
    if not bblfsh:
        bblfsh = "localhost"
    if not engine:
        try:
            engine = get_distribution("sourced-engine").version
        except DistributionNotFound:
            log = logging.getLogger("engine_version")
            engine = requests.get("https://api.github.com/repos/src-d/engine/releases/latest") \
                .json()["tag_name"].replace("v", "")
            log.warning(
                "Engine not found, queried GitHub to get the latest release tag (%s)",
                engine)
    config = assemble_spark_config(config=config, memory=memory)
    add_engine_dependencies(engine=engine, config=config, packages=packages)
    add_bblfsh_dependencies(bblfsh=bblfsh, config=config)
    session = create_spark(session_name,
                           spark=spark,
                           spark_local_dir=spark_local_dir,
                           config=config,
                           packages=packages,
                           spark_log_level=spark_log_level,
                           dep_zip=dep_zip)
    log = logging.getLogger("engine")
    log.info("Initializing on %s", repositories)
    engine = Engine(session, repositories, repository_format)
    return engine
Exemple #10
0
def create_spark_for_test(name="test"):
    if sys.version_info >= (3, 7):
        raise SkipTest("Python 3.7 is not yet supported.")
    packages = (get_engine_package(get_engine_version()),)
    config = (get_bblfsh_dependency("localhost"),)
    return create_spark(name, config=config, packages=packages)
 def setUp(self):
     self.spark = create_spark("test")