示例#1
0
 def handle_event(self, sock, fd, event):
     # handle events and dispatch to handlers
     if sock:
         logging.log(shell.VERBOSE_LEVEL, 'fd %d %s', fd,
                     eventloop.EVENT_NAMES.get(event, event))
     if sock == self._server_socket:
         if event & eventloop.POLL_ERR:
             # TODO
             raise Exception('server_socket error')
         try:
             logging.debug('accept')
             conn = self._server_socket.accept()
             TCPRelayHandler(self, self._fd_to_handlers, self._eventloop,
                             conn[0], self._config, self._dns_resolver,
                             self._is_local)
         except (OSError, IOError) as e:
             error_no = eventloop.errno_from_exception(e)
             if error_no in (errno.EAGAIN, errno.EINPROGRESS,
                             errno.EWOULDBLOCK):
                 return
             else:
                 shell.print_exception(e)
                 if self._config['verbose']:
                     traceback.print_exc()
     else:
         if sock:
             handler = self._fd_to_handlers.get(fd, None)
             if handler:
                 handler.handle_event(sock, event)
         else:
             logging.warn('poll removed fd')
示例#2
0
    def run(self):
        events = []
        while not self._stopping:
            asap = False
            try:
                events = self.poll(TIMEOUT_PRECISION)
            except (OSError, IOError) as e:
                if errno_from_exception(e) in (errno.EPIPE, errno.EINTR):
                    # EPIPE: Happens when the client closes the connection
                    # EINTR: Happens when received a signal
                    # handles them as soon as possible
                    asap = True
                    logging.debug('poll:%s', e)
                else:
                    logging.error('poll:%s', e)
                    import traceback
                    traceback.print_exc()
                    continue

            for sock, fd, event in events:
                handler = self._fdmap.get(fd, None)
                if handler is not None:
                    handler = handler[1]
                    try:
                        handler.handle_event(sock, fd, event)
                    except (OSError, IOError) as e:
                        shell.print_exception(e)
            now = time.time()
            if asap or now - self._last_time >= TIMEOUT_PRECISION:
                for callback in self._periodic_callbacks:
                    callback()
                self._last_time = now
示例#3
0
def write_pid_file(pid_file, pid):
    import fcntl
    import stat

    try:
        fd = os.open(pid_file, os.O_RDWR | os.O_CREAT,
                     stat.S_IRUSR | stat.S_IWUSR)
    except OSError as e:
        shell.print_exception(e)
        return -1
    flags = fcntl.fcntl(fd, fcntl.F_GETFD)
    assert flags != -1
    flags |= fcntl.FD_CLOEXEC
    r = fcntl.fcntl(fd, fcntl.F_SETFD, flags)
    assert r != -1
    # There is no platform independent way to implement fcntl(fd, F_SETLK, &fl)
    # via fcntl.fcntl. So use lockf instead
    try:
        fcntl.lockf(fd, fcntl.LOCK_EX | fcntl.LOCK_NB, 0, 0, os.SEEK_SET)
    except IOError:
        r = os.read(fd, 32)
        if r:
            logging.error('already started at pid %s' % common.to_str(r))
        else:
            logging.error('already started')
        os.close(fd)
        return -1
    os.ftruncate(fd, 0)
    os.write(fd, common.to_bytes(str(pid)))
    return 0
示例#4
0
    def _on_remote_read(self):
        # handle all remote read events
        data = None
        try:
            data = self._remote_sock.recv(BUF_SIZE)

        except (OSError, IOError) as e:
            if eventloop.errno_from_exception(e) in \
                    (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK):
                return
        if not data:
            self.destroy()
            return
        self._update_activity(len(data))
        if self._is_local:
            data = self._encryptor.decrypt(data)
        else:
            data = self._encryptor.encrypt(data)
        try:
            self._write_to_sock(data, self._local_sock)
        except Exception as e:
            shell.print_exception(e)
            if self._config['verbose']:
                traceback.print_exc()
            # TODO use logging when debug completed
            self.destroy()
示例#5
0
    def run_server():
        def child_handler(signum, _):
            logging.warn('received SIGQUIT, doing graceful shutting down..')
            list(
                map(lambda s: s.close(next_tick=True),
                    tcp_servers + udp_servers))

        signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM),
                      child_handler)

        def int_handler(signum, _):
            sys.exit(1)

        signal.signal(signal.SIGINT, int_handler)

        try:
            loop = eventloop.EventLoop()
            dns_resolver.add_to_loop(loop)
            list(map(lambda s: s.add_to_loop(loop), tcp_servers + udp_servers))

            daemon.set_user(config.get('user', None))
            loop.run()
        except Exception as e:
            shell.print_exception(e)
            sys.exit(1)
示例#6
0
 def _send_control_data(self, data):
     if self._control_client_addr:
         try:
             self._control_socket.sendto(data, self._control_client_addr)
         except (socket.error, OSError, IOError) as e:
             error_no = eventloop.errno_from_exception(e)
             if error_no in (errno.EAGAIN, errno.EINPROGRESS,
                             errno.EWOULDBLOCK):
                 return
             else:
                 shell.print_exception(e)
                 if self._config['verbose']:
                     traceback.print_exc()
示例#7
0
    def _handle_dns_resolved(self, result, error):
        if error:
            self._log_error(error)
            self.destroy()
            return
        if result:
            ip = result[1]
            if ip:

                try:
                    self._stage = STAGE_CONNECTING
                    remote_addr = ip
                    if self._is_local:
                        remote_port = self._chosen_server[1]
                    else:
                        remote_port = self._remote_address[1]

                    if self._is_local and self._config['fast_open']:
                        # for fastopen:
                        # wait for more data to arrive and send them in one SYN
                        self._stage = STAGE_CONNECTING
                        # we don't have to wait for remote since it's not
                        # created
                        self._update_stream(STREAM_UP, WAIT_STATUS_READING)
                        # TODO when there is already data in this packet
                    else:
                        # else do connect
                        remote_sock = self._create_remote_socket(
                            remote_addr, remote_port)
                        try:
                            remote_sock.connect((remote_addr, remote_port))
                        except (OSError, IOError) as e:
                            if eventloop.errno_from_exception(e) == \
                                    errno.EINPROGRESS:
                                pass
                        self._loop.add(remote_sock,
                                       eventloop.POLL_ERR | eventloop.POLL_OUT,
                                       self._server)
                        self._stage = STAGE_CONNECTING
                        self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
                        self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
                    return
                except Exception as e:
                    shell.print_exception(e)
                    if self._config['verbose']:
                        traceback.print_exc()
        self.destroy()
示例#8
0
def daemon_stop(pid_file):
    import errno
    try:
        with open(pid_file) as f:
            buf = f.read()
            pid = common.to_str(buf)
            if not buf:
                logging.error('not running')
    except IOError as e:
        shell.print_exception(e)
        if e.errno == errno.ENOENT:
            # always exit 0 if we are sure daemon is not running
            logging.error('not running')
            return
        sys.exit(1)
    pid = int(pid)
    if pid > 0:
        try:
            os.kill(pid, signal.SIGTERM)
        except OSError as e:
            if e.errno == errno.ESRCH:
                logging.error('not running')
                # always exit 0 if we are sure daemon is not running
                return
            shell.print_exception(e)
            sys.exit(1)
    else:
        logging.error('pid is not positive: %d', pid)

    # sleep for maximum 10s
    for i in range(0, 200):
        try:
            # query for the pid
            os.kill(pid, 0)
        except OSError as e:
            if e.errno == errno.ESRCH:
                break
        time.sleep(0.05)
    else:
        logging.error('timed out when stopping pid %d', pid)
        sys.exit(1)
    print('stopped')
    os.unlink(pid_file)
示例#9
0
def main():
    shell.check_python()

    # fix py2exe
    if hasattr(sys, "frozen") and sys.frozen in \
            ("windows_exe", "console_exe"):
        p = os.path.dirname(os.path.abspath(sys.executable))
        os.chdir(p)

    config = shell.get_config(True)

    daemon.daemon_exec(config)

    try:
        logging.info("starting local at %s:%d" %
                     (config['local_address'], config['local_port']))

        dns_resolver = asyncdns.DNSResolver()
        tcp_server = tcprelay.TCPRelay(config, dns_resolver, True)
        udp_server = udprelay.UDPRelay(config, dns_resolver, True)
        loop = eventloop.EventLoop()
        dns_resolver.add_to_loop(loop)
        tcp_server.add_to_loop(loop)
        udp_server.add_to_loop(loop)

        def handler(signum, _):
            logging.warn('received SIGQUIT, doing graceful shutting down..')
            tcp_server.close(next_tick=True)
            udp_server.close(next_tick=True)

        signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM), handler)

        def int_handler(signum, _):
            sys.exit(1)

        signal.signal(signal.SIGINT, int_handler)

        daemon.set_user(config.get('user', None))
        loop.run()
    except Exception as e:
        shell.print_exception(e)
        sys.exit(1)
示例#10
0
 def _write_to_sock(self, data, sock):
     # write data to sock
     # if only some of the data are written, put remaining in the buffer
     # and update the stream to wait for writing
     if not data or not sock:
         return False
     uncomplete = False
     try:
         l = len(data)
         s = sock.send(data)
         if s < l:
             data = data[s:]
             uncomplete = True
     except (OSError, IOError) as e:
         error_no = eventloop.errno_from_exception(e)
         if error_no in (errno.EAGAIN, errno.EINPROGRESS,
                         errno.EWOULDBLOCK):
             uncomplete = True
         else:
             shell.print_exception(e)
             self.destroy()
             return False
     if uncomplete:
         if sock == self._local_sock:
             self._data_to_write_to_local.append(data)
             self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING)
         elif sock == self._remote_sock:
             self._data_to_write_to_remote.append(data)
             self._update_stream(STREAM_UP, WAIT_STATUS_WRITING)
         else:
             logging.error('write_all_to_sock:unknown socket')
     else:
         if sock == self._local_sock:
             self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
         elif sock == self._remote_sock:
             self._update_stream(STREAM_UP, WAIT_STATUS_READING)
         else:
             logging.error('write_all_to_sock:unknown socket')
     return True
示例#11
0
def parse_response(data):
    try:
        if len(data) >= 12:
            header = parse_header(data)
            if not header:
                return None
            res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, \
                res_ancount, res_nscount, res_arcount = header

            qds = []
            ans = []
            offset = 12
            for i in range(0, res_qdcount):
                l, r = parse_record(data, offset, True)
                offset += l
                if r:
                    qds.append(r)
            for i in range(0, res_ancount):
                l, r = parse_record(data, offset)
                offset += l
                if r:
                    ans.append(r)
            for i in range(0, res_nscount):
                l, r = parse_record(data, offset)
                offset += l
            for i in range(0, res_arcount):
                l, r = parse_record(data, offset)
                offset += l
            response = DNSResponse()
            if qds:
                response.hostname = qds[0][0]
            for an in qds:
                response.questions.append((an[1], an[2], an[3]))
            for an in ans:
                response.answers.append((an[1], an[2], an[3]))
            return response
    except Exception as e:
        shell.print_exception(e)
        return None
示例#12
0
def daemon_start(pid_file, log_file):
    def handle_exit(signum, _):
        if signum == signal.SIGTERM:
            sys.exit(0)
        sys.exit(1)

    signal.signal(signal.SIGINT, handle_exit)
    signal.signal(signal.SIGTERM, handle_exit)

    # fork only once because we are sure parent will exit
    pid = os.fork()
    assert pid != -1

    if pid > 0:
        # parent waits for its child
        time.sleep(5)
        sys.exit(0)

    # child signals its parent to exit
    ppid = os.getppid()
    pid = os.getpid()
    if write_pid_file(pid_file, pid) != 0:
        os.kill(ppid, signal.SIGINT)
        sys.exit(1)

    os.setsid()
    signal.signal(signal.SIG_IGN, signal.SIGHUP)

    print('started')
    os.kill(ppid, signal.SIGTERM)

    sys.stdin.close()
    try:
        freopen(log_file, 'a', sys.stdout)
        freopen(log_file, 'a', sys.stderr)
    except IOError as e:
        shell.print_exception(e)
        sys.exit(1)
示例#13
0
 def _handle_stage_connecting(self, data):
     if self._is_local:
         data = self._encryptor.encrypt(data)
     self._data_to_write_to_remote.append(data)
     if self._is_local and not self._fastopen_connected and \
             self._config['fast_open']:
         # for sslocal and fastopen, we basically wait for data and use
         # sendto to connect
         try:
             # only connect once
             self._fastopen_connected = True
             remote_sock = \
                 self._create_remote_socket(self._chosen_server[0],
                                            self._chosen_server[1])
             self._loop.add(remote_sock, eventloop.POLL_ERR, self._server)
             data = b''.join(self._data_to_write_to_remote)
             l = len(data)
             s = remote_sock.sendto(data, MSG_FASTOPEN, self._chosen_server)
             if s < l:
                 data = data[s:]
                 self._data_to_write_to_remote = [data]
             else:
                 self._data_to_write_to_remote = []
             self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
         except (OSError, IOError) as e:
             if eventloop.errno_from_exception(e) == errno.EINPROGRESS:
                 # in this case data is not sent at all
                 self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
             elif eventloop.errno_from_exception(e) == errno.ENOTCONN:
                 logging.error('fast open not supported on this OS')
                 self._config['fast_open'] = False
                 self.destroy()
             else:
                 shell.print_exception(e)
                 if self._config['verbose']:
                     traceback.print_exc()
                 self.destroy()
示例#14
0
    def _handle_server(self):
        server = self._server_socket
        data, r_addr = server.recvfrom(BUF_SIZE)
        if not data:
            logging.debug('UDP handle_server: data is empty')
        if self._stat_callback:
            self._stat_callback(self._listen_port, len(data))
        if self._is_local:
            frag = common.ord(data[2])
            if frag != 0:
                logging.warn('drop a message since frag is not 0')
                return
            else:
                data = data[3:]
        else:
            data = encrypt.encrypt_all(self._password, self._method, 0, data)
            # decrypt data
            if not data:
                logging.debug('UDP handle_server: data is empty after decrypt')
                return
        header_result = parse_header(data)
        if header_result is None:
            return
        addrtype, dest_addr, dest_port, header_length = header_result

        if self._is_local:
            server_addr, server_port = self._get_a_server()
        else:
            server_addr, server_port = dest_addr, dest_port

        addrs = self._dns_cache.get(server_addr, None)
        if addrs is None:
            addrs = socket.getaddrinfo(server_addr, server_port, 0,
                                       socket.SOCK_DGRAM, socket.SOL_UDP)
            if not addrs:
                # drop
                return
            else:
                self._dns_cache[server_addr] = addrs

        af, socktype, proto, canonname, sa = addrs[0]
        key = client_key(r_addr, af)
        client = self._cache.get(key, None)
        if not client:
            # TODO async getaddrinfo
            if self._forbidden_iplist:
                if common.to_str(sa[0]) in self._forbidden_iplist:
                    logging.debug('IP %s is in forbidden list, drop' %
                                  common.to_str(sa[0]))
                    # drop
                    return
            client = socket.socket(af, socktype, proto)
            client.setblocking(False)
            self._cache[key] = client
            self._client_fd_to_server_addr[client.fileno()] = r_addr

            self._sockets.add(client.fileno())
            self._eventloop.add(client, eventloop.POLL_IN, self)

        if self._is_local:
            data = encrypt.encrypt_all(self._password, self._method, 1, data)
            if not data:
                return
        else:
            data = data[header_length:]
        if not data:
            return
        try:
            client.sendto(data, (server_addr, server_port))
        except IOError as e:
            err = eventloop.errno_from_exception(e)
            if err in (errno.EINPROGRESS, errno.EAGAIN):
                pass
            else:
                shell.print_exception(e)