def handle_event(self, sock, fd, event): # handle events and dispatch to handlers if sock: logging.log(shell.VERBOSE_LEVEL, 'fd %d %s', fd, eventloop.EVENT_NAMES.get(event, event)) if sock == self._server_socket: if event & eventloop.POLL_ERR: # TODO raise Exception('server_socket error') try: logging.debug('accept') conn = self._server_socket.accept() TCPRelayHandler(self, self._fd_to_handlers, self._eventloop, conn[0], self._config, self._dns_resolver, self._is_local) except (OSError, IOError) as e: error_no = eventloop.errno_from_exception(e) if error_no in (errno.EAGAIN, errno.EINPROGRESS, errno.EWOULDBLOCK): return else: shell.print_exception(e) if self._config['verbose']: traceback.print_exc() else: if sock: handler = self._fd_to_handlers.get(fd, None) if handler: handler.handle_event(sock, event) else: logging.warn('poll removed fd')
def run(self): events = [] while not self._stopping: asap = False try: events = self.poll(TIMEOUT_PRECISION) except (OSError, IOError) as e: if errno_from_exception(e) in (errno.EPIPE, errno.EINTR): # EPIPE: Happens when the client closes the connection # EINTR: Happens when received a signal # handles them as soon as possible asap = True logging.debug('poll:%s', e) else: logging.error('poll:%s', e) import traceback traceback.print_exc() continue for sock, fd, event in events: handler = self._fdmap.get(fd, None) if handler is not None: handler = handler[1] try: handler.handle_event(sock, fd, event) except (OSError, IOError) as e: shell.print_exception(e) now = time.time() if asap or now - self._last_time >= TIMEOUT_PRECISION: for callback in self._periodic_callbacks: callback() self._last_time = now
def write_pid_file(pid_file, pid): import fcntl import stat try: fd = os.open(pid_file, os.O_RDWR | os.O_CREAT, stat.S_IRUSR | stat.S_IWUSR) except OSError as e: shell.print_exception(e) return -1 flags = fcntl.fcntl(fd, fcntl.F_GETFD) assert flags != -1 flags |= fcntl.FD_CLOEXEC r = fcntl.fcntl(fd, fcntl.F_SETFD, flags) assert r != -1 # There is no platform independent way to implement fcntl(fd, F_SETLK, &fl) # via fcntl.fcntl. So use lockf instead try: fcntl.lockf(fd, fcntl.LOCK_EX | fcntl.LOCK_NB, 0, 0, os.SEEK_SET) except IOError: r = os.read(fd, 32) if r: logging.error('already started at pid %s' % common.to_str(r)) else: logging.error('already started') os.close(fd) return -1 os.ftruncate(fd, 0) os.write(fd, common.to_bytes(str(pid))) return 0
def _on_remote_read(self): # handle all remote read events data = None try: data = self._remote_sock.recv(BUF_SIZE) except (OSError, IOError) as e: if eventloop.errno_from_exception(e) in \ (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK): return if not data: self.destroy() return self._update_activity(len(data)) if self._is_local: data = self._encryptor.decrypt(data) else: data = self._encryptor.encrypt(data) try: self._write_to_sock(data, self._local_sock) except Exception as e: shell.print_exception(e) if self._config['verbose']: traceback.print_exc() # TODO use logging when debug completed self.destroy()
def run_server(): def child_handler(signum, _): logging.warn('received SIGQUIT, doing graceful shutting down..') list( map(lambda s: s.close(next_tick=True), tcp_servers + udp_servers)) signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM), child_handler) def int_handler(signum, _): sys.exit(1) signal.signal(signal.SIGINT, int_handler) try: loop = eventloop.EventLoop() dns_resolver.add_to_loop(loop) list(map(lambda s: s.add_to_loop(loop), tcp_servers + udp_servers)) daemon.set_user(config.get('user', None)) loop.run() except Exception as e: shell.print_exception(e) sys.exit(1)
def _send_control_data(self, data): if self._control_client_addr: try: self._control_socket.sendto(data, self._control_client_addr) except (socket.error, OSError, IOError) as e: error_no = eventloop.errno_from_exception(e) if error_no in (errno.EAGAIN, errno.EINPROGRESS, errno.EWOULDBLOCK): return else: shell.print_exception(e) if self._config['verbose']: traceback.print_exc()
def _handle_dns_resolved(self, result, error): if error: self._log_error(error) self.destroy() return if result: ip = result[1] if ip: try: self._stage = STAGE_CONNECTING remote_addr = ip if self._is_local: remote_port = self._chosen_server[1] else: remote_port = self._remote_address[1] if self._is_local and self._config['fast_open']: # for fastopen: # wait for more data to arrive and send them in one SYN self._stage = STAGE_CONNECTING # we don't have to wait for remote since it's not # created self._update_stream(STREAM_UP, WAIT_STATUS_READING) # TODO when there is already data in this packet else: # else do connect remote_sock = self._create_remote_socket( remote_addr, remote_port) try: remote_sock.connect((remote_addr, remote_port)) except (OSError, IOError) as e: if eventloop.errno_from_exception(e) == \ errno.EINPROGRESS: pass self._loop.add(remote_sock, eventloop.POLL_ERR | eventloop.POLL_OUT, self._server) self._stage = STAGE_CONNECTING self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) return except Exception as e: shell.print_exception(e) if self._config['verbose']: traceback.print_exc() self.destroy()
def daemon_stop(pid_file): import errno try: with open(pid_file) as f: buf = f.read() pid = common.to_str(buf) if not buf: logging.error('not running') except IOError as e: shell.print_exception(e) if e.errno == errno.ENOENT: # always exit 0 if we are sure daemon is not running logging.error('not running') return sys.exit(1) pid = int(pid) if pid > 0: try: os.kill(pid, signal.SIGTERM) except OSError as e: if e.errno == errno.ESRCH: logging.error('not running') # always exit 0 if we are sure daemon is not running return shell.print_exception(e) sys.exit(1) else: logging.error('pid is not positive: %d', pid) # sleep for maximum 10s for i in range(0, 200): try: # query for the pid os.kill(pid, 0) except OSError as e: if e.errno == errno.ESRCH: break time.sleep(0.05) else: logging.error('timed out when stopping pid %d', pid) sys.exit(1) print('stopped') os.unlink(pid_file)
def main(): shell.check_python() # fix py2exe if hasattr(sys, "frozen") and sys.frozen in \ ("windows_exe", "console_exe"): p = os.path.dirname(os.path.abspath(sys.executable)) os.chdir(p) config = shell.get_config(True) daemon.daemon_exec(config) try: logging.info("starting local at %s:%d" % (config['local_address'], config['local_port'])) dns_resolver = asyncdns.DNSResolver() tcp_server = tcprelay.TCPRelay(config, dns_resolver, True) udp_server = udprelay.UDPRelay(config, dns_resolver, True) loop = eventloop.EventLoop() dns_resolver.add_to_loop(loop) tcp_server.add_to_loop(loop) udp_server.add_to_loop(loop) def handler(signum, _): logging.warn('received SIGQUIT, doing graceful shutting down..') tcp_server.close(next_tick=True) udp_server.close(next_tick=True) signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM), handler) def int_handler(signum, _): sys.exit(1) signal.signal(signal.SIGINT, int_handler) daemon.set_user(config.get('user', None)) loop.run() except Exception as e: shell.print_exception(e) sys.exit(1)
def _write_to_sock(self, data, sock): # write data to sock # if only some of the data are written, put remaining in the buffer # and update the stream to wait for writing if not data or not sock: return False uncomplete = False try: l = len(data) s = sock.send(data) if s < l: data = data[s:] uncomplete = True except (OSError, IOError) as e: error_no = eventloop.errno_from_exception(e) if error_no in (errno.EAGAIN, errno.EINPROGRESS, errno.EWOULDBLOCK): uncomplete = True else: shell.print_exception(e) self.destroy() return False if uncomplete: if sock == self._local_sock: self._data_to_write_to_local.append(data) self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) elif sock == self._remote_sock: self._data_to_write_to_remote.append(data) self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) else: logging.error('write_all_to_sock:unknown socket') else: if sock == self._local_sock: self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) elif sock == self._remote_sock: self._update_stream(STREAM_UP, WAIT_STATUS_READING) else: logging.error('write_all_to_sock:unknown socket') return True
def parse_response(data): try: if len(data) >= 12: header = parse_header(data) if not header: return None res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, \ res_ancount, res_nscount, res_arcount = header qds = [] ans = [] offset = 12 for i in range(0, res_qdcount): l, r = parse_record(data, offset, True) offset += l if r: qds.append(r) for i in range(0, res_ancount): l, r = parse_record(data, offset) offset += l if r: ans.append(r) for i in range(0, res_nscount): l, r = parse_record(data, offset) offset += l for i in range(0, res_arcount): l, r = parse_record(data, offset) offset += l response = DNSResponse() if qds: response.hostname = qds[0][0] for an in qds: response.questions.append((an[1], an[2], an[3])) for an in ans: response.answers.append((an[1], an[2], an[3])) return response except Exception as e: shell.print_exception(e) return None
def daemon_start(pid_file, log_file): def handle_exit(signum, _): if signum == signal.SIGTERM: sys.exit(0) sys.exit(1) signal.signal(signal.SIGINT, handle_exit) signal.signal(signal.SIGTERM, handle_exit) # fork only once because we are sure parent will exit pid = os.fork() assert pid != -1 if pid > 0: # parent waits for its child time.sleep(5) sys.exit(0) # child signals its parent to exit ppid = os.getppid() pid = os.getpid() if write_pid_file(pid_file, pid) != 0: os.kill(ppid, signal.SIGINT) sys.exit(1) os.setsid() signal.signal(signal.SIG_IGN, signal.SIGHUP) print('started') os.kill(ppid, signal.SIGTERM) sys.stdin.close() try: freopen(log_file, 'a', sys.stdout) freopen(log_file, 'a', sys.stderr) except IOError as e: shell.print_exception(e) sys.exit(1)
def _handle_stage_connecting(self, data): if self._is_local: data = self._encryptor.encrypt(data) self._data_to_write_to_remote.append(data) if self._is_local and not self._fastopen_connected and \ self._config['fast_open']: # for sslocal and fastopen, we basically wait for data and use # sendto to connect try: # only connect once self._fastopen_connected = True remote_sock = \ self._create_remote_socket(self._chosen_server[0], self._chosen_server[1]) self._loop.add(remote_sock, eventloop.POLL_ERR, self._server) data = b''.join(self._data_to_write_to_remote) l = len(data) s = remote_sock.sendto(data, MSG_FASTOPEN, self._chosen_server) if s < l: data = data[s:] self._data_to_write_to_remote = [data] else: self._data_to_write_to_remote = [] self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) except (OSError, IOError) as e: if eventloop.errno_from_exception(e) == errno.EINPROGRESS: # in this case data is not sent at all self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) elif eventloop.errno_from_exception(e) == errno.ENOTCONN: logging.error('fast open not supported on this OS') self._config['fast_open'] = False self.destroy() else: shell.print_exception(e) if self._config['verbose']: traceback.print_exc() self.destroy()
def _handle_server(self): server = self._server_socket data, r_addr = server.recvfrom(BUF_SIZE) if not data: logging.debug('UDP handle_server: data is empty') if self._stat_callback: self._stat_callback(self._listen_port, len(data)) if self._is_local: frag = common.ord(data[2]) if frag != 0: logging.warn('drop a message since frag is not 0') return else: data = data[3:] else: data = encrypt.encrypt_all(self._password, self._method, 0, data) # decrypt data if not data: logging.debug('UDP handle_server: data is empty after decrypt') return header_result = parse_header(data) if header_result is None: return addrtype, dest_addr, dest_port, header_length = header_result if self._is_local: server_addr, server_port = self._get_a_server() else: server_addr, server_port = dest_addr, dest_port addrs = self._dns_cache.get(server_addr, None) if addrs is None: addrs = socket.getaddrinfo(server_addr, server_port, 0, socket.SOCK_DGRAM, socket.SOL_UDP) if not addrs: # drop return else: self._dns_cache[server_addr] = addrs af, socktype, proto, canonname, sa = addrs[0] key = client_key(r_addr, af) client = self._cache.get(key, None) if not client: # TODO async getaddrinfo if self._forbidden_iplist: if common.to_str(sa[0]) in self._forbidden_iplist: logging.debug('IP %s is in forbidden list, drop' % common.to_str(sa[0])) # drop return client = socket.socket(af, socktype, proto) client.setblocking(False) self._cache[key] = client self._client_fd_to_server_addr[client.fileno()] = r_addr self._sockets.add(client.fileno()) self._eventloop.add(client, eventloop.POLL_IN, self) if self._is_local: data = encrypt.encrypt_all(self._password, self._method, 1, data) if not data: return else: data = data[header_length:] if not data: return try: client.sendto(data, (server_addr, server_port)) except IOError as e: err = eventloop.errno_from_exception(e) if err in (errno.EINPROGRESS, errno.EAGAIN): pass else: shell.print_exception(e)