Beispiel #1
0
 def test_als_ratings_serialize(self):
     ser = CPickleSerializer()
     r = Rating(7, 1123, 3.14)
     jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r)))
     nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr)))
     self.assertEqual(r.user, nr.user)
     self.assertEqual(r.product, nr.product)
     self.assertAlmostEqual(r.rating, nr.rating, 2)
Beispiel #2
0
 def test_als_ratings_id_long_error(self):
     ser = CPickleSerializer()
     r = Rating(1205640308657491975, 50233468418, 1.0)
     # rating user id exceeds max int value, should fail when pickled
     self.assertRaises(
         Py4JJavaError,
         self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads,
         bytearray(ser.dumps(r)),
     )
Beispiel #3
0
 def test_hash_serializer(self):
     hash(NoOpSerializer())
     hash(UTF8Deserializer())
     hash(CPickleSerializer())
     hash(MarshalSerializer())
     hash(AutoSerializer())
     hash(BatchedSerializer(CPickleSerializer()))
     hash(AutoBatchedSerializer(MarshalSerializer()))
     hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer()))
     hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer()))
     hash(CompressedSerializer(CPickleSerializer()))
     hash(FlattenedValuesSerializer(CPickleSerializer()))
Beispiel #4
0
    def test_take_on_jrdd_with_large_rows_should_not_cause_deadlock(self):
        # Regression test for SPARK-38677.
        #
        # Create a DataFrame with many columns, call a Python function on each row, and take only
        # the first result row.
        #
        # This produces large rows that trigger a deadlock involving the following three threads:
        #
        # 1. The Scala task executor thread. During task execution, this is responsible for reading
        #    output produced by the Python process. However, in this case the task has finished
        #    early, and this thread is no longer reading output produced by the Python process.
        #    Instead, it is waiting for the Scala WriterThread to exit so that it can finish the
        #    task.
        #
        # 2. The Scala WriterThread. This is trying to send a large row to the Python process, and
        #    is waiting for the Python process to read that row.
        #
        # 3. The Python process. This is trying to send a large output to the Scala task executor
        #    thread, and is waiting for that thread to read that output.
        #
        # For this test to succeed rather than hanging, the Scala MonitorThread must detect this
        # deadlock and kill the Python worker.
        import numpy as np
        import pandas as pd

        num_rows = 100000
        num_columns = 134
        data = np.zeros((num_rows, num_columns))
        columns = map(str, range(num_columns))
        df = SparkSession(self.sc).createDataFrame(pd.DataFrame(data, columns=columns))
        actual = CPickleSerializer().loads(df.rdd.map(list)._jrdd.first())
        expected = [list(data[0])]
        self.assertEqual(expected, actual)
Beispiel #5
0
def _java2py(sc: SparkContext,
             r: "JavaObjectOrPickleDump",
             encoding: str = "bytes") -> Any:
    if isinstance(r, JavaObject):
        clsName = r.getClass().getSimpleName()
        # convert RDD into JavaRDD
        if clsName != "JavaRDD" and clsName.endswith("RDD"):
            r = r.toJavaRDD()
            clsName = "JavaRDD"

        assert sc._jvm is not None

        if clsName == "JavaRDD":
            jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(
                r)  # type: ignore[attr-defined]
            return RDD(jrdd, sc)

        if clsName == "Dataset":
            return DataFrame(r, SparkSession(sc)._wrapped)

        if clsName in _picklable_classes:
            r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(
                r)  # type: ignore[attr-defined]
        elif isinstance(r, (JavaArray, JavaList)):
            try:
                r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(
                    r)  # type: ignore[attr-defined]
            except Py4JJavaError:
                pass  # not picklable

    if isinstance(r, (bytearray, bytes)):
        r = CPickleSerializer().loads(bytes(r), encoding=encoding)
    return r
Beispiel #6
0
def _to_java_object_rdd(rdd: RDD) -> JavaObject:
    """Return a JavaRDD of Object by unpickling

    It will convert each Python object into Java object by Pickle, whenever the
    RDD is serialized in batch or not.
    """
    rdd = rdd._reserialize(AutoBatchedSerializer(CPickleSerializer()))  # type: ignore[attr-defined]
    return rdd.ctx._jvm.org.apache.spark.mllib.api.python.SerDe.pythonToJava(rdd._jrdd, True)  # type: ignore[attr-defined]
Beispiel #7
0
 def _test_serialize(self, v):
     ser = CPickleSerializer()
     self.assertEqual(v, ser.loads(ser.dumps(v)))
     jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v)))
     nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec)))
     self.assertEqual(v, nv)
     vs = [v] * 100
     jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs)))
     nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs)))
     self.assertEqual(vs, nvs)
Beispiel #8
0
def _to_java_object_rdd(rdd: RDD) -> JavaObject:
    """Return an JavaRDD of Object by unpickling

    It will convert each Python object into Java object by Pickle, whenever the
    RDD is serialized in batch or not.
    """
    rdd = rdd._reserialize(AutoBatchedSerializer(CPickleSerializer()))
    assert rdd.ctx._jvm is not None
    return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(
        rdd._jrdd, True)
Beispiel #9
0
 def _open_file(self):
     dirs = _get_local_dirs("objects")
     d = dirs[id(self) % len(dirs)]
     if not os.path.exists(d):
         os.makedirs(d)
     p = os.path.join(d, str(id(self)))
     self._file = open(p, "w+b", 65536)
     self._ser = BatchedSerializer(
         CompressedSerializer(CPickleSerializer()), 1024)
     os.unlink(p)
Beispiel #10
0
    def test_compressed_serializer(self):
        ser = CompressedSerializer(CPickleSerializer())
        from io import BytesIO as StringIO

        io = StringIO()
        ser.dump_stream(["abc", "123", range(5)], io)
        io.seek(0)
        self.assertEqual(["abc", "123", range(5)], list(ser.load_stream(io)))
        ser.dump_stream(range(1000), io)
        io.seek(0)
        self.assertEqual(["abc", "123", range(5)] + list(range(1000)),
                         list(ser.load_stream(io)))
        io.close()
Beispiel #11
0
 def test_zip_with_different_serializers(self):
     a = self.sc.parallelize(range(5))
     b = self.sc.parallelize(range(100, 105))
     self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
     a = a._reserialize(BatchedSerializer(CPickleSerializer(), 2))
     b = b._reserialize(MarshalSerializer())
     self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
     # regression test for SPARK-4841
     path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
     t = self.sc.textFile(path)
     cnt = t.count()
     self.assertEqual(cnt, t.zip(t).count())
     rdd = t.map(str)
     self.assertEqual(cnt, t.zip(rdd).count())
     # regression test for bug in _reserializer()
     self.assertEqual(cnt, t.zip(rdd).count())
Beispiel #12
0
def _py2java(sc: SparkContext, obj: Any) -> JavaObject:
    """Convert Python object into Java"""
    if isinstance(obj, RDD):
        obj = _to_java_object_rdd(obj)
    elif isinstance(obj, DataFrame):
        obj = obj._jdf
    elif isinstance(obj, SparkContext):
        obj = obj._jsc  # type: ignore[attr-defined]
    elif isinstance(obj, list):
        obj = [_py2java(sc, x) for x in obj]
    elif isinstance(obj, JavaObject):
        pass
    elif isinstance(obj, (int, float, bool, bytes, str)):
        pass
    else:
        data = bytearray(CPickleSerializer().dumps(obj))
        obj = sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(data)  # type: ignore[attr-defined]
    return obj
Beispiel #13
0
    def test_serialize(self):
        from scipy.sparse import lil_matrix

        ser = CPickleSerializer()
        lil = lil_matrix((4, 1))
        lil[1, 0] = 1
        lil[3, 0] = 2
        sv = SparseVector(4, {1: 1, 3: 2})
        self.assertEqual(sv, _convert_to_vector(lil))
        self.assertEqual(sv, _convert_to_vector(lil.tocsc()))
        self.assertEqual(sv, _convert_to_vector(lil.tocoo()))
        self.assertEqual(sv, _convert_to_vector(lil.tocsr()))
        self.assertEqual(sv, _convert_to_vector(lil.todok()))

        def serialize(d):
            return ser.loads(ser.dumps(_convert_to_vector(d)))

        self.assertEqual(sv, serialize(lil))
        self.assertEqual(sv, serialize(lil.tocsc()))
        self.assertEqual(sv, serialize(lil.tocsr()))
        self.assertEqual(sv, serialize(lil.todok()))
Beispiel #14
0
def read_udfs(pickleSer, infile, eval_type):
    runner_conf = {}

    if eval_type in (
        PythonEvalType.SQL_SCALAR_PANDAS_UDF,
        PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
        PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
        PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
        PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
        PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
        PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
    ):

        # Load conf used for pandas_udf evaluation
        num_conf = read_int(infile)
        for i in range(num_conf):
            k = utf8_deserializer.loads(infile)
            v = utf8_deserializer.loads(infile)
            runner_conf[k] = v

        # NOTE: if timezone is set here, that implies respectSessionTimeZone is True
        timezone = runner_conf.get("spark.sql.session.timeZone", None)
        safecheck = (
            runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower()
            == "true"
        )
        # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType
        assign_cols_by_name = (
            runner_conf.get(
                "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true"
            ).lower()
            == "true"
        )

        if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
            ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name)
        elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
            ser = ArrowStreamUDFSerializer()
        else:
            # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
            # pandas Series. See SPARK-27240.
            df_for_struct = (
                eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
                or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
                or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
            )
            ser = ArrowStreamPandasUDFSerializer(
                timezone, safecheck, assign_cols_by_name, df_for_struct
            )
    else:
        ser = BatchedSerializer(CPickleSerializer(), 100)

    num_udfs = read_int(infile)

    is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
    is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
    is_map_arrow_iter = eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF

    if is_scalar_iter or is_map_pandas_iter or is_map_arrow_iter:
        if is_scalar_iter:
            assert num_udfs == 1, "One SCALAR_ITER UDF expected here."
        if is_map_pandas_iter:
            assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here."
        if is_map_arrow_iter:
            assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here."

        arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0)

        def func(_, iterator):
            num_input_rows = 0

            def map_batch(batch):
                nonlocal num_input_rows

                udf_args = [batch[offset] for offset in arg_offsets]
                num_input_rows += len(udf_args[0])
                if len(udf_args) == 1:
                    return udf_args[0]
                else:
                    return tuple(udf_args)

            iterator = map(map_batch, iterator)
            result_iter = udf(iterator)

            num_output_rows = 0
            for result_batch, result_type in result_iter:
                num_output_rows += len(result_batch)
                # This assert is for Scalar Iterator UDF to fail fast.
                # The length of the entire input can only be explicitly known
                # by consuming the input iterator in user side. Therefore,
                # it's very unlikely the output length is higher than
                # input length.
                assert (
                    is_map_pandas_iter or is_map_arrow_iter or num_output_rows <= num_input_rows
                ), "Pandas SCALAR_ITER UDF outputted more rows than input rows."
                yield (result_batch, result_type)

            if is_scalar_iter:
                try:
                    next(iterator)
                except StopIteration:
                    pass
                else:
                    raise RuntimeError("pandas iterator UDF should exhaust the input " "iterator.")

                if num_output_rows != num_input_rows:
                    raise RuntimeError(
                        "The length of output in Scalar iterator pandas UDF should be "
                        "the same with the input's; however, the length of output was %d and the "
                        "length of input was %d." % (num_output_rows, num_input_rows)
                    )

        # profiling is not supported for UDF
        return func, None, ser, ser

    def extract_key_value_indexes(grouped_arg_offsets):
        """
        Helper function to extract the key and value indexes from arg_offsets for the grouped and
        cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for equivalent scala code.

        Parameters
        ----------
        grouped_arg_offsets:  list
            List containing the key and value indexes of columns of the
            DataFrames to be passed to the udf. It consists of n repeating groups where n is the
            number of DataFrames.  Each group has the following format:
                group[0]: length of group
                group[1]: length of key indexes
                group[2.. group[1] +2]: key attributes
                group[group[1] +3 group[0]]: value attributes
        """
        parsed = []
        idx = 0
        while idx < len(grouped_arg_offsets):
            offsets_len = grouped_arg_offsets[idx]
            idx += 1
            offsets = grouped_arg_offsets[idx : idx + offsets_len]
            split_index = offsets[0] + 1
            offset_keys = offsets[1:split_index]
            offset_values = offsets[split_index:]
            parsed.append([offset_keys, offset_values])
            idx += offsets_len
        return parsed

    if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
        # We assume there is only one UDF here because grouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1

        # See FlatMapGroupsInPandasExec for how arg_offsets are used to
        # distinguish between grouping attributes and data attributes
        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0)
        parsed_offsets = extract_key_value_indexes(arg_offsets)

        # Create function like this:
        #   mapper a: f([a[0]], [a[0], a[1]])
        def mapper(a):
            keys = [a[o] for o in parsed_offsets[0][0]]
            vals = [a[o] for o in parsed_offsets[0][1]]
            return f(keys, vals)

    elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
        # We assume there is only one UDF here because cogrouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1
        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0)

        parsed_offsets = extract_key_value_indexes(arg_offsets)

        def mapper(a):
            df1_keys = [a[0][o] for o in parsed_offsets[0][0]]
            df1_vals = [a[0][o] for o in parsed_offsets[0][1]]
            df2_keys = [a[1][o] for o in parsed_offsets[1][0]]
            df2_vals = [a[1][o] for o in parsed_offsets[1][1]]
            return f(df1_keys, df1_vals, df2_keys, df2_vals)

    else:
        udfs = []
        for i in range(num_udfs):
            udfs.append(read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=i))

        def mapper(a):
            result = tuple(f(*[a[o] for o in arg_offsets]) for (arg_offsets, f) in udfs)
            # In the special case of a single UDF this will return a single result rather
            # than a tuple of results; this is the format that the JVM side expects.
            if len(result) == 1:
                return result[0]
            else:
                return result

    def func(_, it):
        return map(mapper, it)

    # profiling is not supported for UDF
    return func, None, ser, ser
Beispiel #15
0
    SpecialLengths,
    UTF8Deserializer,
    CPickleSerializer,
    BatchedSerializer,
)
from pyspark.sql.pandas.serializers import (
    ArrowStreamPandasUDFSerializer,
    CogroupUDFSerializer,
    ArrowStreamUDFSerializer,
)
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.types import StructType
from pyspark.util import fail_on_stopiteration, try_simplify_traceback
from pyspark import shuffle

pickleSer = CPickleSerializer()
utf8_deserializer = UTF8Deserializer()


def report_times(outfile, boot, init, finish):
    write_int(SpecialLengths.TIMING_DATA, outfile)
    write_long(int(1000 * boot), outfile)
    write_long(int(1000 * init), outfile)
    write_long(int(1000 * finish), outfile)


def add_path(path):
    # worker can be used, so do not add path multiple times
    if path not in sys.path:
        # overwrite system packages
        sys.path.insert(1, path)
Beispiel #16
0
def _compressed_serializer(self, serializer=None):
    # always use CPickleSerializer to simplify implementation
    ser = CPickleSerializer()
    return AutoBatchedSerializer(CompressedSerializer(ser))
Beispiel #17
0
    def foreach(
        self, f: Union[Callable[[Row], None],
                       "SupportsProcess"]) -> "DataStreamWriter":
        """
        Sets the output of the streaming query to be processed using the provided writer ``f``.
        This is often used to write the output of a streaming query to arbitrary storage systems.
        The processing logic can be specified in two ways.

        #. A **function** that takes a row as input.
            This is a simple way to express your processing logic. Note that this does
            not allow you to deduplicate generated data when failures cause reprocessing of
            some input data. That would require you to specify the processing logic in the next
            way.

        #. An **object** with a ``process`` method and optional ``open`` and ``close`` methods.
            The object can have the following methods.

            * ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing
                (for example, open a connection, start a transaction, etc). Additionally, you can
                use the `partition_id` and `epoch_id` to deduplicate regenerated data
                (discussed later).

            * ``process(row)``: *Non-optional* method that processes each :class:`Row`.

            * ``close(error)``: *Optional* method that finalizes and cleans up (for example,
                close connection, commit transaction, etc.) after all rows have been processed.

            The object will be used by Spark in the following way.

            * A single copy of this object is responsible of all the data generated by a
                single task in a query. In other words, one instance is responsible for
                processing one partition of the data generated in a distributed manner.

            * This object must be serializable because each task will get a fresh
                serialized-deserialized copy of the provided object. Hence, it is strongly
                recommended that any initialization for writing data (e.g. opening a
                connection or starting a transaction) is done after the `open(...)`
                method has been called, which signifies that the task is ready to generate data.

            * The lifecycle of the methods are as follows.

                For each partition with ``partition_id``:

                ... For each batch/epoch of streaming data with ``epoch_id``:

                ....... Method ``open(partitionId, epochId)`` is called.

                ....... If ``open(...)`` returns true, for each row in the partition and
                        batch/epoch, method ``process(row)`` is called.

                ....... Method ``close(errorOrNull)`` is called with error (if any) seen while
                        processing rows.

            Important points to note:

            * The `partitionId` and `epochId` can be used to deduplicate generated data when
                failures cause reprocessing of some input data. This depends on the execution
                mode of the query. If the streaming query is being executed in the micro-batch
                mode, then every partition represented by a unique tuple (partition_id, epoch_id)
                is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used
                to deduplicate and/or transactionally commit data and achieve exactly-once
                guarantees. However, if the streaming query is being executed in the continuous
                mode, then this guarantee does not hold and therefore should not be used for
                deduplication.

            * The ``close()`` method (if exists) will be called if `open()` method exists and
                returns successfully (irrespective of the return value), except if the Python
                crashes in the middle.

        .. versionadded:: 2.4.0

        Notes
        -----
        This API is evolving.

        Examples
        --------
        >>> # Print every row using a function
        >>> def print_row(row):
        ...     print(row)
        ...
        >>> writer = sdf.writeStream.foreach(print_row)
        >>> # Print every row using a object with process() method
        >>> class RowPrinter:
        ...     def open(self, partition_id, epoch_id):
        ...         print("Opened %d, %d" % (partition_id, epoch_id))
        ...         return True
        ...     def process(self, row):
        ...         print(row)
        ...     def close(self, error):
        ...         print("Closed with error: %s" % str(error))
        ...
        >>> writer = sdf.writeStream.foreach(RowPrinter())
        """

        from pyspark.rdd import _wrap_function
        from pyspark.serializers import CPickleSerializer, AutoBatchedSerializer
        from pyspark.taskcontext import TaskContext

        if callable(f):
            # The provided object is a callable function that is supposed to be called on each row.
            # Construct a function that takes an iterator and calls the provided function on each
            # row.
            def func_without_process(_: Any, iterator: Iterator) -> Iterator:
                for x in iterator:
                    f(x)  # type: ignore[operator]
                return iter([])

            func = func_without_process

        else:
            # The provided object is not a callable function. Then it is expected to have a
            # 'process(row)' method, and optional 'open(partition_id, epoch_id)' and
            # 'close(error)' methods.

            if not hasattr(f, "process"):
                raise AttributeError(
                    "Provided object does not have a 'process' method")

            if not callable(getattr(f, "process")):
                raise TypeError(
                    "Attribute 'process' in provided object is not callable")

            def doesMethodExist(method_name: str) -> bool:
                exists = hasattr(f, method_name)
                if exists and not callable(getattr(f, method_name)):
                    raise TypeError(
                        "Attribute '%s' in provided object is not callable" %
                        method_name)
                return exists

            open_exists = doesMethodExist("open")
            close_exists = doesMethodExist("close")

            def func_with_open_process_close(partition_id: Any,
                                             iterator: Iterator) -> Iterator:
                epoch_id = cast(TaskContext,
                                TaskContext.get()).getLocalProperty(
                                    "streaming.sql.batchId")
                if epoch_id:
                    int_epoch_id = int(epoch_id)
                else:
                    raise RuntimeError(
                        "Could not get batch id from TaskContext")

                # Check if the data should be processed
                should_process = True
                if open_exists:
                    should_process = f.open(
                        partition_id, int_epoch_id)  # type: ignore[union-attr]

                error = None

                try:
                    if should_process:
                        for x in iterator:
                            cast("SupportsProcess", f).process(x)
                except Exception as ex:
                    error = ex
                finally:
                    if close_exists:
                        f.close(error)  # type: ignore[union-attr]
                    if error:
                        raise error

                return iter([])

            func = func_with_open_process_close  # type: ignore[assignment]

        serializer = AutoBatchedSerializer(CPickleSerializer())
        wrapped_func = _wrap_function(self._spark._sc, func, serializer,
                                      serializer)
        assert self._spark._sc._jvm is not None
        jForeachWriter = (self._spark._sc._jvm.org.apache.spark.sql.execution.
                          python.PythonForeachWriter(wrapped_func,
                                                     self._df._jdf.schema()))
        self._jwrite.foreach(jForeachWriter)
        return self