Ejemplo n.º 1
0
def setup_local_spark():
    from pyspark import find_spark_home, SparkContext

    spark_python = os.path.join(find_spark_home._find_spark_home(), 'python')
    py4j = glob(os.path.join(spark_python, 'lib', 'py4j-*.zip'))[0]
    sys.path[:0] = [spark_python, py4j]
    _log.debug('sys.path: {p!r}'.format(p=sys.path))
    if 'TRAVIS' in os.environ:
        master_str = "local[2]"
    else:
        master_str = "local[*]"

    from geopyspark import geopyspark_conf
    conf = geopyspark_conf(master=master_str, appName="openeo-geotrellis-local")
    conf.set('spark.kryoserializer.buffer.max', value='1G')
    conf.set('spark.ui.enabled', True)
    # Some options to allow attaching a Java debugger to running Spark driver
    conf.set('spark.driver.extraJavaOptions', '-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=5009')

    if 'TRAVIS' in os.environ:
        conf.set(key='spark.driver.memory', value='2G')
        conf.set(key='spark.executor.memory', value='2G')

    if 'PYSPARK_PYTHON' not in os.environ:
        os.environ['PYSPARK_PYTHON'] = sys.executable

    _log.info('Creating Spark context with config:')
    for k, v in conf.getAll():
        _log.info("Spark config: {k!r}: {v!r}".format(k=k, v=v))
    pysc = SparkContext.getOrCreate(conf)
    _log.info('Created Spark Context {s}'.format(s=pysc))
    _log.info('Spark web UI: http://localhost:{p}/'.format(p=pysc.getConf().get('spark.ui.port') or 4040))

    return pysc
Ejemplo n.º 2
0
 def test_find_spark_home(self):
     # SPARK-38827: Test find_spark_home without `SPARK_HOME` environment variable set.
     origin = os.environ["SPARK_HOME"]
     try:
         del os.environ["SPARK_HOME"]
         self.assertEquals(origin, _find_spark_home())
     finally:
         os.environ["SPARK_HOME"] = origin
Ejemplo n.º 3
0
def require_test_compiled() -> None:
    """Raise Exception if test classes are not compiled"""
    import os
    import glob

    test_class_path = os.path.join(_find_spark_home(), "sql", "core", "target", "*", "test-classes")
    paths = glob.glob(test_class_path)

    if len(paths) == 0:
        raise RuntimeError(
            "%s doesn't exist. Spark sql test classes are not compiled." % test_class_path
        )
Ejemplo n.º 4
0
def _ensure_geopyspark(out: TerminalReporter):
    """Make sure GeoPySpark knows where to find Spark (SPARK_HOME) and py4j"""
    try:
        import geopyspark
        out.write_line("Succeeded to import geopyspark automatically: {p!r}".format(p=geopyspark))
    except KeyError as e:
        # Geopyspark failed to detect Spark home and py4j, let's fix that.
        from pyspark import find_spark_home
        pyspark_home = Path(find_spark_home._find_spark_home())
        out.write_line("Failed to import geopyspark automatically. "
                       "Will set up py4j path using Spark home: {h}".format(h=pyspark_home))
        py4j_zip = next((pyspark_home / 'python' / 'lib').glob('py4j-*-src.zip'))
        out.write_line("py4j zip: {z!r}".format(z=py4j_zip))
        sys.path.append(str(py4j_zip))
Ejemplo n.º 5
0
    def setUpClass(cls):
        import glob
        from pyspark.find_spark_home import _find_spark_home

        SPARK_HOME = _find_spark_home()
        filename_pattern = (
            "sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
            "SparkSessionExtensionSuite.class")
        if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
            raise unittest.SkipTest(
                "'org.apache.spark.sql.SparkSessionExtensionSuite' is not "
                "available. Will skip the related tests.")

        # Note that 'spark.sql.extensions' is a static immutable configuration.
        cls.spark = (SparkSession.builder.master("local[4]").appName(
            cls.__name__).config(
                "spark.sql.extensions",
                "org.apache.spark.sql.MyExtensions").getOrCreate())
Ejemplo n.º 6
0
    def setUpClass(cls):
        import glob
        from pyspark.find_spark_home import _find_spark_home

        SPARK_HOME = _find_spark_home()
        filename_pattern = (
            "sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
            "TestQueryExecutionListener.class")
        cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern)))

        if cls.has_listener:
            # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration.
            cls.spark = SparkSession.builder \
                .master("local[4]") \
                .appName(cls.__name__) \
                .config(
                    "spark.sql.queryExecutionListeners",
                    "org.apache.spark.sql.TestQueryExecutionListener") \
                .getOrCreate()
Ejemplo n.º 7
0
    def setUpClass(cls):
        import glob
        from pyspark.find_spark_home import _find_spark_home

        SPARK_HOME = _find_spark_home()
        filename_pattern = (
            "sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
            "TestQueryExecutionListener.class")
        cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern)))

        if cls.has_listener:
            # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration.
            cls.spark = SparkSession.builder \
                .master("local[4]") \
                .appName(cls.__name__) \
                .config(
                    "spark.sql.queryExecutionListeners",
                    "org.apache.spark.sql.TestQueryExecutionListener") \
                .getOrCreate()
Ejemplo n.º 8
0
    def setUpClass(cls):
        import glob
        from pyspark.find_spark_home import _find_spark_home

        SPARK_HOME = _find_spark_home()
        filename_pattern = (
            "sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
            "SparkSessionExtensionSuite.class")
        if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
            raise unittest.SkipTest(
                "'org.apache.spark.sql.SparkSessionExtensionSuite' is not "
                "available. Will skip the related tests.")

        # Note that 'spark.sql.extensions' is a static immutable configuration.
        cls.spark = SparkSession.builder \
            .master("local[4]") \
            .appName(cls.__name__) \
            .config(
                "spark.sql.extensions",
                "org.apache.spark.sql.MyExtensions") \
            .getOrCreate()
Ejemplo n.º 9
0
def launch_gateway(conf=None):
    """
    launch jvm gateway
    :param conf: spark configuration passed to spark-submit
    :return:
    """
    if "PYSPARK_GATEWAY_PORT" in os.environ:
        gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
        gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
    else:
        SPARK_HOME = _find_spark_home()
        # Launch the Py4j gateway using Spark's run command so that we pick up the
        # proper classpath and settings from spark-env.sh
        on_windows = platform.system() == "Windows"
        script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
        command = [os.path.join(SPARK_HOME, script)]
        if conf:
            for k, v in conf.getAll():
                command += ['--conf', '%s=%s' % (k, v)]
        submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
        if os.environ.get("SPARK_TESTING"):
            submit_args = ' '.join([
                "--conf spark.ui.enabled=false",
                submit_args
            ])
        command = command + shlex.split(submit_args)

        # Create a temporary directory where the gateway server should write the connection
        # information.
        conn_info_dir = tempfile.mkdtemp()
        try:
            fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir)
            os.close(fd)
            os.unlink(conn_info_file)

            env = dict(os.environ)
            env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file

            # Launch the Java gateway.
            # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
            if not on_windows:
                # Don't send ctrl-c / SIGINT to the Java gateway:
                def preexec_func():
                    signal.signal(signal.SIGINT, signal.SIG_IGN)
                proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
            else:
                # preexec_fn not supported on Windows
                proc = Popen(command, stdin=PIPE, env=env)

            # Wait for the file to appear, or for the process to exit, whichever happens first.
            while not proc.poll() and not os.path.isfile(conn_info_file):
                time.sleep(0.1)

            if not os.path.isfile(conn_info_file):
                raise Exception("Java gateway process exited before sending its port number")

            with open(conn_info_file, "rb") as info:
                gateway_port = read_int(info)
                gateway_secret = UTF8Deserializer().loads(info)
        finally:
            shutil.rmtree(conn_info_dir)

        # In Windows, ensure the Java child processes do not linger after Python has exited.
        # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
        # the parent process' stdin sends an EOF). In Windows, however, this is not possible
        # because java.lang.Process reads directly from the parent process' stdin, contending
        # with any opportunity to read an EOF from the parent. Note that this is only best
        # effort and will not take effect if the python process is violently terminated.
        if on_windows:
            # In Windows, the child process here is "spark-submit.cmd", not the JVM itself
            # (because the UNIX "exec" command is not available). This means we cannot simply
            # call proc.kill(), which kills only the "spark-submit.cmd" process but not the
            # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all
            # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx)
            def killChild():
                Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)])
            atexit.register(killChild)

    # Connect to the gateway
    gateway = JavaGateway(
        gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
                                             auto_convert=True))

    # Import the classes used by PySpark
    java_import(gateway.jvm, "org.apache.spark.SparkConf")
    java_import(gateway.jvm, "org.apache.spark.api.java.*")
    java_import(gateway.jvm, "org.apache.spark.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.ml.python.*")
    java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
    # TODO(davies): move into sql
    java_import(gateway.jvm, "org.apache.spark.sql.*")
    java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
    java_import(gateway.jvm, "scala.Tuple2")

    return gateway
Ejemplo n.º 10
0
def launch_gateway(conf=None):
    """
    launch jvm gateway
    :param conf: spark configuration passed to spark-submit
    :return:
    """
    if "PYSPARK_GATEWAY_PORT" in os.environ:
        gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
        gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
    else:
        SPARK_HOME = _find_spark_home()
        # Launch the Py4j gateway using Spark's run command so that we pick up the
        # proper classpath and settings from spark-env.sh
        on_windows = platform.system() == "Windows"
        script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
        command = [os.path.join(SPARK_HOME, script)]
        if conf:
            for k, v in conf.getAll():
                command += ['--conf', '%s=%s' % (k, v)]
        submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
        if os.environ.get("SPARK_TESTING"):
            submit_args = ' '.join([
                "--conf spark.ui.enabled=false",
                submit_args
            ])
        command = command + shlex.split(submit_args)

        # Create a temporary directory where the gateway server should write the connection
        # information.
        conn_info_dir = tempfile.mkdtemp()
        try:
            fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir)
            os.close(fd)
            os.unlink(conn_info_file)

            env = dict(os.environ)
            env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file

            # Launch the Java gateway.
            # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
            if not on_windows:
                # Don't send ctrl-c / SIGINT to the Java gateway:
                def preexec_func():
                    signal.signal(signal.SIGINT, signal.SIG_IGN)
                proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
            else:
                # preexec_fn not supported on Windows
                proc = Popen(command, stdin=PIPE, env=env)

            # Wait for the file to appear, or for the process to exit, whichever happens first.
            while not proc.poll() and not os.path.isfile(conn_info_file):
                time.sleep(0.1)

            if not os.path.isfile(conn_info_file):
                raise Exception("Java gateway process exited before sending its port number")

            with open(conn_info_file, "rb") as info:
                gateway_port = read_int(info)
                gateway_secret = UTF8Deserializer().loads(info)
        finally:
            shutil.rmtree(conn_info_dir)

        # In Windows, ensure the Java child processes do not linger after Python has exited.
        # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
        # the parent process' stdin sends an EOF). In Windows, however, this is not possible
        # because java.lang.Process reads directly from the parent process' stdin, contending
        # with any opportunity to read an EOF from the parent. Note that this is only best
        # effort and will not take effect if the python process is violently terminated.
        if on_windows:
            # In Windows, the child process here is "spark-submit.cmd", not the JVM itself
            # (because the UNIX "exec" command is not available). This means we cannot simply
            # call proc.kill(), which kills only the "spark-submit.cmd" process but not the
            # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all
            # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx)
            def killChild():
                Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)])
            atexit.register(killChild)

    # Connect to the gateway
    gateway = JavaGateway(
        gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
                                             auto_convert=True))

    # Import the classes used by PySpark
    java_import(gateway.jvm, "org.apache.spark.SparkConf")
    java_import(gateway.jvm, "org.apache.spark.api.java.*")
    java_import(gateway.jvm, "org.apache.spark.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.ml.python.*")
    java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
    # TODO(davies): move into sql
    java_import(gateway.jvm, "org.apache.spark.sql.*")
    java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
    java_import(gateway.jvm, "scala.Tuple2")

    return gateway
Ejemplo n.º 11
0
def launch_gateway(conf=None):
    """
    launch jvm gateway
    :param conf: spark configuration passed to spark-submit
    :return:
    """
    if "PYSPARK_GATEWAY_PORT" in os.environ:
        gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
    else:
        SPARK_HOME = _find_spark_home()
        # Launch the Py4j gateway using Spark's run command so that we pick up the
        # proper classpath and settings from spark-env.sh
        on_windows = platform.system() == "Windows"
        script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
        command = [os.path.join(SPARK_HOME, script)]
        if conf:
            for k, v in conf.getAll():
                command += ['--conf', '%s=%s' % (k, v)]
        submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
        if os.environ.get("SPARK_TESTING"):
            submit_args = ' '.join([
                "--conf spark.ui.enabled=false",
                submit_args
            ])
        command = command + shlex.split(submit_args)

        # Start a socket that will be used by PythonGatewayServer to communicate its port to us
        callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        callback_socket.bind(('127.0.0.1', 0))
        callback_socket.listen(1)
        callback_host, callback_port = callback_socket.getsockname()
        env = dict(os.environ)
        env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host
        env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port)

        # Launch the Java gateway.
        # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
        if not on_windows:
            # Don't send ctrl-c / SIGINT to the Java gateway:
            def preexec_func():
                signal.signal(signal.SIGINT, signal.SIG_IGN)
            proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
        else:
            # preexec_fn not supported on Windows
            proc = Popen(command, stdin=PIPE, env=env)

        gateway_port = None
        # We use select() here in order to avoid blocking indefinitely if the subprocess dies
        # before connecting
        while gateway_port is None and proc.poll() is None:
            timeout = 1  # (seconds)
            readable, _, _ = select.select([callback_socket], [], [], timeout)
            if callback_socket in readable:
                gateway_connection = callback_socket.accept()[0]
                # Determine which ephemeral port the server started on:
                gateway_port = read_int(gateway_connection.makefile(mode="rb"))
                gateway_connection.close()
                callback_socket.close()
        if gateway_port is None:
            raise Exception("Java gateway process exited before sending the driver its port number")

        # In Windows, ensure the Java child processes do not linger after Python has exited.
        # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
        # the parent process' stdin sends an EOF). In Windows, however, this is not possible
        # because java.lang.Process reads directly from the parent process' stdin, contending
        # with any opportunity to read an EOF from the parent. Note that this is only best
        # effort and will not take effect if the python process is violently terminated.
        if on_windows:
            # In Windows, the child process here is "spark-submit.cmd", not the JVM itself
            # (because the UNIX "exec" command is not available). This means we cannot simply
            # call proc.kill(), which kills only the "spark-submit.cmd" process but not the
            # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all
            # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx)
            def killChild():
                Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)])
            atexit.register(killChild)

    # Connect to the gateway
    gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)

    # Import the classes used by PySpark
    java_import(gateway.jvm, "org.apache.spark.SparkConf")
    java_import(gateway.jvm, "org.apache.spark.api.java.*")
    java_import(gateway.jvm, "org.apache.spark.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.ml.python.*")
    java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
    # TODO(davies): move into sql
    java_import(gateway.jvm, "org.apache.spark.sql.*")
    java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
    java_import(gateway.jvm, "scala.Tuple2")

    return gateway
#!/usr/bin/env python

import argparse
import multiprocessing
import os
import subprocess

try:
    from pyspark.find_spark_home import _find_spark_home
    default_spark_home = _find_spark_home()
except:
    default_spark_home = os.environ.get("SPARK_HOME")

p = argparse.ArgumentParser()
p.add_argument(
    "--spark-home",
    default=default_spark_home,
    help=
    "The local spark directory (default: $SPARK_HOME). Required for --run-locally"
)
p.add_argument(
    "--cpu-limit",
    help=
    "How many CPUs to use when running locally. Defaults to all available CPUs.",
    type=int)
p.add_argument("--driver-memory",
               help="Spark driver memory limit when running locally")
p.add_argument("--executor-memory",
               help="Spark executor memory limit when running locally")
p.add_argument("--num-executors",
               help="Spark number of executors",
Ejemplo n.º 13
0
try:
    import scipy.sparse  # noqa: F401

    have_scipy = True
except ImportError:
    # No SciPy, but that's okay, we'll skip those tests
    pass
try:
    import numpy as np  # noqa: F401

    have_numpy = True
except ImportError:
    # No NumPy, but that's okay, we'll skip those tests
    pass

SPARK_HOME = _find_spark_home()


def read_int(b):
    return struct.unpack("!i", b)[0]


def write_int(i):
    return struct.pack("!i", i)


def eventually(condition, timeout=30.0, catch_assertions=False):
    """
    Wait a given amount of time for a condition to pass, else fail with an error.
    This is a helper utility for PySpark tests.
Ejemplo n.º 14
0
def launch_gateway(conf=None, popen_kwargs=None):
    """
    launch jvm gateway

    Parameters
    ----------
    conf : :py:class:`pyspark.SparkConf`
        spark configuration passed to spark-submit
    popen_kwargs : dict
        Dictionary of kwargs to pass to Popen when spawning
        the py4j JVM. This is a developer feature intended for use in
        customizing how pyspark interacts with the py4j JVM (e.g., capturing
        stdout/stderr).

    Returns
    -------
    ClientServer or JavaGateway
    """
    if "PYSPARK_GATEWAY_PORT" in os.environ:
        gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
        gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
        # Process already exists
        proc = None
    else:
        SPARK_HOME = _find_spark_home()
        # Launch the Py4j gateway using Spark's run command so that we pick up the
        # proper classpath and settings from spark-env.sh
        on_windows = platform.system() == "Windows"
        script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
        command = [os.path.join(SPARK_HOME, script)]
        if conf:
            for k, v in conf.getAll():
                command += ["--conf", "%s=%s" % (k, v)]
        submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
        if os.environ.get("SPARK_TESTING"):
            submit_args = " ".join(
                ["--conf spark.ui.enabled=false", submit_args])
        command = command + shlex.split(submit_args)

        # Create a temporary directory where the gateway server should write the connection
        # information.
        conn_info_dir = tempfile.mkdtemp()
        try:
            fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir)
            os.close(fd)
            os.unlink(conn_info_file)

            env = dict(os.environ)
            env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file

            # Launch the Java gateway.
            popen_kwargs = {} if popen_kwargs is None else popen_kwargs
            # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
            popen_kwargs["stdin"] = PIPE
            # We always set the necessary environment variables.
            popen_kwargs["env"] = env
            if not on_windows:
                # Don't send ctrl-c / SIGINT to the Java gateway:
                def preexec_func():
                    signal.signal(signal.SIGINT, signal.SIG_IGN)

                popen_kwargs["preexec_fn"] = preexec_func
                proc = Popen(command, **popen_kwargs)
            else:
                # preexec_fn not supported on Windows
                proc = Popen(command, **popen_kwargs)

            # Wait for the file to appear, or for the process to exit, whichever happens first.
            while not proc.poll() and not os.path.isfile(conn_info_file):
                time.sleep(0.1)

            if not os.path.isfile(conn_info_file):
                raise RuntimeError(
                    "Java gateway process exited before sending its port number"
                )

            with open(conn_info_file, "rb") as info:
                gateway_port = read_int(info)
                gateway_secret = UTF8Deserializer().loads(info)
        finally:
            shutil.rmtree(conn_info_dir)

        # In Windows, ensure the Java child processes do not linger after Python has exited.
        # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
        # the parent process' stdin sends an EOF). In Windows, however, this is not possible
        # because java.lang.Process reads directly from the parent process' stdin, contending
        # with any opportunity to read an EOF from the parent. Note that this is only best
        # effort and will not take effect if the python process is violently terminated.
        if on_windows:
            # In Windows, the child process here is "spark-submit.cmd", not the JVM itself
            # (because the UNIX "exec" command is not available). This means we cannot simply
            # call proc.kill(), which kills only the "spark-submit.cmd" process but not the
            # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all
            # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx)
            def killChild():
                Popen([
                    "cmd", "/c", "taskkill", "/f", "/t", "/pid",
                    str(proc.pid)
                ])

            atexit.register(killChild)

    # Connect to the gateway (or client server to pin the thread between JVM and Python)
    if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
        gateway = ClientServer(
            java_parameters=JavaParameters(port=gateway_port,
                                           auth_token=gateway_secret,
                                           auto_convert=True),
            python_parameters=PythonParameters(port=0, eager_load=False),
        )
    else:
        gateway = JavaGateway(gateway_parameters=GatewayParameters(
            port=gateway_port, auth_token=gateway_secret, auto_convert=True))

    # Store a reference to the Popen object for use by the caller (e.g., in reading stdout/stderr)
    gateway.proc = proc

    # Import the classes used by PySpark
    java_import(gateway.jvm, "org.apache.spark.SparkConf")
    java_import(gateway.jvm, "org.apache.spark.api.java.*")
    java_import(gateway.jvm, "org.apache.spark.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.ml.python.*")
    java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.resource.*")
    # TODO(davies): move into sql
    java_import(gateway.jvm, "org.apache.spark.sql.*")
    java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
    java_import(gateway.jvm, "scala.Tuple2")

    return gateway
Ejemplo n.º 15
0
def main():
    if is_dataproc_VM():
        logging.info(
            "This is a Dataproc VM which should already have the GCS cloud connector installed. Exiting..."
        )
        return

    args = parse_args()

    spark_home = _find_spark_home()

    # download GCS connector jar
    local_jar_path = os.path.join(spark_home, "jars",
                                  os.path.basename(GCS_CONNECTOR_URL))
    try:
        logging.info(f"Downloading {GCS_CONNECTOR_URL}")
        logging.info(f"   to {local_jar_path}")
        urllib.request.urlretrieve(GCS_CONNECTOR_URL, local_jar_path)
    except Exception as e:
        logging.error(
            f"Unable to download GCS connector to {local_jar_path}. {e}")
        return

    # update spark-defaults.conf
    spark_config_dir = os.path.join(spark_home, "conf")
    if not os.path.exists(spark_config_dir):
        os.mkdir(spark_config_dir)
    spark_config_file_path = os.path.join(spark_config_dir,
                                          "spark-defaults.conf")
    logging.info(f"Updating {spark_config_file_path} json.keyfile")
    logging.info(f"Setting json.keyfile = {args.key_file_path}")

    spark_config_lines = [
        "spark.hadoop.google.cloud.auth.service.account.enable true\n",
        f"spark.hadoop.google.cloud.auth.service.account.json.keyfile {args.key_file_path}\n",
    ]

    if args.gcs_requester_pays_project:
        spark_config_lines.extend([
            "spark.hadoop.fs.gs.requester.pays.mode AUTO\n",
            f"spark.hadoop.fs.gs.requester.pays.project.id {args.gcs_requester_pays_project}\n",
        ])

    try:
        # spark hadoop options docs @ https://github.com/GoogleCloudDataproc/hadoop-connectors/blob/master/gcs/CONFIGURATION.md#cloud-storage-requester-pays-feature-configuration
        if os.path.isfile(spark_config_file_path):
            with open(spark_config_file_path, "rt") as f:
                for line in f:
                    # avoid duplicating options
                    if any([
                            option.split(' ')[0] in line
                            for option in spark_config_lines
                    ]):
                        continue

                    spark_config_lines.append(line)

        with open(spark_config_file_path, "wt") as f:
            for line in spark_config_lines:
                f.write(line)

    except Exception as e:
        logging.error(
            f"Unable to update spark config {spark_config_file_path}. {e}")
        return
Ejemplo n.º 16
0
def launch_gateway(conf=None):
    """
    launch jvm gateway
    :param conf: spark configuration passed to spark-submit
    :return:
    """
    if "PYSPARK_GATEWAY_PORT" in os.environ:
        gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
    else:
        SPARK_HOME = _find_spark_home()
        # Launch the Py4j gateway using Spark's run command so that we pick up the
        # proper classpath and settings from spark-env.sh
        on_windows = platform.system() == "Windows"
        script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
        command = [os.path.join(SPARK_HOME, script)]
        if conf:
            for k, v in conf.getAll():
                command += ['--conf', '%s=%s' % (k, v)]
        submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
        if os.environ.get("SPARK_TESTING"):
            submit_args = ' '.join(
                ["--conf spark.ui.enabled=false", submit_args])
        command = command + shlex.split(submit_args)

        # Start a socket that will be used by PythonGatewayServer to communicate its port to us
        callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        callback_socket.bind(('127.0.0.1', 0))
        callback_socket.listen(1)
        callback_host, callback_port = callback_socket.getsockname()
        env = dict(os.environ)
        env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host
        env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port)

        # Launch the Java gateway.
        # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
        if not on_windows:
            # Don't send ctrl-c / SIGINT to the Java gateway:
            def preexec_func():
                signal.signal(signal.SIGINT, signal.SIG_IGN)

            proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
        else:
            # preexec_fn not supported on Windows
            proc = Popen(command, stdin=PIPE, env=env)

        gateway_port = None
        # We use select() here in order to avoid blocking indefinitely if the subprocess dies
        # before connecting
        while gateway_port is None and proc.poll() is None:
            timeout = 1  # (seconds)
            readable, _, _ = select.select([callback_socket], [], [], timeout)
            if callback_socket in readable:
                gateway_connection = callback_socket.accept()[0]
                # Determine which ephemeral port the server started on:
                gateway_port = read_int(gateway_connection.makefile(mode="rb"))
                gateway_connection.close()
                callback_socket.close()
        if gateway_port is None:
            raise Exception(
                "Java gateway process exited before sending the driver its port number"
            )

        # In Windows, ensure the Java child processes do not linger after Python has exited.
        # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
        # the parent process' stdin sends an EOF). In Windows, however, this is not possible
        # because java.lang.Process reads directly from the parent process' stdin, contending
        # with any opportunity to read an EOF from the parent. Note that this is only best
        # effort and will not take effect if the python process is violently terminated.
        if on_windows:
            # In Windows, the child process here is "spark-submit.cmd", not the JVM itself
            # (because the UNIX "exec" command is not available). This means we cannot simply
            # call proc.kill(), which kills only the "spark-submit.cmd" process but not the
            # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all
            # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx)
            def killChild():
                Popen([
                    "cmd", "/c", "taskkill", "/f", "/t", "/pid",
                    str(proc.pid)
                ])

            atexit.register(killChild)

    # Connect to the gateway
    gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)

    # Import the classes used by PySpark
    java_import(gateway.jvm, "org.apache.spark.SparkConf")
    java_import(gateway.jvm, "org.apache.spark.api.java.*")
    java_import(gateway.jvm, "org.apache.spark.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.ml.python.*")
    java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
    # TODO(davies): move into sql
    java_import(gateway.jvm, "org.apache.spark.sql.*")
    java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
    java_import(gateway.jvm, "scala.Tuple2")

    return gateway
Ejemplo n.º 17
0
    def run(self):
        install.run(self)

        if is_dataproc_VM():
            self.announce(
                "Running on Dataproc VM. Skipping GCS cloud connector installation.",
                level=3)
            return  # cloud connector is installed automatically on dataproc VMs

        spark_home = _find_spark_home()

        # download GCS connector jar
        local_jar_path = os.path.join(spark_home, "jars",
                                      os.path.basename(GCS_CONNECTOR_URL))
        try:
            self.announce("Downloading %s to %s" %
                          (GCS_CONNECTOR_URL, local_jar_path),
                          level=3)
            urllib.request.urlretrieve(GCS_CONNECTOR_URL, local_jar_path)
        except Exception as e:
            self.warn("Unable to download GCS connector to %s. %s" %
                      (local_jar_path, e))
            return

        # look for existing key files in the ~/.config. If there's more than one, select the newest.
        try:
            key_file_regexp = "~/.config/gcloud/legacy_credentials/*/adc.json"
            key_file_sort = lambda file_path: -1 * os.path.getctime(file_path)
            key_file_path = next(
                iter(
                    sorted(glob.glob(os.path.expanduser(key_file_regexp)),
                           key=key_file_sort)))
            self.announce("Using key file: %s" % key_file_path, level=3)
        except Exception as e:
            self.warn("No keys found in %s. %s" % (key_file_regexp, e))
            key_file_path = None

        # if existing keys not found, download generic key that allows access to public (bucket-owner-pays) buckets.
        if key_file_path is None:
            local_key_dir = os.path.expanduser("~/.hail/gcs-keys")
            try:
                if not os.path.exists(local_key_dir):
                    os.makedirs(local_key_dir)
            except Exception as e:
                self.warn("Unable to create directory %s. %s" %
                          (local_key_dir, e))
                return

            key_file_path = os.path.join(local_key_dir,
                                         "gcs-connector-key.json")
            try:
                self.announce("Downloading %s to %s" %
                              (GENERIC_KEY_FILE_URL, key_file_path),
                              level=3)
                urllib.request.urlretrieve(GENERIC_KEY_FILE_URL, key_file_path)
            except Exception as e:
                self.warn("Unable to download shared key from %s to %s. %s" %
                          (GENERIC_KEY_FILE_URL, key_file_path, e))
                return

        # update spark-defaults.conf
        spark_config_dir = os.path.join(spark_home, "conf")
        if not os.path.exists(spark_config_dir):
            os.mkdir(spark_config_dir)
        spark_config_file_path = os.path.join(spark_config_dir,
                                              "spark-defaults.conf")
        self.announce("Setting json.keyfile to %s in %s" %
                      (key_file_path, spark_config_file_path),
                      level=3)

        spark_config_lines = [
            "spark.hadoop.google.cloud.auth.service.account.enable true\n",
            "spark.hadoop.google.cloud.auth.service.account.json.keyfile %s\n"
            % key_file_path,
        ]
        try:
            if os.path.isfile(spark_config_file_path):
                with open(spark_config_file_path, "rt") as f:
                    for line in f:
                        if "spark.hadoop.google.cloud.auth.service.account.enable" in line:
                            continue
                        if "spark.hadoop.google.cloud.auth.service.account.json.keyfile" in line:
                            continue

                        spark_config_lines.append(line)

            with open(spark_config_file_path, "wt") as f:
                for line in spark_config_lines:
                    f.write(line)

        except Exception as e:
            self.warn("Unable to update spark config %s. %s" %
                      (spark_config_file_path, e))
            return