def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): from pyspark.sql import SparkSession assert isinstance(self, SparkSession) from pyspark.sql.pandas.utils import require_minimum_pandas_version require_minimum_pandas_version() if self._wrapped._conf.pandasRespectSessionTimeZone(): timezone = self._wrapped._conf.sessionLocalTimeZone() else: timezone = None # If no schema supplied by user then get the names of columns only if schema is None: schema = [ str(x) if not isinstance(x, basestring) else (x.encode('utf-8') if not isinstance(x, str) else x) for x in data.columns ] if self._wrapped._conf.arrowPySparkEnabled() and len(data) > 0: try: return self._create_from_pandas_with_arrow( data, schema, timezone) except Exception as e: from pyspark.util import _exception_message if self._wrapped._conf.arrowPySparkFallbackEnabled(): msg = ( "createDataFrame attempted Arrow optimization because " "'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, " "failed by the reason below:\n %s\n" "Attempting non-optimization as " "'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to " "true." % _exception_message(e)) warnings.warn(msg) else: msg = ( "createDataFrame attempted Arrow optimization because " "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has " "reached the error below and will not continue because automatic " "fallback with 'spark.sql.execution.arrow.pyspark.fallback.enabled' " "has been set to false.\n %s" % _exception_message(e)) warnings.warn(msg) raise data = self._convert_from_pandas(data, schema, timezone) return self._create_dataframe(data, schema, samplingRatio, samplingRatio)
def local_connect_and_auth(port, auth_secret): tmp_port = PysparkGateway.open_tmp_tunnel(port) """ Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection. Handles IPV4 & IPV6, does some error handling. :param port :param auth_secret :return: a tuple with (sockfile, sock) """ sock = None errors = [] # Support for both IPv4 and IPv6. # On most of IPv6-ready systems, IPv6 will take precedence. for res in socket.getaddrinfo(PysparkGateway.host, tmp_port, socket.AF_UNSPEC, socket.SOCK_STREAM): af, socktype, proto, _, sa = res try: sock = socket.socket(af, socktype, proto) sock.settimeout(15) sock.connect(sa) sockfile = sock.makefile("rwb", 65536) _do_server_auth(sockfile, auth_secret) return (sockfile, sock) except socket.error as e: emsg = _exception_message(e) errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg)) sock.close() sock = None else: raise Exception("could not open socket: %s" % errors)
def log_failure( self, class_name: str, name: str, ex: Exception, duration: float, signature: Optional[Signature] = None, ) -> None: """ Log the function or property call failed. :param class_name: the target class name :param name: the target function or property name :param ex: the exception causing the failure :param duration: the duration until the function or property call fails :param signature: the signature if the target is a function, else None """ if self.logger.isEnabledFor(logging.WARNING): msg = ("A {function} `{class_name}.{name}{signature}` was failed " "after {duration:.3f} ms: {msg}").format( class_name=class_name, name=name, signature=_format_signature(signature), msg=_exception_message(ex), duration=duration * 1000, function="function" if signature is not None else "property", ) self.logger.warning(msg)
def test_toPandas_fallback_enabled(self): with self.sql_conf( {"spark.sql.execution.arrow.pyspark.fallback.enabled": True}): schema = StructType([ StructField("map", MapType(StringType(), IntegerType()), True) ]) df = self.spark.createDataFrame([({u'a': 1}, )], schema=schema) with QuietTest(self.sc): with self.warnings_lock: with warnings.catch_warnings(record=True) as warns: # we want the warnings to appear even if this test is run from a subclass warnings.simplefilter("always") pdf = df.toPandas() # Catch and check the last UserWarning. user_warns = [ warn.message for warn in warns if isinstance(warn.message, UserWarning) ] self.assertTrue(len(user_warns) > 0) self.assertTrue("Attempting non-optimization" in _exception_message(user_warns[-1])) assert_frame_equal(pdf, pd.DataFrame({u'map': [{ u'a': 1 }]}))
def local_connect_and_auth(port, auth_secret): """ Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection. Handles IPV4 & IPV6, does some error handling. :param port :param auth_secret :return: a tuple with (sockfile, sock) """ sock = None errors = [] # Support for both IPv4 and IPv6. # On most of IPv6-ready systems, IPv6 will take precedence. for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM): af, socktype, proto, _, sa = res try: sock = socket.socket(af, socktype, proto) sock.settimeout(15) sock.connect(sa) sockfile = sock.makefile("rwb", 65536) _do_server_auth(sockfile, auth_secret) return (sockfile, sock) except socket.error as e: emsg = _exception_message(e) errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg)) sock.close() sock = None raise Exception("could not open socket: %s" % errors)
def _auto_patch(): import os import logging # Attach a usage logger. logger_module = os.getenv("KOALAS_USAGE_LOGGER", None) if logger_module is not None: try: from databricks.koalas import usage_logging usage_logging.attach(logger_module) except Exception as e: from pyspark.util import _exception_message logger = logging.getLogger('databricks.koalas.usage_logger') logger.warning('Tried to attach usage logger `{}`, but an exception was raised: {}' .format(logger_module, _exception_message(e))) # Autopatching is on by default. x = os.getenv("SPARK_KOALAS_AUTOPATCH", "true") if x.lower() in ("true", "1", "enabled"): logger = logging.getLogger('spark') logger.info("Patching spark automatically. You can disable it by setting " "SPARK_KOALAS_AUTOPATCH=false in your environment") from pyspark.sql import dataframe as df df.DataFrame.to_koalas = DataFrame.to_koalas
def test_py4j_exception_message(self): from pyspark.util import _exception_message with self.assertRaises(Py4JJavaError) as context: # This attempts java.lang.String(null) which throws an NPE. self.sc._jvm.java.lang.String(None) self.assertTrue('NullPointerException' in _exception_message(context.exception))
def test_py4j_exception_message(self): from pyspark.util import _exception_message with self.assertRaises(Py4JJavaError) as context: # This attempts java.lang.String(null) which throws an NPE. self.sc._jvm.java.lang.String(None) self.assertTrue( 'NullPointerException' in _exception_message(context.exception))
def dump(self, value, f): try: pickle.dump(value, f, pickle_protocol) except pickle.PickleError: raise except Exception as e: msg = "Could not serialize broadcast: %s: %s" \ % (e.__class__.__name__, _exception_message(e)) print_exec(sys.stderr) raise pickle.PicklingError(msg) f.close()
def dump(self, value, f): try: pickle.dump(value, f, pickle_protocol) except pickle.PickleError: raise except Exception as e: msg = "Could not serialize broadcast: %s: %s" \ % (e.__class__.__name__, _exception_message(e)) print_exec(sys.stderr) raise pickle.PicklingError(msg) f.close()
def dumps(self, obj): try: return cloudpickle.dumps(obj, 2) except pickle.PickleError: raise except Exception as e: emsg = _exception_message(e) if "'i' format requires" in emsg: msg = "Object too large to serialize: %s" % emsg else: msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) cloudpickle.print_exec(sys.stderr) raise pickle.PicklingError(msg)
def dumps(self, obj): try: return cloudpickle.dumps(obj, 2) except pickle.PickleError: raise except Exception as e: emsg = _exception_message(e) if "'i' format requires" in emsg: msg = "Object too large to serialize: %s" % emsg else: msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) cloudpickle.print_exec(sys.stderr) raise pickle.PicklingError(msg)
def test_createDataFrame_fallback_enabled(self): with QuietTest(self.sc): with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): with warnings.catch_warnings(record=True) as warns: # we want the warnings to appear even if this test is run from a subclass warnings.simplefilter("always") df = self.spark.createDataFrame( pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>") # Catch and check the last UserWarning. user_warns = [ warn.message for warn in warns if isinstance(warn.message, UserWarning)] self.assertTrue(len(user_warns) > 0) self.assertTrue( "Attempting non-optimization" in _exception_message(user_warns[-1])) self.assertEqual(df.collect(), [Row(a={u'a': 1})])
def test_createDataFrame_fallback_enabled(self): with QuietTest(self.sc): with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): with warnings.catch_warnings(record=True) as warns: # we want the warnings to appear even if this test is run from a subclass warnings.simplefilter("always") df = self.spark.createDataFrame( pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>") # Catch and check the last UserWarning. user_warns = [ warn.message for warn in warns if isinstance(warn.message, UserWarning)] self.assertTrue(len(user_warns) > 0) self.assertTrue( "Attempting non-optimization" in _exception_message(user_warns[-1])) self.assertEqual(df.collect(), [Row(a={u'a': 1})])
def test_toPandas_fallback_enabled(self): with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) with QuietTest(self.sc): with self.warnings_lock: with warnings.catch_warnings(record=True) as warns: # we want the warnings to appear even if this test is run from a subclass warnings.simplefilter("always") pdf = df.toPandas() # Catch and check the last UserWarning. user_warns = [ warn.message for warn in warns if isinstance(warn.message, UserWarning)] self.assertTrue(len(user_warns) > 0) self.assertTrue( "Attempting non-optimization" in _exception_message(user_warns[-1])) assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
def dump(self, obj): self.inject_addons() try: return Pickler.dump(self, obj) except RuntimeError as e: if 'recursion' in e.args[0]: msg = """Could not pickle object as excessively deep recursion required.""" raise pickle.PicklingError(msg) except pickle.PickleError: raise except Exception as e: emsg = _exception_message(e) if "'i' format requires" in emsg: msg = "Object too large to serialize: %s" % emsg else: msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) print_exec(sys.stderr) raise pickle.PicklingError(msg)
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): """ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. When ``schema`` is a list of column names, the type of each column will be inferred from ``data``. When ``schema`` is ``None``, it will try to infer the schema (column names and types) from ``data``, which should be an RDD of :class:`Row`, or :class:`namedtuple`, or :class:`dict`. When ``schema`` is :class:`pyspark.sql.types.DataType` or a datatype string, it must match the real data, or an exception will be thrown at runtime. If the given schema is not :class:`pyspark.sql.types.StructType`, it will be wrapped into a :class:`pyspark.sql.types.StructType` as its only field, and the field name will be "value", each record will also be wrapped into a tuple, which can be converted to row later. If schema inference is needed, ``samplingRatio`` is used to determined the ratio of rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. :param data: an RDD of any kind of SQL data representation(e.g. row, tuple, int, boolean, etc.), or :class:`list`, or :class:`pandas.DataFrame`. :param schema: a :class:`pyspark.sql.types.DataType` or a datatype string or a list of column names, default is ``None``. The data type string format equals to :class:`pyspark.sql.types.DataType.simpleString`, except that top level struct type can omit the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. We can also use ``int`` as a short name for ``IntegerType``. :param samplingRatio: the sample ratio of rows used for inferring :param verifySchema: verify data types of every row against schema. :return: :class:`DataFrame` .. versionchanged:: 2.1 Added verifySchema. .. note:: Usage with spark.sql.execution.arrow.enabled=True is experimental. >>> l = [('Alice', 1)] >>> spark.createDataFrame(l).collect() [Row(_1=u'Alice', _2=1)] >>> spark.createDataFrame(l, ['name', 'age']).collect() [Row(name=u'Alice', age=1)] >>> d = [{'name': 'Alice', 'age': 1}] >>> spark.createDataFrame(d).collect() [Row(age=1, name=u'Alice')] >>> rdd = sc.parallelize(l) >>> spark.createDataFrame(rdd).collect() [Row(_1=u'Alice', _2=1)] >>> df = spark.createDataFrame(rdd, ['name', 'age']) >>> df.collect() [Row(name=u'Alice', age=1)] >>> from pyspark.sql import Row >>> Person = Row('name', 'age') >>> person = rdd.map(lambda r: Person(*r)) >>> df2 = spark.createDataFrame(person) >>> df2.collect() [Row(name=u'Alice', age=1)] >>> from pyspark.sql.types import * >>> schema = StructType([ ... StructField("name", StringType(), True), ... StructField("age", IntegerType(), True)]) >>> df3 = spark.createDataFrame(rdd, schema) >>> df3.collect() [Row(name=u'Alice', age=1)] >>> spark.createDataFrame(df.toPandas()).collect() # doctest: +SKIP [Row(name=u'Alice', age=1)] >>> spark.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP [Row(0=1, 1=2)] >>> spark.createDataFrame(rdd, "a: string, b: int").collect() [Row(a=u'Alice', b=1)] >>> rdd = rdd.map(lambda row: row[1]) >>> spark.createDataFrame(rdd, "int").collect() [Row(value=1)] >>> spark.createDataFrame(rdd, "boolean").collect() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... Py4JJavaError: ... """ if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") if isinstance(schema, basestring): schema = _parse_datatype_string(schema) elif isinstance(schema, (list, tuple)): # Must re-encode any unicode strings to be consistent with StructField names schema = [ x.encode('utf-8') if not isinstance(x, str) else x for x in schema ] try: import pandas has_pandas = True except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ == "true": timezone = self.conf.get("spark.sql.session.timeZone") else: timezone = None # If no schema supplied by user then get the names of columns only if schema is None: schema = [ str(x) if not isinstance(x, basestring) else (x.encode('utf-8') if not isinstance(x, str) else x) for x in data.columns ] if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ and len(data) > 0: try: return self._create_from_pandas_with_arrow( data, schema, timezone) except Exception as e: from pyspark.util import _exception_message if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "true") \ .lower() == "true": msg = ( "createDataFrame attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " "failed by the reason below:\n %s\n" "Attempting non-optimization as " "'spark.sql.execution.arrow.fallback.enabled' is set to " "true." % _exception_message(e)) warnings.warn(msg) else: msg = ( "createDataFrame attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true, but has reached " "the error below and will not continue because automatic fallback " "with 'spark.sql.execution.arrow.fallback.enabled' has been set to " "false.\n %s" % _exception_message(e)) warnings.warn(msg) raise data = self._convert_from_pandas(data, schema, timezone) if isinstance(schema, StructType): verify_func = _make_type_verifier( schema) if verifySchema else lambda _: True def prepare(obj): verify_func(obj) return obj elif isinstance(schema, DataType): dataType = schema schema = StructType().add("value", schema) verify_func = _make_type_verifier( dataType, name="field value") if verifySchema else lambda _: True def prepare(obj): verify_func(obj) return obj, else: prepare = lambda obj: obj if isinstance(data, RDD): rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio) else: rdd, schema = self._createFromLocal(map(prepare, data), schema) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) jdf = self._jsparkSession.applySchemaToPythonRDD( jrdd.rdd(), schema.json()) df = DataFrame(jdf, self._wrapped) df._schema = schema return df
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): """ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. When ``schema`` is a list of column names, the type of each column will be inferred from ``data``. When ``schema`` is ``None``, it will try to infer the schema (column names and types) from ``data``, which should be an RDD of :class:`Row`, or :class:`namedtuple`, or :class:`dict`. When ``schema`` is :class:`pyspark.sql.types.DataType` or a datatype string, it must match the real data, or an exception will be thrown at runtime. If the given schema is not :class:`pyspark.sql.types.StructType`, it will be wrapped into a :class:`pyspark.sql.types.StructType` as its only field, and the field name will be "value", each record will also be wrapped into a tuple, which can be converted to row later. If schema inference is needed, ``samplingRatio`` is used to determined the ratio of rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. :param data: an RDD of any kind of SQL data representation(e.g. row, tuple, int, boolean, etc.), or :class:`list`, or :class:`pandas.DataFrame`. :param schema: a :class:`pyspark.sql.types.DataType` or a datatype string or a list of column names, default is ``None``. The data type string format equals to :class:`pyspark.sql.types.DataType.simpleString`, except that top level struct type can omit the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. We can also use ``int`` as a short name for ``IntegerType``. :param samplingRatio: the sample ratio of rows used for inferring :param verifySchema: verify data types of every row against schema. :return: :class:`DataFrame` .. versionchanged:: 2.1 Added verifySchema. .. note:: Usage with spark.sql.execution.arrow.enabled=True is experimental. >>> l = [('Alice', 1)] >>> spark.createDataFrame(l).collect() [Row(_1=u'Alice', _2=1)] >>> spark.createDataFrame(l, ['name', 'age']).collect() [Row(name=u'Alice', age=1)] >>> d = [{'name': 'Alice', 'age': 1}] >>> spark.createDataFrame(d).collect() [Row(age=1, name=u'Alice')] >>> rdd = sc.parallelize(l) >>> spark.createDataFrame(rdd).collect() [Row(_1=u'Alice', _2=1)] >>> df = spark.createDataFrame(rdd, ['name', 'age']) >>> df.collect() [Row(name=u'Alice', age=1)] >>> from pyspark.sql import Row >>> Person = Row('name', 'age') >>> person = rdd.map(lambda r: Person(*r)) >>> df2 = spark.createDataFrame(person) >>> df2.collect() [Row(name=u'Alice', age=1)] >>> from pyspark.sql.types import * >>> schema = StructType([ ... StructField("name", StringType(), True), ... StructField("age", IntegerType(), True)]) >>> df3 = spark.createDataFrame(rdd, schema) >>> df3.collect() [Row(name=u'Alice', age=1)] >>> spark.createDataFrame(df.toPandas()).collect() # doctest: +SKIP [Row(name=u'Alice', age=1)] >>> spark.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP [Row(0=1, 1=2)] >>> spark.createDataFrame(rdd, "a: string, b: int").collect() [Row(a=u'Alice', b=1)] >>> rdd = rdd.map(lambda row: row[1]) >>> spark.createDataFrame(rdd, "int").collect() [Row(value=1)] >>> spark.createDataFrame(rdd, "boolean").collect() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... Py4JJavaError: ... """ SparkSession._activeSession = self self._jvm.SparkSession.setActiveSession(self._jsparkSession) if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") if isinstance(schema, basestring): schema = _parse_datatype_string(schema) elif isinstance(schema, (list, tuple)): # Must re-encode any unicode strings to be consistent with StructField names schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] try: import pandas has_pandas = True except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() if self._wrapped._conf.pandasRespectSessionTimeZone(): timezone = self._wrapped._conf.sessionLocalTimeZone() else: timezone = None # If no schema supplied by user then get the names of columns only if schema is None: schema = [str(x) if not isinstance(x, basestring) else (x.encode('utf-8') if not isinstance(x, str) else x) for x in data.columns] if self._wrapped._conf.arrowEnabled() and len(data) > 0: try: return self._create_from_pandas_with_arrow(data, schema, timezone) except Exception as e: from pyspark.util import _exception_message if self._wrapped._conf.arrowFallbackEnabled(): msg = ( "createDataFrame attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " "failed by the reason below:\n %s\n" "Attempting non-optimization as " "'spark.sql.execution.arrow.fallback.enabled' is set to " "true." % _exception_message(e)) warnings.warn(msg) else: msg = ( "createDataFrame attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true, but has reached " "the error below and will not continue because automatic fallback " "with 'spark.sql.execution.arrow.fallback.enabled' has been set to " "false.\n %s" % _exception_message(e)) warnings.warn(msg) raise data = self._convert_from_pandas(data, schema, timezone) if isinstance(schema, StructType): verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True def prepare(obj): verify_func(obj) return obj elif isinstance(schema, DataType): dataType = schema schema = StructType().add("value", schema) verify_func = _make_type_verifier( dataType, name="field value") if verifySchema else lambda _: True def prepare(obj): verify_func(obj) return obj, else: prepare = lambda obj: obj if isinstance(data, RDD): rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio) else: rdd, schema = self._createFromLocal(map(prepare, data), schema) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) jdf = self._jsparkSession.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) df = DataFrame(jdf, self._wrapped) df._schema = schema return df
import tempfile from contextlib import contextmanager from pyspark.sql import SparkSession from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row from pyspark.testing.utils import ReusedPySparkTestCase from pyspark.util import _exception_message pandas_requirement_message = None try: from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() except ImportError as e: # If Pandas version requirement is not satisfied, skip related tests. pandas_requirement_message = _exception_message(e) pyarrow_requirement_message = None try: from pyspark.sql.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() except ImportError as e: # If Arrow version requirement is not satisfied, skip related tests. pyarrow_requirement_message = _exception_message(e) test_not_compiled_message = None try: from pyspark.sql.utils import require_test_compiled require_test_compiled() except Exception as e: test_not_compiled_message = _exception_message(e)
def toPandas(self): """ Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. .. note:: This method should only be used if the resulting Pandas's :class:`DataFrame` is expected to be small, as all the data is loaded into the driver's memory. .. note:: Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental. >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice 1 5 Bob """ from pyspark.sql.dataframe import DataFrame assert isinstance(self, DataFrame) from pyspark.sql.pandas.utils import require_minimum_pandas_version require_minimum_pandas_version() import numpy as np import pandas as pd timezone = self.sql_ctx._conf.sessionLocalTimeZone() if self.sql_ctx._conf.arrowPySparkEnabled(): use_arrow = True try: from pyspark.sql.pandas.types import to_arrow_schema from pyspark.sql.pandas.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() to_arrow_schema(self.schema) except Exception as e: if self.sql_ctx._conf.arrowPySparkFallbackEnabled(): msg = ( "toPandas attempted Arrow optimization because " "'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, " "failed by the reason below:\n %s\n" "Attempting non-optimization as " "'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to " "true." % _exception_message(e)) warnings.warn(msg) use_arrow = False else: msg = ( "toPandas attempted Arrow optimization because " "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has " "reached the error below and will not continue because automatic fallback " "with 'spark.sql.execution.arrow.pyspark.fallback.enabled' has been set to " "false.\n %s" % _exception_message(e)) warnings.warn(msg) raise # Try to use Arrow optimization when the schema is supported and the required version # of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled. if use_arrow: try: from pyspark.sql.pandas.types import _check_series_localize_timestamps import pyarrow # Rename columns to avoid duplicated column names. tmp_column_names = [ 'col_{}'.format(i) for i in range(len(self.columns)) ] batches = self.toDF(*tmp_column_names)._collect_as_arrow() if len(batches) > 0: table = pyarrow.Table.from_batches(batches) # Pandas DataFrame created from PyArrow uses datetime64[ns] for date type # values, but we should use datetime.date to match the behavior with when # Arrow optimization is disabled. pdf = table.to_pandas(date_as_object=True) # Rename back to the original column names. pdf.columns = self.columns for field in self.schema: if isinstance(field.dataType, TimestampType): pdf[field.name] = \ _check_series_localize_timestamps(pdf[field.name], timezone) return pdf else: return pd.DataFrame.from_records([], columns=self.columns) except Exception as e: # We might have to allow fallback here as well but multiple Spark jobs can # be executed. So, simply fail in this case for now. msg = ( "toPandas attempted Arrow optimization because " "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has " "reached the error below and can not continue. Note that " "'spark.sql.execution.arrow.pyspark.fallback.enabled' does not have an " "effect on failures in the middle of " "computation.\n %s" % _exception_message(e)) warnings.warn(msg) raise # Below is toPandas without Arrow optimization. pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) column_counter = Counter(self.columns) dtype = [None] * len(self.schema) for fieldIdx, field in enumerate(self.schema): # For duplicate column name, we use `iloc` to access it. if column_counter[field.name] > 1: pandas_col = pdf.iloc[:, fieldIdx] else: pandas_col = pdf[field.name] pandas_type = PandasConversionMixin._to_corrected_pandas_type( field.dataType) # SPARK-21766: if an integer field is nullable and has null values, it can be # inferred by pandas as float column. Once we convert the column with NaN back # to integer type e.g., np.int16, we will hit exception. So we use the inferred # float type, not the corrected type from the schema in this case. if pandas_type is not None and \ not(isinstance(field.dataType, IntegralType) and field.nullable and pandas_col.isnull().any()): dtype[fieldIdx] = pandas_type # Ensure we fall back to nullable numpy types, even when whole column is null: if isinstance(field.dataType, IntegralType) and pandas_col.isnull().any(): dtype[fieldIdx] = np.float64 if isinstance(field.dataType, BooleanType) and pandas_col.isnull().any(): dtype[fieldIdx] = np.object df = pd.DataFrame() for index, t in enumerate(dtype): column_name = self.schema[index].name # For duplicate column name, we use `iloc` to access it. if column_counter[column_name] > 1: series = pdf.iloc[:, index] else: series = pdf[column_name] if t is not None: series = series.astype(t, copy=False) # `insert` API makes copy of data, we only do it for Series of duplicate column names. # `pdf.iloc[:, index] = pdf.iloc[:, index]...` doesn't always work because `iloc` could # return a view or a copy depending by context. if column_counter[column_name] > 1: df.insert(index, column_name, series, allow_duplicates=True) else: df[column_name] = series pdf = df if timezone is None: return pdf else: from pyspark.sql.pandas.types import _check_series_convert_timestamps_local_tz for field in self.schema: # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if isinstance(field.dataType, TimestampType): pdf[field.name] = \ _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) return pdf
import shutil import tempfile from contextlib import contextmanager from pyspark.sql import SparkSession from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row from pyspark.testing.utils import ReusedPySparkTestCase from pyspark.util import _exception_message pandas_requirement_message = None try: from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() except ImportError as e: # If Pandas version requirement is not satisfied, skip related tests. pandas_requirement_message = _exception_message(e) pyarrow_requirement_message = None try: from pyspark.sql.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() except ImportError as e: # If Arrow version requirement is not satisfied, skip related tests. pyarrow_requirement_message = _exception_message(e) test_not_compiled_message = None try: from pyspark.sql.utils import require_test_compiled require_test_compiled() except Exception as e: test_not_compiled_message = _exception_message(e)