コード例 #1
0
def prepare_ssh_config(cluster_id, profile, public_dns):
    """Add/edit the ssh configuration belonging to the given cluster in ~/.ssh/config
    
    Args:
        cluster_id (str): Cluster ID
        profile (str): Databricks CLI profile string
        public_dns (str): Public DNS/IP address
    """
    backup_path = "~/.databrickslabs_jupyterlab/ssh_config_backup"

    config = os.path.expanduser("~/.ssh/config")
    if not os.path.exists(os.path.expanduser(backup_path)):
        os.makedirs(os.path.expanduser(backup_path))

    print_warning("   => ~/.ssh/config will be changed")
    backup = "%s/config.%s" % (backup_path, time.strftime("%Y-%m-%d_%H-%M-%S"))
    print_warning(
        "   => A backup of the current ~/.ssh/config has been created")
    print_warning("   => at %s" % backup)

    shutil.copy(config, os.path.expanduser(backup))
    try:
        sc = SSHConfig.load(config)
    except:
        sc = SSHConfig(config)
    hosts = [h.name for h in sc.hosts()]
    if cluster_id in hosts:
        host = sc.get(cluster_id)
        host.set("HostName", public_dns)
        host.set("ServerAliveInterval", 30)
        host.set("ServerAliveCountMax", 5760)
        host.set("ConnectTimeout", 5)
        print("   => Added ssh config entry or modified IP address:\n")
        print(textwrap.indent(str(host), "      "))
    else:
        # ServerAliveInterval * ServerAliveCountMax = 48h
        attrs = {
            "HostName": public_dns,
            "IdentityFile": "~/.ssh/id_%s" % profile,
            "Port": 2200,
            "User": "******",
            "ServerAliveInterval": 30,
            "ServerAliveCountMax": 5760,
            "ConnectTimeout": 5,
        }
        host = Host(name=cluster_id, attrs=attrs)
        print("   => Adding ssh config to ~/.ssh/config:\n")
        print(textwrap.indent(str(host), "      "))
        sc.append(host)
    sc.write()

    add_known_host(public_dns)
コード例 #2
0
    def test_remove_known_host_entries(self):
        known_hosts_file = os.path.expanduser("~/.ssh/known_hosts")
        with open(known_hosts_file, "r") as fd:
            known_hosts = fd.read()

        config = os.path.expanduser("~/.ssh/config")
        try:
            sc = SSHConfig.load(config)
        except:  # pylint: disable=bare-except
            sc = SSHConfig(config)

        test_addresses = [
            h.get("HostName") for h in sc.hosts() if h.name in list(self.clusters.values())
        ]
        test_ips = [".".join(a.split(".")[0].split("-")[1:]) for a in test_addresses]
        keep_addresses = []
        for line in known_hosts.split("\n"):
            if line.strip() != "":
                address = line.split("]:2200")[0][1:]
                if not address in test_addresses and not address in test_ips:
                    keep_addresses.append(line)
        with open(known_hosts_file, "w") as fd:
            fd.write("\n".join(keep_addresses))