def test_als_ratings_serialize(self): ser = CPickleSerializer() r = Rating(7, 1123, 3.14) jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r))) nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr))) self.assertEqual(r.user, nr.user) self.assertEqual(r.product, nr.product) self.assertAlmostEqual(r.rating, nr.rating, 2)
def test_als_ratings_id_long_error(self): ser = CPickleSerializer() r = Rating(1205640308657491975, 50233468418, 1.0) # rating user id exceeds max int value, should fail when pickled self.assertRaises( Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads, bytearray(ser.dumps(r)), )
def test_hash_serializer(self): hash(NoOpSerializer()) hash(UTF8Deserializer()) hash(CPickleSerializer()) hash(MarshalSerializer()) hash(AutoSerializer()) hash(BatchedSerializer(CPickleSerializer())) hash(AutoBatchedSerializer(MarshalSerializer())) hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) hash(CompressedSerializer(CPickleSerializer())) hash(FlattenedValuesSerializer(CPickleSerializer()))
def test_take_on_jrdd_with_large_rows_should_not_cause_deadlock(self): # Regression test for SPARK-38677. # # Create a DataFrame with many columns, call a Python function on each row, and take only # the first result row. # # This produces large rows that trigger a deadlock involving the following three threads: # # 1. The Scala task executor thread. During task execution, this is responsible for reading # output produced by the Python process. However, in this case the task has finished # early, and this thread is no longer reading output produced by the Python process. # Instead, it is waiting for the Scala WriterThread to exit so that it can finish the # task. # # 2. The Scala WriterThread. This is trying to send a large row to the Python process, and # is waiting for the Python process to read that row. # # 3. The Python process. This is trying to send a large output to the Scala task executor # thread, and is waiting for that thread to read that output. # # For this test to succeed rather than hanging, the Scala MonitorThread must detect this # deadlock and kill the Python worker. import numpy as np import pandas as pd num_rows = 100000 num_columns = 134 data = np.zeros((num_rows, num_columns)) columns = map(str, range(num_columns)) df = SparkSession(self.sc).createDataFrame(pd.DataFrame(data, columns=columns)) actual = CPickleSerializer().loads(df.rdd.map(list)._jrdd.first()) expected = [list(data[0])] self.assertEqual(expected, actual)
def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "bytes") -> Any: if isinstance(r, JavaObject): clsName = r.getClass().getSimpleName() # convert RDD into JavaRDD if clsName != "JavaRDD" and clsName.endswith("RDD"): r = r.toJavaRDD() clsName = "JavaRDD" assert sc._jvm is not None if clsName == "JavaRDD": jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython( r) # type: ignore[attr-defined] return RDD(jrdd, sc) if clsName == "Dataset": return DataFrame(r, SparkSession(sc)._wrapped) if clsName in _picklable_classes: r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps( r) # type: ignore[attr-defined] elif isinstance(r, (JavaArray, JavaList)): try: r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps( r) # type: ignore[attr-defined] except Py4JJavaError: pass # not picklable if isinstance(r, (bytearray, bytes)): r = CPickleSerializer().loads(bytes(r), encoding=encoding) return r
def _to_java_object_rdd(rdd: RDD) -> JavaObject: """Return a JavaRDD of Object by unpickling It will convert each Python object into Java object by Pickle, whenever the RDD is serialized in batch or not. """ rdd = rdd._reserialize(AutoBatchedSerializer(CPickleSerializer())) # type: ignore[attr-defined] return rdd.ctx._jvm.org.apache.spark.mllib.api.python.SerDe.pythonToJava(rdd._jrdd, True) # type: ignore[attr-defined]
def _test_serialize(self, v): ser = CPickleSerializer() self.assertEqual(v, ser.loads(ser.dumps(v))) jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v))) nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec))) self.assertEqual(v, nv) vs = [v] * 100 jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs))) nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs))) self.assertEqual(vs, nvs)
def _to_java_object_rdd(rdd: RDD) -> JavaObject: """Return an JavaRDD of Object by unpickling It will convert each Python object into Java object by Pickle, whenever the RDD is serialized in batch or not. """ rdd = rdd._reserialize(AutoBatchedSerializer(CPickleSerializer())) assert rdd.ctx._jvm is not None return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava( rdd._jrdd, True)
def _open_file(self): dirs = _get_local_dirs("objects") d = dirs[id(self) % len(dirs)] if not os.path.exists(d): os.makedirs(d) p = os.path.join(d, str(id(self))) self._file = open(p, "w+b", 65536) self._ser = BatchedSerializer( CompressedSerializer(CPickleSerializer()), 1024) os.unlink(p)
def test_compressed_serializer(self): ser = CompressedSerializer(CPickleSerializer()) from io import BytesIO as StringIO io = StringIO() ser.dump_stream(["abc", "123", range(5)], io) io.seek(0) self.assertEqual(["abc", "123", range(5)], list(ser.load_stream(io))) ser.dump_stream(range(1000), io) io.seek(0) self.assertEqual(["abc", "123", range(5)] + list(range(1000)), list(ser.load_stream(io))) io.close()
def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5)) b = self.sc.parallelize(range(100, 105)) self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) a = a._reserialize(BatchedSerializer(CPickleSerializer(), 2)) b = b._reserialize(MarshalSerializer()) self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) # regression test for SPARK-4841 path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") t = self.sc.textFile(path) cnt = t.count() self.assertEqual(cnt, t.zip(t).count()) rdd = t.map(str) self.assertEqual(cnt, t.zip(rdd).count()) # regression test for bug in _reserializer() self.assertEqual(cnt, t.zip(rdd).count())
def _py2java(sc: SparkContext, obj: Any) -> JavaObject: """Convert Python object into Java""" if isinstance(obj, RDD): obj = _to_java_object_rdd(obj) elif isinstance(obj, DataFrame): obj = obj._jdf elif isinstance(obj, SparkContext): obj = obj._jsc # type: ignore[attr-defined] elif isinstance(obj, list): obj = [_py2java(sc, x) for x in obj] elif isinstance(obj, JavaObject): pass elif isinstance(obj, (int, float, bool, bytes, str)): pass else: data = bytearray(CPickleSerializer().dumps(obj)) obj = sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(data) # type: ignore[attr-defined] return obj
def test_serialize(self): from scipy.sparse import lil_matrix ser = CPickleSerializer() lil = lil_matrix((4, 1)) lil[1, 0] = 1 lil[3, 0] = 2 sv = SparseVector(4, {1: 1, 3: 2}) self.assertEqual(sv, _convert_to_vector(lil)) self.assertEqual(sv, _convert_to_vector(lil.tocsc())) self.assertEqual(sv, _convert_to_vector(lil.tocoo())) self.assertEqual(sv, _convert_to_vector(lil.tocsr())) self.assertEqual(sv, _convert_to_vector(lil.todok())) def serialize(d): return ser.loads(ser.dumps(_convert_to_vector(d))) self.assertEqual(sv, serialize(lil)) self.assertEqual(sv, serialize(lil.tocsc())) self.assertEqual(sv, serialize(lil.tocsr())) self.assertEqual(sv, serialize(lil.todok()))
def read_udfs(pickleSer, infile, eval_type): runner_conf = {} if eval_type in ( PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, ): # Load conf used for pandas_udf evaluation num_conf = read_int(infile) for i in range(num_conf): k = utf8_deserializer.loads(infile) v = utf8_deserializer.loads(infile) runner_conf[k] = v # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = ( runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower() == "true" ) # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType assign_cols_by_name = ( runner_conf.get( "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true" ).lower() == "true" ) if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: ser = ArrowStreamUDFSerializer() else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. df_for_struct = ( eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF ) ser = ArrowStreamPandasUDFSerializer( timezone, safecheck, assign_cols_by_name, df_for_struct ) else: ser = BatchedSerializer(CPickleSerializer(), 100) num_udfs = read_int(infile) is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF is_map_arrow_iter = eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF if is_scalar_iter or is_map_pandas_iter or is_map_arrow_iter: if is_scalar_iter: assert num_udfs == 1, "One SCALAR_ITER UDF expected here." if is_map_pandas_iter: assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here." if is_map_arrow_iter: assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here." arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) def func(_, iterator): num_input_rows = 0 def map_batch(batch): nonlocal num_input_rows udf_args = [batch[offset] for offset in arg_offsets] num_input_rows += len(udf_args[0]) if len(udf_args) == 1: return udf_args[0] else: return tuple(udf_args) iterator = map(map_batch, iterator) result_iter = udf(iterator) num_output_rows = 0 for result_batch, result_type in result_iter: num_output_rows += len(result_batch) # This assert is for Scalar Iterator UDF to fail fast. # The length of the entire input can only be explicitly known # by consuming the input iterator in user side. Therefore, # it's very unlikely the output length is higher than # input length. assert ( is_map_pandas_iter or is_map_arrow_iter or num_output_rows <= num_input_rows ), "Pandas SCALAR_ITER UDF outputted more rows than input rows." yield (result_batch, result_type) if is_scalar_iter: try: next(iterator) except StopIteration: pass else: raise RuntimeError("pandas iterator UDF should exhaust the input " "iterator.") if num_output_rows != num_input_rows: raise RuntimeError( "The length of output in Scalar iterator pandas UDF should be " "the same with the input's; however, the length of output was %d and the " "length of input was %d." % (num_output_rows, num_input_rows) ) # profiling is not supported for UDF return func, None, ser, ser def extract_key_value_indexes(grouped_arg_offsets): """ Helper function to extract the key and value indexes from arg_offsets for the grouped and cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for equivalent scala code. Parameters ---------- grouped_arg_offsets: list List containing the key and value indexes of columns of the DataFrames to be passed to the udf. It consists of n repeating groups where n is the number of DataFrames. Each group has the following format: group[0]: length of group group[1]: length of key indexes group[2.. group[1] +2]: key attributes group[group[1] +3 group[0]]: value attributes """ parsed = [] idx = 0 while idx < len(grouped_arg_offsets): offsets_len = grouped_arg_offsets[idx] idx += 1 offsets = grouped_arg_offsets[idx : idx + offsets_len] split_index = offsets[0] + 1 offset_keys = offsets[1:split_index] offset_values = offsets[split_index:] parsed.append([offset_keys, offset_values]) idx += offsets_len return parsed if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) parsed_offsets = extract_key_value_indexes(arg_offsets) # Create function like this: # mapper a: f([a[0]], [a[0], a[1]]) def mapper(a): keys = [a[o] for o in parsed_offsets[0][0]] vals = [a[o] for o in parsed_offsets[0][1]] return f(keys, vals) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) parsed_offsets = extract_key_value_indexes(arg_offsets) def mapper(a): df1_keys = [a[0][o] for o in parsed_offsets[0][0]] df1_vals = [a[0][o] for o in parsed_offsets[0][1]] df2_keys = [a[1][o] for o in parsed_offsets[1][0]] df2_vals = [a[1][o] for o in parsed_offsets[1][1]] return f(df1_keys, df1_vals, df2_keys, df2_vals) else: udfs = [] for i in range(num_udfs): udfs.append(read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=i)) def mapper(a): result = tuple(f(*[a[o] for o in arg_offsets]) for (arg_offsets, f) in udfs) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. if len(result) == 1: return result[0] else: return result def func(_, it): return map(mapper, it) # profiling is not supported for UDF return func, None, ser, ser
SpecialLengths, UTF8Deserializer, CPickleSerializer, BatchedSerializer, ) from pyspark.sql.pandas.serializers import ( ArrowStreamPandasUDFSerializer, CogroupUDFSerializer, ArrowStreamUDFSerializer, ) from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import StructType from pyspark.util import fail_on_stopiteration, try_simplify_traceback from pyspark import shuffle pickleSer = CPickleSerializer() utf8_deserializer = UTF8Deserializer() def report_times(outfile, boot, init, finish): write_int(SpecialLengths.TIMING_DATA, outfile) write_long(int(1000 * boot), outfile) write_long(int(1000 * init), outfile) write_long(int(1000 * finish), outfile) def add_path(path): # worker can be used, so do not add path multiple times if path not in sys.path: # overwrite system packages sys.path.insert(1, path)
def _compressed_serializer(self, serializer=None): # always use CPickleSerializer to simplify implementation ser = CPickleSerializer() return AutoBatchedSerializer(CompressedSerializer(ser))
def foreach( self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataStreamWriter": """ Sets the output of the streaming query to be processed using the provided writer ``f``. This is often used to write the output of a streaming query to arbitrary storage systems. The processing logic can be specified in two ways. #. A **function** that takes a row as input. This is a simple way to express your processing logic. Note that this does not allow you to deduplicate generated data when failures cause reprocessing of some input data. That would require you to specify the processing logic in the next way. #. An **object** with a ``process`` method and optional ``open`` and ``close`` methods. The object can have the following methods. * ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing (for example, open a connection, start a transaction, etc). Additionally, you can use the `partition_id` and `epoch_id` to deduplicate regenerated data (discussed later). * ``process(row)``: *Non-optional* method that processes each :class:`Row`. * ``close(error)``: *Optional* method that finalizes and cleans up (for example, close connection, commit transaction, etc.) after all rows have been processed. The object will be used by Spark in the following way. * A single copy of this object is responsible of all the data generated by a single task in a query. In other words, one instance is responsible for processing one partition of the data generated in a distributed manner. * This object must be serializable because each task will get a fresh serialized-deserialized copy of the provided object. Hence, it is strongly recommended that any initialization for writing data (e.g. opening a connection or starting a transaction) is done after the `open(...)` method has been called, which signifies that the task is ready to generate data. * The lifecycle of the methods are as follows. For each partition with ``partition_id``: ... For each batch/epoch of streaming data with ``epoch_id``: ....... Method ``open(partitionId, epochId)`` is called. ....... If ``open(...)`` returns true, for each row in the partition and batch/epoch, method ``process(row)`` is called. ....... Method ``close(errorOrNull)`` is called with error (if any) seen while processing rows. Important points to note: * The `partitionId` and `epochId` can be used to deduplicate generated data when failures cause reprocessing of some input data. This depends on the execution mode of the query. If the streaming query is being executed in the micro-batch mode, then every partition represented by a unique tuple (partition_id, epoch_id) is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used to deduplicate and/or transactionally commit data and achieve exactly-once guarantees. However, if the streaming query is being executed in the continuous mode, then this guarantee does not hold and therefore should not be used for deduplication. * The ``close()`` method (if exists) will be called if `open()` method exists and returns successfully (irrespective of the return value), except if the Python crashes in the middle. .. versionadded:: 2.4.0 Notes ----- This API is evolving. Examples -------- >>> # Print every row using a function >>> def print_row(row): ... print(row) ... >>> writer = sdf.writeStream.foreach(print_row) >>> # Print every row using a object with process() method >>> class RowPrinter: ... def open(self, partition_id, epoch_id): ... print("Opened %d, %d" % (partition_id, epoch_id)) ... return True ... def process(self, row): ... print(row) ... def close(self, error): ... print("Closed with error: %s" % str(error)) ... >>> writer = sdf.writeStream.foreach(RowPrinter()) """ from pyspark.rdd import _wrap_function from pyspark.serializers import CPickleSerializer, AutoBatchedSerializer from pyspark.taskcontext import TaskContext if callable(f): # The provided object is a callable function that is supposed to be called on each row. # Construct a function that takes an iterator and calls the provided function on each # row. def func_without_process(_: Any, iterator: Iterator) -> Iterator: for x in iterator: f(x) # type: ignore[operator] return iter([]) func = func_without_process else: # The provided object is not a callable function. Then it is expected to have a # 'process(row)' method, and optional 'open(partition_id, epoch_id)' and # 'close(error)' methods. if not hasattr(f, "process"): raise AttributeError( "Provided object does not have a 'process' method") if not callable(getattr(f, "process")): raise TypeError( "Attribute 'process' in provided object is not callable") def doesMethodExist(method_name: str) -> bool: exists = hasattr(f, method_name) if exists and not callable(getattr(f, method_name)): raise TypeError( "Attribute '%s' in provided object is not callable" % method_name) return exists open_exists = doesMethodExist("open") close_exists = doesMethodExist("close") def func_with_open_process_close(partition_id: Any, iterator: Iterator) -> Iterator: epoch_id = cast(TaskContext, TaskContext.get()).getLocalProperty( "streaming.sql.batchId") if epoch_id: int_epoch_id = int(epoch_id) else: raise RuntimeError( "Could not get batch id from TaskContext") # Check if the data should be processed should_process = True if open_exists: should_process = f.open( partition_id, int_epoch_id) # type: ignore[union-attr] error = None try: if should_process: for x in iterator: cast("SupportsProcess", f).process(x) except Exception as ex: error = ex finally: if close_exists: f.close(error) # type: ignore[union-attr] if error: raise error return iter([]) func = func_with_open_process_close # type: ignore[assignment] serializer = AutoBatchedSerializer(CPickleSerializer()) wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer) assert self._spark._sc._jvm is not None jForeachWriter = (self._spark._sc._jvm.org.apache.spark.sql.execution. python.PythonForeachWriter(wrapped_func, self._df._jdf.schema())) self._jwrite.foreach(jForeachWriter) return self