def testErrorInPythonCallback(self):
        with clientserver_example_app_process():
            client_server = ClientServer(
                JavaParameters(),
                PythonParameters(propagate_java_exceptions=True))
            example = client_server.entry_point.getNewExample()

            try:
                example.callHello(
                    IHelloFailingImpl(
                        ValueError('My interesting Python exception')))
                self.fail()
            except Py4JJavaError as e:
                self.assertTrue(
                    is_instance_of(client_server, e.java_exception,
                                   'py4j.Py4JException'))
                self.assertTrue('interesting Python exception' in str(e))

            try:
                example.callHello(
                    IHelloFailingImpl(
                        Py4JJavaError(
                            '',
                            client_server.jvm.java.lang.IllegalStateException(
                                'My IllegalStateException'))))
                self.fail()
            except Py4JJavaError as e:
                self.assertTrue(
                    is_instance_of(client_server, e.java_exception,
                                   'java.lang.IllegalStateException'))

            client_server.shutdown()
Exemple #2
0
def convert_exception(e: Py4JJavaError) -> CapturedException:
    assert e is not None
    assert SparkContext._jvm is not None
    assert SparkContext._gateway is not None

    jvm = SparkContext._jvm
    gw = SparkContext._gateway

    if is_instance_of(gw, e, "org.apache.spark.sql.catalyst.parser.ParseException"):
        return ParseException(origin=e)
    # Order matters. ParseException inherits AnalysisException.
    elif is_instance_of(gw, e, "org.apache.spark.sql.AnalysisException"):
        return AnalysisException(origin=e)
    elif is_instance_of(gw, e, "org.apache.spark.sql.streaming.StreamingQueryException"):
        return StreamingQueryException(origin=e)
    elif is_instance_of(gw, e, "org.apache.spark.sql.execution.QueryExecutionException"):
        return QueryExecutionException(origin=e)
    elif is_instance_of(gw, e, "java.lang.IllegalArgumentException"):
        return IllegalArgumentException(origin=e)
    elif is_instance_of(gw, e, "org.apache.spark.SparkUpgradeException"):
        return SparkUpgradeException(origin=e)

    c: Py4JJavaError = e.getCause()
    stacktrace: str = jvm.org.apache.spark.util.Utils.exceptionString(e)
    if c is not None and (
        is_instance_of(gw, c, "org.apache.spark.api.python.PythonException")
        # To make sure this only catches Python UDFs.
        and any(
            map(
                lambda v: "org.apache.spark.sql.execution.python" in v.toString(), c.getStackTrace()
            )
        )
    ):
        msg = (
            "\n  An exception was thrown from the Python worker. "
            "Please see the stack trace below.\n%s" % c.getMessage()
        )
        return PythonException(msg, stacktrace)

    return UnknownException(desc=e.toString(), stackTrace=stacktrace, cause=c)
    def testProxyError(self):
        sleep()
        example = self.gateway.entry_point.getNewExample()

        try:
            example.callHello(IHelloFailingImpl(
                Py4JJavaError(
                    '',
                    self.gateway.jvm.java.lang.IllegalStateException(
                        'My IllegalStateException'))))
            self.fail()
        except Py4JJavaError as e:
            self.assertTrue(is_instance_of(
                self.gateway, e.java_exception,
                'py4j.Py4JException'))
            self.assertIn('My IllegalStateException', str(e))
Exemple #4
0
    def zip(self, other):
        if self.count() != other.count():
            raise Py4JJavaError(
                "Can only zip RDDs with same number of elements in each partition",
                JavaException(''))

        other_flat = [o for p in other.partitions for o in p]

        zipped = []
        idx = 0
        for p in self.partitions:
            zipped_p = []
            for r in p:
                zipped_p.append((r, other_flat[idx]))
                idx += 1
            zipped.append(zipped_p)

        return self._toRDD(zipped)
Exemple #5
0
 def get(self):
     if failure_reason:
         return "failure-reason"
     else:
         raise Py4JJavaError("msg", JavaException())