Пример #1
0
def set_host(args):
    '''
    Add/Modify ssh_config entry
    Usage:
        Update Field:
            !set <host> <field_name> <value>
        Add Entry:
            !set <host> <hostname> <user> <port>
    '''
    global HOSTS
    try:
        config = SSHConfig.load(SSH_CONFIG_PATH)
    except ssh_config.client.EmptySSHConfig:
        config = SSHConfig(SSH_CONFIG_PATH)

    # print out current hosts
    if not args:
        table = PrettyTable()
        table.field_names = field_names
        table.border = False
        table.align = 'l'

        for host in config:
            table.add_row([
                host.name, host.HostName, host.User, host.Port,
                host.IdentityFile, host.ProxyJump
            ])
        logger.info(table)
        return

    current_hosts = [host.name for host in config]

    args = list(map(lambda x: x.strip(), args.split()))
    # update entry
    if args[0] in current_hosts and len(args) == 3:
        if args[1] not in field_names:
            logger.error('Use Designated Field Names')
            logger.error(set_host.__doc__)
            return

        if args[2].lower() == 'none':
            args[2] = None

        config.update(args[0], {args[1]: args[2]})
    # add entry
    elif args[0] not in current_hosts and len(args) == 4:
        new_host = Host(args[0], {
            'HostName': args[1],
            'User': args[2],
            'Port': args[3]
        })
        config.append(new_host)
    # invalid
    else:
        logger.error(set_host.__doc__)
        return

    config.write()
    load_config()
Пример #2
0
def prepare_ssh_config(cluster_id, profile, public_dns):
    """Add/edit the ssh configuration belonging to the given cluster in ~/.ssh/config
    
    Args:
        cluster_id (str): Cluster ID
        profile (str): Databricks CLI profile string
        public_dns (str): Public DNS/IP address
    """
    backup_path = "~/.databrickslabs_jupyterlab/ssh_config_backup"

    config = os.path.expanduser("~/.ssh/config")
    if not os.path.exists(os.path.expanduser(backup_path)):
        os.makedirs(os.path.expanduser(backup_path))

    print_warning("   => ~/.ssh/config will be changed")
    backup = "%s/config.%s" % (backup_path, time.strftime("%Y-%m-%d_%H-%M-%S"))
    print_warning(
        "   => A backup of the current ~/.ssh/config has been created")
    print_warning("   => at %s" % backup)

    shutil.copy(config, os.path.expanduser(backup))
    try:
        sc = SSHConfig.load(config)
    except:
        sc = SSHConfig(config)
    hosts = [h.name for h in sc.hosts()]
    if cluster_id in hosts:
        host = sc.get(cluster_id)
        host.set("HostName", public_dns)
        host.set("ServerAliveInterval", 30)
        host.set("ServerAliveCountMax", 5760)
        host.set("ConnectTimeout", 5)
        print("   => Added ssh config entry or modified IP address:\n")
        print(textwrap.indent(str(host), "      "))
    else:
        # ServerAliveInterval * ServerAliveCountMax = 48h
        attrs = {
            "HostName": public_dns,
            "IdentityFile": "~/.ssh/id_%s" % profile,
            "Port": 2200,
            "User": "******",
            "ServerAliveInterval": 30,
            "ServerAliveCountMax": 5760,
            "ConnectTimeout": 5,
        }
        host = Host(name=cluster_id, attrs=attrs)
        print("   => Adding ssh config to ~/.ssh/config:\n")
        print(textwrap.indent(str(host), "      "))
        sc.append(host)
    sc.write()

    add_known_host(public_dns)
Пример #3
0
import unittest
from unittest import mock
import pytest
from io import StringIO

from contextlib import redirect_stdout

sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from ssh_config import SSHConfig, Host
from ssh_config.errors import EmptySSHConfig, WrongSSHConfig, HostExistsError

logging.basicConfig(level=logging.INFO)
sample = os.path.join(os.path.dirname(__file__), "sample")

new_host = Host("server2", {
    "ServerAliveInterval": 200,
    "HostName": "203.0.113.77"
})

new_data = """Host server2
    HostName 203.0.113.77
    ServerAliveInterval 200
"""


class TestSSHConfig(unittest.TestCase):
    def test_load(self):
        configs = SSHConfig(sample)
        for config in configs:
            self.assertIn(config.name, ["server1", "*"])
            break