Exemplo n.º 1
0
def fork(shm_address: str,
         all_non_child_fds: Iterable[int]) -> Tuple[int, int]:
    sz = pos = 0
    with SharedMemory(name=shm_address, unlink_on_exit=True) as shm:
        data = shm.read_data_with_size()
        cmd = json.loads(data)
        sz = cmd.get('stdin_size', 0)
        if sz:
            pos = shm.tell()
            shm.unlink_on_exit = False

    r, w = safe_pipe()
    ready_fd_read, ready_fd_write = safe_pipe()
    try:
        child_pid = os.fork()
    except OSError:
        if sz:
            with SharedMemory(shm_address, unlink_on_exit=True):
                pass
    if child_pid:
        # master process
        os.close(w)
        os.close(ready_fd_read)
        poll = select.poll()
        poll.register(r, select.POLLIN)
        tuple(poll.poll())
        os.close(r)
        return child_pid, ready_fd_write
    # child process
    remove_signal_handlers()
    os.close(r)
    os.close(ready_fd_write)
    for fd in all_non_child_fds:
        os.close(fd)
    os.setsid()
    tty_name = cmd.get('tty_name')
    if tty_name:
        from kitty.fast_data_types import establish_controlling_tty
        sys.__stdout__.flush()
        sys.__stderr__.flush()
        establish_controlling_tty(tty_name, sys.__stdin__.fileno(),
                                  sys.__stdout__.fileno(),
                                  sys.__stderr__.fileno())
    os.close(w)
    if shm.unlink_on_exit:
        child_main(cmd, ready_fd_read)
    else:
        with SharedMemory(shm_address, unlink_on_exit=True) as shm:
            stdin_data = memoryview(shm.mmap)[pos:pos + sz]
            if stdin_data:
                sys.stdin = MemoryViewReadWrapper(stdin_data)
            try:
                child_main(cmd, ready_fd_read)
            finally:
                stdin_data.release()
                sys.stdin = sys.__stdin__
Exemplo n.º 2
0
def read_data_from_shared_memory(shm_name: str) -> Any:
    with SharedMemory(shm_name, readonly=True) as shm:
        shm.unlink()
        if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid():
            raise ValueError('Incorrect owner on pwfile')
        mode = stat.S_IMODE(shm.stats.st_mode)
        if mode != stat.S_IREAD:
            raise ValueError('Incorrect permissions on pwfile')
        return json.loads(shm.read_data_with_size())
Exemplo n.º 3
0
def bootstrap_script(
        ssh_opts: SSHOptions,
        script_type: str = 'sh',
        remote_args: Sequence[str] = (),
        test_script: str = '',
        request_id: Optional[str] = None,
        cli_hostname: str = '',
        cli_uname: str = '',
        request_data: bool = False,
        echo_on: bool = True) -> Tuple[str, Dict[str, str], SharedMemory]:
    if request_id is None:
        request_id = os.environ['KITTY_PID'] + '-' + os.environ[
            'KITTY_WINDOW_ID']
    is_python = script_type == 'py'
    export_home_cmd = prepare_export_home_cmd(
        ssh_opts, is_python) if 'HOME' in ssh_opts.env else ''
    exec_cmd = prepare_exec_cmd(remote_args, is_python) if remote_args else ''
    with open(
            os.path.join(shell_integration_dir, 'ssh',
                         f'bootstrap.{script_type}')) as f:
        ans = f.read()
    pw = secrets.token_hex()
    tfd = standard_b64encode(
        make_tarfile(ssh_opts, dict(os.environ),
                     'gz' if script_type == 'sh' else 'bz2')).decode('ascii')
    data = {
        'pw': pw,
        'opts': ssh_opts._asdict(),
        'hostname': cli_hostname,
        'uname': cli_uname,
        'tarfile': tfd
    }
    db = json.dumps(data)
    with SharedMemory(size=len(db) + SharedMemory.num_bytes_for_size,
                      mode=stat.S_IREAD,
                      prefix=f'kssh-{os.getpid()}-') as shm:
        shm.write_data_with_size(db)
        shm.flush()
        atexit.register(shm.unlink)
    sensitive_data = {
        'REQUEST_ID': request_id,
        'DATA_PASSWORD': pw,
        'PASSWORD_FILENAME': shm.name
    }
    replacements = {
        'EXPORT_HOME_CMD': export_home_cmd,
        'EXEC_CMD': exec_cmd,
        'TEST_SCRIPT': test_script,
        'REQUEST_DATA': '1' if request_data else '0',
        'ECHO_ON': '1' if echo_on else '0',
    }
    sd = replacements.copy()
    if request_data:
        sd.update(sensitive_data)
    replacements.update(sensitive_data)
    return prepare_script(ans, sd, script_type), replacements, shm
Exemplo n.º 4
0
 def __call__(
     self,
     tty_fd: int,
     argv: List[str],
     cwd: str = '',
     env: Optional[Dict[str, str]] = None,
     stdin_data: Optional[Union[str, bytes]] = None,
     timeout: float = TIMEOUT,
 ) -> Child:
     tty_name = os.ttyname(tty_fd)
     if isinstance(stdin_data, str):
         stdin_data = stdin_data.encode()
     if env is None:
         env = dict(os.environ)
     cmd: Dict[str, Union[int, List[str], str, Dict[str, str]]] = {
         'tty_name': tty_name,
         'cwd': cwd or os.getcwd(),
         'argv': argv,
         'env': env,
     }
     total_size = 0
     if stdin_data is not None:
         cmd['stdin_size'] = len(stdin_data)
         total_size += len(stdin_data)
     data = json.dumps(cmd).encode()
     total_size += len(data) + SharedMemory.num_bytes_for_size
     with SharedMemory(size=total_size, unlink_on_exit=True) as shm:
         shm.write_data_with_size(data)
         if stdin_data:
             shm.write(stdin_data)
         shm.flush()
         self.send_to_prewarm_process(f'fork:{shm.name}\n')
         input_buf = b''
         st = time.monotonic()
         while time.monotonic() - st < timeout:
             for (fd, event) in self.poll.poll(2):
                 if event & error_events:
                     raise PrewarmProcessFailed(
                         'Failed doing I/O with prewarm process')
                 if fd == self.read_from_process_fd and event & select.POLLIN:
                     d = os.read(self.read_from_process_fd,
                                 io.DEFAULT_BUFFER_SIZE)
                     input_buf += d
                     while (idx := input_buf.find(b'\n')) > -1:
                         line = input_buf[:idx].decode()
                         input_buf = input_buf[idx + 1:]
                         if line.startswith('CHILD:'):
                             _, cid, pid = line.split(':')
                             child = self.add_child(int(cid), int(pid))
                             shm.unlink_on_exit = False
                             return child
                         if line.startswith('ERR:'):
                             raise PrewarmProcessFailed(
                                 line.split(':', 1)[-1])
Exemplo n.º 5
0
def create_shared_memory(data: Any, prefix: str) -> str:
    import atexit
    import json
    import stat

    from kitty.shm import SharedMemory
    db = json.dumps(data).encode('utf-8')
    with SharedMemory(size=len(db) + SharedMemory.num_bytes_for_size,
                      mode=stat.S_IREAD,
                      prefix=prefix) as shm:
        shm.write_data_with_size(db)
        shm.flush()
        atexit.register(shm.unlink)
    return shm.name
Exemplo n.º 6
0
def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]:
    yield b'\nKITTY_DATA_START\n'  # to discard leading data
    try:
        msg = standard_b64decode(msg).decode('utf-8')
        md = dict(x.split('=', 1) for x in msg.split(':'))
        pw = md['pw']
        pwfilename = md['pwfile']
        rq_id = md['id']
    except Exception:
        traceback.print_exc()
        yield b'invalid ssh data request message\n'
    else:
        try:
            with SharedMemory(pwfilename, readonly=True) as shm:
                shm.unlink()
                if shm.stats.st_uid != os.geteuid(
                ) or shm.stats.st_gid != os.getegid():
                    raise ValueError('Incorrect owner on pwfile')
                mode = stat.S_IMODE(shm.stats.st_mode)
                if mode != stat.S_IREAD:
                    raise ValueError('Incorrect permissions on pwfile')
                env_data = json.loads(shm.read_data_with_size())
                if pw != env_data['pw']:
                    raise ValueError('Incorrect password')
                if rq_id != request_id:
                    raise ValueError('Incorrect request id')
        except Exception as e:
            traceback.print_exc()
            yield f'{e}\n'.encode('utf-8')
        else:
            yield b'OK\n'
            ssh_opts = SSHOptions(env_data['opts'])
            ssh_opts.copy = {
                k: CopyInstruction(*v)
                for k, v in ssh_opts.copy.items()
            }
            encoded_data = memoryview(env_data['tarfile'].encode('ascii'))
            # macOS has a 255 byte limit on its input queue as per man stty.
            # Not clear if that applies to canonical mode input as well, but
            # better to be safe.
            line_sz = 254
            while encoded_data:
                yield encoded_data[:line_sz]
                yield b'\n'
                encoded_data = encoded_data[line_sz:]
            yield b'KITTY_DATA_END\n'
Exemplo n.º 7
0
import time

from kitty.shm import SharedMemory

msg = sys.argv[-1]
prompt = os.environ.get('SSH_ASKPASS_PROMPT', '')
is_confirm = prompt == 'confirm'
is_fingerprint_check = '(yes/no/[fingerprint])' in msg
q = {
    'message': msg,
    'type': 'confirm' if is_confirm else 'get_line',
    'is_password': not is_fingerprint_check,
}

data = json.dumps(q)
with SharedMemory(
    size=len(data) + 1 + SharedMemory.num_bytes_for_size, unlink_on_exit=True, prefix=f'askpass-{os.getpid()}-') as shm, \
        open(os.ctermid(), 'wb') as tty:
    shm.write(b'\0')
    shm.write_data_with_size(data)
    shm.flush()
    with open(os.ctermid(), 'wb') as f:
        f.write(f'\x1bP@kitty-ask|{shm.name}\x1b\\'.encode('ascii'))
        f.flush()
    while True:
        # TODO: Replace sleep() with a mutex and condition variable created in the shared memory
        time.sleep(0.05)
        shm.seek(0)
        if shm.read(1) == b'\x01':
            break
    response = json.loads(shm.read_data_with_size())
if is_confirm: