Beispiel #1
0
 def test_write(self):
     configs = SSHConfig.load(sample)
     configs.append(new_host)
     new_sample_path = os.path.join(os.path.dirname(__file__), "sample_new")
     configs.write(filename=new_sample_path)
     new_config = SSHConfig.load(new_sample_path)
     os.remove(new_sample_path)
     self.assertEqual("server2", new_config.get("server2").name)
Beispiel #2
0
def set_host(args):
    '''
    Add/Modify ssh_config entry
    Usage:
        Update Field:
            !set <host> <field_name> <value>
        Add Entry:
            !set <host> <hostname> <user> <port>
    '''
    global HOSTS
    try:
        config = SSHConfig.load(SSH_CONFIG_PATH)
    except ssh_config.client.EmptySSHConfig:
        config = SSHConfig(SSH_CONFIG_PATH)

    # print out current hosts
    if not args:
        table = PrettyTable()
        table.field_names = field_names
        table.border = False
        table.align = 'l'

        for host in config:
            table.add_row([
                host.name, host.HostName, host.User, host.Port,
                host.IdentityFile, host.ProxyJump
            ])
        logger.info(table)
        return

    current_hosts = [host.name for host in config]

    args = list(map(lambda x: x.strip(), args.split()))
    # update entry
    if args[0] in current_hosts and len(args) == 3:
        if args[1] not in field_names:
            logger.error('Use Designated Field Names')
            logger.error(set_host.__doc__)
            return

        if args[2].lower() == 'none':
            args[2] = None

        config.update(args[0], {args[1]: args[2]})
    # add entry
    elif args[0] not in current_hosts and len(args) == 4:
        new_host = Host(args[0], {
            'HostName': args[1],
            'User': args[2],
            'Port': args[3]
        })
        config.append(new_host)
    # invalid
    else:
        logger.error(set_host.__doc__)
        return

    config.write()
    load_config()
Beispiel #3
0
 def test_bastion(self):
     self.maxDiff = None
     new_sample = os.path.join(os.path.dirname(__file__), "sample.update")
     shutil.copy(sample, new_sample)
     cli.main(
         [
             "ssh_config",
             "-f",
             new_sample,
             "add",
             "-b",
             "-y",
             "bastion1",
             "HostName=bastion.example.com",
             "User=ssh_user",
         ]
     )
     configs = SSHConfig.load(new_sample)
     self.assertEqual(
         configs.get("bastion1").attributes(),
         {
             "HostName": "bastion.example.com",
             "User": "******",
             "ProxyCommand": "none",
             "ForwardAgent": "yes",
         },
     )
     cli.main(
         [
             "ssh_config",
             "-f",
             new_sample,
             "bastion",
             "-y",
             "bastion1",
             "server1",
         ]
     )
     configs = SSHConfig.load(new_sample)
     self.assertEqual(
         configs.get("bastion1").attributes(),
         {
             "HostName": "bastion.example.com",
             "User": "******",
             "ProxyCommand": "none",
             "ForwardAgent": "yes",
         },
     )
     os.remove(new_sample)
Beispiel #4
0
 def test_asdict(self):
     configs = SSHConfig.load(sample)
     self.assertEqual(
         {
             "*": {
                 "ServerAliveInterval": 40
             },
             "server1": {
                 "HostName": "203.0.113.76",
                 "ServerAliveInterval": 200
             },
             "server_cmd_1": {
                 "HostName": "203.0.113.76",
                 "Port": 2202
             },
             "server_cmd_2": {
                 "HostName": "203.0.113.76",
                 "Port": 22,
                 "User": "******",
             },
             "server_cmd_3": {
                 "HostName": "203.0.113.76",
                 "Port": 2202,
                 "User": "******",
             },
         },
         configs.asdict(),
     )
Beispiel #5
0
    def test_update(self):
        configs = SSHConfig.load(sample)
        configs.update("server1", {"IdentityFile": "~/.ssh/id_rsa_new"})
        self.assertRaises(AttributeError, configs.update, "server1", [])
        self.assertEqual(
            configs.get("server1").IdentityFile, "~/.ssh/id_rsa_new")

        attrs = {
            "HostName": "example.com",
            "User": "******",
            "Port": 22,
            "IdentityFile": "~/.ssh/id_rsa",
            "ProxyCommand": "",
            "LocalCommand": "",
            "LocalForward": "",
            "Match": "",
            "AddKeysToAgent": "",
            "AddressFamily": "",
            "BatchMode": "",
            "BindAddress": "",
            "BindInterface": "",
            "CanonialDomains": "",
            "CnonicalizeFallbackLocal": "",
            "IdentityAgent": "",
            "LogLevel": "",
            "PreferredAuthentications": "",
            "ServerAliveInterval": 10,
            "ForwardAgent": "",
        }
        configs.update("server1", attrs)
        for attr, attr_type in Host.attrs:
            self.assertEqual(getattr(configs.get("server1"), attr),
                             attrs[attr])
Beispiel #6
0
def prepare_ssh_config(cluster_id, profile, public_dns):
    """Add/edit the ssh configuration belonging to the given cluster in ~/.ssh/config
    
    Args:
        cluster_id (str): Cluster ID
        profile (str): Databricks CLI profile string
        public_dns (str): Public DNS/IP address
    """
    backup_path = "~/.databrickslabs_jupyterlab/ssh_config_backup"

    config = os.path.expanduser("~/.ssh/config")
    if not os.path.exists(os.path.expanduser(backup_path)):
        os.makedirs(os.path.expanduser(backup_path))

    print_warning("   => ~/.ssh/config will be changed")
    backup = "%s/config.%s" % (backup_path, time.strftime("%Y-%m-%d_%H-%M-%S"))
    print_warning(
        "   => A backup of the current ~/.ssh/config has been created")
    print_warning("   => at %s" % backup)

    shutil.copy(config, os.path.expanduser(backup))
    try:
        sc = SSHConfig.load(config)
    except:
        sc = SSHConfig(config)
    hosts = [h.name for h in sc.hosts()]
    if cluster_id in hosts:
        host = sc.get(cluster_id)
        host.set("HostName", public_dns)
        host.set("ServerAliveInterval", 30)
        host.set("ServerAliveCountMax", 5760)
        host.set("ConnectTimeout", 5)
        print("   => Added ssh config entry or modified IP address:\n")
        print(textwrap.indent(str(host), "      "))
    else:
        # ServerAliveInterval * ServerAliveCountMax = 48h
        attrs = {
            "HostName": public_dns,
            "IdentityFile": "~/.ssh/id_%s" % profile,
            "Port": 2200,
            "User": "******",
            "ServerAliveInterval": 30,
            "ServerAliveCountMax": 5760,
            "ConnectTimeout": 5,
        }
        host = Host(name=cluster_id, attrs=attrs)
        print("   => Adding ssh config to ~/.ssh/config:\n")
        print(textwrap.indent(str(host), "      "))
        sc.append(host)
    sc.write()

    add_known_host(public_dns)
Beispiel #7
0
    def test_other(self):
        configs = SSHConfig(sample)
        for host in configs:
            if host.name == "server1":
                self.assertEqual(host.HostName, "203.0.113.76")

            if host.name == "*":
                self.assertEqual(host.ServerAliveInterval, 40)
Beispiel #8
0
 def test_new(self):
     empty_sample = os.path.join(os.path.dirname(__file__), "sample_empty")
     config = SSHConfig.create(empty_sample)
     config.add(new_host)
     config.write()
     with open(empty_sample, "r") as f:
         self.assertEqual(new_data, f.read())
     os.remove(empty_sample)
Beispiel #9
0
 def test_rm(self):
     new_sample = os.path.join(os.path.dirname(__file__), "sample.rm")
     shutil.copy(sample, new_sample)
     cli.main(["ssh_config", "-f", new_sample, "rm", "-y", "server1"])
     sshconfig = SSHConfig.load(new_sample)
     host = sshconfig.get("server1", raise_exception=False)
     self.assertIsNone(host)
     os.remove(new_sample)
Beispiel #10
0
 def test_host_command(self):
     configs = SSHConfig.load(sample)
     self.assertEqual("ssh 203.0.113.76", configs.get("server1").command())
     self.assertEqual("ssh -p 2202 203.0.113.76",
                      configs.get("server_cmd_1").command())
     self.assertEqual("ssh [email protected]",
                      configs.get("server_cmd_2").command())
     self.assertEqual("ssh -p 2202 [email protected]",
                      configs.get("server_cmd_3").command())
Beispiel #11
0
    def test_asdict(self):
        configs = SSHConfig(sample)
        expected = sorted([
            {
                "Host": "*",
                "ServerAliveInterval": 40
            },
            {
                "Host": "server1",
                "HostName": "203.0.113.76",
                "ServerAliveInterval": 200
            },
            {
                "Host": "server_cmd_1",
                "HostName": "203.0.113.76",
                "Port": 2202
            },
            {
                "Host": "server_cmd_2",
                "HostName": "203.0.113.76",
                "Port": 22,
                "User": "******",
            },
            {
                "Host": "server_cmd_3",
                "HostName": "203.0.113.76",
                "Port": 2202,
                "User": "******",
            },
            {
                "Host": "host_1 host_2",
                "HostName": "%h.test.com",
                "Port": 2202,
                "User": "******",
            },
        ],
                          key=lambda h: h['Host'])

        self.assertEqual(
            expected,
            sorted(configs.asdict(), key=lambda h: h['Host']),
        )
 def test_remove_ssh_config(self):
     config = os.path.expanduser("~/.ssh/config")
     try:
         sc = SSHConfig.load(config)
     except:  # pylint: disable=bare-except
         sc = SSHConfig(config)
     for host in list(self.clusters.values()):
         sc.remove(host)
     sc.write()
Beispiel #13
0
 def test_import(self):
     new_sample = os.path.join(os.path.dirname(__file__), "sample.import")
     shutil.copy(sample, new_sample)
     import_csv = os.path.join(os.path.dirname(__file__), "import.csv")
     cli.main(["ssh_config", "-f", new_sample, "import", "-q", "-y", import_csv])
     sshconfig = SSHConfig.load(new_sample)
     import_1 = sshconfig.remove("import1")
     import_2 = sshconfig.remove("import2")
     sshconfig.write()
     self.assertTrue(import_1)
     self.assertTrue(import_2)
     os.remove(new_sample)
    def test_remove_known_host_entries(self):
        known_hosts_file = os.path.expanduser("~/.ssh/known_hosts")
        with open(known_hosts_file, "r") as fd:
            known_hosts = fd.read()

        config = os.path.expanduser("~/.ssh/config")
        try:
            sc = SSHConfig.load(config)
        except:  # pylint: disable=bare-except
            sc = SSHConfig(config)

        test_addresses = [
            h.get("HostName") for h in sc.hosts() if h.name in list(self.clusters.values())
        ]
        test_ips = [".".join(a.split(".")[0].split("-")[1:]) for a in test_addresses]
        keep_addresses = []
        for line in known_hosts.split("\n"):
            if line.strip() != "":
                address = line.split("]:2200")[0][1:]
                if not address in test_addresses and not address in test_ips:
                    keep_addresses.append(line)
        with open(known_hosts_file, "w") as fd:
            fd.write("\n".join(keep_addresses))
Beispiel #15
0
    def test_update(self):
        configs = SSHConfig(sample)
        configs.update("server1", {"IdentityFile": "~/.ssh/id_rsa_new"})
        self.assertRaises(AttributeError, configs.update, "server1", [])
        self.assertEqual(
            configs.get("server1").IdentityFile, "~/.ssh/id_rsa_new")

        attrs = {
            "HostName": "example.com",
            "User": "******",
            "Port": 22,
            "IdentityFile": "~/.ssh/id_rsa",
            "ServerAliveInterval": 10,
        }
        configs.update("server1", attrs)
        for key, value in attrs.items():
            self.assertEqual(getattr(configs.get("server1"), key), value)
Beispiel #16
0
def load_config():
    global HOSTS

    del HOSTS[:]
    if not path.isfile(expanduser(SSH_CONFIG_PATH)):
        if not path.exists(os.path.dirname(SSH_CONFIG_PATH)):
            os.makedirs(os.path.dirname(SSH_CONFIG_PATH))
        Path(SSH_CONFIG_PATH).touch()
        return

    try:
        for host in SSHConfig.load(expanduser(SSH_CONFIG_PATH)):
            HOSTS.append(host)
    except ssh_config.client.EmptySSHConfig:
        pass
Beispiel #17
0
 def test_add(self):
     sample_add = os.path.join(os.path.dirname(__file__), "sample.add")
     shutil.copy(sample, sample_add)
     cli.main([
         "ssh_config",
         "-f",
         sample_add,
         "add",
         "-y",
         "test_add",
         "HostName=238.0.4.1",
     ])
     sshconfig = SSHConfig.load(sample_add)
     host = sshconfig.get("test_add", raise_exception=False)
     self.assertIsNotNone(host)
     self.assertEqual(host.HostName, "238.0.4.1")
     os.remove(sample_add)
 def test_update(self):
     new_sample = os.path.join(os.path.dirname(__file__), "sample.update")
     shutil.copy(sample, new_sample)
     cli.main([
         "ssh_config",
         "-f",
         new_sample,
         "update",
         "-y",
         "server1",
         "IdentityFile=~/.ssh/id_rsa_test",
     ])
     sshconfig = SSHConfig.load(new_sample)
     host = sshconfig.get("server1", raise_exception=False)
     self.assertEqual("203.0.113.76", host.HostName)
     self.assertEqual("~/.ssh/id_rsa_test", host.IdentityFile)
     os.remove(new_sample)
    def test_configure_ssh(self, name, cluster_id):
        cluster_id2, public_ip, cluster_name, _ = get_cluster(
            self.profile, cluster_id)
        assert cluster_id2 == cluster_id
        assert cluster_name == name
        assert public_ip is not None

        prepare_ssh_config(cluster_id, self.profile, public_ip)
        ssh_config = os.path.expanduser("~/.ssh/config")
        sc = SSHConfig.load(ssh_config)
        host = sc.get(cluster_id)
        assert host.get("ConnectTimeout") == "5"
        assert host.get("ServerAliveCountMax") == "5760"
        assert host.get("IdentityFile") == "~/.ssh/id_{}".format(self.profile)

        assert is_reachable(public_dns=public_ip)

        subprocess.check_output([SSH, cluster_id, "hostname"])
Beispiel #20
0
 def test_update_with_pattern(self):
     new_sample = os.path.join(os.path.dirname(__file__), "sample.update")
     shutil.copy(sample, new_sample)
     cli.main([
         "ssh_config",
         "-f",
         new_sample,
         "add",
         "-y",
         "-p",
         "server_*",
         "IdentityFile=~/.ssh/id_rsa_test",
     ])
     sshconfig = SSHConfig.load(new_sample)
     for host in sshconfig:
         if "server_cmd" in host.name:
             self.assertEqual("203.0.113.76", host.HostName)
             self.assertEqual("~/.ssh/id_rsa_test", host.IdentityFile)
     os.remove(new_sample)
Beispiel #21
0
def remove_host(host):
    '''
    Remove entry from SSH config
    Usage:
        !remove <host>

    '''
    if not host or len(host.split()) != 1:
        logger.error(remove_host.__doc__)
        return

    # lookup host infos
    identity = host_lookup(host)
    if not identity:
        logger.error('Invalid Host')
        return

    try:
        config = SSHConfig.load(SSH_CONFIG_PATH)
    except ssh_config.client.EmptySSHConfig:
        return  # nothing to remove

    config.remove(host)
    config.write()
Beispiel #22
0
 def test_remove(self):
     config = SSHConfig(sample)
     config.remove("server1")
     with self.assertRaises(NameError):
         config.get("server1")
Beispiel #23
0
 def test_set_host(self):
     configs = SSHConfig(sample)
     configs.add(new_host)
     self.assertEqual(new_host, configs.hosts[-1])
Beispiel #24
0
 def test_get_host(self):
     configs = SSHConfig(sample)
     self.assertEqual("server1", configs.get("server1").name)
     with self.assertRaises(NameError):
         configs.get("NoExist")
Beispiel #25
0
 def test_set(self):
     configs = SSHConfig(sample)
     host_0 = configs.hosts[0]
     host_1 = configs.hosts[1]
     self.assertTrue(isinstance(host_0, Host))
     self.assertTrue(isinstance(host_1, Host))
Beispiel #26
0
 def test_load(self):
     configs = SSHConfig(sample)
     for config in configs:
         self.assertIn(config.name, ["server1", "*"])
         break
Beispiel #27
0
 def test_remove(self):
     configs = SSHConfig.load(sample)
     configs.remove("server1")
     self.assertRaises(KeyError, configs.get, "server1")
Beispiel #28
0
 def test_set_host(self):
     configs = SSHConfig.load(sample)
     configs.append(new_host)
     self.assertEqual(new_host, configs[-1])
Beispiel #29
0
 def test_get_host(self):
     configs = SSHConfig.load(sample)
     self.assertEqual("server1", configs.get("server1").name)
     self.assertRaises(KeyError, configs.get, "NoExist")
Beispiel #30
0
 def test_set(self):
     configs = SSHConfig.load(sample)
     host_0 = configs[0]
     host_1 = configs[1]