def add_known_host(address, port, known_hosts="~/.ssh/known_hosts"):
    result = execute(["ssh-keygen", "-R", "[%s]:%s" % (address, port)])
    result = execute(["ssh-keyscan", "-p", str(port), address])
    if result["returncode"] == 0:
        fingerprint = result["stdout"]
        known_hosts_file = os.path.expanduser(known_hosts)
        with open(known_hosts_file, "a") as fd:
            fd.write("\n%s" % fingerprint)
        print("   => Known hosts fingerprint added for %s:%s\n" % (address, port))
    else:
        print_warning("   => Could not add know_hosts fingerprint for %s:%s\n" % (address, port))
Beispiel #2
0
def prepare_ssh_config(cluster_id, profile, public_dns):
    """Add/edit the ssh configuration belonging to the given cluster in ~/.ssh/config
    
    Args:
        cluster_id (str): Cluster ID
        profile (str): Databricks CLI profile string
        public_dns (str): Public DNS/IP address
    """
    backup_path = "~/.databrickslabs_jupyterlab/ssh_config_backup"

    config = os.path.expanduser("~/.ssh/config")
    if not os.path.exists(os.path.expanduser(backup_path)):
        os.makedirs(os.path.expanduser(backup_path))

    print_warning("   => ~/.ssh/config will be changed")
    backup = "%s/config.%s" % (backup_path, time.strftime("%Y-%m-%d_%H-%M-%S"))
    print_warning(
        "   => A backup of the current ~/.ssh/config has been created")
    print_warning("   => at %s" % backup)

    shutil.copy(config, os.path.expanduser(backup))
    try:
        sc = SSHConfig.load(config)
    except:
        sc = SSHConfig(config)
    hosts = [h.name for h in sc.hosts()]
    if cluster_id in hosts:
        host = sc.get(cluster_id)
        host.set("HostName", public_dns)
        host.set("ServerAliveInterval", 30)
        host.set("ServerAliveCountMax", 5760)
        host.set("ConnectTimeout", 5)
        print("   => Added ssh config entry or modified IP address:\n")
        print(textwrap.indent(str(host), "      "))
    else:
        # ServerAliveInterval * ServerAliveCountMax = 48h
        attrs = {
            "HostName": public_dns,
            "IdentityFile": "~/.ssh/id_%s" % profile,
            "Port": 2200,
            "User": "******",
            "ServerAliveInterval": 30,
            "ServerAliveCountMax": 5760,
            "ConnectTimeout": 5,
        }
        host = Host(name=cluster_id, attrs=attrs)
        print("   => Adding ssh config to ~/.ssh/config:\n")
        print(textwrap.indent(str(host), "      "))
        sc.append(host)
    sc.write()

    add_known_host(public_dns)
def prepare_ssh_config(cluster_id, profile, public_dns):
    """Add/edit the ssh configuration belonging to the given cluster in ~/.ssh/config
    
    Args:
        cluster_id (str): Cluster ID
        profile (str): Databricks CLI profile string
        public_dns (str): Public DNS/IP address
    """
    backup_path = "~/.databrickslabs_jupyterlab/ssh_config_backup"

    config = os.path.expanduser("~/.ssh/config")
    if not os.path.exists(os.path.expanduser(backup_path)):
        os.makedirs(os.path.expanduser(backup_path))

    data = ""
    if os.path.exists(config):
        print_warning("   => ~/.ssh/config will be changed")
        backup = "%s/config.%s" % (backup_path,
                                   time.strftime("%Y-%m-%d_%H-%M-%S"))
        print_warning(
            "   => A backup of the current ~/.ssh/config has been created")
        print_warning("   => at %s" % backup)

        shutil.copy(config, os.path.expanduser(backup))
        with open(config, "r") as fd:
            data = fd.read()

    ssh_config = SshConfig(data)

    host = ssh_config.get_host(cluster_id)
    if host is None:
        host = ssh_config.add_host(cluster_id)
        (host.set_param("HostName", public_dns).set_param(
            "IdentityFile",
            "~/.ssh/id_%s" % profile).set_param("Port", 2200).set_param(
                "User",
                "ubuntu").set_param("ServerAliveInterval", 30).set_param(
                    "ServerAliveCountMax",
                    5760).set_param("ConnectTimeout", 5))
    print(
        f"   => Jupyterlab Integration made the following changes to {config}:"
    )
    ssh_config.dump()
    with open(config, "w") as fd:
        fd.write(ssh_config.to_string())

    add_known_host(public_dns)
def get_cluster(profile, cluster_id=None, status=None):
    """Get the cluster configuration from remote
    
    Args:
        profile (str): Databricks CLI profile string
        cluster_id (str, optional): If cluster_id is given, the user will not be asked to select one.
                                    Defaults to None.
        status (Status, optional): A Status class providing set_status. Defaults to None.
    
    Returns:
        tuple: Cluster configs: cluster_id, public_ip, cluster_name and a flag whether it was started
    
    Returns:
        tuple: (cluster_id, public_ip, cluster_name, started)
    """
    with open("%s/.ssh/id_%s.pub" % (expanduser("~"), profile)) as fd:
        try:
            ssh_pub = fd.read().strip()
        except:
            print_error(
                "   Error: ssh key for profile 'id_%s.pub' does not exist in %s/.ssh"
                % (profile, expanduser("~")))
            bye()

    try:
        apiclient = connect(profile)
        client = ClusterApi(apiclient)
        clusters = client.list_clusters()
    except Exception as ex:
        print_error(ex)
        return (None, None, None, None)

    clusters = clusters["clusters"]

    if cluster_id is not None:
        cluster = None
        for c in clusters:
            if c["cluster_id"] == cluster_id:
                cluster = c
                break

        if cluster is None:
            print_error(
                "   Error: A cluster with id '%s' does not exist in the workspace of profile '%s'"
                % (cluster_id, profile))
            return (None, None, None, None)

        if ssh_pub not in [
                c.strip() for c in cluster.get("ssh_public_keys", [])
        ]:
            print_error(
                "   Error: Cluster with id '%s' does not have ssh key '~/.ssh/id_%s' configured"
                % (cluster_id, profile))
            return (None, None, None, None)
    else:
        my_clusters = [
            cluster for cluster in clusters if ssh_pub in
            [c.strip() for c in cluster.get("ssh_public_keys", [])]
        ]

        if not my_clusters:
            print_error(
                "    Error: There is no cluster in the workspace for profile '%s' configured with ssh key '~/.ssh/id_%s':"
                % (profile, profile))
            print(
                "    Use 'databrickslabs_jupyterlab %s -s' to configure ssh for clusters in this workspace\n"
                % profile)
            return (None, None, None, None)

        current_conda_env = os.environ.get("CONDA_DEFAULT_ENV", None)
        found = None
        for i, c in enumerate(my_clusters):
            if c["cluster_name"].replace(" ", "_") == current_conda_env:
                found = c["cluster_name"]
                break

        if found is not None:
            print_warning(
                "\n   => The current conda environment is '%s'.\n      You might want to select cluster %d with the name '%s'?\n"
                % (current_conda_env, i, found))

        cluster = select_cluster(my_clusters)

    cluster_id = cluster["cluster_id"]
    cluster_name = cluster["cluster_name"]

    try:
        response = client.get_cluster(cluster_id)
    except Exception as ex:
        print_error(ex)
        return (None, None, None, None)

    state = response["state"]
    if not state in ["RUNNING", "RESIZING"]:
        if state == "TERMINATED":
            print("   => Starting cluster %s" % cluster_id)
            if status is not None:
                status.set_status(profile, cluster_id, "Cluster Starting")
            try:
                response = client.start_cluster(cluster_id)
            except Exception as ex:
                print_error(ex)
                return (None, None, None, None)

        print(
            "   => Waiting for cluster %s being started (this can take up to 5 min)"
            % cluster_id)
        print("   ", end="", flush=True)

        while not state in ("RUNNING", "RESIZING"):
            if status is not None:
                status.set_status(profile, cluster_id, "Cluster Starting",
                                  False)
            print(".", end="", flush=True)
            time.sleep(5)
            try:
                response = client.get_cluster(cluster_id)
            except Exception as ex:
                print_error(ex)
                return (None, None, None, None)

            if response.get("error", None) is not None:
                print_error(response["error"])
                return (None, None, None, None)
            else:
                state = response["state"]
        print_ok("\n   => OK")

        if status is not None:
            status.set_status(profile, cluster_id, "Cluster started")

        print(
            "\n   => Waiting for libraries on cluster %s being installed (this can take some time)"
            % cluster_id)
        print("   ", end="", flush=True)

        done = False
        while not done:
            try:
                states = get_library_state(profile, cluster_id)
            except DatabricksApiException as ex:
                print_error(ex)
                return (None, None, None, None)

            installing = any(
                [s in ["PENDING", "RESOLVING", "INSTALLING"] for s in states])
            if installing:
                if status is not None:
                    status.set_status(profile, cluster_id,
                                      "Installing cluster libraries", False)
                print(".", end="", flush=True)
                time.sleep(5)
            else:
                done = True
                print_ok("\n   => OK\n")

        if status is not None:
            status.set_status(profile, cluster_id,
                              "Cluster libraries installed", False)

    public_ip = response["driver"].get("public_dns", None)
    if public_ip is None:
        print_error("   Error: Cluster does not have public DNS name")
        return (None, None, None, None)

    print_ok("   => Selected cluster: %s (%s)" % (cluster_name, public_ip))

    return (cluster_id, public_ip, cluster_name, None)
def configure_ssh(profile, host, token, cluster_id):
    """Configure SSH for the remote cluster
    
    Args:
        profile (str): Databricks CLI profile string
        host (str): host from databricks cli config for given profile string
        token (str): token from databricks cli config for given profile string
        cluster_id (str): cluster ID
    """
    sshkey_file = os.path.expanduser("~/.ssh/id_%s" % profile)
    if not os.path.exists(sshkey_file):
        print("\n   => ssh key '%s' does not exist" % sshkey_file)
        answer = input("   => Shall it be created (y/n)? (default = n): ")
        if answer.lower() == "y":
            print("   => Creating ssh key %s" % sshkey_file)
            result = execute(
                ["ssh-keygen", "-b", "2048", "-N", "", "-f", sshkey_file])
            if result["returncode"] == 0:
                print_ok("   => OK")
            else:
                print_error(result["stderr"])
                bye()
        else:
            bye()
    else:
        print_ok("\n   => ssh key '%s' already exists" % sshkey_file)

    with open(sshkey_file + ".pub", "r") as fd:
        sshkey = fd.read().strip()

    try:
        apiclient = connect(profile)
        client = ClusterApi(apiclient)
    except Exception as ex:
        print_error(ex)
        return None

    try:
        response = client.get_cluster(cluster_id)
    except Exception as ex:
        print_error(ex)
        return None

    ssh_public_keys = response.get("ssh_public_keys", [])
    if sshkey in [key.strip() for key in ssh_public_keys]:
        print_ok("   => public ssh key already configured for cluster %s" %
                 cluster_id)
        bye()

    request = {}
    for key in [
            "autotermination_minutes",
            "cluster_id",
            "cluster_name",
            "cluster_source",
            "creator_user_name",
            "default_tags",
            "driver_node_type_id",
            "enable_elastic_disk",
            "init_scripts_safe_mode",
            "node_type_id",
            "spark_version",
    ]:
        request[key] = response[key]

    if response.get("spark_env_vars", None) is not None:
        request["spark_env_vars"] = response["spark_env_vars"]

    if response.get("aws_attributes", None) is not None:
        request["aws_attributes"] = response["aws_attributes"]

    if response.get("num_workers", None) is not None:
        request["num_workers"] = response["num_workers"]

    if response.get("autoscale", None) is not None:
        request["autoscale"] = response["autoscale"]

    request["ssh_public_keys"] = ssh_public_keys + [sshkey]

    print_warning(
        "   => The ssh key will be added to the cluster. \n   Note: The cluster will be restarted immediately!"
    )
    answer = input(
        "   => Shall the ssh key be added and the cluster be restarted (y/n)? (default = n): "
    )
    if answer.lower() == "y":
        try:
            response = client.edit_cluster(request)
        except DatabricksApiException as ex:
            print_error(str(ex))
            return None
        print_ok("   => OK")
    else:
        print_error("   => Cancelled")
def configure_ssh(profile, cluster_id):
    """Configure SSH for the remote cluster

    Args:
        profile (str): Databricks CLI profile string
        host (str): host from databricks cli config for given profile string
        token (str): token from databricks cli config for given profile string
        cluster_id (str): cluster ID
    """
    sshkey_file = os.path.expanduser("~/.ssh/id_%s" % profile)
    if not os.path.exists(sshkey_file):
        print("\n   => ssh key '%s' does not exist" % sshkey_file)
        answer = input("   => Shall it be created (y/n)? (default = n): ")
        if answer.lower() == "y":
            print("   => Creating ssh key %s" % sshkey_file)
            result = execute(
                ["ssh-keygen", "-b", "2048", "-N", "", "-f", sshkey_file])
            if result["returncode"] == 0:
                print_ok("   => OK")
            else:
                print_error(result["stderr"])
                bye()
        else:
            bye()
    else:
        print_ok("\n   => ssh key '%s' already exists" % sshkey_file)

    with open(sshkey_file + ".pub", "r") as fd:
        sshkey = fd.read().strip()

    try:
        apiclient = connect(profile)
        client = ClusterApi(apiclient)
    except Exception as ex:  # pylint: disable=broad-except
        print_error(ex)
        return None

    try:
        response = client.get_cluster(cluster_id)
    except Exception as ex:  # pylint: disable=broad-except
        print_error(ex)
        return None

    ssh_public_keys = response.get("ssh_public_keys", [])
    if sshkey in [key.strip() for key in ssh_public_keys]:
        print_ok("   => public ssh key already configured for cluster %s" %
                 cluster_id)
        bye()

    request = {}

    for key, val in response.items():
        if key not in [
                "driver",
                "executors",
                "spark_context_id",
                "state",
                "state_message",
                "start_time",
                "terminated_time",
                "last_state_loss_time",
                "last_activity_time",
                "disk_spec",
        ]:  # omit runtime attributes
            request[key] = val

    request["ssh_public_keys"] = ssh_public_keys + [sshkey]

    print_warning("   => The ssh key will be added to the cluster." +
                  "\n   Note: The cluster will be restarted immediately!")
    answer = input(
        "   => Shall the ssh key be added and the cluster be restarted (y/n)? (default = n): "
    )
    if answer.lower() == "y":
        try:
            response = client.edit_cluster(request)
        except DatabricksApiException as ex:
            print_error(str(ex))
            return None
        print_ok("   => OK")
    else:
        print_error("   => Cancelled")