def connect(profile):
    """Initialize Databricks API client

    Args:
        profile (str): Databricks CLI profile string

    Returns:
        ApiClient: Databricks ApiClient object
    """
    config = ProfileConfigProvider(profile).get_config()
    if config is None:
        print_error("Cannot initialize ApiClient")
        bye(1)
    verify = config.insecure is None
    if config.is_valid_with_token:
        api_client = ApiClient(host=config.host,
                               token=config.token,
                               verify=verify)
        api_client.default_headers[
            "user-agent"] = "databrickslabs-jupyterlab-%s" % __version__
        return api_client
    else:
        print_error(
            "No token found for profile '%s'.\nUsername/password in .databrickscfg is not supported by databrickslabs-jupyterlab"
            % profile)
        bye(1)
Exemple #2
0
def get_db_config(profile):
    """Get Databricks configuration from ~/.databricks.cfg for given profile
    
    Args:
        profile (str): Databricks CLI profile string
    
    Returns:
        tuple: The tuple of host and personal access token from ~/.databrickscfg
    """
    config = configparser.ConfigParser()
    configs = config.read(expanduser("~/.databrickscfg"))
    if not configs:
        print_error("Cannot read ~/.databrickscfg")
        bye(1)

    profiles = config.sections()
    if not profile in profiles:
        print(" The profile '%s' is not available in ~/.databrickscfg:" %
              profile)
        for p in profiles:
            print("- %s" % p)
        bye()
    else:
        host = config[profile]["host"]
        token = config[profile]["token"]
        return host, token
def get_cluster(profile, cluster_id=None, status=None):
    """Get the cluster configuration from remote
    
    Args:
        profile (str): Databricks CLI profile string
        cluster_id (str, optional): If cluster_id is given, the user will not be asked to select one.
                                    Defaults to None.
        status (Status, optional): A Status class providing set_status. Defaults to None.
    
    Returns:
        tuple: Cluster configs: cluster_id, public_ip, cluster_name and a flag whether it was started
    
    Returns:
        tuple: (cluster_id, public_ip, cluster_name, started)
    """
    with open("%s/.ssh/id_%s.pub" % (expanduser("~"), profile)) as fd:
        try:
            ssh_pub = fd.read().strip()
        except:
            print_error(
                "   Error: ssh key for profile 'id_%s.pub' does not exist in %s/.ssh"
                % (profile, expanduser("~")))
            bye()

    try:
        apiclient = connect(profile)
        client = ClusterApi(apiclient)
        clusters = client.list_clusters()
    except Exception as ex:
        print_error(ex)
        return (None, None, None, None)

    clusters = clusters["clusters"]

    if cluster_id is not None:
        cluster = None
        for c in clusters:
            if c["cluster_id"] == cluster_id:
                cluster = c
                break

        if cluster is None:
            print_error(
                "   Error: A cluster with id '%s' does not exist in the workspace of profile '%s'"
                % (cluster_id, profile))
            return (None, None, None, None)

        if ssh_pub not in [
                c.strip() for c in cluster.get("ssh_public_keys", [])
        ]:
            print_error(
                "   Error: Cluster with id '%s' does not have ssh key '~/.ssh/id_%s' configured"
                % (cluster_id, profile))
            return (None, None, None, None)
    else:
        my_clusters = [
            cluster for cluster in clusters if ssh_pub in
            [c.strip() for c in cluster.get("ssh_public_keys", [])]
        ]

        if not my_clusters:
            print_error(
                "    Error: There is no cluster in the workspace for profile '%s' configured with ssh key '~/.ssh/id_%s':"
                % (profile, profile))
            print(
                "    Use 'databrickslabs_jupyterlab %s -s' to configure ssh for clusters in this workspace\n"
                % profile)
            return (None, None, None, None)

        current_conda_env = os.environ.get("CONDA_DEFAULT_ENV", None)
        found = None
        for i, c in enumerate(my_clusters):
            if c["cluster_name"].replace(" ", "_") == current_conda_env:
                found = c["cluster_name"]
                break

        if found is not None:
            print_warning(
                "\n   => The current conda environment is '%s'.\n      You might want to select cluster %d with the name '%s'?\n"
                % (current_conda_env, i, found))

        cluster = select_cluster(my_clusters)

    cluster_id = cluster["cluster_id"]
    cluster_name = cluster["cluster_name"]

    try:
        response = client.get_cluster(cluster_id)
    except Exception as ex:
        print_error(ex)
        return (None, None, None, None)

    state = response["state"]
    if not state in ["RUNNING", "RESIZING"]:
        if state == "TERMINATED":
            print("   => Starting cluster %s" % cluster_id)
            if status is not None:
                status.set_status(profile, cluster_id, "Cluster Starting")
            try:
                response = client.start_cluster(cluster_id)
            except Exception as ex:
                print_error(ex)
                return (None, None, None, None)

        print(
            "   => Waiting for cluster %s being started (this can take up to 5 min)"
            % cluster_id)
        print("   ", end="", flush=True)

        while not state in ("RUNNING", "RESIZING"):
            if status is not None:
                status.set_status(profile, cluster_id, "Cluster Starting",
                                  False)
            print(".", end="", flush=True)
            time.sleep(5)
            try:
                response = client.get_cluster(cluster_id)
            except Exception as ex:
                print_error(ex)
                return (None, None, None, None)

            if response.get("error", None) is not None:
                print_error(response["error"])
                return (None, None, None, None)
            else:
                state = response["state"]
        print_ok("\n   => OK")

        if status is not None:
            status.set_status(profile, cluster_id, "Cluster started")

        print(
            "\n   => Waiting for libraries on cluster %s being installed (this can take some time)"
            % cluster_id)
        print("   ", end="", flush=True)

        done = False
        while not done:
            try:
                states = get_library_state(profile, cluster_id)
            except DatabricksApiException as ex:
                print_error(ex)
                return (None, None, None, None)

            installing = any(
                [s in ["PENDING", "RESOLVING", "INSTALLING"] for s in states])
            if installing:
                if status is not None:
                    status.set_status(profile, cluster_id,
                                      "Installing cluster libraries", False)
                print(".", end="", flush=True)
                time.sleep(5)
            else:
                done = True
                print_ok("\n   => OK\n")

        if status is not None:
            status.set_status(profile, cluster_id,
                              "Cluster libraries installed", False)

    public_ip = response["driver"].get("public_dns", None)
    if public_ip is None:
        print_error("   Error: Cluster does not have public DNS name")
        return (None, None, None, None)

    print_ok("   => Selected cluster: %s (%s)" % (cluster_name, public_ip))

    return (cluster_id, public_ip, cluster_name, None)
def configure_ssh(profile, host, token, cluster_id):
    """Configure SSH for the remote cluster
    
    Args:
        profile (str): Databricks CLI profile string
        host (str): host from databricks cli config for given profile string
        token (str): token from databricks cli config for given profile string
        cluster_id (str): cluster ID
    """
    sshkey_file = os.path.expanduser("~/.ssh/id_%s" % profile)
    if not os.path.exists(sshkey_file):
        print("\n   => ssh key '%s' does not exist" % sshkey_file)
        answer = input("   => Shall it be created (y/n)? (default = n): ")
        if answer.lower() == "y":
            print("   => Creating ssh key %s" % sshkey_file)
            result = execute(
                ["ssh-keygen", "-b", "2048", "-N", "", "-f", sshkey_file])
            if result["returncode"] == 0:
                print_ok("   => OK")
            else:
                print_error(result["stderr"])
                bye()
        else:
            bye()
    else:
        print_ok("\n   => ssh key '%s' already exists" % sshkey_file)

    with open(sshkey_file + ".pub", "r") as fd:
        sshkey = fd.read().strip()

    try:
        apiclient = connect(profile)
        client = ClusterApi(apiclient)
    except Exception as ex:
        print_error(ex)
        return None

    try:
        response = client.get_cluster(cluster_id)
    except Exception as ex:
        print_error(ex)
        return None

    ssh_public_keys = response.get("ssh_public_keys", [])
    if sshkey in [key.strip() for key in ssh_public_keys]:
        print_ok("   => public ssh key already configured for cluster %s" %
                 cluster_id)
        bye()

    request = {}
    for key in [
            "autotermination_minutes",
            "cluster_id",
            "cluster_name",
            "cluster_source",
            "creator_user_name",
            "default_tags",
            "driver_node_type_id",
            "enable_elastic_disk",
            "init_scripts_safe_mode",
            "node_type_id",
            "spark_version",
    ]:
        request[key] = response[key]

    if response.get("spark_env_vars", None) is not None:
        request["spark_env_vars"] = response["spark_env_vars"]

    if response.get("aws_attributes", None) is not None:
        request["aws_attributes"] = response["aws_attributes"]

    if response.get("num_workers", None) is not None:
        request["num_workers"] = response["num_workers"]

    if response.get("autoscale", None) is not None:
        request["autoscale"] = response["autoscale"]

    request["ssh_public_keys"] = ssh_public_keys + [sshkey]

    print_warning(
        "   => The ssh key will be added to the cluster. \n   Note: The cluster will be restarted immediately!"
    )
    answer = input(
        "   => Shall the ssh key be added and the cluster be restarted (y/n)? (default = n): "
    )
    if answer.lower() == "y":
        try:
            response = client.edit_cluster(request)
        except DatabricksApiException as ex:
            print_error(str(ex))
            return None
        print_ok("   => OK")
    else:
        print_error("   => Cancelled")
def configure_ssh(profile, cluster_id):
    """Configure SSH for the remote cluster

    Args:
        profile (str): Databricks CLI profile string
        host (str): host from databricks cli config for given profile string
        token (str): token from databricks cli config for given profile string
        cluster_id (str): cluster ID
    """
    sshkey_file = os.path.expanduser("~/.ssh/id_%s" % profile)
    if not os.path.exists(sshkey_file):
        print("\n   => ssh key '%s' does not exist" % sshkey_file)
        answer = input("   => Shall it be created (y/n)? (default = n): ")
        if answer.lower() == "y":
            print("   => Creating ssh key %s" % sshkey_file)
            result = execute(
                ["ssh-keygen", "-b", "2048", "-N", "", "-f", sshkey_file])
            if result["returncode"] == 0:
                print_ok("   => OK")
            else:
                print_error(result["stderr"])
                bye()
        else:
            bye()
    else:
        print_ok("\n   => ssh key '%s' already exists" % sshkey_file)

    with open(sshkey_file + ".pub", "r") as fd:
        sshkey = fd.read().strip()

    try:
        apiclient = connect(profile)
        client = ClusterApi(apiclient)
    except Exception as ex:  # pylint: disable=broad-except
        print_error(ex)
        return None

    try:
        response = client.get_cluster(cluster_id)
    except Exception as ex:  # pylint: disable=broad-except
        print_error(ex)
        return None

    ssh_public_keys = response.get("ssh_public_keys", [])
    if sshkey in [key.strip() for key in ssh_public_keys]:
        print_ok("   => public ssh key already configured for cluster %s" %
                 cluster_id)
        bye()

    request = {}

    for key, val in response.items():
        if key not in [
                "driver",
                "executors",
                "spark_context_id",
                "state",
                "state_message",
                "start_time",
                "terminated_time",
                "last_state_loss_time",
                "last_activity_time",
                "disk_spec",
        ]:  # omit runtime attributes
            request[key] = val

    request["ssh_public_keys"] = ssh_public_keys + [sshkey]

    print_warning("   => The ssh key will be added to the cluster." +
                  "\n   Note: The cluster will be restarted immediately!")
    answer = input(
        "   => Shall the ssh key be added and the cluster be restarted (y/n)? (default = n): "
    )
    if answer.lower() == "y":
        try:
            response = client.edit_cluster(request)
        except DatabricksApiException as ex:
            print_error(str(ex))
            return None
        print_ok("   => OK")
    else:
        print_error("   => Cancelled")
def install(profile, host, token, cluster_id, cluster_name, use_whitelist):
    print(
        "\n* Installation of local environment to mirror a remote Databricks cluster"
    )
    result = get_remote_packages(cluster_id, host, token)
    if result[0] != 0:
        print_error(result[1])
        bye(1)
    libs = json.loads(result[1])

    if use_whitelist:
        print_ok("   => Using whitelist to select packages")
        ds_libs = [lib for lib in libs if lib["name"].lower() in WHITELIST]
    else:
        print_ok("   => Using blacklist to select packages")
        ds_libs = [lib for lib in libs if lib["name"].lower() not in BLACKLIST]

    ds_yml = ""
    for lib in ds_libs:
        if lib["name"] == "python":  # just artificially added
            python_version = lib["version"]
        else:
            if lib["name"] in ["hyperopt", "torchvision"]:
                r = re.compile(r"(\d+\.\d+.\d+)(.*)")
                version = r.match(lib["version"]).groups()[0]
            elif lib["name"] in ["tensorboardx"]:
                r = re.compile(r"(\d+\.\d+)(.*)")
                version = r.match(lib["version"]).groups()[0]
            else:
                version = lib["version"]
            ds_yml += "    - %s==%s\n" % (lib["name"], version)

    module_path = os.path.dirname(databrickslabs_jupyterlab.__file__)
    env_file = os.path.join(module_path, "lib", "env.yml")
    with open(env_file, "r") as fd:
        master_yml = fd.read()
    lines = master_yml.split("\n")
    for i in range(len(lines)):
        if lines[i].startswith("dependencies"):
            lines.insert(i + 1, "  - python=%s" % python_version)
            break
    master_yml = "\n".join(lines)

    print("\n Installed environment \n")
    print(master_yml)
    print("\n    # Data Science Libs\n")
    print(ds_yml + "\n")

    with tempfile.TemporaryDirectory() as tmpdir:
        env_file = os.path.join(tmpdir, "env.yml")
        with open(env_file, "w") as fd:
            fd.write(master_yml)
            fd.write("\n    # Data Science Libs\n")
            fd.write(ds_yml)
            fd.write("\n")

        env_name = cluster_name.replace(" ", "_")
        answer = input(
            "    => Provide a conda environment name (default = %s): " %
            env_name)
        if answer != "":
            env_name = answer.replace(" ", "_")

        install_env(env_file, env_name)

        labext_file = os.path.join(module_path, "lib", "labextensions.txt")
        install_labextensions(labext_file, env_name)

    usage(env_name)