def connect(profile): """Initialize Databricks API client Args: profile (str): Databricks CLI profile string Returns: ApiClient: Databricks ApiClient object """ config = ProfileConfigProvider(profile).get_config() if config is None: print_error("Cannot initialize ApiClient") bye(1) verify = config.insecure is None if config.is_valid_with_token: api_client = ApiClient(host=config.host, token=config.token, verify=verify) api_client.default_headers[ "user-agent"] = "databrickslabs-jupyterlab-%s" % __version__ return api_client else: print_error( "No token found for profile '%s'.\nUsername/password in .databrickscfg is not supported by databrickslabs-jupyterlab" % profile) bye(1)
def get_db_config(profile): """Get Databricks configuration from ~/.databricks.cfg for given profile Args: profile (str): Databricks CLI profile string Returns: tuple: The tuple of host and personal access token from ~/.databrickscfg """ config = configparser.ConfigParser() configs = config.read(expanduser("~/.databrickscfg")) if not configs: print_error("Cannot read ~/.databrickscfg") bye(1) profiles = config.sections() if not profile in profiles: print(" The profile '%s' is not available in ~/.databrickscfg:" % profile) for p in profiles: print("- %s" % p) bye() else: host = config[profile]["host"] token = config[profile]["token"] return host, token
def get_cluster(profile, cluster_id=None, status=None): """Get the cluster configuration from remote Args: profile (str): Databricks CLI profile string cluster_id (str, optional): If cluster_id is given, the user will not be asked to select one. Defaults to None. status (Status, optional): A Status class providing set_status. Defaults to None. Returns: tuple: Cluster configs: cluster_id, public_ip, cluster_name and a flag whether it was started Returns: tuple: (cluster_id, public_ip, cluster_name, started) """ with open("%s/.ssh/id_%s.pub" % (expanduser("~"), profile)) as fd: try: ssh_pub = fd.read().strip() except: print_error( " Error: ssh key for profile 'id_%s.pub' does not exist in %s/.ssh" % (profile, expanduser("~"))) bye() try: apiclient = connect(profile) client = ClusterApi(apiclient) clusters = client.list_clusters() except Exception as ex: print_error(ex) return (None, None, None, None) clusters = clusters["clusters"] if cluster_id is not None: cluster = None for c in clusters: if c["cluster_id"] == cluster_id: cluster = c break if cluster is None: print_error( " Error: A cluster with id '%s' does not exist in the workspace of profile '%s'" % (cluster_id, profile)) return (None, None, None, None) if ssh_pub not in [ c.strip() for c in cluster.get("ssh_public_keys", []) ]: print_error( " Error: Cluster with id '%s' does not have ssh key '~/.ssh/id_%s' configured" % (cluster_id, profile)) return (None, None, None, None) else: my_clusters = [ cluster for cluster in clusters if ssh_pub in [c.strip() for c in cluster.get("ssh_public_keys", [])] ] if not my_clusters: print_error( " Error: There is no cluster in the workspace for profile '%s' configured with ssh key '~/.ssh/id_%s':" % (profile, profile)) print( " Use 'databrickslabs_jupyterlab %s -s' to configure ssh for clusters in this workspace\n" % profile) return (None, None, None, None) current_conda_env = os.environ.get("CONDA_DEFAULT_ENV", None) found = None for i, c in enumerate(my_clusters): if c["cluster_name"].replace(" ", "_") == current_conda_env: found = c["cluster_name"] break if found is not None: print_warning( "\n => The current conda environment is '%s'.\n You might want to select cluster %d with the name '%s'?\n" % (current_conda_env, i, found)) cluster = select_cluster(my_clusters) cluster_id = cluster["cluster_id"] cluster_name = cluster["cluster_name"] try: response = client.get_cluster(cluster_id) except Exception as ex: print_error(ex) return (None, None, None, None) state = response["state"] if not state in ["RUNNING", "RESIZING"]: if state == "TERMINATED": print(" => Starting cluster %s" % cluster_id) if status is not None: status.set_status(profile, cluster_id, "Cluster Starting") try: response = client.start_cluster(cluster_id) except Exception as ex: print_error(ex) return (None, None, None, None) print( " => Waiting for cluster %s being started (this can take up to 5 min)" % cluster_id) print(" ", end="", flush=True) while not state in ("RUNNING", "RESIZING"): if status is not None: status.set_status(profile, cluster_id, "Cluster Starting", False) print(".", end="", flush=True) time.sleep(5) try: response = client.get_cluster(cluster_id) except Exception as ex: print_error(ex) return (None, None, None, None) if response.get("error", None) is not None: print_error(response["error"]) return (None, None, None, None) else: state = response["state"] print_ok("\n => OK") if status is not None: status.set_status(profile, cluster_id, "Cluster started") print( "\n => Waiting for libraries on cluster %s being installed (this can take some time)" % cluster_id) print(" ", end="", flush=True) done = False while not done: try: states = get_library_state(profile, cluster_id) except DatabricksApiException as ex: print_error(ex) return (None, None, None, None) installing = any( [s in ["PENDING", "RESOLVING", "INSTALLING"] for s in states]) if installing: if status is not None: status.set_status(profile, cluster_id, "Installing cluster libraries", False) print(".", end="", flush=True) time.sleep(5) else: done = True print_ok("\n => OK\n") if status is not None: status.set_status(profile, cluster_id, "Cluster libraries installed", False) public_ip = response["driver"].get("public_dns", None) if public_ip is None: print_error(" Error: Cluster does not have public DNS name") return (None, None, None, None) print_ok(" => Selected cluster: %s (%s)" % (cluster_name, public_ip)) return (cluster_id, public_ip, cluster_name, None)
def configure_ssh(profile, host, token, cluster_id): """Configure SSH for the remote cluster Args: profile (str): Databricks CLI profile string host (str): host from databricks cli config for given profile string token (str): token from databricks cli config for given profile string cluster_id (str): cluster ID """ sshkey_file = os.path.expanduser("~/.ssh/id_%s" % profile) if not os.path.exists(sshkey_file): print("\n => ssh key '%s' does not exist" % sshkey_file) answer = input(" => Shall it be created (y/n)? (default = n): ") if answer.lower() == "y": print(" => Creating ssh key %s" % sshkey_file) result = execute( ["ssh-keygen", "-b", "2048", "-N", "", "-f", sshkey_file]) if result["returncode"] == 0: print_ok(" => OK") else: print_error(result["stderr"]) bye() else: bye() else: print_ok("\n => ssh key '%s' already exists" % sshkey_file) with open(sshkey_file + ".pub", "r") as fd: sshkey = fd.read().strip() try: apiclient = connect(profile) client = ClusterApi(apiclient) except Exception as ex: print_error(ex) return None try: response = client.get_cluster(cluster_id) except Exception as ex: print_error(ex) return None ssh_public_keys = response.get("ssh_public_keys", []) if sshkey in [key.strip() for key in ssh_public_keys]: print_ok(" => public ssh key already configured for cluster %s" % cluster_id) bye() request = {} for key in [ "autotermination_minutes", "cluster_id", "cluster_name", "cluster_source", "creator_user_name", "default_tags", "driver_node_type_id", "enable_elastic_disk", "init_scripts_safe_mode", "node_type_id", "spark_version", ]: request[key] = response[key] if response.get("spark_env_vars", None) is not None: request["spark_env_vars"] = response["spark_env_vars"] if response.get("aws_attributes", None) is not None: request["aws_attributes"] = response["aws_attributes"] if response.get("num_workers", None) is not None: request["num_workers"] = response["num_workers"] if response.get("autoscale", None) is not None: request["autoscale"] = response["autoscale"] request["ssh_public_keys"] = ssh_public_keys + [sshkey] print_warning( " => The ssh key will be added to the cluster. \n Note: The cluster will be restarted immediately!" ) answer = input( " => Shall the ssh key be added and the cluster be restarted (y/n)? (default = n): " ) if answer.lower() == "y": try: response = client.edit_cluster(request) except DatabricksApiException as ex: print_error(str(ex)) return None print_ok(" => OK") else: print_error(" => Cancelled")
def configure_ssh(profile, cluster_id): """Configure SSH for the remote cluster Args: profile (str): Databricks CLI profile string host (str): host from databricks cli config for given profile string token (str): token from databricks cli config for given profile string cluster_id (str): cluster ID """ sshkey_file = os.path.expanduser("~/.ssh/id_%s" % profile) if not os.path.exists(sshkey_file): print("\n => ssh key '%s' does not exist" % sshkey_file) answer = input(" => Shall it be created (y/n)? (default = n): ") if answer.lower() == "y": print(" => Creating ssh key %s" % sshkey_file) result = execute( ["ssh-keygen", "-b", "2048", "-N", "", "-f", sshkey_file]) if result["returncode"] == 0: print_ok(" => OK") else: print_error(result["stderr"]) bye() else: bye() else: print_ok("\n => ssh key '%s' already exists" % sshkey_file) with open(sshkey_file + ".pub", "r") as fd: sshkey = fd.read().strip() try: apiclient = connect(profile) client = ClusterApi(apiclient) except Exception as ex: # pylint: disable=broad-except print_error(ex) return None try: response = client.get_cluster(cluster_id) except Exception as ex: # pylint: disable=broad-except print_error(ex) return None ssh_public_keys = response.get("ssh_public_keys", []) if sshkey in [key.strip() for key in ssh_public_keys]: print_ok(" => public ssh key already configured for cluster %s" % cluster_id) bye() request = {} for key, val in response.items(): if key not in [ "driver", "executors", "spark_context_id", "state", "state_message", "start_time", "terminated_time", "last_state_loss_time", "last_activity_time", "disk_spec", ]: # omit runtime attributes request[key] = val request["ssh_public_keys"] = ssh_public_keys + [sshkey] print_warning(" => The ssh key will be added to the cluster." + "\n Note: The cluster will be restarted immediately!") answer = input( " => Shall the ssh key be added and the cluster be restarted (y/n)? (default = n): " ) if answer.lower() == "y": try: response = client.edit_cluster(request) except DatabricksApiException as ex: print_error(str(ex)) return None print_ok(" => OK") else: print_error(" => Cancelled")
def install(profile, host, token, cluster_id, cluster_name, use_whitelist): print( "\n* Installation of local environment to mirror a remote Databricks cluster" ) result = get_remote_packages(cluster_id, host, token) if result[0] != 0: print_error(result[1]) bye(1) libs = json.loads(result[1]) if use_whitelist: print_ok(" => Using whitelist to select packages") ds_libs = [lib for lib in libs if lib["name"].lower() in WHITELIST] else: print_ok(" => Using blacklist to select packages") ds_libs = [lib for lib in libs if lib["name"].lower() not in BLACKLIST] ds_yml = "" for lib in ds_libs: if lib["name"] == "python": # just artificially added python_version = lib["version"] else: if lib["name"] in ["hyperopt", "torchvision"]: r = re.compile(r"(\d+\.\d+.\d+)(.*)") version = r.match(lib["version"]).groups()[0] elif lib["name"] in ["tensorboardx"]: r = re.compile(r"(\d+\.\d+)(.*)") version = r.match(lib["version"]).groups()[0] else: version = lib["version"] ds_yml += " - %s==%s\n" % (lib["name"], version) module_path = os.path.dirname(databrickslabs_jupyterlab.__file__) env_file = os.path.join(module_path, "lib", "env.yml") with open(env_file, "r") as fd: master_yml = fd.read() lines = master_yml.split("\n") for i in range(len(lines)): if lines[i].startswith("dependencies"): lines.insert(i + 1, " - python=%s" % python_version) break master_yml = "\n".join(lines) print("\n Installed environment \n") print(master_yml) print("\n # Data Science Libs\n") print(ds_yml + "\n") with tempfile.TemporaryDirectory() as tmpdir: env_file = os.path.join(tmpdir, "env.yml") with open(env_file, "w") as fd: fd.write(master_yml) fd.write("\n # Data Science Libs\n") fd.write(ds_yml) fd.write("\n") env_name = cluster_name.replace(" ", "_") answer = input( " => Provide a conda environment name (default = %s): " % env_name) if answer != "": env_name = answer.replace(" ", "_") install_env(env_file, env_name) labext_file = os.path.join(module_path, "lib", "labextensions.txt") install_labextensions(labext_file, env_name) usage(env_name)