def connect(profile):
    """Initialize Databricks API client

    Args:
        profile (str): Databricks CLI profile string

    Returns:
        ApiClient: Databricks ApiClient object
    """
    config = ProfileConfigProvider(profile).get_config()
    if config is None:
        print_error("Cannot initialize ApiClient")
        bye(1)
    verify = config.insecure is None
    if config.is_valid_with_token:
        api_client = ApiClient(host=config.host,
                               token=config.token,
                               verify=verify)
        api_client.default_headers[
            "user-agent"] = "databrickslabs-jupyterlab-%s" % __version__
        return api_client
    else:
        print_error(
            "No token found for profile '%s'.\nUsername/password in .databrickscfg is not supported by databrickslabs-jupyterlab"
            % profile)
        bye(1)
Exemple #2
0
def get_db_config(profile):
    """Get Databricks configuration from ~/.databricks.cfg for given profile
    
    Args:
        profile (str): Databricks CLI profile string
    
    Returns:
        tuple: The tuple of host and personal access token from ~/.databrickscfg
    """
    config = configparser.ConfigParser()
    configs = config.read(expanduser("~/.databrickscfg"))
    if not configs:
        print_error("Cannot read ~/.databrickscfg")
        bye(1)

    profiles = config.sections()
    if not profile in profiles:
        print(" The profile '%s' is not available in ~/.databrickscfg:" %
              profile)
        for p in profiles:
            print("- %s" % p)
        bye()
    else:
        host = config[profile]["host"]
        token = config[profile]["token"]
        return host, token
Exemple #3
0
def conda_version():
    """Check conda version"""
    result = execute(["conda", "--version"])
    if result["returncode"] != 0:
        print_error(result["stderr"])
        sys.exit(1)

    return result["stdout"].strip().split(" ")[1]
def get_local_libs():
    """Get installed libraries of the currtent conda environment"""
    result = execute(["conda", "list", "--json"])
    if result["returncode"] != 0:
        print_error(result["stderr"])
        return False

    try:
        return json.loads(result["stdout"])
    except Exception as ex:  # pylint: disable=broad-except
        print_error(ex)
        return False
def version_check(cluster_id, host, token, flag):
    """Compare local and remote library versions
    
    Args:
        cluster_id (str): Cluster ID
        host (str): host from databricks cli config for given profile string
        token (str): token from databricks cli config for given profile string
        flag (str): all|diff|same
    """
    def normalize(key):
        return key.lower().replace("-", "_")

    packages = get_local_libs()
    deps = {normalize(p["name"]): p["version"] for p in packages}

    result = get_remote_packages(cluster_id, host, token)
    if result[0] == 0:
        remote_packages = json.loads(result[1])
    else:
        return
    remote_deps = {normalize(p["name"]): p["version"] for p in remote_packages}
    joint_keys = sorted(list(set(list(deps.keys()) +
                                 list(remote_deps.keys()))))
    print("%-30s %-10s%-10s" % ("Package", "local", "remote"))
    if str(flag) == "all":
        scope = joint_keys
    elif str(flag) == "same":
        scope = [
            key for key in joint_keys
            if deps.get(key, None) == remote_deps.get(key, None)
        ]
    else:
        scope = [
            key for key in joint_keys
            if deps.get(key, None) != remote_deps.get(key, None)
            # and deps.get(key, None) is not None and remote_deps.get(key, None) is not None
        ]
    for key in scope:
        result = "%-30s %-10s  %-10s" % (key, deps.get(
            key, "--"), remote_deps.get(key, "--"))
        if deps.get(key) == remote_deps.get(key):
            print_ok(result)
        else:
            print_error(result)
def get_library_state(profile, cluster_id):
    """Get the state of the library installation on the remote cluster
    
    Args:
        cluster_id (str): Cluster ID
        host (str): host from databricks cli config for given profile string
        token (str): token from databricks cli config for given profile stringf
    
    Returns:
        list: list of installation status for each custom library
    """
    try:
        apiclient = connect(profile)
        client = LibrariesApi(apiclient)
        libraries = client.cluster_status(cluster_id)
    except Exception as ex:
        print_error(ex)
        return None

    if libraries.get("library_statuses", None) is None:
        return []
    else:
        return [lib["status"] for lib in libraries["library_statuses"]]
def get_cluster(profile, cluster_id=None, status=None):
    """Get the cluster configuration from remote
    
    Args:
        profile (str): Databricks CLI profile string
        cluster_id (str, optional): If cluster_id is given, the user will not be asked to select one.
                                    Defaults to None.
        status (Status, optional): A Status class providing set_status. Defaults to None.
    
    Returns:
        tuple: Cluster configs: cluster_id, public_ip, cluster_name and a flag whether it was started
    
    Returns:
        tuple: (cluster_id, public_ip, cluster_name, started)
    """
    with open("%s/.ssh/id_%s.pub" % (expanduser("~"), profile)) as fd:
        try:
            ssh_pub = fd.read().strip()
        except:
            print_error(
                "   Error: ssh key for profile 'id_%s.pub' does not exist in %s/.ssh"
                % (profile, expanduser("~")))
            bye()

    try:
        apiclient = connect(profile)
        client = ClusterApi(apiclient)
        clusters = client.list_clusters()
    except Exception as ex:
        print_error(ex)
        return (None, None, None, None)

    clusters = clusters["clusters"]

    if cluster_id is not None:
        cluster = None
        for c in clusters:
            if c["cluster_id"] == cluster_id:
                cluster = c
                break

        if cluster is None:
            print_error(
                "   Error: A cluster with id '%s' does not exist in the workspace of profile '%s'"
                % (cluster_id, profile))
            return (None, None, None, None)

        if ssh_pub not in [
                c.strip() for c in cluster.get("ssh_public_keys", [])
        ]:
            print_error(
                "   Error: Cluster with id '%s' does not have ssh key '~/.ssh/id_%s' configured"
                % (cluster_id, profile))
            return (None, None, None, None)
    else:
        my_clusters = [
            cluster for cluster in clusters if ssh_pub in
            [c.strip() for c in cluster.get("ssh_public_keys", [])]
        ]

        if not my_clusters:
            print_error(
                "    Error: There is no cluster in the workspace for profile '%s' configured with ssh key '~/.ssh/id_%s':"
                % (profile, profile))
            print(
                "    Use 'databrickslabs_jupyterlab %s -s' to configure ssh for clusters in this workspace\n"
                % profile)
            return (None, None, None, None)

        current_conda_env = os.environ.get("CONDA_DEFAULT_ENV", None)
        found = None
        for i, c in enumerate(my_clusters):
            if c["cluster_name"].replace(" ", "_") == current_conda_env:
                found = c["cluster_name"]
                break

        if found is not None:
            print_warning(
                "\n   => The current conda environment is '%s'.\n      You might want to select cluster %d with the name '%s'?\n"
                % (current_conda_env, i, found))

        cluster = select_cluster(my_clusters)

    cluster_id = cluster["cluster_id"]
    cluster_name = cluster["cluster_name"]

    try:
        response = client.get_cluster(cluster_id)
    except Exception as ex:
        print_error(ex)
        return (None, None, None, None)

    state = response["state"]
    if not state in ["RUNNING", "RESIZING"]:
        if state == "TERMINATED":
            print("   => Starting cluster %s" % cluster_id)
            if status is not None:
                status.set_status(profile, cluster_id, "Cluster Starting")
            try:
                response = client.start_cluster(cluster_id)
            except Exception as ex:
                print_error(ex)
                return (None, None, None, None)

        print(
            "   => Waiting for cluster %s being started (this can take up to 5 min)"
            % cluster_id)
        print("   ", end="", flush=True)

        while not state in ("RUNNING", "RESIZING"):
            if status is not None:
                status.set_status(profile, cluster_id, "Cluster Starting",
                                  False)
            print(".", end="", flush=True)
            time.sleep(5)
            try:
                response = client.get_cluster(cluster_id)
            except Exception as ex:
                print_error(ex)
                return (None, None, None, None)

            if response.get("error", None) is not None:
                print_error(response["error"])
                return (None, None, None, None)
            else:
                state = response["state"]
        print_ok("\n   => OK")

        if status is not None:
            status.set_status(profile, cluster_id, "Cluster started")

        print(
            "\n   => Waiting for libraries on cluster %s being installed (this can take some time)"
            % cluster_id)
        print("   ", end="", flush=True)

        done = False
        while not done:
            try:
                states = get_library_state(profile, cluster_id)
            except DatabricksApiException as ex:
                print_error(ex)
                return (None, None, None, None)

            installing = any(
                [s in ["PENDING", "RESOLVING", "INSTALLING"] for s in states])
            if installing:
                if status is not None:
                    status.set_status(profile, cluster_id,
                                      "Installing cluster libraries", False)
                print(".", end="", flush=True)
                time.sleep(5)
            else:
                done = True
                print_ok("\n   => OK\n")

        if status is not None:
            status.set_status(profile, cluster_id,
                              "Cluster libraries installed", False)

    public_ip = response["driver"].get("public_dns", None)
    if public_ip is None:
        print_error("   Error: Cluster does not have public DNS name")
        return (None, None, None, None)

    print_ok("   => Selected cluster: %s (%s)" % (cluster_name, public_ip))

    return (cluster_id, public_ip, cluster_name, None)
def configure_ssh(profile, host, token, cluster_id):
    """Configure SSH for the remote cluster
    
    Args:
        profile (str): Databricks CLI profile string
        host (str): host from databricks cli config for given profile string
        token (str): token from databricks cli config for given profile string
        cluster_id (str): cluster ID
    """
    sshkey_file = os.path.expanduser("~/.ssh/id_%s" % profile)
    if not os.path.exists(sshkey_file):
        print("\n   => ssh key '%s' does not exist" % sshkey_file)
        answer = input("   => Shall it be created (y/n)? (default = n): ")
        if answer.lower() == "y":
            print("   => Creating ssh key %s" % sshkey_file)
            result = execute(
                ["ssh-keygen", "-b", "2048", "-N", "", "-f", sshkey_file])
            if result["returncode"] == 0:
                print_ok("   => OK")
            else:
                print_error(result["stderr"])
                bye()
        else:
            bye()
    else:
        print_ok("\n   => ssh key '%s' already exists" % sshkey_file)

    with open(sshkey_file + ".pub", "r") as fd:
        sshkey = fd.read().strip()

    try:
        apiclient = connect(profile)
        client = ClusterApi(apiclient)
    except Exception as ex:
        print_error(ex)
        return None

    try:
        response = client.get_cluster(cluster_id)
    except Exception as ex:
        print_error(ex)
        return None

    ssh_public_keys = response.get("ssh_public_keys", [])
    if sshkey in [key.strip() for key in ssh_public_keys]:
        print_ok("   => public ssh key already configured for cluster %s" %
                 cluster_id)
        bye()

    request = {}
    for key in [
            "autotermination_minutes",
            "cluster_id",
            "cluster_name",
            "cluster_source",
            "creator_user_name",
            "default_tags",
            "driver_node_type_id",
            "enable_elastic_disk",
            "init_scripts_safe_mode",
            "node_type_id",
            "spark_version",
    ]:
        request[key] = response[key]

    if response.get("spark_env_vars", None) is not None:
        request["spark_env_vars"] = response["spark_env_vars"]

    if response.get("aws_attributes", None) is not None:
        request["aws_attributes"] = response["aws_attributes"]

    if response.get("num_workers", None) is not None:
        request["num_workers"] = response["num_workers"]

    if response.get("autoscale", None) is not None:
        request["autoscale"] = response["autoscale"]

    request["ssh_public_keys"] = ssh_public_keys + [sshkey]

    print_warning(
        "   => The ssh key will be added to the cluster. \n   Note: The cluster will be restarted immediately!"
    )
    answer = input(
        "   => Shall the ssh key be added and the cluster be restarted (y/n)? (default = n): "
    )
    if answer.lower() == "y":
        try:
            response = client.edit_cluster(request)
        except DatabricksApiException as ex:
            print_error(str(ex))
            return None
        print_ok("   => OK")
    else:
        print_error("   => Cancelled")
def configure_ssh(profile, cluster_id):
    """Configure SSH for the remote cluster

    Args:
        profile (str): Databricks CLI profile string
        host (str): host from databricks cli config for given profile string
        token (str): token from databricks cli config for given profile string
        cluster_id (str): cluster ID
    """
    sshkey_file = os.path.expanduser("~/.ssh/id_%s" % profile)
    if not os.path.exists(sshkey_file):
        print("\n   => ssh key '%s' does not exist" % sshkey_file)
        answer = input("   => Shall it be created (y/n)? (default = n): ")
        if answer.lower() == "y":
            print("   => Creating ssh key %s" % sshkey_file)
            result = execute(
                ["ssh-keygen", "-b", "2048", "-N", "", "-f", sshkey_file])
            if result["returncode"] == 0:
                print_ok("   => OK")
            else:
                print_error(result["stderr"])
                bye()
        else:
            bye()
    else:
        print_ok("\n   => ssh key '%s' already exists" % sshkey_file)

    with open(sshkey_file + ".pub", "r") as fd:
        sshkey = fd.read().strip()

    try:
        apiclient = connect(profile)
        client = ClusterApi(apiclient)
    except Exception as ex:  # pylint: disable=broad-except
        print_error(ex)
        return None

    try:
        response = client.get_cluster(cluster_id)
    except Exception as ex:  # pylint: disable=broad-except
        print_error(ex)
        return None

    ssh_public_keys = response.get("ssh_public_keys", [])
    if sshkey in [key.strip() for key in ssh_public_keys]:
        print_ok("   => public ssh key already configured for cluster %s" %
                 cluster_id)
        bye()

    request = {}

    for key, val in response.items():
        if key not in [
                "driver",
                "executors",
                "spark_context_id",
                "state",
                "state_message",
                "start_time",
                "terminated_time",
                "last_state_loss_time",
                "last_activity_time",
                "disk_spec",
        ]:  # omit runtime attributes
            request[key] = val

    request["ssh_public_keys"] = ssh_public_keys + [sshkey]

    print_warning("   => The ssh key will be added to the cluster." +
                  "\n   Note: The cluster will be restarted immediately!")
    answer = input(
        "   => Shall the ssh key be added and the cluster be restarted (y/n)? (default = n): "
    )
    if answer.lower() == "y":
        try:
            response = client.edit_cluster(request)
        except DatabricksApiException as ex:
            print_error(str(ex))
            return None
        print_ok("   => OK")
    else:
        print_error("   => Cancelled")