def setup_local_spark(): from pyspark import find_spark_home, SparkContext spark_python = os.path.join(find_spark_home._find_spark_home(), 'python') py4j = glob(os.path.join(spark_python, 'lib', 'py4j-*.zip'))[0] sys.path[:0] = [spark_python, py4j] _log.debug('sys.path: {p!r}'.format(p=sys.path)) if 'TRAVIS' in os.environ: master_str = "local[2]" else: master_str = "local[*]" from geopyspark import geopyspark_conf conf = geopyspark_conf(master=master_str, appName="openeo-geotrellis-local") conf.set('spark.kryoserializer.buffer.max', value='1G') conf.set('spark.ui.enabled', True) # Some options to allow attaching a Java debugger to running Spark driver conf.set('spark.driver.extraJavaOptions', '-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=5009') if 'TRAVIS' in os.environ: conf.set(key='spark.driver.memory', value='2G') conf.set(key='spark.executor.memory', value='2G') if 'PYSPARK_PYTHON' not in os.environ: os.environ['PYSPARK_PYTHON'] = sys.executable _log.info('Creating Spark context with config:') for k, v in conf.getAll(): _log.info("Spark config: {k!r}: {v!r}".format(k=k, v=v)) pysc = SparkContext.getOrCreate(conf) _log.info('Created Spark Context {s}'.format(s=pysc)) _log.info('Spark web UI: http://localhost:{p}/'.format(p=pysc.getConf().get('spark.ui.port') or 4040)) return pysc
def test_find_spark_home(self): # SPARK-38827: Test find_spark_home without `SPARK_HOME` environment variable set. origin = os.environ["SPARK_HOME"] try: del os.environ["SPARK_HOME"] self.assertEquals(origin, _find_spark_home()) finally: os.environ["SPARK_HOME"] = origin
def require_test_compiled() -> None: """Raise Exception if test classes are not compiled""" import os import glob test_class_path = os.path.join(_find_spark_home(), "sql", "core", "target", "*", "test-classes") paths = glob.glob(test_class_path) if len(paths) == 0: raise RuntimeError( "%s doesn't exist. Spark sql test classes are not compiled." % test_class_path )
def _ensure_geopyspark(out: TerminalReporter): """Make sure GeoPySpark knows where to find Spark (SPARK_HOME) and py4j""" try: import geopyspark out.write_line("Succeeded to import geopyspark automatically: {p!r}".format(p=geopyspark)) except KeyError as e: # Geopyspark failed to detect Spark home and py4j, let's fix that. from pyspark import find_spark_home pyspark_home = Path(find_spark_home._find_spark_home()) out.write_line("Failed to import geopyspark automatically. " "Will set up py4j path using Spark home: {h}".format(h=pyspark_home)) py4j_zip = next((pyspark_home / 'python' / 'lib').glob('py4j-*-src.zip')) out.write_line("py4j zip: {z!r}".format(z=py4j_zip)) sys.path.append(str(py4j_zip))
def setUpClass(cls): import glob from pyspark.find_spark_home import _find_spark_home SPARK_HOME = _find_spark_home() filename_pattern = ( "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" "SparkSessionExtensionSuite.class") if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): raise unittest.SkipTest( "'org.apache.spark.sql.SparkSessionExtensionSuite' is not " "available. Will skip the related tests.") # Note that 'spark.sql.extensions' is a static immutable configuration. cls.spark = (SparkSession.builder.master("local[4]").appName( cls.__name__).config( "spark.sql.extensions", "org.apache.spark.sql.MyExtensions").getOrCreate())
def setUpClass(cls): import glob from pyspark.find_spark_home import _find_spark_home SPARK_HOME = _find_spark_home() filename_pattern = ( "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" "TestQueryExecutionListener.class") cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern))) if cls.has_listener: # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. cls.spark = SparkSession.builder \ .master("local[4]") \ .appName(cls.__name__) \ .config( "spark.sql.queryExecutionListeners", "org.apache.spark.sql.TestQueryExecutionListener") \ .getOrCreate()
def setUpClass(cls): import glob from pyspark.find_spark_home import _find_spark_home SPARK_HOME = _find_spark_home() filename_pattern = ( "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" "SparkSessionExtensionSuite.class") if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): raise unittest.SkipTest( "'org.apache.spark.sql.SparkSessionExtensionSuite' is not " "available. Will skip the related tests.") # Note that 'spark.sql.extensions' is a static immutable configuration. cls.spark = SparkSession.builder \ .master("local[4]") \ .appName(cls.__name__) \ .config( "spark.sql.extensions", "org.apache.spark.sql.MyExtensions") \ .getOrCreate()
def launch_gateway(conf=None): """ launch jvm gateway :param conf: spark configuration passed to spark-submit :return: """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] else: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" command = [os.path.join(SPARK_HOME, script)] if conf: for k, v in conf.getAll(): command += ['--conf', '%s=%s' % (k, v)] submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") if os.environ.get("SPARK_TESTING"): submit_args = ' '.join([ "--conf spark.ui.enabled=false", submit_args ]) command = command + shlex.split(submit_args) # Create a temporary directory where the gateway server should write the connection # information. conn_info_dir = tempfile.mkdtemp() try: fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir) os.close(fd) os.unlink(conn_info_file) env = dict(os.environ) env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file # Launch the Java gateway. # We open a pipe to stdin so that the Java gateway can die when the pipe is broken if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) else: # preexec_fn not supported on Windows proc = Popen(command, stdin=PIPE, env=env) # Wait for the file to appear, or for the process to exit, whichever happens first. while not proc.poll() and not os.path.isfile(conn_info_file): time.sleep(0.1) if not os.path.isfile(conn_info_file): raise Exception("Java gateway process exited before sending its port number") with open(conn_info_file, "rb") as info: gateway_port = read_int(info) gateway_secret = UTF8Deserializer().loads(info) finally: shutil.rmtree(conn_info_dir) # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when # the parent process' stdin sends an EOF). In Windows, however, this is not possible # because java.lang.Process reads directly from the parent process' stdin, contending # with any opportunity to read an EOF from the parent. Note that this is only best # effort and will not take effect if the python process is violently terminated. if on_windows: # In Windows, the child process here is "spark-submit.cmd", not the JVM itself # (because the UNIX "exec" command is not available). This means we cannot simply # call proc.kill(), which kills only the "spark-submit.cmd" process but not the # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx) def killChild(): Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)]) atexit.register(killChild) # Connect to the gateway gateway = JavaGateway( gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret, auto_convert=True)) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.ml.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") # TODO(davies): move into sql java_import(gateway.jvm, "org.apache.spark.sql.*") java_import(gateway.jvm, "org.apache.spark.sql.api.python.*") java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") return gateway
def launch_gateway(conf=None): """ launch jvm gateway :param conf: spark configuration passed to spark-submit :return: """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" command = [os.path.join(SPARK_HOME, script)] if conf: for k, v in conf.getAll(): command += ['--conf', '%s=%s' % (k, v)] submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") if os.environ.get("SPARK_TESTING"): submit_args = ' '.join([ "--conf spark.ui.enabled=false", submit_args ]) command = command + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) callback_socket.bind(('127.0.0.1', 0)) callback_socket.listen(1) callback_host, callback_port = callback_socket.getsockname() env = dict(os.environ) env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port) # Launch the Java gateway. # We open a pipe to stdin so that the Java gateway can die when the pipe is broken if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) else: # preexec_fn not supported on Windows proc = Popen(command, stdin=PIPE, env=env) gateway_port = None # We use select() here in order to avoid blocking indefinitely if the subprocess dies # before connecting while gateway_port is None and proc.poll() is None: timeout = 1 # (seconds) readable, _, _ = select.select([callback_socket], [], [], timeout) if callback_socket in readable: gateway_connection = callback_socket.accept()[0] # Determine which ephemeral port the server started on: gateway_port = read_int(gateway_connection.makefile(mode="rb")) gateway_connection.close() callback_socket.close() if gateway_port is None: raise Exception("Java gateway process exited before sending the driver its port number") # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when # the parent process' stdin sends an EOF). In Windows, however, this is not possible # because java.lang.Process reads directly from the parent process' stdin, contending # with any opportunity to read an EOF from the parent. Note that this is only best # effort and will not take effect if the python process is violently terminated. if on_windows: # In Windows, the child process here is "spark-submit.cmd", not the JVM itself # (because the UNIX "exec" command is not available). This means we cannot simply # call proc.kill(), which kills only the "spark-submit.cmd" process but not the # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx) def killChild(): Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)]) atexit.register(killChild) # Connect to the gateway gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.ml.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") # TODO(davies): move into sql java_import(gateway.jvm, "org.apache.spark.sql.*") java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") return gateway
#!/usr/bin/env python import argparse import multiprocessing import os import subprocess try: from pyspark.find_spark_home import _find_spark_home default_spark_home = _find_spark_home() except: default_spark_home = os.environ.get("SPARK_HOME") p = argparse.ArgumentParser() p.add_argument( "--spark-home", default=default_spark_home, help= "The local spark directory (default: $SPARK_HOME). Required for --run-locally" ) p.add_argument( "--cpu-limit", help= "How many CPUs to use when running locally. Defaults to all available CPUs.", type=int) p.add_argument("--driver-memory", help="Spark driver memory limit when running locally") p.add_argument("--executor-memory", help="Spark executor memory limit when running locally") p.add_argument("--num-executors", help="Spark number of executors",
try: import scipy.sparse # noqa: F401 have_scipy = True except ImportError: # No SciPy, but that's okay, we'll skip those tests pass try: import numpy as np # noqa: F401 have_numpy = True except ImportError: # No NumPy, but that's okay, we'll skip those tests pass SPARK_HOME = _find_spark_home() def read_int(b): return struct.unpack("!i", b)[0] def write_int(i): return struct.pack("!i", i) def eventually(condition, timeout=30.0, catch_assertions=False): """ Wait a given amount of time for a condition to pass, else fail with an error. This is a helper utility for PySpark tests.
def launch_gateway(conf=None, popen_kwargs=None): """ launch jvm gateway Parameters ---------- conf : :py:class:`pyspark.SparkConf` spark configuration passed to spark-submit popen_kwargs : dict Dictionary of kwargs to pass to Popen when spawning the py4j JVM. This is a developer feature intended for use in customizing how pyspark interacts with the py4j JVM (e.g., capturing stdout/stderr). Returns ------- ClientServer or JavaGateway """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] # Process already exists proc = None else: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" command = [os.path.join(SPARK_HOME, script)] if conf: for k, v in conf.getAll(): command += ["--conf", "%s=%s" % (k, v)] submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") if os.environ.get("SPARK_TESTING"): submit_args = " ".join( ["--conf spark.ui.enabled=false", submit_args]) command = command + shlex.split(submit_args) # Create a temporary directory where the gateway server should write the connection # information. conn_info_dir = tempfile.mkdtemp() try: fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir) os.close(fd) os.unlink(conn_info_file) env = dict(os.environ) env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file # Launch the Java gateway. popen_kwargs = {} if popen_kwargs is None else popen_kwargs # We open a pipe to stdin so that the Java gateway can die when the pipe is broken popen_kwargs["stdin"] = PIPE # We always set the necessary environment variables. popen_kwargs["env"] = env if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) popen_kwargs["preexec_fn"] = preexec_func proc = Popen(command, **popen_kwargs) else: # preexec_fn not supported on Windows proc = Popen(command, **popen_kwargs) # Wait for the file to appear, or for the process to exit, whichever happens first. while not proc.poll() and not os.path.isfile(conn_info_file): time.sleep(0.1) if not os.path.isfile(conn_info_file): raise RuntimeError( "Java gateway process exited before sending its port number" ) with open(conn_info_file, "rb") as info: gateway_port = read_int(info) gateway_secret = UTF8Deserializer().loads(info) finally: shutil.rmtree(conn_info_dir) # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when # the parent process' stdin sends an EOF). In Windows, however, this is not possible # because java.lang.Process reads directly from the parent process' stdin, contending # with any opportunity to read an EOF from the parent. Note that this is only best # effort and will not take effect if the python process is violently terminated. if on_windows: # In Windows, the child process here is "spark-submit.cmd", not the JVM itself # (because the UNIX "exec" command is not available). This means we cannot simply # call proc.kill(), which kills only the "spark-submit.cmd" process but not the # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx) def killChild(): Popen([ "cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid) ]) atexit.register(killChild) # Connect to the gateway (or client server to pin the thread between JVM and Python) if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true": gateway = ClientServer( java_parameters=JavaParameters(port=gateway_port, auth_token=gateway_secret, auto_convert=True), python_parameters=PythonParameters(port=0, eager_load=False), ) else: gateway = JavaGateway(gateway_parameters=GatewayParameters( port=gateway_port, auth_token=gateway_secret, auto_convert=True)) # Store a reference to the Popen object for use by the caller (e.g., in reading stdout/stderr) gateway.proc = proc # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.ml.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") java_import(gateway.jvm, "org.apache.spark.resource.*") # TODO(davies): move into sql java_import(gateway.jvm, "org.apache.spark.sql.*") java_import(gateway.jvm, "org.apache.spark.sql.api.python.*") java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") return gateway
def main(): if is_dataproc_VM(): logging.info( "This is a Dataproc VM which should already have the GCS cloud connector installed. Exiting..." ) return args = parse_args() spark_home = _find_spark_home() # download GCS connector jar local_jar_path = os.path.join(spark_home, "jars", os.path.basename(GCS_CONNECTOR_URL)) try: logging.info(f"Downloading {GCS_CONNECTOR_URL}") logging.info(f" to {local_jar_path}") urllib.request.urlretrieve(GCS_CONNECTOR_URL, local_jar_path) except Exception as e: logging.error( f"Unable to download GCS connector to {local_jar_path}. {e}") return # update spark-defaults.conf spark_config_dir = os.path.join(spark_home, "conf") if not os.path.exists(spark_config_dir): os.mkdir(spark_config_dir) spark_config_file_path = os.path.join(spark_config_dir, "spark-defaults.conf") logging.info(f"Updating {spark_config_file_path} json.keyfile") logging.info(f"Setting json.keyfile = {args.key_file_path}") spark_config_lines = [ "spark.hadoop.google.cloud.auth.service.account.enable true\n", f"spark.hadoop.google.cloud.auth.service.account.json.keyfile {args.key_file_path}\n", ] if args.gcs_requester_pays_project: spark_config_lines.extend([ "spark.hadoop.fs.gs.requester.pays.mode AUTO\n", f"spark.hadoop.fs.gs.requester.pays.project.id {args.gcs_requester_pays_project}\n", ]) try: # spark hadoop options docs @ https://github.com/GoogleCloudDataproc/hadoop-connectors/blob/master/gcs/CONFIGURATION.md#cloud-storage-requester-pays-feature-configuration if os.path.isfile(spark_config_file_path): with open(spark_config_file_path, "rt") as f: for line in f: # avoid duplicating options if any([ option.split(' ')[0] in line for option in spark_config_lines ]): continue spark_config_lines.append(line) with open(spark_config_file_path, "wt") as f: for line in spark_config_lines: f.write(line) except Exception as e: logging.error( f"Unable to update spark config {spark_config_file_path}. {e}") return
def launch_gateway(conf=None): """ launch jvm gateway :param conf: spark configuration passed to spark-submit :return: """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" command = [os.path.join(SPARK_HOME, script)] if conf: for k, v in conf.getAll(): command += ['--conf', '%s=%s' % (k, v)] submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") if os.environ.get("SPARK_TESTING"): submit_args = ' '.join( ["--conf spark.ui.enabled=false", submit_args]) command = command + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) callback_socket.bind(('127.0.0.1', 0)) callback_socket.listen(1) callback_host, callback_port = callback_socket.getsockname() env = dict(os.environ) env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port) # Launch the Java gateway. # We open a pipe to stdin so that the Java gateway can die when the pipe is broken if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) else: # preexec_fn not supported on Windows proc = Popen(command, stdin=PIPE, env=env) gateway_port = None # We use select() here in order to avoid blocking indefinitely if the subprocess dies # before connecting while gateway_port is None and proc.poll() is None: timeout = 1 # (seconds) readable, _, _ = select.select([callback_socket], [], [], timeout) if callback_socket in readable: gateway_connection = callback_socket.accept()[0] # Determine which ephemeral port the server started on: gateway_port = read_int(gateway_connection.makefile(mode="rb")) gateway_connection.close() callback_socket.close() if gateway_port is None: raise Exception( "Java gateway process exited before sending the driver its port number" ) # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when # the parent process' stdin sends an EOF). In Windows, however, this is not possible # because java.lang.Process reads directly from the parent process' stdin, contending # with any opportunity to read an EOF from the parent. Note that this is only best # effort and will not take effect if the python process is violently terminated. if on_windows: # In Windows, the child process here is "spark-submit.cmd", not the JVM itself # (because the UNIX "exec" command is not available). This means we cannot simply # call proc.kill(), which kills only the "spark-submit.cmd" process but not the # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx) def killChild(): Popen([ "cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid) ]) atexit.register(killChild) # Connect to the gateway gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.ml.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") # TODO(davies): move into sql java_import(gateway.jvm, "org.apache.spark.sql.*") java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") return gateway
def run(self): install.run(self) if is_dataproc_VM(): self.announce( "Running on Dataproc VM. Skipping GCS cloud connector installation.", level=3) return # cloud connector is installed automatically on dataproc VMs spark_home = _find_spark_home() # download GCS connector jar local_jar_path = os.path.join(spark_home, "jars", os.path.basename(GCS_CONNECTOR_URL)) try: self.announce("Downloading %s to %s" % (GCS_CONNECTOR_URL, local_jar_path), level=3) urllib.request.urlretrieve(GCS_CONNECTOR_URL, local_jar_path) except Exception as e: self.warn("Unable to download GCS connector to %s. %s" % (local_jar_path, e)) return # look for existing key files in the ~/.config. If there's more than one, select the newest. try: key_file_regexp = "~/.config/gcloud/legacy_credentials/*/adc.json" key_file_sort = lambda file_path: -1 * os.path.getctime(file_path) key_file_path = next( iter( sorted(glob.glob(os.path.expanduser(key_file_regexp)), key=key_file_sort))) self.announce("Using key file: %s" % key_file_path, level=3) except Exception as e: self.warn("No keys found in %s. %s" % (key_file_regexp, e)) key_file_path = None # if existing keys not found, download generic key that allows access to public (bucket-owner-pays) buckets. if key_file_path is None: local_key_dir = os.path.expanduser("~/.hail/gcs-keys") try: if not os.path.exists(local_key_dir): os.makedirs(local_key_dir) except Exception as e: self.warn("Unable to create directory %s. %s" % (local_key_dir, e)) return key_file_path = os.path.join(local_key_dir, "gcs-connector-key.json") try: self.announce("Downloading %s to %s" % (GENERIC_KEY_FILE_URL, key_file_path), level=3) urllib.request.urlretrieve(GENERIC_KEY_FILE_URL, key_file_path) except Exception as e: self.warn("Unable to download shared key from %s to %s. %s" % (GENERIC_KEY_FILE_URL, key_file_path, e)) return # update spark-defaults.conf spark_config_dir = os.path.join(spark_home, "conf") if not os.path.exists(spark_config_dir): os.mkdir(spark_config_dir) spark_config_file_path = os.path.join(spark_config_dir, "spark-defaults.conf") self.announce("Setting json.keyfile to %s in %s" % (key_file_path, spark_config_file_path), level=3) spark_config_lines = [ "spark.hadoop.google.cloud.auth.service.account.enable true\n", "spark.hadoop.google.cloud.auth.service.account.json.keyfile %s\n" % key_file_path, ] try: if os.path.isfile(spark_config_file_path): with open(spark_config_file_path, "rt") as f: for line in f: if "spark.hadoop.google.cloud.auth.service.account.enable" in line: continue if "spark.hadoop.google.cloud.auth.service.account.json.keyfile" in line: continue spark_config_lines.append(line) with open(spark_config_file_path, "wt") as f: for line in spark_config_lines: f.write(line) except Exception as e: self.warn("Unable to update spark config %s. %s" % (spark_config_file_path, e)) return