def test_multiple_python_java_RDD_conversions(self): # Regression test for SPARK-5361 data = [("1", {"director": "David Lean"}), ("2", {"director": "Andrew Dominik"})] data_rdd = self.sc.parallelize(data) data_java_rdd = data_rdd._to_java_object_rdd() data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) converted_rdd = RDD(data_python_rdd, self.sc) self.assertEqual(2, converted_rdd.count()) # conversion between python and java RDD threw exceptions data_java_rdd = converted_rdd._to_java_object_rdd() data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) converted_rdd = RDD(data_python_rdd, self.sc) self.assertEqual(2, converted_rdd.count())
def predict(self, x): """ Predict the label of one or more examples. :param x: Data point (feature vector), or an RDD of data points (feature vectors). """ SerDe = self._sc._jvm.SerDe ser = PickleSerializer() if isinstance(x, RDD): # Bulk prediction first = x.take(1) if not first: return self._sc.parallelize([]) if not isinstance(first[0], Vector): x = x.map(_convert_to_vector) jPred = self._java_model.predict(x._to_java_object_rdd()).toJavaRDD() jpyrdd = self._sc._jvm.PythonRDD.javaToPython(jPred) return RDD(jpyrdd, self._sc, BatchedSerializer(ser, 1024)) else: # Assume x is a single data point. bytes = bytearray(ser.dumps(_convert_to_vector(x))) vec = self._sc._jvm.SerDe.loads(bytes) return self._java_model.predict(vec)
def SpatialRangeQuery(self, spatialRDD: SpatialRDD, rangeQueryWindow: BaseGeometry, considerBoundaryIntersection: bool, usingIndex: bool): """ :param spatialRDD: :param rangeQueryWindow: :param considerBoundaryIntersection: :param usingIndex: :return: """ jvm = spatialRDD._jvm sc = spatialRDD._sc jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry(jvm, rangeQueryWindow) srdd = jvm.\ RangeQuery.SpatialRangeQuery( spatialRDD._srdd, jvm_geom, considerBoundaryIntersection, usingIndex ) serialized = JvmGeoSparkPythonConverter(jvm).translate_spatial_rdd_to_python(srdd) return RDD(serialized, sc, GeoSparkPickler())
class ProjectedExtentSchemaTest(BaseTestClass): projected_extents = [ {'epsg': 2004, 'extent': {'xmax': 1.0, 'xmin': 0.0, 'ymax': 1.0, 'ymin': 0.0}, 'proj4': None}, {'epsg': 2004, 'extent': {'xmax': 3.0, 'xmin': 1.0, 'ymax': 4.0, 'ymin': 2.0}, 'proj4': None}, {'epsg': 2004, 'extent': {'xmax': 7.0, 'xmin': 5.0, 'ymax': 8.0, 'ymin': 6.0}, 'proj4': None}] sc = BaseTestClass.geopysc.pysc._jsc.sc() ew = BaseTestClass.geopysc.pysc._jvm.geopyspark.geotrellis.tests.schemas.ProjectedExtentWrapper tup = ew.testOut(sc) java_rdd = tup._1() ser = AvroSerializer(tup._2()) rdd = RDD(java_rdd, BaseTestClass.geopysc.pysc, AutoBatchedSerializer(ser)) collected = rdd.collect() @pytest.fixture(autouse=True) def tearDown(self): yield BaseTestClass.geopysc.pysc._gateway.close() def result_checker(self, actual_pe, expected_pe): for actual, expected in zip(actual_pe, expected_pe): self.assertDictEqual(actual, expected) def test_encoded_pextents(self): encoded = self.rdd.map(lambda s: s) actual_encoded = encoded.collect() self.result_checker(actual_encoded, self.projected_extents) def test_decoded_pextents(self): self.result_checker(self.collected, self.projected_extents)
def transformToRDD(cursor, sc, parallelism=1): """ Transform a StellarCursor to a Python RDD object param cursor: StellarCursor param sc: SparkContext param parallelism: Parallelism of RDD """ # Get all data from cursor data = cursor.fetchall() # Set parallelism parallelism = max(1, parallelism) def reader_func(temp_filename): return sc._jvm.PythonRDD.readRDDFromFile(sc._jsc, temp_filename, parallelism) def createRDDServer(): return sc._jvm.PythonParallelizeServer(sc._jsc.sc(), parallelism) batchSize = max(1, min(len(data) // parallelism, 1024)) serializer = BatchedSerializer(sc._unbatched_serializer, batchSize) jrdd = sc._serialize_to_jvm(data, serializer, reader_func, createRDDServer) return RDD(jrdd, sc, serializer)
class SpatialKeySchemaTest(BaseTestClass): expected_keys = {'col': 7, 'row': 3} sc = BaseTestClass.geopysc.pysc._jsc.sc() ew = BaseTestClass.geopysc.pysc._jvm.geopyspark.geotrellis.tests.schemas.SpatialKeyWrapper tup = ew.testOut(sc) java_rdd = tup._1() ser = AvroSerializer(tup._2()) rdd = RDD(java_rdd, BaseTestClass.geopysc.pysc, AutoBatchedSerializer(ser)) collected = rdd.first() @pytest.fixture(autouse=True) def tearDown(self): yield BaseTestClass.geopysc.pysc._gateway.close() def result_checker(self, actual_keys, expected_keys): self.assertDictEqual(actual_keys, expected_keys) def test_encoded_keyss(self): encoded = self.rdd.map(lambda s: s) actual_encoded = encoded.first() self.result_checker(actual_encoded, self.expected_keys) def test_decoded_extents(self): self.assertDictEqual(self.collected, self.expected_keys)
class ByteTileSchemaTest(BaseTestClass): tiles = [ Tile.from_numpy_array(np.int8([0, 0, 1, 1]).reshape(2, 2), -128), Tile.from_numpy_array(np.int8([1, 2, 3, 4]).reshape(2, 2), -128), Tile.from_numpy_array(np.int8([5, 6, 7, 8]).reshape(2, 2), -128) ] sc = BaseTestClass.pysc._jsc.sc() tw = BaseTestClass.pysc._jvm.geopyspark.geotrellis.tests.schemas.ByteArrayTileWrapper java_rdd = tw.testOut(sc) ser = ProtoBufSerializer(tile_decoder, tile_encoder) rdd = RDD(java_rdd, BaseTestClass.pysc, AutoBatchedSerializer(ser)) collected = rdd.collect() def test_encoded_tiles(self): expected_encoded = [to_pb_tile(x) for x in self.collected] for actual, expected in zip(self.tiles, expected_encoded): cells = actual.cells rows, cols = cells.shape self.assertEqual(expected.cols, cols) self.assertEqual(expected.rows, rows) self.assertEqual(expected.cellType.nd, actual.no_data_value) self.assertEqual(expected.cellType.dataType, mapped_data_types[actual.cell_type]) def test_decoded_tiles(self): for actual, expected in zip(self.collected, self.tiles): self.assertTrue((actual.cells == expected.cells).all()) self.assertTrue(actual.cells.dtype == expected.cells.dtype) self.assertEqual(actual.cells.shape, actual.cells.shape)
def create_python_rdd(self, jrdd, serializer): """Creates a Python RDD from a RDD from Scala. Args: jrdd (org.apache.spark.api.java.JavaRDD): The RDD that came from Scala. serializer (:class:`~geopyspark.AvroSerializer` or pyspark.serializers.AutoBatchedSerializer(AvroSerializer)): An instance of ``AvroSerializer`` that is either alone, or wrapped by ``AutoBatchedSerializer``. Returns: ``pyspark.RDD`` """ if isinstance(serializer, AutoBatchedSerializer): return RDD(jrdd, self.pysc, serializer) else: return RDD(jrdd, self.pysc, AutoBatchedSerializer(serializer))
def SpatialJoinQueryFlat(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex: bool, considerBoundaryIntersection: bool) -> RDD: """ Function takes SpatialRDD and other SpatialRDD and based on two parameters - useIndex - considerBoundaryIntersection creates RDD with result of Spatial Join operation. It Returns RDD[GeoData, GeoData] :param spatialRDD: SpatialRDD :param queryRDD: SpatialRDD :param useIndex: bool :param considerBoundaryIntersection: bool :return: RDD >> spatial_join_result = JoinQuery.SpatialJoinQueryFlat( >> spatialRDD, queryRDD, useIndex, considerBoundaryIntersection >> ) >> spatial_join_result.collect() [[GeoData(Polygon, ), GeoData()], [GeoData(), GeoData()], [GeoData(), GeoData()]] """ jvm = spatialRDD._jvm sc = spatialRDD._sc spatial_join = jvm.JoinQuery.SpatialJoinQueryFlat srdd = spatial_join(spatialRDD._srdd, queryRDD._srdd, useIndex, considerBoundaryIntersection) serlialized = jvm.GeoSerializerData.serializeToPython(srdd) return RDD(serlialized, sc, GeoSparkPickler())
class ByteTileSchemaTest(BaseTestClass): tiles = [ {'data': np.array([0, 0, 1, 1]).reshape(2, 2), 'no_data_value': -128}, {'data': np.array([1, 2, 3, 4]).reshape(2, 2), 'no_data_value': -128}, {'data': np.array([5, 6, 7, 8]).reshape(2, 2), 'no_data_value': -128} ] sc = BaseTestClass.geopysc.pysc._jsc.sc() tw = BaseTestClass.geopysc.pysc._jvm.geopyspark.geotrellis.tests.schemas.ByteArrayTileWrapper tup = tw.testOut(sc) java_rdd = tup._1() ser = AvroSerializer(tup._2(), AvroRegistry.tile_decoder, AvroRegistry.tile_encoder) rdd = RDD(java_rdd, BaseTestClass.geopysc.pysc, AutoBatchedSerializer(ser)) collected = rdd.collect() def test_encoded_tiles(self): encoded = self.rdd.map(lambda s: AvroRegistry.tile_encoder(s)) actual_encoded = encoded.collect() expected_encoded = [ {'bands': [{'cols': 2, 'rows': 2, 'cells': bytearray([0, 0, 1, 1]), 'noDataValue': -128}]}, {'bands': [{'cols': 2, 'rows': 2, 'cells': bytearray([1, 2, 3, 4]), 'noDataValue': -128}]}, {'bands': [{'cols': 2, 'rows': 2, 'cells': bytearray([5, 6, 7, 8]), 'noDataValue': -128}]} ] for actual, expected in zip(actual_encoded, expected_encoded): self.assertEqual(actual, expected) def test_decoded_tiles(self): for actual, expected in zip(self.collected, self.tiles): self.assertTrue((actual['data'] == expected['data']).all())
def _java2py(sc, r, encoding="bytes"): if isinstance(r, JavaObject): clsName = r.getClass().getSimpleName() # convert RDD into JavaRDD if clsName != 'JavaRDD' and clsName.endswith("RDD"): r = r.toJavaRDD() clsName = 'JavaRDD' if clsName == 'JavaRDD': jrdd = sc._jvm.SerDe.javaToPython(r) return RDD(jrdd, sc) if clsName == 'DataFrame': return DataFrame(r, get_spark_sql_context(sc)) if clsName == 'Dataset': return DataFrame(r, get_spark_sql_context(sc)) if clsName in _picklable_classes: r = sc._jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.dumps(r) elif isinstance(r, (JavaArray, JavaList, JavaMap)): try: r = sc._jvm.org.apache.spark.bigdl.api.python.BigDLSerDe.dumps( r) except Py4JJavaError: pass # not pickable if isinstance(r, (bytearray, bytes)): r = PickleSerializer().loads(bytes(r), encoding=encoding) return r
def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "bytes") -> Any: if isinstance(r, JavaObject): clsName = r.getClass().getSimpleName() # convert RDD into JavaRDD if clsName != "JavaRDD" and clsName.endswith("RDD"): r = r.toJavaRDD() clsName = "JavaRDD" assert sc._jvm is not None if clsName == "JavaRDD": jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython( r) # type: ignore[attr-defined] return RDD(jrdd, sc) if clsName == "Dataset": return DataFrame(r, SparkSession(sc)._wrapped) if clsName in _picklable_classes: r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps( r) # type: ignore[attr-defined] elif isinstance(r, (JavaArray, JavaList)): try: r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps( r) # type: ignore[attr-defined] except Py4JJavaError: pass # not picklable if isinstance(r, (bytearray, bytes)): r = CPickleSerializer().loads(bytes(r), encoding=encoding) return r
class SpatialKeySchemaTest(BaseTestClass): expected_keys = {'col': 7, 'row': 3} sc = BaseTestClass.pysc._jsc.sc() ew = BaseTestClass.pysc._jvm.geopyspark.geotrellis.tests.schemas.SpatialKeyWrapper java_rdd = ew.testOut(sc) ser = ProtoBufSerializer(spatial_key_decoder, spatial_key_encoder) rdd = RDD(java_rdd, BaseTestClass.pysc, AutoBatchedSerializer(ser)) collected = rdd.first()._asdict() @pytest.fixture(autouse=True) def tearDown(self): yield BaseTestClass.pysc._gateway.close() def result_checker(self, actual_keys, expected_keys): self.assertDictEqual(actual_keys, expected_keys) def test_encoded_keyss(self): actual_encoded = [spatial_key_encoder(x) for x in self.rdd.collect()] proto_spatial_key = keyMessages_pb2.ProtoSpatialKey() proto_spatial_key.col = 7 proto_spatial_key.row = 3 expected_encoded = proto_spatial_key.SerializeToString() self.assertEqual(actual_encoded[0], expected_encoded) def test_decoded_extents(self): self.assertDictEqual(self.collected, self.expected_keys)
def SpatialRangeQuery(self, spatialRDD: SpatialRDD, rangeQueryWindow: Envelope, considerBoundaryIntersection: bool, usingIndex: bool): """ :param spatialRDD: :param rangeQueryWindow: :param considerBoundaryIntersection: :param usingIndex: :return: """ jvm = spatialRDD._jvm sc = spatialRDD._sc jvm_envelope = rangeQueryWindow.create_jvm_instance(jvm) srdd = jvm.\ RangeQuery.SpatialRangeQuery( spatialRDD._srdd, jvm_envelope, considerBoundaryIntersection, usingIndex ) serlialized = jvm.GeoSerializerData.serializeToPython(srdd) return RDD(serlialized, sc, GeoSparkPickler())
def createRDD(sc, logServiceProject, logStoreName, accessKeyId, accessKeySecret, loghubEndpoint, startTime, endTime=None): """ :param sc: RDD object. :param logServiceProject: The name of `LogService` project. :param logStoreName: The name of logStore. :param accessKeyId: Aliyun Access Key ID. :param accessKeySecret: Aliyun Access Key Secret. :param loghubEndpoint: The endpoint of loghub. :param startTime: Set user defined startTime (Unix Timestamp). :param endTime: Set user defined endTime (Unix Timestamp). :return: A RDD object. """ try: helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ .loadClass("org.apache.spark.streaming.aliyun.logservice.LoghubUtilsHelper") helper = helperClass.newInstance() if endTime: jrdd = helper.createRDD(sc._jsc, logServiceProject, logStoreName, accessKeyId, accessKeySecret, loghubEndpoint, startTime, endTime) else: jrdd = helper.createRDD(sc._jsc, logServiceProject, logStoreName, accessKeyId, accessKeySecret, loghubEndpoint, startTime) except Py4JJavaError as e: if 'ClassNotFoundException' in str(e.java_exception): LoghubUtils._printErrorMsg() raise e return RDD(jrdd, sc, UTF8Deserializer())
def DistanceJoinQueryFlat(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex: bool, considerBoundaryIntersection: bool) -> RDD: """ :param spatialRDD: SpatialRDD :param queryRDD: SpatialRDD :param useIndex: bool :param considerBoundaryIntersection: bool >> spatial_rdd = >> query_rdd = >> spatial_join_result = JoinQuery.DistanceJoinQueryFlat(spatial_rdd, query_rdd, True, True) >> spatial_join_result.collect() [GeoData(), GeoData()] :return: """ jvm = spatialRDD._jvm sc = spatialRDD._sc spatial_join = jvm.JoinQuery.DistanceJoinQueryFlat srdd = spatial_join(spatialRDD._srdd, queryRDD._srdd, useIndex, considerBoundaryIntersection) serlialized = jvm.GeoSerializerData.serializeToPython(srdd) return RDD(serlialized, sc, GeoSparkPickler())
def DistanceJoinQueryFlat(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex: bool, considerBoundaryIntersection: bool) -> RDD: """ :param spatialRDD: SpatialRDD :param queryRDD: SpatialRDD :param useIndex: bool :param considerBoundaryIntersection: bool >> spatial_rdd = >> query_rdd = >> spatial_join_result = JoinQuery.DistanceJoinQueryFlat(spatial_rdd, query_rdd, True, True) >> spatial_join_result.collect() [GeoData(), GeoData()] :return: """ jvm = spatialRDD._jvm sc = spatialRDD._sc spatial_join = jvm.JoinQuery.DistanceJoinQueryFlat srdd = spatial_join( spatialRDD._srdd, queryRDD._srdd, useIndex, considerBoundaryIntersection ) serialized = JvmSedonaPythonConverter(jvm).\ translate_spatial_pair_rdd_to_python(srdd) return RDD(serialized, sc, SedonaPickler())
def _java2py(sc, r, encoding="bytes"): if isinstance(r, JavaObject): clsName = r.getClass().getSimpleName() # convert RDD into JavaRDD if clsName != "JavaRDD" and clsName.endswith("RDD"): r = r.toJavaRDD() clsName = "JavaRDD" if clsName == "JavaRDD": jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r) return RDD(jrdd, sc) if clsName == "Dataset": return DataFrame(r, SQLContext.getOrCreate(sc)) if clsName in _picklable_classes: r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r) elif isinstance(r, (JavaArray, JavaList)): try: r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r) except Py4JJavaError: pass # not pickable if isinstance(r, (bytearray, bytes)): r = PickleSerializer().loads(bytes(r), encoding=encoding) return r
class TupleSchemaTest(BaseTestClass): extent = { 'epsg': 2004, 'extent': { 'xmax': 1.0, 'xmin': 0.0, 'ymax': 1.0, 'ymin': 0.0 }, 'proj4': None } arr = np.int8([0, 0, 1, 1]).reshape(2, 2) bands = [arr, arr, arr] multiband_tile = np.array(bands) multiband_dict = Tile(multiband_tile, 'BYTE', -128) sc = BaseTestClass.pysc._jsc.sc() ew = BaseTestClass.pysc._jvm.geopyspark.geotrellis.tests.schemas.TupleWrapper java_rdd = ew.testOut(sc) decoder = create_partial_tuple_decoder(key_type="ProjectedExtent") encoder = create_partial_tuple_encoder(key_type="ProjectedExtent") ser = ProtoBufSerializer(decoder, encoder) rdd = RDD(java_rdd, BaseTestClass.pysc, AutoBatchedSerializer(ser)) collected = rdd.collect() @pytest.mark.skipif( 'TRAVIS' in os.environ, reason="Encoding using methods in Main causes issues on Travis") def test_encoded_tuples(self): proto_tuple = tupleMessages_pb2.ProtoTuple() self.extent['extent'] = Extent(**self.extent['extent']) proto_extent = to_pb_projected_extent(ProjectedExtent(**self.extent)) proto_multiband = to_pb_multibandtile(self.multiband_dict) proto_tuple.projectedExtent.CopyFrom(proto_extent) proto_tuple.tiles.CopyFrom(proto_multiband) bs = proto_tuple.SerializeToString() expected_encoded = [self.ser.dumps(x) for x in self.collected] for expected in expected_encoded: self.assertEqual(bs, expected) def test_decoded_tuples(self): expected_tuples = [(self.extent, self.multiband_dict), (self.extent, self.multiband_dict), (self.extent, self.multiband_dict)] for actual, expected in zip(self.collected, expected_tuples): (actual_extent, actual_tile) = actual (expected_extent, expected_tile) = expected self.assertTrue((actual_tile.cells == expected_tile.cells).all()) self.assertDictEqual(actual_extent._asdict(), expected_extent)
def getRawSpatialRDD(self): """ :return: """ serialized_spatial_rdd = self._jvm.GeoSerializerData.serializeToPython( self._srdd.getRawSpatialRDD()) return RDD(serialized_spatial_rdd, self._sc, GeoSparkPickler())
def slice(self, begin, end): """ Return all the RDDs between 'begin' to 'end' (both included) `begin`, `end` could be datetime.datetime() or unix_timestamp """ jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end)) return [RDD(jrdd, self._sc, self._jrdd_deserializer) for jrdd in jrdds]
class TemporalProjectedExtentSchemaTest(BaseTestClass): extents = [ Extent(0.0, 0.0, 1.0, 1.0), Extent(1.0, 2.0, 3.0, 4.0), Extent(5.0, 6.0, 7.0, 8.0), ] time = datetime.datetime.strptime("2016-08-24T09:00:00Z", '%Y-%m-%dT%H:%M:%SZ') expected_tpextents = [ TemporalProjectedExtent(epsg=2004, extent=extents[0], instant=time)._asdict(), TemporalProjectedExtent(epsg=2004, extent=extents[1], instant=time)._asdict(), TemporalProjectedExtent(epsg=2004, extent=extents[2], instant=time)._asdict() ] sc = BaseTestClass.pysc._jsc.sc() ew = BaseTestClass.pysc._jvm.geopyspark.geotrellis.tests.schemas.TemporalProjectedExtentWrapper java_rdd = ew.testOut(sc) ser = ProtoBufSerializer(temporal_projected_extent_decoder, temporal_projected_extent_encoder) rdd = RDD(java_rdd, BaseTestClass.pysc, AutoBatchedSerializer(ser)) collected = [tpex._asdict() for tpex in rdd.collect()] @pytest.fixture(scope='class', autouse=True) def tearDown(self): yield BaseTestClass.pysc._gateway.close() def result_checker(self, actual_tpe, expected_tpe): for actual, expected in zip(actual_tpe, expected_tpe): self.assertDictEqual(actual, expected) def test_encoded_tpextents(self): actual_encoded = [ temporal_projected_extent_encoder(x) for x in self.rdd.collect() ] for x in range(0, len(self.expected_tpextents)): self.expected_tpextents[x]['extent'] = Extent( **self.expected_tpextents[x]['extent']) expected_encoded = [ to_pb_temporal_projected_extent(TemporalProjectedExtent(**ex)).SerializeToString() \ for ex in self.expected_tpextents ] for actual, expected in zip(actual_encoded, expected_encoded): self.assertEqual(actual, expected) def test_decoded_tpextents(self): self.result_checker(self.collected, self.expected_tpextents)
def indexedRawRDD(self): jrdd = get_field(self._srdd, "indexedRawRDD") if not hasattr(self, "_indexed_raw_rdd"): RDD.saveAsObjectFile = lambda x, path: x._jrdd.saveAsObjectFile(path) RDD.count = lambda x: x._jrdd.count() setattr(self, "_indexed_raw_rdd", RDD(jrdd, self._sc)) else: self._indexed_raw_rdd._jrdd = jrdd return getattr(self, "_indexed_raw_rdd")
def to_instants(self): """ Returns an RDD of instants, each a horizontal slice of this TimeSeriesRDD at a time. This essentially transposes the TimeSeriesRDD, producing an RDD of tuples of datetime and a numpy array containing all the observations that occurred at that time. """ jrdd = self._jtsrdd.toInstants(-1).map( \ self.ctx._jvm.com.cloudera.sparkts.InstantToBytes()) return RDD(jrdd, self.ctx, _InstantDeserializer())
class MultibandSchemaTest(BaseTestClass): arr = np.int8([0, 0, 1, 1]).reshape(2, 2) no_data = -128 arr_dict = Tile(arr, 'BYTE', no_data) band_dicts = [arr_dict, arr_dict, arr_dict] bands = [arr, arr, arr] multiband_tile = np.array(bands) multiband_dict = Tile(multiband_tile, 'BYTE', no_data) sc = BaseTestClass.pysc._jsc.sc() mw = BaseTestClass.pysc._jvm.geopyspark.geotrellis.tests.schemas.ArrayMultibandTileWrapper java_rdd = mw.testOut(sc) ser = ProtoBufSerializer(multibandtile_decoder, multibandtile_encoder) rdd = RDD(java_rdd, BaseTestClass.pysc, AutoBatchedSerializer(ser)) collected = rdd.collect() @pytest.fixture(autouse=True) def tearDown(self): yield BaseTestClass.pysc._gateway.close() def test_encoded_multibands(self): actual_encoded = [multibandtile_encoder(x) for x in self.collected] proto_tile = tileMessages_pb2.ProtoTile() cell_type = tileMessages_pb2.ProtoCellType() cell_type.nd = self.no_data cell_type.hasNoData = True cell_type.dataType = 1 proto_tile.cols = 2 proto_tile.rows = 2 proto_tile.sint32Cells.extend(self.arr.flatten().tolist()) proto_tile.cellType.CopyFrom(cell_type) proto_multiband = tileMessages_pb2.ProtoMultibandTile() proto_multiband.tiles.extend([proto_tile, proto_tile, proto_tile]) bs = proto_multiband.SerializeToString() expected_encoded = [bs, bs, bs] for actual, expected in zip(actual_encoded, expected_encoded): self.assertEqual(actual, expected) def test_decoded_multibands(self): expected_multibands = [ self.multiband_dict, self.multiband_dict, self.multiband_dict ] for actual, expected in zip(self.collected, expected_multibands): self.assertTrue((actual.cells == expected.cells).all())
class FeatureCellValueSchemaTest(BaseTestClass): sc = BaseTestClass.pysc._jsc.sc() fw = BaseTestClass.pysc._jvm.geopyspark.geotrellis.tests.schemas.FeatureCellValueWrapper java_rdd = fw.testOut(sc) ser = ProtoBufSerializer(feature_cellvalue_decoder, feature_cellvalue_encoder) rdd = RDD(java_rdd, BaseTestClass.pysc, AutoBatchedSerializer(ser)) point = Point(0, 2) line_1 = LineString( [point, Point(1, 3), Point(2, 4), Point(3, 5), Point(4, 6)]) line_2 = LineString( [Point(5, 7), Point(6, 8), Point(7, 9), Point(8, 10), Point(9, 11)]) multi_line = MultiLineString([line_1, line_2]) features = [ Feature(point, CellValue(2, 1)), Feature(line_1, CellValue(1, 0)), Feature(multi_line, CellValue(1, 0)) ] collected = [f for f in rdd.collect()] @pytest.fixture(autouse=True) def tearDown(self): yield BaseTestClass.pysc._gateway.close() def test_decoder(self): geoms = [g.geometry for g in self.collected] ms = [m.properties for m in self.collected] for x in self.features: self.assertTrue(x.geometry in geoms) self.assertTrue(x.properties in ms) def test_encoder(self): expected_encoded = [ to_pb_feature_cellvalue(f).SerializeToString() for f in self.features ] actual_encoded = [feature_cellvalue_encoder(f) for f in self.collected] for x in expected_encoded: self.assertTrue(x in actual_encoded)
def spatialPartitionedRDD(self): """ :return: """ serialized_spatial_rdd = GeoSparkPythonConverter(self._jvm).translate_spatial_rdd_to_python( get_field(self._srdd, "spatialPartitionedRDD")) if not hasattr(self, "_spatial_partitioned_rdd"): setattr(self, "_spatial_partitioned_rdd", RDD(serialized_spatial_rdd, self._sc, GeoSparkPickler())) else: self._spatial_partitioned_rdd._jrdd = serialized_spatial_rdd return getattr(self, "_spatial_partitioned_rdd")
def _java2py(sc, r): if isinstance(r, JavaObject): clsName = r.getClass().getSimpleName() if clsName in ("RDD", "JavaRDD"): if clsName == "RDD": r = r.toJavaRDD() jrdd = sc._jvm.SerDe.javaToPython(r) return RDD(jrdd, sc, AutoBatchedSerializer(PickleSerializer())) elif clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) if isinstance(r, bytearray): r = PickleSerializer().loads(str(r)) return r
def spatialPartitionedRDD(self): """ :return: """ serialized_spatial_rdd = self._jvm.GeoSerializerData.serializeToPython( get_field(self._srdd, "spatialPartitionedRDD")) if not hasattr(self, "_spatial_partitioned_rdd"): setattr(self, "_spatial_partitioned_rdd", RDD(serialized_spatial_rdd, self._sc, GeoSparkPickler())) else: self._spatial_partitioned_rdd._jrdd = serialized_spatial_rdd return getattr(self, "_spatial_partitioned_rdd")
def getRawSpatialRDD(self): """ :return: """ serialized_spatial_rdd = GeoSparkPythonConverter(self._jvm).translate_spatial_rdd_to_python(self._srdd.getRawSpatialRDD()) if not hasattr(self, "_raw_spatial_rdd"): RDD.saveAsObjectFile = lambda x, path: x._jrdd.saveAsObjectFile(path) setattr(self, "_raw_spatial_rdd", RDD(serialized_spatial_rdd, self._sc, GeoSparkPickler())) else: self._raw_spatial_rdd._jrdd = serialized_spatial_rdd return getattr(self, "_raw_spatial_rdd")