示例#1
0
def _test():
    import doctest
    import os
    import tempfile
    import py4j
    from pyspark.context import SparkContext
    from pyspark.sql import SparkSession, Row
    import pyspark.sql.readwriter

    os.chdir(os.environ["SPARK_HOME"])

    globs = pyspark.sql.readwriter.__dict__.copy()
    sc = SparkContext("local[4]", "PythonTest")
    try:
        spark = SparkSession.builder.enableHiveSupport().getOrCreate()
    except py4j.protocol.Py4JError:
        spark = SparkSession(sc)

    globs["tempfile"] = tempfile
    globs["os"] = os
    globs["sc"] = sc
    globs["spark"] = spark
    globs["df"] = spark.read.parquet("python/test_support/sql/parquet_partitioned")
    (failure_count, test_count) = doctest.testmod(
        pyspark.sql.readwriter,
        globs=globs,
        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF,
    )
    sc.stop()
    if failure_count:
        exit(-1)
示例#2
0
class PMMLTest(TestCase):
    def setUp(self):
        self.sc = SparkContext()
        self.sqlContext = SQLContext(self.sc)

    def tearDown(self):
        self.sc.stop()

    def testWorkflow(self):
        df = self.sqlContext.read.csv(os.path.join(os.path.dirname(__file__),
                                                   "resources/Iris.csv"),
                                      header=True,
                                      inferSchema=True)

        formula = RFormula(formula="Species ~ .")
        classifier = DecisionTreeClassifier()
        pipeline = Pipeline(stages=[formula, classifier])
        pipelineModel = pipeline.fit(df)

        pmmlBuilder = PMMLBuilder(self.sc, df, pipelineModel) \
         .putOption(classifier, "compact", True)
        pmmlBytes = pmmlBuilder.buildByteArray()
        pmmlString = pmmlBytes.decode("UTF-8")
        self.assertTrue(
            pmmlString.find(
                "<PMML xmlns=\"http://www.dmg.org/PMML-4_3\" version=\"4.3\">")
            > -1)
示例#3
0
class SparkTestingBaseTestCase(unittest2.TestCase):

    """Basic common test case for Spark. Provides a Spark context as sc.
    For non local mode testing you can either override sparkMaster
    or set the enviroment property SPARK_MASTER for non-local mode testing."""

    @classmethod
    def getMaster(cls):
        return os.getenv('SPARK_MASTER', "local[4]")

    def setUp(self):
        """Setup a basic Spark context for testing"""
        self.sc = SparkContext(self.getMaster())
        self.sql_context = HiveContext(self.sc)
        quiet_py4j()

    def tearDown(self):
        """
        Tear down the basic panda spark test case. This stops the running
        context and does a hack to prevent Akka rebinding on the same port.
        """
        self.sc.stop()
        # To avoid Akka rebinding to the same port, since it doesn't unbind
        # immediately on shutdown
        self.sc._jvm.System.clearProperty("spark.driver.port")
示例#4
0
def _test():
    import doctest
    import os
    import tempfile
    import py4j
    from pyspark.context import SparkContext
    from pyspark.sql import SparkSession, Row
    import pyspark.sql.readwriter

    os.chdir(os.environ["SPARK_HOME"])

    globs = pyspark.sql.readwriter.__dict__.copy()
    sc = SparkContext('local[4]', 'PythonTest')
    try:
        spark = SparkSession.builder.enableHiveSupport().getOrCreate()
    except py4j.protocol.Py4JError:
        spark = SparkSession(sc)

    globs['tempfile'] = tempfile
    globs['os'] = os
    globs['sc'] = sc
    globs['spark'] = spark
    globs['df'] = spark.read.parquet(
        'python/test_support/sql/parquet_partitioned')
    (failure_count, test_count) = doctest.testmod(
        pyspark.sql.readwriter,
        globs=globs,
        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE
        | doctest.REPORT_NDIFF)
    sc.stop()
    if failure_count:
        exit(-1)
示例#5
0
class SparkTestCase(unittest.TestCase):
    def resourceFile(self, filename, module='adam-core'):

        adamRoot = os.path.dirname(os.getcwd())
        return os.path.join(
            os.path.join(adamRoot, "%s/src/test/resources" % module), filename)

    def tmpFile(self):

        tempFile = tempfile.NamedTemporaryFile(delete=True)
        tempFile.close()
        return tempFile.name

    def checkFiles(self, file1, file2):

        f1 = open(file1)
        f2 = open(file2)

        try:
            self.assertEquals(f1.read(), f2.read())
        finally:
            f1.close()
            f2.close()

    def setUp(self):
        self._old_sys_path = list(sys.path)
        class_name = self.__class__.__name__
        self.sc = SparkContext('local[4]', class_name)

    def tearDown(self):
        self.sc.stop()
        sys.path = self._old_sys_path
示例#6
0
def _test():
    import doctest
    import os
    import tempfile
    import py4j
    from pyspark.context import SparkContext
    from pyspark.sql import SparkSession, Row
    import pyspark.sql.readwriter

    os.chdir(os.environ["SPARK_HOME"])

    globs = pyspark.sql.readwriter.__dict__.copy()
    sc = SparkContext('local[4]', 'PythonTest')
    try:
        spark = SparkSession.builder.enableHiveSupport().getOrCreate()
    except py4j.protocol.Py4JError:
        spark = SparkSession(sc)

    globs['tempfile'] = tempfile
    globs['os'] = os
    globs['sc'] = sc
    globs['spark'] = spark
    globs['df'] = spark.read.parquet('python/test_support/sql/parquet_partitioned')
    globs['sdf'] = \
        spark.read.format('text').stream('python/test_support/sql/streaming')

    (failure_count, test_count) = doctest.testmod(
        pyspark.sql.readwriter, globs=globs,
        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
    sc.stop()
    if failure_count:
        exit(-1)
示例#7
0
class PySparkTestCase(unittest.TestCase):
    def setUp(self):
        self._old_sys_path = list(sys.path)
        class_name = self.__class__.__name__
        self.sc = SparkContext('local[4]', class_name, batchSize=2)

    def tearDown(self):
        self.sc.stop()
        sys.path = self._old_sys_path
示例#8
0
class PyVertexRDDTestCase(unittest.TestCase):
    """
    Test collect, take, count, mapValues, diff,
    filter, mapVertexPartitions, innerJoin and leftJoin
    for VertexRDD
    """

    def setUp(self):
        class_name = self.__class__.__name__
        conf = SparkConf().set("spark.default.parallelism", 1)
        self.sc = SparkContext(appName=class_name, conf=conf)
        self.sc.setCheckpointDir("/tmp")

    def tearDown(self):
        self.sc.stop()

    def collect(self):
        vertexData = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertices = VertexRDD(vertexData)
        results = vertices.take(1)
        self.assertEqual(results, [(3, ("rxin", "student"))])

    def take(self):
        vertexData = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertices = VertexRDD(vertexData)
        results = vertices.collect()
        self.assertEqual(results, [(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])

    def count(self):
        vertexData = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertices = VertexRDD(vertexData)
        results = vertices.count()
        self.assertEqual(results, 2)

    def mapValues(self):
        vertexData = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertices = VertexRDD(vertexData)
        results = vertices.mapValues(lambda x: x + ":" + x)
        self.assertEqual(results, [(3, ("rxin:rxin", "student:student")),
                                   (7, ("jgonzal:jgonzal", "postdoc:postdoc"))])

    def innerJoin(self):
        vertexData0 = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertexData1 = self.sc.parallelize([(1, ("rxin", "student")), (2, ("jgonzal", "postdoc"))])
        vertices0 = VertexRDD(vertexData0)
        vertices1 = VertexRDD(vertexData1)
        results = vertices0.innerJoin(vertices1).collect()
        self.assertEqual(results, [])

    def leftJoin(self):
        vertexData0 = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertexData1 = self.sc.parallelize([(1, ("rxin", "student")), (2, ("jgonzal", "postdoc"))])
        vertices0 = VertexRDD(vertexData0)
        vertices1 = VertexRDD(vertexData1)
        results = vertices0.diff(vertices1)
        self.assertEqual(results, 2)
def main():
    parser = argparse.ArgumentParser(description="Find Dependency inclusions")
    parser.add_argument('--path', type=str)
    parser.add_argument('--cores', type=str)
    args = parser.parse_args()

    sc = SparkContext(appName="DDM")
    sc.getConf().set("spark.executor.cores", args.cores)
    sc.getConf().set("spark.driver.cores", args.cores)
    sc.getConf().set("spark.worker.cores", args.cores)
    sc.getConf().set("spark.deploy.defaultCores", args.cores)
    sc.getConf().set("spark.driver.memory", "15g")
    global number_of_columns
    data = []
    file_headers = []
    for file in os.listdir(args.path):
        if file.endswith(".csv"):
            rdd = sc.textFile(os.path.join(args.path, file)).map(lambda line: line[1:-1].split("\";\""))

            file_data = rdd.collect()
            file_header = file_data[0]
            del file_data[0]
            file_data = [(number_of_columns, x) for x in file_data]
            data += file_data
            file_headers += file_header
            number_of_columns = number_of_columns + len(file_header)

    header_dummies = list(range(0, number_of_columns))
    rdd = sc.parallelize(data)
    values_as_key = rdd.flatMap(lambda el: list(zip(el[1], range(el[0], el[0] + len(el[1])))))
    unique_values = values_as_key.map(lambda x: (x[0], x[1])).groupByKey().mapValues(set)
    unique_values = unique_values.map(lambda x: (tuple(x[1]), 0)).reduceByKey(sum_func)
    matrix_per_key = unique_values.map(lambda x: make_candidate_matrix(x[0]))
    result_matrix = matrix_per_key.reduce(lambda x, y: matrix_and(x, y))

    assert len(result_matrix) == number_of_columns

    output = []
    for i in range(0, number_of_columns):
        assert len(result_matrix[i]) == number_of_columns
        output.append([])

    for i in range(0, len(result_matrix)):
        for j in range(0, len(result_matrix[i])):
            if i != j and result_matrix[i][j]:
                output[j].append(file_headers[i])

    for i in range(0, len(output)):
        row = output[i]
        if len(row) != 0:
            output_string = str(row[0])
            for j in range(1, len(row)):
                output_string += (", " + str(row[j]))
            print(str(file_headers[i]) + " < " + output_string)

    sc.stop()
示例#10
0
def main(inputfolderpath, outputfolderpath, jobname):
    #inputfolderpath = "hdfs://santa-fe:47001/Source-Recommendation-System/FakeNewsCorpus/news_cleaned_2018_02_13.csv"
    #inputfolderpath = "hdfs://santa-fe:47001/FakeNewsCorpus/news_cleaned_2018_02_13.csv"
    #inputfolderpath = "hdfs://santa-fe:47001/FakeNewsCorpus-Outputs/news_cleaned_partitioned/news_cleaned_2018_02_1300000"
    #inputfolderpath = "hdfs://santa-fe:47001/Source-Recommendation-System/FakeNewsCorpus/news_sample.csv"
    #outputfolderpath = "hdfs://santa-fe:47001/Source-Recommendation-System/FakeNewsCorpus-Outputs"
    #outputfolderpath = "hdfs://santa-fe:47001/FakeNewsCorpus-Outputs/KeywordsFromPartitions/news_cleaned_partitioned/news_cleaned_2018_02_1300000temp"
    title_score = 10
    keywords_score = 13
    meta_keywords_score = 13
    meta_description_score = 13
    tags_score = 13
    summary_score = 10
    #spark = SparkSession.builder.appName(jobname).getOrCreate()
    sc = SparkContext(master="spark://santa-fe.cs.colostate.edu:47002", appName=jobname)
    delete_path(sc, outputfolderpath)
    sqlContext = SQLContext(sc)
    inputfile_rdd = sqlContext.read.csv(inputfolderpath, header=True,sep=",", multiLine = True, quote='"', escape='"')\
        .rdd.repartition(29)
    keywords_from_content = inputfile_rdd\
        .filter(lambda row : row["content"] is not None and row["content"] != "null")\
        .map(lambda  row : extract_with_row_id(row["id"], row["content"]))\
        .flatMap(lambda xs: [(x) for x in xs])
    keywords_from_title = inputfile_rdd\
        .filter(lambda row : row["title"] is not None and row["title"] != "null")\
        .map(lambda row : [(x,"(" + str(row["id"]) + "," + str(title_score) + ")") for x in get_processed_words(row["title"])])\
        .flatMap(lambda xs: [(x) for x in xs])
    keywords_from_keywords_col = inputfile_rdd\
        .filter(lambda row : row["keywords"] is not None and row["keywords"] != "null")\
        .map(lambda row : [(x.lower(),"(" + str(row["id"]) + "," + str(keywords_score) + ")") for x in get_keywords_from_keywords_col(row["keywords"])])\
        .flatMap(lambda xs: [(x) for x in xs])
    keywords_from_meta_keywords = inputfile_rdd\
        .filter(lambda row : row["meta_keywords"] is not None and row["meta_keywords"] != "null")\
        .map(lambda row : [(x.lower(),"(" + str(row["id"]) + "," + str(meta_keywords_score) + ")") for x in parse_meta_keywords(row["meta_keywords"]) if len(x) > 1 ])\
        .flatMap(lambda xs: [(x) for x in xs])
    keywords_from_meta_description = inputfile_rdd\
        .filter(lambda row : row["meta_description"] is not None and row["meta_description"] != "null")\
        .map(lambda row : [(x, "(" + str(row["id"]) + "," + str(meta_description_score) + ")") for x in get_processed_words(row["meta_description"])])\
        .flatMap(lambda xs: [(x) for x in xs])
    keywords_from_tags = inputfile_rdd\
        .filter(lambda row : row["tags"] is not None and row["tags"] != "null")\
        .map(lambda row : [(x.lower(), "(" + str(row["id"]) + "," + str(tags_score) + ")") for x in str(row["tags"].encode('ascii', "ignore")).split(",") ])\
        .flatMap(lambda xs: [(x) for x in xs])
    keywords_from_summary = inputfile_rdd\
        .filter(lambda row : row["summary"] is not None and row["summary"] != "null")\
        .map(lambda  row : extract_with_row_id(row["id"], row["summary"]))\
        .flatMap(lambda xs: [(x) for x in xs])
    all_keywords_list = [keywords_from_content, keywords_from_title, keywords_from_keywords_col, keywords_from_meta_keywords,
        keywords_from_meta_description, keywords_from_tags, keywords_from_summary]
    all_keywords_rdd = sc.union(all_keywords_list)
    all_keywords_rdd = all_keywords_rdd\
        .filter(lambda row: len(row[0]) > 2)\
        .reduceByKey(concat)
    all_keywords_df = all_keywords_rdd.toDF(["Keyword", "RowId & Score"])
    all_keywords_df.write.csv(outputfolderpath, header=True, quote='"', escape='"')
    sc.stop()
示例#11
0
文件: tests.py 项目: fireflyc/spark
class PySparkTestCase(unittest.TestCase):

    def setUp(self):
        self._old_sys_path = list(sys.path)
        class_name = self.__class__.__name__
        self.sc = SparkContext('local[4]', class_name, batchSize=2)

    def tearDown(self):
        self.sc.stop()
        sys.path = self._old_sys_path
示例#12
0
class PyEdgeRDDTestCase(unittest.TestCase):
    """
    Test collect, take, count, mapValues,
    filter and innerJoin for EdgeRDD
    """

    def setUp(self):
        class_name = self.__class__.__name__
        conf = SparkConf().set("spark.default.parallelism", 1)
        self.sc = SparkContext(appName=class_name, conf=conf)
        self.sc.setCheckpointDir("/tmp")

    def tearDown(self):
        self.sc.stop()

    # TODO
    def collect(self):
        vertexData = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertices = VertexRDD(vertexData)
        results = vertices.collect()
        self.assertEqual(results, [(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])

    # TODO
    def take(self):
        vertexData = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertices = VertexRDD(vertexData)
        results = vertices.collect()
        self.assertEqual(results, [(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])

    # TODO
    def count(self):
        vertexData = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertices = VertexRDD(vertexData)
        results = vertices.collect()
        self.assertEqual(results, 2)

    # TODO
    def mapValues(self):
        vertexData = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertices = VertexRDD(vertexData)
        results = vertices.collect()
        self.assertEqual(results, 2)

    # TODO
    def filter(self):
        return

    # TODO
    def innerJoin(self):
        vertexData0 = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertexData1 = self.sc.parallelize([(1, ("rxin", "student")), (2, ("jgonzal", "postdoc"))])
        vertices0 = VertexRDD(vertexData0)
        vertices1 = VertexRDD(vertexData1)
        results = vertices0.diff(vertices1)
        self.assertEqual(results, 2)
示例#13
0
class PySparkTestCase(unittest.TestCase):
 
    def setUp(self):
        self._old_sys_path = list(sys.path)
        conf = SparkConf().setMaster("local[2]") \
            .setAppName(self.__class__.__name__) \
        self.sc = SparkContext(conf=conf)
 
    def tearDown(self):
        self.sc.stop()
        sys.path = self._old_sys_path
示例#14
0
class Converter():
    def __init__(self, **kwargs):
        print(kwargs)
        self.setUp()
        self.input = kwargs.get('input')
        self.output = kwargs.get('output')
        self.in_format = self.input.split('.')[-1]
        self.out_format = self.output.split('.')[-1]
        self.mode = kwargs.get('mode', 'overwrite')
        self.compression = kwargs.get('compression', None)
        self.partitionBy = kwargs.get('partitionBy', None)
        if self.in_format == 'csv':
            self.df = self.sqlCtx.read.csv(self.input, header=True)
        elif 'parquet':
            self.df = self.sqlCtx.read.parquet(self.input)
        else:
            raise ValueError('Not support this format of source')

    def getMaster(self):
        return os.getenv('SPARK_MASTER', 'local[2]')

    def setUp(self):
        self.sc = SparkContext(self.getMaster())
        #quiet_py4j()
        try:
            from pyspark.sql import SparkSession
            self.sqlCtx = SparkSession.builder.getOrCreate()
        except:
            self.sqlCtx = SQLContext(self.sc)

    def tearDown(self):
        self.sc.stop()
        # To avoid Akka rebinding to the same port, since it doesn't unbind
        # immediately on shutdown
        self.sc._jvm.System.clearProperty("spark.driver.port")
 
    def head(self):
        return self.df.head

    def take(self, n):
        return self.df.take(n)

    def write(self):
        if self.out_format == 'csv':
            self.df.write.csv(self.output, mode=self.mode, compression=self.compression, header=True)
        elif self.out_format == 'parquet':
            self.df.write.parquet(self.output, mode=self.mode, compression=self.compression)

    def validate(self):
        df_out = self.df = self.sqlCtx.read.format(self.out_format).load(self.output)
        df_out_cnt = df_out.count()
        df_cnt = self.df.count()
        return df_cnt == df_out_cnt
示例#15
0
文件: tests.py 项目: shivaram/spark
class PySparkTestCase(unittest.TestCase):
    def setUp(self):
        self._old_sys_path = list(sys.path)
        class_name = self.__class__.__name__
        self.sc = SparkContext("local[4]", class_name, batchSize=2)

    def tearDown(self):
        self.sc.stop()
        sys.path = self._old_sys_path
        # To avoid Akka rebinding to the same port, since it doesn't unbind
        # immediately on shutdown
        self.sc._jvm.System.clearProperty("spark.driver.port")
示例#16
0
class PySparkTestCase(unittest.TestCase):
    def setUp(self):
        self._old_sys_path = list(sys.path)
        class_name = self.__class__.__name__
        self.sc = SparkContext('local[4]', class_name, batchSize=2)

    def tearDown(self):
        self.sc.stop()
        sys.path = self._old_sys_path
        # To avoid Akka rebinding to the same port, since it doesn't unbind
        # immediately on shutdown
        self.sc._jvm.System.clearProperty("spark.driver.port")
示例#17
0
class PySparkTestCase(unittest.TestCase):
    def setUp(self):
        class_name = self.__class__.__name__
        self.sc = SparkContext('local', class_name)
        self.sc._jvm.System.setProperty("spark.ui.showConsoleProgress", "false")
        log4j = self.sc._jvm.org.apache.log4j
        log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL)

    def tearDown(self):
        self.sc.stop()
        # To avoid Akka rebinding to the same port, since it doesn't unbind
        # immediately on shutdown
        self.sc._jvm.System.clearProperty("spark.driver.port")
class PySparkTestCase(unittest.TestCase):
    def setUp(self):
        class_name = self.__class__.__name__
        self.sc = SparkContext('local', class_name)

    def tearDown(self):
        self.sc.stop()

    def test_should_be_able_to_word_count(self):
        rdd = self.sc.parallelize(["This is a text", "Another text", "More text", "a text"])
        result = python_word_count.wordcount(rdd)
        expected = [('a', 2), ('This', 1), ('text', 4), ('is', 1), ('Another', 1), ('More', 1)]
        self.assertEquals(expected, result.collect())
def main(args):

    sc = SparkContext(appName="PGM")

    graph1 = sc.textFile(args.IN[0]).map(line_to_edge)
    graph2 = sc.textFile(args.IN[1]).map(line_to_edge)

    graph_name = args.IN[0]

    seed_num = args.sn
    PARTS = args.PARTS

    G1 = deep_copy(graph1, PARTS)
    G2 = deep_copy(graph2, PARTS)

    IsBucket = ""
    matchtype = ""

    if args.inseeds:
        seeds = sc.textFile(args.inseeds).map(line_to_edge)
        matchtype = "_seeded_"
        ETA = 0

    else:
        matchtype = "_seedless_"
        start = time()
        seeds = dinoise.seed_generator(sc, G1, G2, seed_num, PARTS)
        stop = time()
        ETA = round(float(stop - start) / 60, 4)
        stats = evaluate_output(graph_name + matchtype + IsBucket, G1, G2,
                                seeds, "seeds_log.csv", ETA, PARTS)

    if not args.bucketing:

        start = time()
        res = dinoise.distributed_noisy_seeds(sc, G1, G2, seeds, PARTS)
        stop = time()

    else:

        start = time()
        res = dinoise_w_bucketing.distributed_noisy_seeds(
            sc, G1, G2, seeds, PARTS)
        IsBucket = "_bucket_"
        stop = time()

    ETB = round(float(stop - start) / 60, 4)
    stats = evaluate_output(graph_name + matchtype + IsBucket, G1, G2, res,
                            "results_log.csv", ETB, PARTS)

    sc.stop()
示例#20
0
def _test() -> None:
    import doctest
    import sys
    from pyspark.context import SparkContext
    from pyspark.sql import SparkSession
    import pyspark.sql.observation
    globs = pyspark.sql.observation.__dict__.copy()
    sc = SparkContext('local[4]', 'PythonTest')
    globs['spark'] = SparkSession(sc)

    (failure_count, test_count) = doctest.testmod(pyspark.sql.observation, globs=globs)
    sc.stop()
    if failure_count:
        sys.exit(-1)
示例#21
0
class SparkBaseTestCase(TestCase):
    def setUp(self):
        """Setup a basic Spark context for testing"""
        super(SparkBaseTestCase, self).setUp()
        quiet_py4j()
        self.sc = SparkContext(os.getenv('SPARK_MASTER', 'local[4]'))
        self.sql_context = SQLContext(self.sc)

    def tearDown(self):
        """ Stops the running Spark context and does a hack to prevent Akka rebinding on the same port. """
        super(SparkBaseTestCase, self).tearDown()
        self.sc.stop()
        # To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
        self.sc._jvm.System.clearProperty('spark.driver.port')
示例#22
0
class PySparkTestCase(unittest.TestCase):
    def setUp(self):
        class_name = self.__class__.__name__
        self.sc = SparkContext('local', class_name)
        self.sc._jvm.System.setProperty("spark.ui.showConsoleProgress",
                                        "false")
        log4j = self.sc._jvm.org.apache.log4j
        log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL)

    def tearDown(self):
        self.sc.stop()
        # To avoid Akka rebinding to the same port, since it doesn't unbind
        # immediately on shutdown
        self.sc._jvm.System.clearProperty("spark.driver.port")
示例#23
0
class PMMLTest(TestCase):
    def setUp(self):
        self.sc = SparkContext()
        self.sqlContext = SQLContext(self.sc)

    def tearDown(self):
        self.sc.stop()

    def testWorkflow(self):
        df = self.sqlContext.read.csv(os.path.join(os.path.dirname(__file__),
                                                   "resources/Iris.csv"),
                                      header=True,
                                      inferSchema=True)

        formula = RFormula(formula="Species ~ .")
        classifier = DecisionTreeClassifier()
        pipeline = Pipeline(stages=[formula, classifier])
        pipelineModel = pipeline.fit(df)

        pmmlBuilder = PMMLBuilder(self.sc, df, pipelineModel) \
         .verify(df.sample(False, 0.1))

        pmml = pmmlBuilder.build()
        self.assertIsInstance(pmml, JavaObject)

        pmmlByteArray = pmmlBuilder.buildByteArray()
        self.assertTrue(
            isinstance(pmmlByteArray, bytes)
            or isinstance(pmmlByteArray, bytearray))

        pmmlString = pmmlByteArray.decode("UTF-8")
        self.assertTrue(
            "<PMML xmlns=\"http://www.dmg.org/PMML-4_3\" xmlns:data=\"http://jpmml.org/jpmml-model/InlineTable\" version=\"4.3\">"
            in pmmlString)
        self.assertTrue("<VerificationFields>" in pmmlString)

        pmmlBuilder = pmmlBuilder.putOption(classifier, "compact", False)
        nonCompactFile = tempfile.NamedTemporaryFile(prefix="pyspark2pmml-",
                                                     suffix=".pmml")
        nonCompactPmmlPath = pmmlBuilder.buildFile(nonCompactFile.name)

        pmmlBuilder = pmmlBuilder.putOption(classifier, "compact", True)
        compactFile = tempfile.NamedTemporaryFile(prefix="pyspark2pmml-",
                                                  suffix=".pmml")
        compactPmmlPath = pmmlBuilder.buildFile(compactFile.name)

        self.assertGreater(os.path.getsize(nonCompactPmmlPath),
                           os.path.getsize(compactPmmlPath) + 100)
示例#24
0
class SparkTestingBaseTestCase(unittest2.TestCase):

    """Basic common test case for Spark. Provides a Spark context as sc.
    For non local mode testing you can either override sparkMaster
    or set the enviroment property SPARK_MASTER for non-local mode testing."""

    @classmethod
    def getMaster(cls):
        return os.getenv('SPARK_MASTER', "local[4]")

    def setUp(self):
        """Setup a basic Spark context for testing"""
        self.sc = SparkContext(self.getMaster())
        quiet_py4j()

    def tearDown(self):
        """
        Tear down the basic panda spark test case. This stops the running
        context and does a hack to prevent Akka rebinding on the same port.
        """
        self.sc.stop()
        # To avoid Akka rebinding to the same port, since it doesn't unbind
        # immediately on shutdown
        self.sc._jvm.System.clearProperty("spark.driver.port")

    def assertRDDEquals(self, expected, result):
        return self.compareRDD(expected, result) == []

    def compareRDD(self, expected, result):
        expectedKeyed = expected.map(lambda x: (x, 1))\
                                .reduceByKey(lambda x, y: x + y)
        resultKeyed = result.map(lambda x: (x, 1))\
                            .reduceByKey(lambda x, y: x + y)
        return expectedKeyed.cogroup(resultKeyed)\
                            .map(lambda x: tuple(map(list, x[1])))\
                            .filter(lambda x: x[0] != x[1]).take(1)

    def assertRDDEqualsWithOrder(self, expected, result):
        return self.compareRDDWithOrder(expected, result) == []

    def compareRDDWithOrder(self, expected, result):
        def indexRDD(rdd):
            return rdd.zipWithIndex().map(lambda x: (x[1], x[0]))
        indexExpected = indexRDD(expected)
        indexResult = indexRDD(result)
        return indexExpected.cogroup(indexResult)\
                            .map(lambda x: tuple(map(list, x[1])))\
                            .filter(lambda x: x[0] != x[1]).take(1)
示例#25
0
class SparkTestingBaseTestCase(unittest2.TestCase):

    """Basic common test case for Spark. Provides a Spark context as sc.
    For non local mode testing you can either override sparkMaster
    or set the enviroment property SPARK_MASTER for non-local mode testing."""

    @classmethod
    def getMaster(cls):
        return os.getenv('SPARK_MASTER', "local[4]")

    def setUp(self):
        """Setup a basic Spark context for testing"""
        self.sc = SparkContext(self.getMaster())
        quiet_py4j()

    def tearDown(self):
        """
        Tear down the basic panda spark test case. This stops the running
        context and does a hack to prevent Akka rebinding on the same port.
        """
        self.sc.stop()
        # To avoid Akka rebinding to the same port, since it doesn't unbind
        # immediately on shutdown
        self.sc._jvm.System.clearProperty("spark.driver.port")

    def assertRDDEquals(self, expected, result):
        return self.compareRDD(expected, result) == []

    def compareRDD(self, expected, result):
        expectedKeyed = expected.map(lambda x: (x, 1))\
			.reduceByKey(lambda x, y: x + y)
        resultKeyed = result.map(lambda x: (x, 1))\
		      .reduceByKey(lambda x, y: x + y)
        return expectedKeyed.cogroup(resultKeyed)\
	       .map(lambda x: tuple(map(list,x[1])))\
	       .filter(lambda x: x[0] != x[1]).take(1)

    def assertRDDEqualsWithOrder(self, expected, result):
        return self.compareRDDWithOrder(expected, result) == []
    
    def compareRDDWithOrder(self, expected, result):
        def indexRDD(rdd):
            return rdd.zipWithIndex().map(lambda x: (x[1], x[0]))
        indexExpected = indexRDD(expected)
        indexResult = indexRDD(result)
        return indexExpected.cogroup(indexResult)\
           .map(lambda x: tuple(map(list,x[1])))\
           .filter(lambda x: x[0] != x[1]).take(1)
示例#26
0
def spark_context(_spark_session):
    """Return a SparkContext instance with reduced logging
    (session scope).
    """

    if _spark_session is None:
        from pyspark import SparkContext

        # pyspark 1.x: create SparkContext instance
        sc = SparkContext(conf=SparkConfigBuilder().get())
    else:
        # pyspark 2.x: get SparkContext from SparkSession fixture
        sc = _spark_session.sparkContext

    reduce_logging(sc)
    yield sc

    if _spark_session is None:
        sc.stop()  # pyspark 1.x: stop SparkContext instance
示例#27
0
class JPMMLTest(TestCase):

	def setUp(self):
		self.sc = SparkContext()
		self.sqlContext = SQLContext(self.sc)

	def tearDown(self):
		self.sc.stop()

	def testWorkflow(self):
		df = self.sqlContext.read.csv(irisCsvFile, header = True, inferSchema = True)
		
		formula = RFormula(formula = "Species ~ .")
		classifier = DecisionTreeClassifier()
		pipeline = Pipeline(stages = [formula, classifier])
		pipelineModel = pipeline.fit(df)
		
		pmmlBytes = toPMMLBytes(self.sc, df, pipelineModel)
		pmmlString = pmmlBytes.decode("UTF-8")
		self.assertTrue(pmmlString.find("<PMML xmlns=\"http://www.dmg.org/PMML-4_3\" version=\"4.3\">") > -1)
示例#28
0
def spark_context(request):
    """ fixture for creating a spark context
    Args:
        request: pytest.FixtureRequest object
    """
    conf = (SparkConf().setMaster("local[2]").setAppName(
        "pytest-pyspark-local-testing"))
    sc = SparkContext(conf=conf)
    request.addfinalizer(lambda: sc.stop())

    quiet_py4j()
    return sc
def main(args):

    sc = SparkContext(appName="PGM")

    graph1 = sc.textFile(args.IN[0]).map(line_to_edge)
    graph2 = sc.textFile(args.IN[1]).map(line_to_edge)

    graph_name = args.IN[0]

    seed_num = args.sn
    PARTS = args.PARTS

    G1 = deep_copy(graph1, PARTS)
    G2 = deep_copy(graph2, PARTS)

    if args.inseeds:
        seeds = sc.textFile(args.inseeds).map(line_to_edge)
        matchtype = "seeded"
        ETA = 0

    else:
        matchtype = "seedless"
        start = time()
        seeds = seed_generator(sc, G1, G2, seed_num, PARTS)
        stop = time()
        ETA = round(float(stop - start) / 60, 4)
        seeds2 = seeds.map(lambda pair: str(pair[0]) + " " + str(pair[1]))
        seeds2.coalesce(1).saveAsTextFile(args.OUT + "bucketing_segen_seeds")

    start = time()
    res = distributed_noisy_seeds(sc, G1, G2, seeds, PARTS)
    stop = time()
    ETB = round(float(stop - start) / 60, 4)
    res2 = res.map(lambda pair: str(pair[0]) + " " + str(pair[1]))
    res2.coalesce(1).saveAsTextFile(args.OUT + matchtype +
                                    "_bucketing_matching")
    print("\nSeGen   time :" + str(ETA) + " min ")
    print("DiNoiSe time :" + str(ETB) + " min\n")
    return [ETA, ETB]
    sc.stop()
示例#30
0
class SparkTestingBaseTestCase(unittest2.TestCase):
    """Basic common test case for Spark. Provides a Spark context as sc.
    For non local mode testing you can either override sparkMaster
    or set the enviroment property SPARK_MASTER for non-local mode testing."""
    @classmethod
    def getMaster(cls):
        return os.getenv('SPARK_MASTER', "local[4]")

    def setUp(self):
        """Setup a basic Spark context for testing"""
        self.sc = SparkContext(self.getMaster())
        quiet_py4j()

    def tearDown(self):
        """
        Tear down the basic panda spark test case. This stops the running
        context and does a hack to prevent Akka rebinding on the same port.
        """
        self.sc.stop()
        # To avoid Akka rebinding to the same port, since it doesn't unbind
        # immediately on shutdown
        self.sc._jvm.System.clearProperty("spark.driver.port")
示例#31
0
class SparkTestCase(unittest.TestCase):


    def resourceFile(self, filename, module='adam-core'):

        adamRoot = os.path.dirname(os.getcwd())
        return os.path.join(os.path.join(adamRoot,
                                         "%s/src/test/resources" % module),
                            filename)


    def tmpFile(self):

        tempFile = tempfile.NamedTemporaryFile(delete=True)
        tempFile.close()
        return tempFile.name


    def checkFiles(self, file1, file2):

        f1 = open(file1)
        f2 = open(file2)

        try:
            self.assertEquals(f1.read(), f2.read())
        finally:
            f1.close()
            f2.close()


    def setUp(self):
        self._old_sys_path = list(sys.path)
        class_name = self.__class__.__name__
        self.sc = SparkContext('local[4]', class_name)

        
    def tearDown(self):
        self.sc.stop()
        sys.path = self._old_sys_path
示例#32
0
import os
import platform
import pyspark
from pyspark.context import SparkContext
from pyspark.storagelevel import StorageLevel

# this is the equivalent of ADD_JARS
add_files = (os.environ.get("ADD_FILES").split(',')
             if os.environ.get("ADD_FILES") is not None else None)

if os.environ.get("SPARK_EXECUTOR_URI"):
    SparkContext.setSystemProperty("spark.executor.uri",
                                   os.environ["SPARK_EXECUTOR_URI"])

sc = SparkContext(appName="PySparkShell", pyFiles=add_files)
atexit.register(lambda: sc.stop())

print("""Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /__ / .__/\_,_/_/ /_/\_\   version 1.0.0-SNAPSHOT
      /_/
""")
print("Using Python version %s (%s, %s)" %
      (platform.python_version(), platform.python_build()[0],
       platform.python_build()[1]))
print("SparkContext available as sc.")

if add_files is not None:
    print("Adding files: [%s]" % ", ".join(add_files))
示例#33
0
def spark_context(request):
    sc = SparkContext('local', 'tests_practicas_spark')
    request.addfinalizer(lambda: sc.stop())
    logger = logging.getLogger('py4j')
    logger.setLevel(logging.WARN)
    return sc
示例#34
0
文件: mrgeo.py 项目: hzmarrou/mrgeo
class MrGeo(object):
    operators = {"+": ["__add__", "__radd__", "__iadd__"],
                 "-": ["__sub__", "__rsub__", "__isub__"],
                 "*": ["__mul__", "__rmul__", "__imul__"],
                 "/": ["__div__", "__truediv__", "__rdiv__", "__rtruediv__", "__idiv__", "__itruediv__"],
                 "//": [],  # floor div
                 "**": ["__pow__", "__rpow__", "__ipow__"],  # pow
                 "=": [],  # assignment, can't do!
                 "<": ["__lt__"],
                 "<=": ["__le__"],
                 ">": ["__lt__"],
                 ">=": ["__ge__"],
                 "==": ["__eq__"],
                 "!=": ["__ne__"],
                 "<>": [],
                 "!": [],
                 "&&": ["__and__", "__rand__", "__iand__"],
                 "&": [],
                 "||": ["__or__", "__ror__", "__ior__"],
                 "|": [],
                 "^": ["__xor__", "__rxor__", "__ixor__"],
                 "^=": []}
    reserved = ["or", "and", "str", "int", "long", "float", "bool"]

    gateway = None
    lock = Lock()

    sparkPyContext = None
    sparkContext = None
    job = None

    def __init__(self, gateway=None):

        MrGeo.ensure_gateway_initialized(self, gateway=gateway)
        try:
            self.initialize()
        except:
            # If an error occurs, clean up in order to allow future SparkContext creation:
            self.stop()
            raise

    @classmethod
    def ensure_gateway_initialized(cls, instance=None, gateway=None):
        """
        Checks whether a SparkContext is initialized or not.
        Throws error if a SparkContext is already running.
        """
        with MrGeo.lock:
            if not MrGeo.gateway:
                MrGeo.gateway = gateway or launch_gateway()
                MrGeo.jvm = MrGeo.gateway.jvm

    def _create_job(self):
        jvm = self.gateway.jvm
        java_import(jvm, "org.mrgeo.data.DataProviderFactory")
        java_import(jvm, "org.mrgeo.job.*")
        java_import(jvm, "org.mrgeo.utils.DependencyLoader")
        java_import(jvm, "org.mrgeo.utils.StringUtils")

        appname = "PyMrGeo"

        self.job = jvm.JobArguments()
        set_field(self.job, "name", appname)

        # Yarn in the default
        self.useyarn()

    def initialize(self):

        self._create_job()
        self._load_mapops()

    def _load_mapops(self):
        jvm = self.gateway.jvm
        client = self.gateway._gateway_client
        java_import(jvm, "org.mrgeo.job.*")
        java_import(jvm, "org.mrgeo.mapalgebra.MapOpFactory")
        java_import(jvm, "org.mrgeo.mapalgebra.raster.RasterMapOp")
        java_import(jvm, "org.mrgeo.mapalgebra.raster.MrsPyramidMapOp")
        java_import(jvm, "org.mrgeo.mapalgebra.ExportMapOp")
        java_import(jvm, "org.mrgeo.mapalgebra.vector.VectorMapOp")
        java_import(jvm, "org.mrgeo.mapalgebra.MapOp")
        java_import(jvm, "org.mrgeo.utils.SparkUtils")

        java_import(jvm, "org.mrgeo.data.*")

        mapops = jvm.MapOpFactory.getMapOpClasses()

        for rawmapop in mapops:
            mapop = str(rawmapop.getCanonicalName().rstrip('$'))

            java_import(jvm, mapop)

            cls = JavaClass(mapop, gateway_client=client)

            if self.is_instance_of(cls, jvm.RasterMapOp):
                instance = 'RasterMapOp'
            elif self.is_instance_of(cls, jvm.VectorMapOp):
                instance = 'VectorMapOp'
            elif self.is_instance_of(cls, jvm.MapOp):
                instance = "MapOp"
            else:
                # raise Exception("mapop (" + mapop + ") is not a RasterMapOp, VectorMapOp, or MapOp")
                print("mapop (" + mapop + ") is not a RasterMapOp, VectorMapOp, or MapOp")
                continue

            signatures = jvm.MapOpFactory.getSignatures(mapop)

            for method in cls.register():
                codes = None
                if method is not None:
                    name = method.strip().lower()
                    if len(name) > 0:
                        if name in self.reserved:
                            # print("reserved: " + name)
                            continue
                        elif name in self.operators:
                            # print("operator: " + name)
                            codes = self._generate_operator_code(mapop, name, signatures, instance)
                        else:
                            # print("method: " + name)
                            codes = self._generate_method_code(mapop, name, signatures, instance)

                if codes is not None:
                    for method_name, code in codes.iteritems():
                        # print(code)

                        compiled = {}
                        exec code in compiled

                        if instance == 'RasterMapOp':
                            setattr(RasterMapOp, method_name, compiled.get(method_name))
                        elif instance == "VectorMapOp":
                            #  setattr(VectorMapOp, method_name, compiled.get(method_name))
                            pass
                        elif self.is_instance_of(cls, jvm.MapOp):
                            setattr(RasterMapOp, method_name, compiled.get(method_name))
                            #  setattr(VectorMapOp, method_name, compiled.get(method_name))

    def _generate_operator_code(self, mapop, name, signatures, instance):
        methods = self._generate_methods(instance, signatures)

        if len(methods) == 0:
            return None

        # need to change the parameter names to "other" for all except us
        corrected_methods = []
        for method in methods:
            new_method = []
            if len(method) > 2:
                raise Exception("The parameters for an operator can only have 1 or 2 parameters")
            for param in method:
                lst = list(param)
                if lst[1].lower() == 'string' or \
                    lst[1].lower() == 'double' or \
                    lst[1].lower() == 'float' or \
                    lst[1].lower() == 'long' or \
                    lst[1].lower() == 'int' or \
                    lst[1].lower() == 'short' or \
                    lst[1].lower() == 'char' or \
                    lst[1].lower() == 'boolean':
                    lst[0] = "other"
                    lst[2] = "other"
                    # need to add this to the start of the list (in case we eventually check other.mapop from the elif
                elif lst[2] != "self":
                    lst[0] = "other"
                    lst[2] = "other"
                new_method.append(tuple(lst))

            corrected_methods.append(new_method)

        codes = {}
        for method_name in self.operators[name]:
            code = ""

            # Signature
            code += "def " + method_name + "(self, other):" + "\n"
            # code += "    print('" + name + "')\n"

            code += self._generate_imports(mapop)
            code += self._generate_calls(corrected_methods)
            code += self._generate_run()

            codes[method_name] = code
        return codes

    def _generate_method_code(self, mapop, name, signatures, instance):
        methods = self._generate_methods(instance, signatures)

        jvm = self.gateway.jvm
        client = self.gateway._gateway_client
        cls = JavaClass(mapop, gateway_client=client)

        is_export = is_remote() and self.is_instance_of(cls, jvm.ExportMapOp)

        if len(methods) == 0:
            return None

        signature = self._generate_signature(methods)

        code = ""
        # Signature
        code += "def " + name + "(" + signature + "):" + "\n"

        # code += "    print('" + name + "')\n"
        code += self._generate_imports(mapop, is_export)
        code += self._generate_calls(methods, is_export)
        code += self._generate_run(is_export)
        # print(code)

        return {name: code}

    def _generate_run(self, is_export=False):
        code = ""

        # Run the MapOp
        code += "    if (op.setup(self.job, self.context.getConf()) and\n"
        code += "        op.execute(self.context) and\n"
        code += "        op.teardown(self.job, self.context.getConf())):\n"
        # copy the Raster/VectorMapOp (so we got all the monkey patched code) and return it as the new mapop
        # TODO:  Add VectorMapOp!
        code += "        new_resource = copy.copy(self)\n"
        code += "        new_resource.mapop = op\n"

        if is_export:
            code += self._generate_saveraster()

        code += "        return new_resource\n"
        code += "    return None\n"
        return code

    def _generate_saveraster(self):
        code = ""
        # code += "        \n"
        code += "        cls = JavaClass('org.mrgeo.mapalgebra.ExportMapOp', gateway_client=self.gateway._gateway_client)\n"
        code += "        if hasattr(self, 'mapop') and self.is_instance_of(self.mapop, 'org.mrgeo.mapalgebra.raster.RasterMapOp') and type(name) is str and isinstance(singleFile, (int, long, float, str)) and isinstance(zoom, (int, long, float)) and isinstance(numTiles, (int, long, float)) and isinstance(mosaic, (int, long, float)) and type(format) is str and isinstance(randomTiles, (int, long, float, str)) and isinstance(tms, (int, long, float, str)) and type(colorscale) is str and type(tileids) is str and type(bounds) is str and isinstance(allLevels, (int, long, float, str)) and isinstance(overridenodata, (int, long, float)):\n"
        code += "            op = cls.create(self.mapop, str(name), True if singleFile else False, int(zoom), int(numTiles), int(mosaic), str(format), True if randomTiles else False, True if tms else False, str(colorscale), str(tileids), str(bounds), True if allLevels else False, float(overridenodata))\n"
        code += "        else:\n"
        code += "            raise Exception('input types differ (TODO: expand this message!)')\n"
        code += "        if (op.setup(self.job, self.context.getConf()) and\n"
        code += "                op.execute(self.context) and\n"
        code += "                op.teardown(self.job, self.context.getConf())):\n"
        code += "            new_resource = copy.copy(self)\n"
        code += "            new_resource.mapop = op\n"
        code += "            gdalutils = JavaClass('org.mrgeo.utils.GDALUtils', gateway_client=self.gateway._gateway_client)\n"
        code += "            java_image = op.image()\n"
        code += "            width = java_image.getRasterXSize()\n"
        code += "            height = java_image.getRasterYSize()\n"
        code += "            options = []\n"
        code += "            if format == 'jpg' or format == 'jpeg':\n"
        code += "                driver_name = 'jpeg'\n"
        code += "                extension = 'jpg'\n"
        code += "            elif format == 'tif' or format == 'tiff' or format == 'geotif' or format == 'geotiff' or format == 'gtif'  or format == 'gtiff':\n"
        code += "                driver_name = 'GTiff'\n"
        code += "                options.append('INTERLEAVE=BAND')\n"
        code += "                options.append('COMPRESS=DEFLATE')\n"
        code += "                options.append('PREDICTOR=1')\n"
        code += "                options.append('ZLEVEL=6')\n"
        code += "                options.append('TILES=YES')\n"
        code += "                if width < 2048:\n"
        code += "                    options.append('BLOCKXSIZE=' + str(width))\n"
        code += "                else:\n"
        code += "                    options.append('BLOCKXSIZE=2048')\n"
        code += "                if height < 2048:\n"
        code += "                    options.append('BLOCKYSIZE=' + str(height))\n"
        code += "                else:\n"
        code += "                    options.append('BLOCKYSIZE=2048')\n"
        code += "                extension = 'tif'\n"
        code += "            else:\n"
        code += "                driver_name = format\n"
        code += "                extension = format\n"
        code += "            datatype = java_image.GetRasterBand(1).getDataType()\n"
        code += "            if not local_name.endswith(extension):\n"
        code += "                local_name += '.' + extension\n"
        code += "            driver = gdal.GetDriverByName(driver_name)\n"
        code += "            local_image = driver.Create(local_name, width, height, java_image.getRasterCount(), datatype, options)\n"
        code += "            local_image.SetProjection(str(java_image.GetProjection()))\n"
        code += "            local_image.SetGeoTransform(java_image.GetGeoTransform())\n"
        code += "            java_nodatas = gdalutils.getnodatas(java_image)\n"
        code += "            print('saving image to ' + local_name)\n"
        code += "            print('downloading data... (' + str(gdalutils.getRasterBytes(java_image, 1) * local_image.RasterCount / 1024) + ' kb uncompressed)')\n"
        code += "            for i in xrange(1, local_image.RasterCount + 1):\n"
        code += "                start = time.time()\n"
        code += "                raw_data = gdalutils.getRasterDataAsCompressedBase64(java_image, i, 0, 0, width, height)\n"
        code += "                print('compressed/encoded data ' + str(len(raw_data)))\n"
        code += "                decoded_data = base64.b64decode(raw_data)\n"
        code += "                print('decoded data ' + str(len(decoded_data)))\n"
        code += "                decompressed_data = zlib.decompress(decoded_data, 16 + zlib.MAX_WBITS)\n"
        code += "                print('decompressed data ' + str(len(decompressed_data)))\n"
        code += "                byte_data = numpy.frombuffer(decompressed_data, dtype='b')\n"
        code += "                print('byte data ' + str(len(byte_data)))\n"
        code += "                image_data = byte_data.view(gdal_array.GDALTypeCodeToNumericTypeCode(datatype))\n"
        code += "                print('gdal-type data ' + str(len(image_data)))\n"
        code += "                image_data = image_data.reshape((-1, width))\n"
        code += "                print('reshaped ' + str(len(image_data)) + ' x ' + str(len(image_data[0])))\n"
        code += "                band = local_image.GetRasterBand(i)\n"
        code += "                print('writing band ' + str(i))\n"
        code += "                band.WriteArray(image_data)\n"
        code += "                end = time.time()\n"
        code += "                print('elapsed time: ' + str(end - start) + ' sec.')\n"
        code += "                band.SetNoDataValue(java_nodatas[i - 1])\n"
        code += "            local_image.FlushCache()\n"
        code += "            print('flushed cache')\n"

        return code

    def _generate_imports(self, mapop, is_export=False):
        code = ""
        # imports
        code += "    import copy\n"
        code += "    from numbers import Number\n"
        if is_export:
            code += "    import base64\n"
            code += "    import numpy\n"
            code += "    from osgeo import gdal, gdal_array\n"
            code += "    import time\n"
            code += "    import zlib\n"

        code += "    from py4j.java_gateway import JavaClass\n"
        # Get the Java class
        code += "    cls = JavaClass('" + mapop + "', gateway_client=self.gateway._gateway_client)\n"
        return code

    def _generate_calls(self, methods, is_export=False):

        # Check the input params and call the appropriate create() method
        firstmethod = True
        varargcode = ""
        code = ""

        if is_export:
            code += "    local_name = name\n"
            code += "    name = 'In-Memory'\n"

        for method in methods:
            iftest = ""
            call = []

            firstparam = True
            for param in method:
                var_name = param[0]
                type_name = param[1]
                call_name = param[2]

                if param[4]:
                    call_name, it, et = self.method_name(type_name, "arg")

                    if len(varargcode) == 0:
                        varargcode += "    array = self.gateway.new_array(self.gateway.jvm." + type_name + ", len(args))\n"
                        varargcode += "    cnt = 0\n"
                        call_name = "array"

                    varargcode += "    for arg in args:\n"
                    varargcode += "        if not(" + it + "):\n"
                    varargcode += "            raise Exception('input types differ (TODO: expand this message!)')\n"
                    varargcode += "    for arg in args:\n"
                    varargcode += "        array[cnt] = arg.mapop\n"
                    varargcode += "        cnt += 1\n"
                else:
                    if firstparam:
                        firstparam = False
                        if firstmethod:
                            firstmethod = False
                            iftest += "if"
                        else:
                            iftest += "elif"
                    else:
                        iftest += " and"

                    if call_name == "self":
                        var_name = call_name

                    call_name, it, et = self.method_name(type_name, var_name)
                    iftest += it

                call += [call_name]

            if len(iftest) > 0:
                iftest += ":\n"
                code += "    " + iftest

            code += "        op = cls.create(" + ", ".join(call) + ')\n'

        code += "    else:\n"
        code += "        raise Exception('input types differ (TODO: expand this message!)')\n"
        # code += "    import inspect\n"
        # code += "    method = inspect.stack()[0][3]\n"
        # code += "    print(method)\n"

        if len(varargcode) > 0:
            code = varargcode + code

        return code

    def method_name(self, type_name, var_name):
        if type_name == "String":
            iftest = " type(" + var_name + ") is str"
            call_name = "str(" + var_name + ")"
            excepttest = "not" + iftest
        elif type_name == "Double" or type_name == "Float":
            iftest = " isinstance(" + var_name + ", (int, long, float))"
            call_name = "float(" + var_name + ")"
            excepttest = "not" + iftest
        elif type_name == "Long":
            iftest = " isinstance(" + var_name + ", (int, long, float))"
            call_name = "long(" + var_name + ")"
            excepttest = "not" + iftest
        elif type_name == "Int" or type_name == "Short" or type_name == "Char":
            iftest = " isinstance(" + var_name + ", (int, long, float))"
            call_name = "int(" + var_name + ")"
            excepttest = "not" + iftest
        elif type_name == "Boolean":
            iftest = " isinstance(" + var_name + ", (int, long, float, str))"
            call_name = "True if " + var_name + " else False"
            excepttest = "not" + iftest
        elif type_name.endswith("MapOp"):
            base_var = var_name
            var_name += ".mapop"
            iftest = " hasattr(" + base_var + ", 'mapop') and self.is_instance_of(" + var_name + ", '" + type_name + "')"
            call_name = var_name
            excepttest = " hasattr(" + base_var + ", 'mapop') and not self.is_instance_of(" + var_name + ", '" + type_name + "')"
        else:
            iftest = " self.is_instance_of(" + var_name + ", '" + type_name + "')"
            call_name = var_name
            excepttest = "not" + iftest

        return call_name, iftest, excepttest

    def _generate_methods(self, instance, signatures):
        methods = []
        for sig in signatures:
            found = False
            method = []
            for variable in sig.split(","):
                names = re.split("[:=]+", variable)
                new_name = names[0]
                new_type = names[1]

                # var args?
                varargs = False
                if new_type.endswith("*"):
                    new_type = new_type[:-1]
                    new_name = "args"
                    varargs = True

                if len(names) == 3:
                    if names[2].lower() == "true":
                        new_value = "True"
                    elif names[2].lower() == "false":
                        new_value = "False"
                    elif names[2].lower() == "infinity":
                        new_value = "float('inf')"
                    elif names[2].lower() == "-infinity":
                        new_value = "float('-inf')"
                    elif names[2].lower() == "null":
                        new_value = "None"
                    else:
                        new_value = names[2]
                else:
                    new_value = None

                if ((not found) and
                        (new_type.endswith("MapOp") or
                             (instance is "RasterMapOp" and new_type.endswith("RasterMapOp")) or
                             (instance is "VectorMapOp" and new_type.endswith("VectorMapOp")))):
                    found = True
                    new_call = "self"
                else:
                    new_call = new_name

                tup = (new_name, new_type, new_call, new_value, varargs)
                method.append(tup)

            methods.append(method)
        return methods

    def _in_signature(self, param, signature):
        for s in signature:
            if s[0] == param[0]:
                if s[1] == param[1]:
                    if s[3] == param[3]:
                        return True
                    else:
                        raise Exception("only default values differ: " + str(s) + ": " + str(param))
                else:
                    raise Exception("type parameters differ: " + str(s) + ": " + str(param))
        return False

    def _generate_signature(self, methods):
        signature = []
        dual = len(methods) > 1
        for method in methods:
            for param in method:
                if not param[2] == "self" and not self._in_signature(param, signature):
                    signature.append(param)
                    if param[4]:
                        # var args must be the last parameter
                        break

        sig = ["self"]
        for s in signature:
            if s[4]:
                sig += ["*args"]
            else:
                if s[3] is not None:
                    sig += [s[0] + "=" + s[3]]
                elif dual:
                    sig += [s[0] + "=None"]
                else:
                    sig += [s[0]]

        return ",".join(sig)

    # @staticmethod
    # def _generate_code(mapop, name, signatures, instance):
    #
    #     signature, call, types, values = MrGeo._generate_params(instance, signatures)
    #
    #     sig = ""
    #     for s, d in zip(signature, values):
    #         if len(sig) > 0:
    #             sig += ", "
    #         sig += s
    #         if d is not None:
    #             sig += "=" + str(d)
    #
    #     code = ""
    #     code += "def " + name + "(" + sig + "):" + "\n"
    #     code += "    from py4j.java_gateway import JavaClass\n"
    #     code += "    #from rastermapop import RasterMapOp\n"
    #     code += "    import copy\n"
    #     code += "    print('" + name + "')\n"
    #     code += "    cls = JavaClass('" + mapop + "', gateway_client=self.gateway._gateway_client)\n"
    #     code += "    newop = cls.apply(" + ", ".join(call) + ')\n'
    #     code += "    if (newop.setup(self.job, self.context.getConf()) and\n"
    #     code += "        newop.execute(self.context) and\n"
    #     code += "        newop.teardown(self.job, self.context.getConf())):\n"
    #     code += "        new_raster = copy.copy(self)\n"
    #     code += "        new_raster.mapop = newop\n"
    #     code += "        return new_raster\n"
    #     code += "    return None\n"
    #
    #     # print(code)
    #
    #     return code

    def is_instance_of(self, java_object, java_class):
        if isinstance(java_class, basestring):
            name = java_class
        elif isinstance(java_class, JavaClass):
            name = java_class._fqn
        elif isinstance(java_class, JavaObject):
            name = java_class.getClass()
        else:
            raise Exception("java_class must be a string, a JavaClass, or a JavaObject")

        jvm = self.gateway.jvm
        name = jvm.Class.forName(name).getCanonicalName()

        if isinstance(java_object, JavaClass):
            cls = jvm.Class.forName(java_object._fqn)
        elif isinstance(java_class, JavaObject):
            cls = java_object.getClass()
        else:
            raise Exception("java_object must be a JavaClass, or a JavaObject")

        if cls.getCanonicalName() == name:
            return True

        return self._is_instance_of(cls.getSuperclass(), name)

    def _is_instance_of(self, clazz, name):
        if clazz:
            if clazz.getCanonicalName() == name:
                return True

            return self._is_instance_of(clazz.getSuperclass(), name)

        return False

    def usedebug(self):
        self.job.useDebug()

    def useyarn(self):
        self.job.useYarn()

    def start(self):
        jvm = self.gateway.jvm

        self.job.addMrGeoProperties()
        dpf_properties = jvm.DataProviderFactory.getConfigurationFromProviders()

        for prop in dpf_properties:
            self.job.setSetting(prop, dpf_properties[prop])

        if self.job.isDebug():
            master = "local"
        elif self.job.isSpark():
            # TODO:  get the master for spark
            master = ""
        elif self.job.isYarn():
            master = "yarn-client"
        else:
            cpus = (multiprocessing.cpu_count() / 4) * 3
            if cpus < 2:
                master = "local"
            else:
                master = "local[" + str(cpus) + "]"

        set_field(self.job, "jars",
                  jvm.StringUtils.concatUnique(
                      jvm.DependencyLoader.getAndCopyDependencies("org.mrgeo.mapalgebra.MapAlgebra", None),
                      jvm.DependencyLoader.getAndCopyDependencies(jvm.MapOpFactory.getMapOpClassNames(), None)))

        conf = jvm.MrGeoDriver.prepareJob(self.job)

        # need to override the yarn mode to "yarn-client" for python
        if self.job.isYarn():
            conf.set("spark.master", "yarn-client")

            mem = jvm.SparkUtils.humantokb(conf.get("spark.executor.memory"))
            workers = int(conf.get("spark.executor.instances")) + 1  # one for the driver

            conf.set("spark.executor.memory", jvm.SparkUtils.kbtohuman(long(mem / workers), "m"))

        # for a in conf.getAll():
        #     print(a._1(), a._2())

        # jsc = jvm.JavaSparkContext(master, appName, sparkHome, jars)
        jsc = jvm.JavaSparkContext(conf)
        self.sparkContext = jsc.sc()
        self.sparkPyContext = SparkContext(master=master, appName=self.job.name(), jsc=jsc, gateway=self.gateway)

        # print("started")

    def stop(self):
        if self.sparkContext:
            self.sparkContext.stop()
            self.sparkContext = None

        if self.sparkPyContext:
            self.sparkPyContext.stop()
            self.sparkPyContext = None

    def list_images(self):
        jvm = self.gateway.jvm

        pstr = self.job.getSetting(constants.provider_properties, "")
        pp = jvm.ProviderProperties.fromDelimitedString(pstr)

        rawimages = jvm.DataProviderFactory.listImages(pp)

        images = []
        for image in rawimages:
            images.append(str(image))

        return images

    def load_image(self, name):
        jvm = self.gateway.jvm

        pstr = self.job.getSetting(constants.provider_properties, "")
        pp = jvm.ProviderProperties.fromDelimitedString(pstr)

        dp = jvm.DataProviderFactory.getMrsImageDataProvider(name, jvm.DataProviderFactory.AccessMode.READ, pp)

        mapop = jvm.MrsPyramidMapOp.apply(dp)
        mapop.context(self.sparkContext)

        # print("loaded " + name)

        return RasterMapOp(mapop=mapop, gateway=self.gateway, context=self.sparkContext, job=self.job)
示例#35
0
class MrGeo(object):
    operators = {
        "+": ["__add__", "__radd__", "__iadd__"],
        "-": ["__sub__", "__rsub__", "__isub__"],
        "*": ["__mul__", "__rmul__", "__imul__"],
        "/": [
            "__div__", "__truediv__", "__rdiv__", "__rtruediv__", "__idiv__",
            "__itruediv__"
        ],
        "//": [],  # floor div
        "**": ["__pow__", "__rpow__", "__ipow__"],  # pow
        "=": [],  # assignment, can't do!
        "<": ["__lt__"],
        "<=": ["__le__"],
        ">": ["__lt__"],
        ">=": ["__ge__"],
        "==": ["__eq__"],
        "!=": ["__ne__"],
        "<>": [],
        "!": [],
        "&&": ["__and__", "__rand__", "__iand__"],
        "&": [],
        "||": ["__or__", "__ror__", "__ior__"],
        "|": [],
        "^": ["__xor__", "__rxor__", "__ixor__"],
        "^=": []
    }
    reserved = ["or", "and", "str", "int", "long", "float", "bool"]

    gateway = None
    lock = Lock()

    sparkPyContext = None
    sparkContext = None
    job = None

    def __init__(self, gateway=None):

        MrGeo.ensure_gateway_initialized(self, gateway=gateway)
        try:
            self.initialize()
        except:
            # If an error occurs, clean up in order to allow future SparkContext creation:
            self.stop()
            raise

    @classmethod
    def ensure_gateway_initialized(cls, instance=None, gateway=None):
        """
        Checks whether a SparkContext is initialized or not.
        Throws error if a SparkContext is already running.
        """
        with MrGeo.lock:
            if not MrGeo.gateway:
                MrGeo.gateway = gateway or launch_gateway()
                MrGeo.jvm = MrGeo.gateway.jvm

    def _create_job(self):
        jvm = self.gateway.jvm
        java_import(jvm, "org.mrgeo.data.DataProviderFactory")
        java_import(jvm, "org.mrgeo.job.*")
        java_import(jvm, "org.mrgeo.utils.DependencyLoader")
        java_import(jvm, "org.mrgeo.utils.StringUtils")

        appname = "PyMrGeo"

        self.job = jvm.JobArguments()
        set_field(self.job, "name", appname)

        # Yarn in the default
        self.useyarn()

    def initialize(self):

        self._create_job()
        self._load_mapops()

    def _load_mapops(self):
        jvm = self.gateway.jvm
        client = self.gateway._gateway_client
        java_import(jvm, "org.mrgeo.job.*")
        java_import(jvm, "org.mrgeo.mapalgebra.MapOpFactory")
        java_import(jvm, "org.mrgeo.mapalgebra.raster.RasterMapOp")
        java_import(jvm, "org.mrgeo.mapalgebra.raster.MrsPyramidMapOp")
        java_import(jvm, "org.mrgeo.mapalgebra.ExportMapOp")
        java_import(jvm, "org.mrgeo.mapalgebra.vector.VectorMapOp")
        java_import(jvm, "org.mrgeo.mapalgebra.MapOp")
        java_import(jvm, "org.mrgeo.utils.SparkUtils")

        java_import(jvm, "org.mrgeo.data.*")

        mapops = jvm.MapOpFactory.getMapOpClasses()

        for rawmapop in mapops:
            mapop = str(rawmapop.getCanonicalName().rstrip('$'))

            java_import(jvm, mapop)

            cls = JavaClass(mapop, gateway_client=client)

            if self.is_instance_of(cls, jvm.RasterMapOp):
                instance = 'RasterMapOp'
            elif self.is_instance_of(cls, jvm.VectorMapOp):
                instance = 'VectorMapOp'
            elif self.is_instance_of(cls, jvm.MapOp):
                instance = "MapOp"
            else:
                # raise Exception("mapop (" + mapop + ") is not a RasterMapOp, VectorMapOp, or MapOp")
                print("mapop (" + mapop +
                      ") is not a RasterMapOp, VectorMapOp, or MapOp")
                continue

            signatures = jvm.MapOpFactory.getSignatures(mapop)

            for method in cls.register():
                codes = None
                if method is not None:
                    name = method.strip().lower()
                    if len(name) > 0:
                        if name in self.reserved:
                            # print("reserved: " + name)
                            continue
                        elif name in self.operators:
                            # print("operator: " + name)
                            codes = self._generate_operator_code(
                                mapop, name, signatures, instance)
                        else:
                            # print("method: " + name)
                            codes = self._generate_method_code(
                                mapop, name, signatures, instance)

                if codes is not None:
                    for method_name, code in codes.iteritems():
                        # print(code)

                        compiled = {}
                        exec code in compiled

                        if instance == 'RasterMapOp':
                            setattr(RasterMapOp, method_name,
                                    compiled.get(method_name))
                        elif instance == "VectorMapOp":
                            #  setattr(VectorMapOp, method_name, compiled.get(method_name))
                            pass
                        elif self.is_instance_of(cls, jvm.MapOp):
                            setattr(RasterMapOp, method_name,
                                    compiled.get(method_name))
                            #  setattr(VectorMapOp, method_name, compiled.get(method_name))

    def _generate_operator_code(self, mapop, name, signatures, instance):
        methods = self._generate_methods(instance, signatures)

        if len(methods) == 0:
            return None

        # need to change the parameter names to "other" for all except us
        corrected_methods = []
        for method in methods:
            new_method = []
            if len(method) > 2:
                raise Exception(
                    "The parameters for an operator can only have 1 or 2 parameters"
                )
            for param in method:
                lst = list(param)
                if lst[1].lower() == 'string' or \
                    lst[1].lower() == 'double' or \
                    lst[1].lower() == 'float' or \
                    lst[1].lower() == 'long' or \
                    lst[1].lower() == 'int' or \
                    lst[1].lower() == 'short' or \
                    lst[1].lower() == 'char' or \
                    lst[1].lower() == 'boolean':
                    lst[0] = "other"
                    lst[2] = "other"
                    # need to add this to the start of the list (in case we eventually check other.mapop from the elif
                elif lst[2] != "self":
                    lst[0] = "other"
                    lst[2] = "other"
                new_method.append(tuple(lst))

            corrected_methods.append(new_method)

        codes = {}
        for method_name in self.operators[name]:
            code = ""

            # Signature
            code += "def " + method_name + "(self, other):" + "\n"
            # code += "    print('" + name + "')\n"

            code += self._generate_imports(mapop)
            code += self._generate_calls(corrected_methods)
            code += self._generate_run()

            codes[method_name] = code
        return codes

    def _generate_method_code(self, mapop, name, signatures, instance):

        methods = self._generate_methods(instance, signatures)

        jvm = self.gateway.jvm
        client = self.gateway._gateway_client
        cls = JavaClass(mapop, gateway_client=client)

        is_export = self.is_instance_of(cls, jvm.ExportMapOp)

        if len(methods) == 0:
            return None

        signature = self._generate_signature(methods)

        code = ""
        # Signature
        code += "def " + name + "(" + signature + "):" + "\n"

        # code += "    print('" + name + "')\n"
        code += self._generate_imports(mapop, is_export)
        code += self._generate_calls(methods, is_export)
        code += self._generate_run(is_export)
        # print(code)

        return {name: code}

    def _generate_run(self, is_export=False):
        code = ""

        # Run the MapOp
        code += "    if (op.setup(self.job, self.context.getConf()) and\n"
        code += "        op.execute(self.context) and\n"
        code += "        op.teardown(self.job, self.context.getConf())):\n"
        # copy the Raster/VectorMapOp (so we got all the monkey patched code) and return it as the new mapop
        # TODO:  Add VectorMapOp!
        code += "        new_resource = copy.copy(self)\n"
        code += "        new_resource.mapop = op\n"

        if is_export:
            code += self._generate_saveraster()

        code += "        return new_resource\n"
        code += "    return None\n"
        return code

    def _generate_saveraster(self):
        code = ""
        # code += "        \n"
        code += "        cls = JavaClass('org.mrgeo.mapalgebra.ExportMapOp', gateway_client=self.gateway._gateway_client)\n"
        code += "        if hasattr(self, 'mapop') and self.is_instance_of(self.mapop, 'org.mrgeo.mapalgebra.raster.RasterMapOp') and type(name) is str and isinstance(singleFile, (int, long, float, str)) and isinstance(zoom, (int, long, float)) and isinstance(numTiles, (int, long, float)) and isinstance(mosaic, (int, long, float)) and type(format) is str and isinstance(randomTiles, (int, long, float, str)) and isinstance(tms, (int, long, float, str)) and type(colorscale) is str and type(tileids) is str and type(bounds) is str and isinstance(allLevels, (int, long, float, str)) and isinstance(overridenodata, (int, long, float)):\n"
        code += "            op = cls.create(self.mapop, str(name), True if singleFile else False, int(zoom), int(numTiles), int(mosaic), str(format), True if randomTiles else False, True if tms else False, str(colorscale), str(tileids), str(bounds), True if allLevels else False, float(overridenodata))\n"
        code += "        else:\n"
        code += "            raise Exception('input types differ (TODO: expand this message!)')\n"
        code += "        if (op.setup(self.job, self.context.getConf()) and\n"
        code += "                op.execute(self.context) and\n"
        code += "                op.teardown(self.job, self.context.getConf())):\n"
        code += "            new_resource = copy.copy(self)\n"
        code += "            new_resource.mapop = op\n"
        code += "            gdalutils = JavaClass('org.mrgeo.utils.GDALUtils', gateway_client=self.gateway._gateway_client)\n"
        code += "            java_image = op.image()\n"
        code += "            width = java_image.getRasterXSize()\n"
        code += "            height = java_image.getRasterYSize()\n"
        code += "            options = []\n"
        code += "            if format == 'jpg' or format == 'jpeg':\n"
        code += "                driver_name = 'jpeg'\n"
        code += "                extension = 'jpg'\n"
        code += "            elif format == 'tif' or format == 'tiff' or format == 'geotif' or format == 'geotiff' or format == 'gtif'  or format == 'gtiff':\n"
        code += "                driver_name = 'GTiff'\n"
        code += "                options.append('INTERLEAVE=BAND')\n"
        code += "                options.append('COMPRESS=DEFLATE')\n"
        code += "                options.append('PREDICTOR=1')\n"
        code += "                options.append('ZLEVEL=6')\n"
        code += "                options.append('TILES=YES')\n"
        code += "                if width < 2048:\n"
        code += "                    options.append('BLOCKXSIZE=' + str(width))\n"
        code += "                else:\n"
        code += "                    options.append('BLOCKXSIZE=2048')\n"
        code += "                if height < 2048:\n"
        code += "                    options.append('BLOCKYSIZE=' + str(height))\n"
        code += "                else:\n"
        code += "                    options.append('BLOCKYSIZE=2048')\n"
        code += "                extension = 'tif'\n"
        code += "            else:\n"
        code += "                driver_name = format\n"
        code += "                extension = format\n"
        code += "            datatype = java_image.GetRasterBand(1).getDataType()\n"
        code += "            if not local_name.endswith(extension):\n"
        code += "                local_name += '.' + extension\n"
        code += "            driver = gdal.GetDriverByName(driver_name)\n"
        code += "            local_image = driver.Create(local_name, width, height, java_image.getRasterCount(), datatype, options)\n"
        code += "            local_image.SetProjection(str(java_image.GetProjection()))\n"
        code += "            local_image.SetGeoTransform(java_image.GetGeoTransform())\n"
        code += "            java_nodatas = gdalutils.getnodatas(java_image)\n"
        code += "            print('saving image to ' + local_name)\n"
        code += "            print('downloading data... (' + str(gdalutils.getRasterBytes(java_image, 1) * local_image.RasterCount / 1024) + ' kb uncompressed)')\n"
        code += "            for i in xrange(1, local_image.RasterCount + 1):\n"
        code += "                start = time.time()\n"
        code += "                raw_data = gdalutils.getRasterDataAsCompressedBase64(java_image, i, 0, 0, width, height)\n"
        code += "                print('compressed/encoded data ' + str(len(raw_data)))\n"
        code += "                decoded_data = base64.b64decode(raw_data)\n"
        code += "                print('decoded data ' + str(len(decoded_data)))\n"
        code += "                decompressed_data = zlib.decompress(decoded_data, 16 + zlib.MAX_WBITS)\n"
        code += "                print('decompressed data ' + str(len(decompressed_data)))\n"
        code += "                byte_data = numpy.frombuffer(decompressed_data, dtype='b')\n"
        code += "                print('byte data ' + str(len(byte_data)))\n"
        code += "                image_data = byte_data.view(gdal_array.GDALTypeCodeToNumericTypeCode(datatype))\n"
        code += "                print('gdal-type data ' + str(len(image_data)))\n"
        code += "                image_data = image_data.reshape((-1, width))\n"
        code += "                print('reshaped ' + str(len(image_data)) + ' x ' + str(len(image_data[0])))\n"
        code += "                band = local_image.GetRasterBand(i)\n"
        code += "                print('writing band ' + str(i))\n"
        code += "                band.WriteArray(image_data)\n"
        code += "                end = time.time()\n"
        code += "                print('elapsed time: ' + str(end - start) + ' sec.')\n"
        code += "                band.SetNoDataValue(java_nodatas[i - 1])\n"
        code += "            local_image.FlushCache()\n"
        code += "            print('flushed cache')\n"

        return code

    def _generate_imports(self, mapop, is_export=False):
        code = ""
        # imports
        code += "    import copy\n"
        code += "    from numbers import Number\n"
        if is_export:
            code += "    import base64\n"
            code += "    import numpy\n"
            code += "    from osgeo import gdal, gdal_array\n"
            code += "    import time\n"
            code += "    import zlib\n"

        code += "    from py4j.java_gateway import JavaClass\n"
        # Get the Java class
        code += "    cls = JavaClass('" + mapop + "', gateway_client=self.gateway._gateway_client)\n"
        return code

    def _generate_calls(self, methods, is_export=False):

        # Check the input params and call the appropriate create() method
        firstmethod = True
        varargcode = ""
        code = ""

        if is_export:
            code += "    local_name = name\n"
            code += "    name = 'In-Memory'\n"

        for method in methods:
            iftest = ""
            call = []

            firstparam = True
            for param in method:
                var_name = param[0]
                type_name = param[1]
                call_name = param[2]

                if param[4]:
                    call_name, it, et = self.method_name(type_name, "arg")

                    if len(varargcode) == 0:
                        varargcode += "    array = self.gateway.new_array(self.gateway.jvm." + type_name + ", len(args))\n"
                        varargcode += "    cnt = 0\n"
                        call_name = "array"

                    varargcode += "    for arg in args:\n"
                    varargcode += "        if not(" + it + "):\n"
                    varargcode += "            raise Exception('input types differ (TODO: expand this message!)')\n"
                    varargcode += "    for arg in args:\n"
                    varargcode += "        array[cnt] = arg.mapop\n"
                    varargcode += "        cnt += 1\n"
                else:
                    if firstparam:
                        firstparam = False
                        if firstmethod:
                            firstmethod = False
                            iftest += "if"
                        else:
                            iftest += "elif"
                    else:
                        iftest += " and"

                    if call_name == "self":
                        var_name = call_name

                    call_name, it, et = self.method_name(type_name, var_name)
                    iftest += it

                call += [call_name]

            if len(iftest) > 0:
                iftest += ":\n"
                code += "    " + iftest

            code += "        op = cls.create(" + ", ".join(call) + ')\n'

        code += "    else:\n"
        code += "        raise Exception('input types differ (TODO: expand this message!)')\n"
        # code += "    import inspect\n"
        # code += "    method = inspect.stack()[0][3]\n"
        # code += "    print(method)\n"

        if len(varargcode) > 0:
            code = varargcode + code

        return code

    def method_name(self, type_name, var_name):
        if type_name == "String":
            iftest = " type(" + var_name + ") is str"
            call_name = "str(" + var_name + ")"
            excepttest = "not" + iftest
        elif type_name == "Double" or type_name == "Float":
            iftest = " isinstance(" + var_name + ", (int, long, float))"
            call_name = "float(" + var_name + ")"
            excepttest = "not" + iftest
        elif type_name == "Long":
            iftest = " isinstance(" + var_name + ", (int, long, float))"
            call_name = "long(" + var_name + ")"
            excepttest = "not" + iftest
        elif type_name == "Int" or type_name == "Short" or type_name == "Char":
            iftest = " isinstance(" + var_name + ", (int, long, float))"
            call_name = "int(" + var_name + ")"
            excepttest = "not" + iftest
        elif type_name == "Boolean":
            iftest = " isinstance(" + var_name + ", (int, long, float, str))"
            call_name = "True if " + var_name + " else False"
            excepttest = "not" + iftest
        elif type_name.endswith("MapOp"):
            base_var = var_name
            var_name += ".mapop"
            iftest = " hasattr(" + base_var + ", 'mapop') and self.is_instance_of(" + var_name + ", '" + type_name + "')"
            call_name = var_name
            excepttest = " hasattr(" + base_var + ", 'mapop') and not self.is_instance_of(" + var_name + ", '" + type_name + "')"
        else:
            iftest = " self.is_instance_of(" + var_name + ", '" + type_name + "')"
            call_name = var_name
            excepttest = "not" + iftest

        return call_name, iftest, excepttest

    def _generate_methods(self, instance, signatures):
        methods = []
        for sig in signatures:
            found = False
            method = []
            for variable in sig.split(","):
                names = re.split("[:=]+", variable)
                new_name = names[0]
                new_type = names[1]

                # var args?
                varargs = False
                if new_type.endswith("*"):
                    new_type = new_type[:-1]
                    new_name = "args"
                    varargs = True

                if len(names) == 3:
                    if names[2].lower() == "true":
                        new_value = "True"
                    elif names[2].lower() == "false":
                        new_value = "False"
                    elif names[2].lower() == "infinity":
                        new_value = "float('inf')"
                    elif names[2].lower() == "-infinity":
                        new_value = "float('-inf')"
                    elif names[2].lower() == "null":
                        new_value = "None"
                    else:
                        new_value = names[2]
                else:
                    new_value = None

                if ((not found) and (new_type.endswith("MapOp") or
                                     (instance is "RasterMapOp"
                                      and new_type.endswith("RasterMapOp")) or
                                     (instance is "VectorMapOp"
                                      and new_type.endswith("VectorMapOp")))):
                    found = True
                    new_call = "self"
                else:
                    new_call = new_name

                tup = (new_name, new_type, new_call, new_value, varargs)
                method.append(tup)

            methods.append(method)
        return methods

    def _in_signature(self, param, signature):
        for s in signature:
            if s[0] == param[0]:
                if s[1] == param[1]:
                    if s[3] == param[3]:
                        return True
                    else:
                        raise Exception("only default values differ: " +
                                        str(s) + ": " + str(param))
                else:
                    raise Exception("type parameters differ: " + str(s) +
                                    ": " + str(param))
        return False

    def _generate_signature(self, methods):
        signature = []
        dual = len(methods) > 1
        for method in methods:
            for param in method:
                if not param[2] == "self" and not self._in_signature(
                        param, signature):
                    signature.append(param)
                    if param[4]:
                        # var args must be the last parameter
                        break

        sig = ["self"]
        for s in signature:
            if s[4]:
                sig += ["*args"]
            else:
                if s[3] is not None:
                    sig += [s[0] + "=" + s[3]]
                elif dual:
                    sig += [s[0] + "=None"]
                else:
                    sig += [s[0]]

        return ",".join(sig)

    # @staticmethod
    # def _generate_code(mapop, name, signatures, instance):
    #
    #     signature, call, types, values = MrGeo._generate_params(instance, signatures)
    #
    #     sig = ""
    #     for s, d in zip(signature, values):
    #         if len(sig) > 0:
    #             sig += ", "
    #         sig += s
    #         if d is not None:
    #             sig += "=" + str(d)
    #
    #     code = ""
    #     code += "def " + name + "(" + sig + "):" + "\n"
    #     code += "    from py4j.java_gateway import JavaClass\n"
    #     code += "    #from rastermapop import RasterMapOp\n"
    #     code += "    import copy\n"
    #     code += "    print('" + name + "')\n"
    #     code += "    cls = JavaClass('" + mapop + "', gateway_client=self.gateway._gateway_client)\n"
    #     code += "    newop = cls.apply(" + ", ".join(call) + ')\n'
    #     code += "    if (newop.setup(self.job, self.context.getConf()) and\n"
    #     code += "        newop.execute(self.context) and\n"
    #     code += "        newop.teardown(self.job, self.context.getConf())):\n"
    #     code += "        new_raster = copy.copy(self)\n"
    #     code += "        new_raster.mapop = newop\n"
    #     code += "        return new_raster\n"
    #     code += "    return None\n"
    #
    #     # print(code)
    #
    #     return code

    def is_instance_of(self, java_object, java_class):
        if isinstance(java_class, basestring):
            name = java_class
        elif isinstance(java_class, JavaClass):
            name = java_class._fqn
        elif isinstance(java_class, JavaObject):
            name = java_class.getClass()
        else:
            raise Exception(
                "java_class must be a string, a JavaClass, or a JavaObject")

        jvm = self.gateway.jvm
        name = jvm.Class.forName(name).getCanonicalName()

        if isinstance(java_object, JavaClass):
            cls = jvm.Class.forName(java_object._fqn)
        elif isinstance(java_class, JavaObject):
            cls = java_object.getClass()
        else:
            raise Exception("java_object must be a JavaClass, or a JavaObject")

        if cls.getCanonicalName() == name:
            return True

        return self._is_instance_of(cls.getSuperclass(), name)

    def _is_instance_of(self, clazz, name):
        if clazz:
            if clazz.getCanonicalName() == name:
                return True

            return self._is_instance_of(clazz.getSuperclass(), name)

        return False

    def usedebug(self):
        self.job.useDebug()

    def useyarn(self):
        self.job.useYarn()

    def start(self):
        jvm = self.gateway.jvm

        self.job.addMrGeoProperties()
        dpf_properties = jvm.DataProviderFactory.getConfigurationFromProviders(
        )

        for prop in dpf_properties:
            self.job.setSetting(prop, dpf_properties[prop])

        if self.job.isDebug():
            master = "local"
        elif self.job.isSpark():
            # TODO:  get the master for spark
            master = ""
        elif self.job.isYarn():
            master = "yarn-client"
        else:
            cpus = (multiprocessing.cpu_count() / 4) * 3
            if cpus < 2:
                master = "local"
            else:
                master = "local[" + str(cpus) + "]"

        set_field(
            self.job, "jars",
            jvm.StringUtils.concatUnique(
                jvm.DependencyLoader.getAndCopyDependencies(
                    "org.mrgeo.mapalgebra.MapAlgebra", None),
                jvm.DependencyLoader.getAndCopyDependencies(
                    jvm.MapOpFactory.getMapOpClassNames(), None)))

        conf = jvm.MrGeoDriver.prepareJob(self.job)

        # need to override the yarn mode to "yarn-client" for python
        if self.job.isYarn():
            conf.set("spark.master", "yarn-client")

            mem = jvm.SparkUtils.humantokb(conf.get("spark.executor.memory"))
            workers = int(
                conf.get("spark.executor.instances")) + 1  # one for the driver

            conf.set("spark.executor.memory",
                     jvm.SparkUtils.kbtohuman(long(mem / workers), "m"))

        # for a in conf.getAll():
        #     print(a._1(), a._2())

        # jsc = jvm.JavaSparkContext(master, appName, sparkHome, jars)
        jsc = jvm.JavaSparkContext(conf)
        self.sparkContext = jsc.sc()
        self.sparkPyContext = SparkContext(master=master,
                                           appName=self.job.name(),
                                           jsc=jsc,
                                           gateway=self.gateway)

        # print("started")

    def stop(self):
        if self.sparkContext:
            self.sparkContext.stop()
            self.sparkContext = None

        if self.sparkPyContext:
            self.sparkPyContext.stop()
            self.sparkPyContext = None

    def list_images(self):
        jvm = self.gateway.jvm

        pstr = self.job.getSetting(constants.provider_properties, "")
        pp = jvm.ProviderProperties.fromDelimitedString(pstr)

        rawimages = jvm.DataProviderFactory.listImages(pp)

        images = []
        for image in rawimages:
            images.append(str(image))

        return images

    def load_image(self, name):
        jvm = self.gateway.jvm

        pstr = self.job.getSetting(constants.provider_properties, "")
        pp = jvm.ProviderProperties.fromDelimitedString(pstr)

        dp = jvm.DataProviderFactory.getMrsImageDataProvider(
            name, jvm.DataProviderFactory.AccessMode.READ, pp)

        mapop = jvm.MrsPyramidMapOp.apply(dp)
        mapop.context(self.sparkContext)

        # print("loaded " + name)

        return RasterMapOp(mapop=mapop,
                           gateway=self.gateway,
                           context=self.sparkContext,
                           job=self.job)
示例#36
0
class PyGraphXTestCase(unittest.TestCase):
    """
    Test vertices, edges, partitionBy, numEdges, numVertices,
    inDegrees, outDegrees, degrees, triplets, mapVertices,
    mapEdges, mapTriplets, reverse, subgraph, groupEdges,
    joinVertices, outerJoinVertices, collectNeighborIds,
    collectNeighbors, mapReduceTriplets, triangleCount for Graph
    """

    def setUp(self):
        class_name = self.__class__.__name__
        conf = SparkConf().set("spark.default.parallelism", 1)
        self.sc = SparkContext(appName=class_name, conf=conf)
        self.sc.setCheckpointDir("/tmp")

    def tearDown(self):
        self.sc.stop()

    def collect(self):
        vertexData = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertices = VertexRDD(vertexData)
        results = vertices.collect()
        self.assertEqual(results, [(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])

    def take(self):
        vertexData = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertices = VertexRDD(vertexData)
        results = vertices.collect()
        self.assertEqual(results, [(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])

    def count(self):
        vertexData = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertices = VertexRDD(vertexData)
        results = vertices.collect()
        self.assertEqual(results, 2)

    def mapValues(self):
        vertexData = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertices = VertexRDD(vertexData)
        results = vertices.collect()
        self.assertEqual(results, 2)

    def diff(self):
        vertexData0 = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertexData1 = self.sc.parallelize([(1, ("rxin", "student")), (2, ("jgonzal", "postdoc"))])
        vertices0 = VertexRDD(vertexData0)
        vertices1 = VertexRDD(vertexData1)
        results = vertices0.diff(vertices1)
        self.assertEqual(results, 2)

    def innerJoin(self):
        vertexData0 = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertexData1 = self.sc.parallelize([(1, ("rxin", "student")), (2, ("jgonzal", "postdoc"))])
        vertices0 = VertexRDD(vertexData0)
        vertices1 = VertexRDD(vertexData1)
        results = vertices0.diff(vertices1)
        self.assertEqual(results, 2)

    def leftJoin(self):
        vertexData0 = self.sc.parallelize([(3, ("rxin", "student")), (7, ("jgonzal", "postdoc"))])
        vertexData1 = self.sc.parallelize([(1, ("rxin", "student")), (2, ("jgonzal", "postdoc"))])
        vertices0 = VertexRDD(vertexData0)
        vertices1 = VertexRDD(vertexData1)
        results = vertices0.diff(vertices1)
        self.assertEqual(results, 2)
示例#37
0
class StreamingContext(object):
    """
    Main entry point for Spark Streaming functionality. A StreamingContext represents the
    connection to a Spark cluster, and can be used to create L{DStream}s and
    broadcast variables on that cluster.
    """
    def __init__(self,
                 master=None,
                 appName=None,
                 sparkHome=None,
                 pyFiles=None,
                 environment=None,
                 batchSize=1024,
                 serializer=PickleSerializer(),
                 conf=None,
                 gateway=None,
                 sparkContext=None,
                 duration=None):
        """
        Create a new StreamingContext. At least the master and app name and duration
        should be set, either through the named parameters here or through C{conf}.

        @param master: Cluster URL to connect to
               (e.g. mesos://host:port, spark://host:port, local[4]).
        @param appName: A name for your job, to display on the cluster web UI.
        @param sparkHome: Location where Spark is installed on cluster nodes.
        @param pyFiles: Collection of .zip or .py files to send to the cluster
               and add to PYTHONPATH.  These can be paths on the local file
               system or HDFS, HTTP, HTTPS, or FTP URLs.
        @param environment: A dictionary of environment variables to set on
               worker nodes.
        @param batchSize: The number of Python objects represented as a single
               Java object.  Set 1 to disable batching or -1 to use an
               unlimited batch size.
        @param serializer: The serializer for RDDs.
        @param conf: A L{SparkConf} object setting Spark properties.
        @param gateway: Use an existing gateway and JVM, otherwise a new JVM
               will be instatiated.
        @param sparkContext: L{SparkContext} object.
        @param duration: A L{Duration} object for SparkStreaming.

        """

        if not isinstance(duration, Duration):
            raise TypeError(
                "Input should be pyspark.streaming.duration.Duration object")

        if sparkContext is None:
            # Create the Python Sparkcontext
            self._sc = SparkContext(master=master,
                                    appName=appName,
                                    sparkHome=sparkHome,
                                    pyFiles=pyFiles,
                                    environment=environment,
                                    batchSize=batchSize,
                                    serializer=serializer,
                                    conf=conf,
                                    gateway=gateway)
        else:
            self._sc = sparkContext

        # Start py4j callback server.
        # Callback sever is need only by SparkStreming; therefore the callback sever
        # is started in StreamingContext.
        SparkContext._gateway.restart_callback_server()
        self._set_clean_up_handler()
        self._jvm = self._sc._jvm
        self._jssc = self._initialize_context(self._sc._jsc,
                                              duration._jduration)

    # Initialize StremaingContext in function to allow subclass specific initialization
    def _initialize_context(self, jspark_context, jduration):
        return self._jvm.JavaStreamingContext(jspark_context, jduration)

    def _set_clean_up_handler(self):
        """ set clean up hander using atexit """
        def clean_up_handler():
            SparkContext._gateway.shutdown()

        atexit.register(clean_up_handler)
        # atext is not called when the program is killed by a signal not handled by
        # Python.
        for sig in (SIGINT, SIGTERM):
            signal(sig, clean_up_handler)

    @property
    def sparkContext(self):
        """
        Return SparkContext which is associated with this StreamingContext.
        """
        return self._sc

    def start(self):
        """
        Start the execution of the streams.
        """
        self._jssc.start()

    def awaitTermination(self, timeout=None):
        """
        Wait for the execution to stop.
        @param timeout: time to wait in milliseconds
        """
        if timeout is None:
            self._jssc.awaitTermination()
        else:
            self._jssc.awaitTermination(timeout)

    def remember(self, duration):
        """
        Set each DStreams in this context to remember RDDs it generated in the last given duration.
        DStreams remember RDDs only for a limited duration of time and releases them for garbage
        collection. This method allows the developer to specify how to long to remember the RDDs (
        if the developer wishes to query old data outside the DStream computation).
        @param duration pyspark.streaming.duration.Duration object.
               Minimum duration that each DStream should remember its RDDs
        """
        if not isinstance(duration, Duration):
            raise TypeError(
                "Input should be pyspark.streaming.duration.Duration object")

        self._jssc.remember(duration._jduration)

    # TODO: add storageLevel
    def socketTextStream(self, hostname, port):
        """
        Create an input from TCP source hostname:port. Data is received using
        a TCP socket and receive byte is interpreted as UTF8 encoded '\n' delimited
        lines.
        """
        return DStream(self._jssc.socketTextStream(hostname, port), self,
                       UTF8Deserializer())

    def textFileStream(self, directory):
        """
        Create an input stream that monitors a Hadoop-compatible file system
        for new files and reads them as text files. Files must be wrriten to the
        monitored directory by "moving" them from another location within the same
        file system. File names starting with . are ignored.
        """
        return DStream(self._jssc.textFileStream(directory), self,
                       UTF8Deserializer())

    def stop(self, stopSparkContext=True, stopGraceFully=False):
        """
        Stop the execution of the streams immediately (does not wait for all received data
        to be processed).
        """
        self._jssc.stop(stopSparkContext, stopGraceFully)
        if stopSparkContext:
            self._sc.stop()

        # Shutdown only callback server and all py3j client is shutdowned
        # clean up handler
        SparkContext._gateway._shutdown_callback_server()

    def _testInputStream(self, test_inputs, numSlices=None):
        """
        This function is only for unittest.
        It requires a list as input, and returns the i_th element at the i_th batch
        under manual clock.
        """
        test_rdds = list()
        test_rdd_deserializers = list()
        for test_input in test_inputs:
            test_rdd = self._sc.parallelize(test_input, numSlices)
            test_rdds.append(test_rdd._jrdd)
            test_rdd_deserializers.append(test_rdd._jrdd_deserializer)
        # All deserializers have to be the same.
        # TODO: add deserializer validation
        jtest_rdds = ListConverter().convert(
            test_rdds, SparkContext._gateway._gateway_client)
        jinput_stream = self._jvm.PythonTestInputStream(
            self._jssc, jtest_rdds).asJavaDStream()

        return DStream(jinput_stream, self, test_rdd_deserializers[0])
示例#38
0
文件: context.py 项目: giworld/spark
class StreamingContext(object):
    """
    Main entry point for Spark Streaming functionality. A StreamingContext represents the
    connection to a Spark cluster, and can be used to create L{DStream}s and
    broadcast variables on that cluster.
    """

    def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
                 environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None,
                 gateway=None, sparkContext=None, duration=None):
        """
        Create a new StreamingContext. At least the master and app name and duration
        should be set, either through the named parameters here or through C{conf}.

        @param master: Cluster URL to connect to
               (e.g. mesos://host:port, spark://host:port, local[4]).
        @param appName: A name for your job, to display on the cluster web UI.
        @param sparkHome: Location where Spark is installed on cluster nodes.
        @param pyFiles: Collection of .zip or .py files to send to the cluster
               and add to PYTHONPATH.  These can be paths on the local file
               system or HDFS, HTTP, HTTPS, or FTP URLs.
        @param environment: A dictionary of environment variables to set on
               worker nodes.
        @param batchSize: The number of Python objects represented as a single
               Java object.  Set 1 to disable batching or -1 to use an
               unlimited batch size.
        @param serializer: The serializer for RDDs.
        @param conf: A L{SparkConf} object setting Spark properties.
        @param gateway: Use an existing gateway and JVM, otherwise a new JVM
               will be instatiated.
        @param sparkContext: L{SparkContext} object.
        @param duration: A L{Duration} object for SparkStreaming.

        """

        if not isinstance(duration, Duration):
            raise TypeError("Input should be pyspark.streaming.duration.Duration object")

        if sparkContext is None:
            # Create the Python Sparkcontext
            self._sc = SparkContext(master=master, appName=appName, sparkHome=sparkHome,
                                    pyFiles=pyFiles, environment=environment, batchSize=batchSize,
                                    serializer=serializer, conf=conf, gateway=gateway)
        else:
            self._sc = sparkContext

        # Start py4j callback server.
        # Callback sever is need only by SparkStreming; therefore the callback sever
        # is started in StreamingContext.
        SparkContext._gateway.restart_callback_server()
        self._set_clean_up_handler()
        self._jvm = self._sc._jvm
        self._jssc = self._initialize_context(self._sc._jsc, duration._jduration)

    # Initialize StremaingContext in function to allow subclass specific initialization
    def _initialize_context(self, jspark_context, jduration):
        return self._jvm.JavaStreamingContext(jspark_context, jduration)

    def _set_clean_up_handler(self):
        """ set clean up hander using atexit """

        def clean_up_handler():
            SparkContext._gateway.shutdown()

        atexit.register(clean_up_handler)
        # atext is not called when the program is killed by a signal not handled by
        # Python.
        for sig in (SIGINT, SIGTERM):
            signal(sig, clean_up_handler)

    @property
    def sparkContext(self):
        """
        Return SparkContext which is associated with this StreamingContext.
        """
        return self._sc

    def start(self):
        """
        Start the execution of the streams.
        """
        self._jssc.start()

    def awaitTermination(self, timeout=None):
        """
        Wait for the execution to stop.
        @param timeout: time to wait in milliseconds
        """
        if timeout is None:
            self._jssc.awaitTermination()
        else:
            self._jssc.awaitTermination(timeout)

    def remember(self, duration):
        """
        Set each DStreams in this context to remember RDDs it generated in the last given duration.
        DStreams remember RDDs only for a limited duration of time and releases them for garbage
        collection. This method allows the developer to specify how to long to remember the RDDs (
        if the developer wishes to query old data outside the DStream computation).
        @param duration pyspark.streaming.duration.Duration object.
               Minimum duration that each DStream should remember its RDDs
        """
        if not isinstance(duration, Duration):
            raise TypeError("Input should be pyspark.streaming.duration.Duration object")

        self._jssc.remember(duration._jduration)

    # TODO: add storageLevel
    def socketTextStream(self, hostname, port):
        """
        Create an input from TCP source hostname:port. Data is received using
        a TCP socket and receive byte is interpreted as UTF8 encoded '\n' delimited
        lines.
        """
        return DStream(self._jssc.socketTextStream(hostname, port), self, UTF8Deserializer())

    def textFileStream(self, directory):
        """
        Create an input stream that monitors a Hadoop-compatible file system
        for new files and reads them as text files. Files must be wrriten to the
        monitored directory by "moving" them from another location within the same
        file system. File names starting with . are ignored.
        """
        return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer())

    def stop(self, stopSparkContext=True, stopGraceFully=False):
        """
        Stop the execution of the streams immediately (does not wait for all received data
        to be processed).
        """
        self._jssc.stop(stopSparkContext, stopGraceFully)
        if stopSparkContext:
            self._sc.stop()

        # Shutdown only callback server and all py3j client is shutdowned
        # clean up handler
        SparkContext._gateway._shutdown_callback_server()
        
    def _testInputStream(self, test_inputs, numSlices=None):
        """
        This function is only for unittest.
        It requires a list as input, and returns the i_th element at the i_th batch
        under manual clock.
        """
        test_rdds = list()
        test_rdd_deserializers = list()
        for test_input in test_inputs:
            test_rdd = self._sc.parallelize(test_input, numSlices)
            test_rdds.append(test_rdd._jrdd)
            test_rdd_deserializers.append(test_rdd._jrdd_deserializer)
        # All deserializers have to be the same.
        # TODO: add deserializer validation
        jtest_rdds = ListConverter().convert(test_rdds, SparkContext._gateway._gateway_client)
        jinput_stream = self._jvm.PythonTestInputStream(self._jssc, jtest_rdds).asJavaDStream()

        return DStream(jinput_stream, self, test_rdd_deserializers[0])
示例#39
0
conf = (SparkConf()
         .setMaster(os.environ["SPARK_MASTER"]))

# set the UI port
conf.set("spark.ui.port", ui_get_available_port())

# configure docker containers as executors
conf.setSparkHome(os.environ.get("SPARK_HOME"))
conf.set("spark.mesos.executor.docker.image", "lab41/spark-mesos-dockerworker-ipython")
conf.set("spark.mesos.executor.home", "/usr/local/spark-1.4.1-bin-hadoop2.4")
conf.set("spark.executorEnv.MESOS_NATIVE_LIBRARY", "/usr/local/lib/libmesos.so")
conf.set("spark.network.timeout", "100")

# establish config-based context
sc = SparkContext(appName="DockerIPythonShell", pyFiles=add_files, conf=conf)
atexit.register(lambda: sc.stop())

try:
    # Try to access HiveConf, it will raise exception if Hive is not added
    sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
    sqlCtx = sqlContext = HiveContext(sc)
except py4j.protocol.Py4JError:
    sqlCtx = sqlContext = SQLContext(sc)

print("""Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /__ / .__/\_,_/_/ /_/\_\   version %s
      /_/
""" % sc.version)
示例#40
0
    def predictQuantiles(self, features):
        """
        Predicted Quantiles
        """
        return self._call_java("predictQuantiles", features)

    def predict(self, features):
        """
        Predicted value
        """
        return self._call_java("predict", features)


if __name__ == "__main__":
    import doctest
    import pyspark.ml.regression
    from pyspark.context import SparkContext
    from pyspark.sql import SQLContext
    globs = pyspark.ml.regression.__dict__.copy()
    # The small batch size here ensures that we see multiple batches,
    # even in these small test examples:
    sc = SparkContext("local[2]", "ml.regression tests")
    sqlContext = SQLContext(sc)
    globs['sc'] = sc
    globs['sqlContext'] = sqlContext
    (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
    sc.stop()
    if failure_count:
        exit(-1)
示例#41
0
    """
    def predictQuantiles(self, features):
        """
        Predicted Quantiles
        """
        return self._call_java("predictQuantiles", features)

    def predict(self, features):
        """
        Predicted value
        """
        return self._call_java("predict", features)


if __name__ == "__main__":
    import doctest
    from pyspark.context import SparkContext
    from pyspark.sql import SQLContext
    globs = globals().copy()
    # The small batch size here ensures that we see multiple batches,
    # even in these small test examples:
    sc = SparkContext("local[2]", "ml.regression tests")
    sqlContext = SQLContext(sc)
    globs['sc'] = sc
    globs['sqlContext'] = sqlContext
    (failure_count, test_count) = doctest.testmod(globs=globs,
                                                  optionflags=doctest.ELLIPSIS)
    sc.stop()
    if failure_count:
        exit(-1)
示例#42
0
class TestBed(unittest.TestCase):

    def setUp(self):
        self.sc = SparkContext('local[1]')

    def tearDown(self):
        self.sc.stop()

    def test_overlaps_any(self):
        self.assertEqual(overlaps_any(('chr1', 10, 20), ('chr1', 10, 20)), True)
        self.assertEqual(overlaps_any(('chr1', 10, 20), ('chr2', 10, 20)), False)

        a = [('chr1', 10, 20), ('chr2', 10, 20), ('chr2', 100, 200)]
        b = [('chr11', 10, 20), ('chr2', 50, 150), ('chr2', 150, 160), ('chr2', 155, 265)]

        res_ab = [overlaps_any(x, b) for x in a]
        res_ba = [overlaps_any(x, a) for x in b]

        self.assertEqual(res_ab, [False, False, True])
        self.assertEqual(res_ba, [False, True, True, True])

    def test_leftjoin_overlap(self):
        a = ('chr1', 10, 20)
        b = ('chr1', 15, 25)
        self.assertEqual(leftjoin_overlap(a, [b]), a + b)
        self.assertEqual(leftjoin_overlap(a, [b, b]), [a + b, a + b])
        self.assertEqual(leftjoin_overlap(a, [a, b]), [a + a, a + b])

        a = ('chr2', 100, 200)
        b = [('chr11', 10, 20), ('chr2', 50, 150), ('chr2', 150, 160), ('chr2', 155, 265)]
        self.assertEqual(leftjoin_overlap(a, b), [a + b[1], a + b[2], a + b[3]])

    def test_leftjoin_overlap_window(self):
        rdd1 = self.sc.parallelize([('chr1', 10, 20), ('chr2', 10, 20), ('chr2', 100, 200)])
        rdd2 = self.sc.parallelize([('chr11', 10, 20), ('chr2', 50, 150), ('chr2', 150, 160), ('chr2', 155, 265)])

        result = leftjoin_overlap_window(rdd1, rdd2, bin_func=coords2bin, bin_size=10000)

        output = result.collect()

        # it is not possible to test the exact results for two reasons
        # - the order may vary
        # - the uid assigned to each entry is aleatory
        self.assertEqual(len(output), 5)
        self.assertEqual(all([type(x) is tuple for x in output]), True)  # all elements of output are tuples
        self.assertEqual(sorted([len(x) for x in output]), [4, 4, 8, 8, 8])  # sorted length sizes are deterministic

    def test_disjoint(self):
        self.assertEqual(disjoint((10, 20, 30), (13, 40, 41)),
                         [[10, 11, 13, 14, 20, 21, 30, 31, 40, 41],
                          [10, 12, 13, 19, 20, 29, 30, 39, 40, 41]])

        # disjoint(10, 15) raises
        # 'TypeError: 'int' object is not iterable'
        self.assertEqual(disjoint((10, ), (15, )), [[10, 11, 15], [10, 14, 15]])
        self.assertEqual(disjoint(10, 15), [[10, 11, 15], [10, 14, 15]])
        self.assertEqual(disjoint((), ()), [[], []])
        self.assertEqual(disjoint([10, ], [15, ]), [[10, 11, 15], [10, 14, 15]])
        self.assertEqual(disjoint([10], [15]), [[10, 11, 15], [10, 14, 15]])
        self.assertEqual(disjoint([], []), [[], []])
        self.assertRaises(RuntimeError, lambda: disjoint((10, ), (15, 20)))

    def test_count_overlaps(self):
        self.assertEqual(count_overlaps((10, ), (20, ), (50, ), (60, )), [0])
        self.assertEqual(count_overlaps((10, ), (20, ), (20, ), (60, )), [1])
        self.assertEqual(count_overlaps((60, ), (80, ), (20, ), (60, )), [1])
        self.assertEqual(count_overlaps((10, ), (80, ), (20, ), (60, )), [1])
        self.assertEqual(count_overlaps((30, ), (40, ), (20, ), (60, )), [1])
        self.assertEqual(count_overlaps((30, 35), (40, 36), (20, ), (60, )), [1, 1])
        self.assertEqual(count_overlaps((30, 135), (40, 136), (20, ), (60, )), [1, 0])
        self.assertEqual(count_overlaps((30, 135), (40, 136), (20, 20), (60, 60)), [2, 0])
        self.assertRaises(Exception, lambda: count_overlaps((30, 135), (40, 136), (20, ), (60, 60)))

    def test_overlaps_any2(self):
        self.assertEqual(overlaps_any2((), (), (50, ), (60, )), [])
        self.assertEqual(overlaps_any2((10, ), (20, ), (50, ), (60, )), [False])
        self.assertEqual(overlaps_any2((10, ), (20, ), (20, ), (60, )), [True])
        self.assertEqual(overlaps_any2((60, ), (80, ), (20, ), (60, )), [True])
        self.assertEqual(overlaps_any2((10, ), (80, ), (20, ), (60, )), [True])
        self.assertEqual(overlaps_any2((30, ), (40, ), (20, ), (60, )), [True])
        self.assertEqual(overlaps_any2((30, 35), (40, 36), (20, ), (60, )), [True, True])
        self.assertEqual(overlaps_any2((30, 135), (40, 136), (20, ), (60, )), [True, False])
        self.assertEqual(overlaps_any2((30, 135), (40, 136), (20, 20), (60, 60)), [True, False])
        self.assertRaises(Exception, lambda: overlaps_any2((30, 135), (40, 136), (20, ), (60, 60)))
示例#43
0
class SparkClient:
    def __init__(self,
                 spark_home,
                 spark_master="local",
                 exec_memory="8g",
                 app_name="SparkClient"):
        """
        Initialize sparkcontext, sqlcontext
        :param spark_master: target spark master
        :param exec_memory: size of memory per executor
        """
        self._spark_master = spark_master
        self._exec_memory = exec_memory
        self._app_name = app_name
        self._spark_home = spark_home
        # Path for spark source folder
        os.environ['SPARK_HOME'] = self._spark_home
        self._spark_url = spark_master
        if spark_master != "local":
            os.environ['SPARK_MASTER_IP'] = spark_master
            self._spark_url = "spark://" + self._spark_master + ":7077"
        # Append pyspark  to Python Path
        sys.path.append(self._spark_home)
        # define the spark configuration
        conf = (SparkConf().setMaster(
            self._spark_url).setAppName(self._app_name).set(
                "spark.executor.memory", self._exec_memory).set(
                    "spark.core.connection.ack.wait.timeout",
                    "600").set("spark.akka.frameSize", "512").set(
                        "spark.cassandra.output.batch.size.bytes", "131072"))
        # create spark context
        self._spark_ctx = None
        if SparkContext._active_spark_context is None:
            self._spark_ctx = SparkContext(conf=conf)
        # create spark-on-hive context
        self._sql = SQLContext(self._spark_ctx)

    def close(self):
        """"
        Close the spark context
        """
        self._spark_ctx.stop()

    @property
    def sc(self):
        return self._spark_ctx

    def save_nda_j1_deces_from_df(self, df, file_dir, vois):
        """
        Save dataframe as Parquet file for nda_j1_deces table
        :param df: source dataframe
        :param file_dir: path to database Parquet files
        :return: None
        """
        # Transform result pandas DataFrame to Spark DataFrame
        pd.options.mode.chained_assignment = None  # to avoid pandas warnings

        #df_src_pd = df['id_ndaj1'].str.extract('(^[0-9]{7})([0-9]{4}-[0-9]{2}-[0-9]{2})', expand=False)
        df_src_pd = df['id_ndaj1'].str.extract(
            '(^[0-9]{10})([0-9]{4}-[0-9]{2}-[0-9]{2})')
        df_src_pd.columns = ['id_nda', 'j1']
        for voi in vois:
            df_src_pd[voi] = df[voi]
        df_src_pd['dt_deces'] = df.dt_deces.apply(str)
        df_src_pd['dt_min'] = df.dt_min.apply(str)
        df_src_pd['dt_max'] = df.dt_max.apply(str)
        df_src_pd['dpt'] = df.dpt
        df_src_pd['cd_sex_tr'] = df.cd_sex_tr
        df_src_pd['stay_len'] = df.stay_len.apply(str)
        spark_df = self._sql.createDataFrame(df_src_pd)
        spark_df.write.parquet(file_dir + "/nda_j1_deces", mode='overwrite')
示例#44
0
    }

    hbase_df = sqlc.read.format('org.apache.hadoop.hbase.spark')\
     .option('hbase.table','mDNA_Biomarker_Unstructured')\
     .option('hbase.columns.mapping',   'KEY_FIELD STRING :key, Txt STRING OCR:Text')\
     .option('hbase.use.hbase.context', False).option('hbase.config.resources', 'file:///etc/hbase/conf/hbase-site.xml').load()

    hbase_df.registerTempTable("mDNA_Biomarker_ILS")

    data_df = sqlc.sql(
        'select * from mDNA_Biomarker_ILS where lower(Txt) LIKE "%brca%"')
    extract_data_rdd = data_df.rdd.map(get_attr_code)
    hbase_data_rdd = extract_data_rdd.flatMap(prepare_data_hbase)
    hbase_data_rdd.saveAsNewAPIHadoopDataset(conf=conf,
                                             keyConverter=keyConv,
                                             valueConverter=valueConv)
    hbase_data_rdd.saveAsNewAPIHadoopDataset(conf=conf,
                                             keyConverter=keyConv,
                                             valueConverter=valueConv)


if __name__ == "__main__":
    # Init Spark Session
    spark = SparkContext(appName="HBaseRead")
    ## Call Startup function to initiate the
    start_HBASE_download(spark)
    ## Process files from NFS directories
    ##start_NFS_download(spark)
    ## Stop the application once it is done
    spark.stop()