def test_window_increments_appropriately(self): e = Encoder() h = HeadersFrame(1) h.data = e.encode({':status': 200, 'content-type': 'foo/bar'}) h.flags = set(['END_HEADERS']) d = DataFrame(1) d.data = b'hi there sir' d2 = DataFrame(1) d2.data = b'hi there sir again' d2.flags = set(['END_STREAM']) sock = DummySocket() sock.buffer = BytesIO(h.serialize() + d.serialize() + d2.serialize()) c = HTTP20Connection('www.google.com') c._sock = sock c.window_manager.window_size = 1000 c.window_manager.initial_window_size = 1000 c.request('GET', '/') resp = c.get_response() resp.read() queue = list(map(decode_frame, map(memoryview, sock.queue))) assert len(queue) == 3 # one headers frame, two window update frames. assert isinstance(queue[1], WindowUpdateFrame) assert queue[1].window_increment == len(b'hi there sir') assert isinstance(queue[2], WindowUpdateFrame) assert queue[2].window_increment == len(b'hi there sir again')
def __init__(self, logger, config, connection, ip, stream_id, task, send_cb, close_cb, receive_window_manager, remote_window_size, max_frame_size): self.logger = logger self.config = config self.connection = connection self.ip = ip self.stream_id = stream_id self.task = task self.state = STATE_IDLE self.get_head_time = None # There are two flow control windows: one for data we're sending, # one for data being sent to us. self.receive_window_manager = receive_window_manager self.remote_window_size = remote_window_size self.max_frame_size = max_frame_size # This is the callback handed to the stream by its parent connection. # It is called when the stream wants to send data. It expects to # receive a list of frames that will be automatically serialized. self._send_cb = send_cb # This is the callback to be called when the stream is closed. self._close_cb = close_cb # A reference to the header encoder and decoder objects belonging to # the parent connection. self._encoder = Encoder() self._decoder = Decoder() self.request_headers = HTTPHeaderMap() # Convert the body to bytes if needed. self.request_body = to_bytestring(self.task.body) # request body not send blocked by send window # the left body will send when send window opened. self.request_body_left = len(self.request_body) self.request_body_sended = False # data list before decode self.response_header_datas = [] # Set to a key-value set of the response headers once their # HEADERS..CONTINUATION frame sequence finishes. self.response_headers = None # Unconsumed response data chunks self.response_body = [] self.response_body_len = 0 threading.Thread(target=self.start_request).start()
def __init__(self, ssl_sock, close_cb, retry_task_cb): super(HTTP2_worker, self).__init__(ssl_sock, close_cb, retry_task_cb) self.max_concurrent = 20 self.network_buffer_size = 128 * 1024 # Google http/2 time out is 4 mins. ssl_sock.settimeout(240) self._sock = BufferedSocket(ssl_sock, self.network_buffer_size) self.next_stream_id = 1 self.streams = {} self.last_ping_time = time.time() # count ping not ACK # increase when send ping # decrease when recv ping ack # if this in not 0, don't accept request. self.ping_on_way = 0 # request_lock self.request_lock = threading.Lock() # all send frame must put to this queue # then send by send_loop # every frame put to this queue must allowed by stream window and connection window # any data frame blocked by connection window should put to self.blocked_send_frames self.send_queue = Queue.Queue() # keep blocked data frame in this buffer # which is allowed by stream window but blocked by connection window. # They will be sent when connection window open self.blocked_send_frames = [] self.encoder = Encoder() self.decoder = Decoder() # Values for the settings used on an HTTP/2 connection. # will send to remote using Setting Frame self.local_settings = { SettingsFrame.INITIAL_WINDOW_SIZE: 1 * 1024 * 1024, SettingsFrame.SETTINGS_MAX_FRAME_SIZE: 256 * 1024 } self.local_connection_initial_windows = 2 * 1024 * 1024 self.local_window_manager = FlowControlManager(self.local_connection_initial_windows) # changed by server, with SettingFrame self.remote_settings = { SettingsFrame.INITIAL_WINDOW_SIZE: DEFAULT_WINDOW_SIZE, SettingsFrame.SETTINGS_MAX_FRAME_SIZE: DEFAULT_MAX_FRAME, SettingsFrame.MAX_CONCURRENT_STREAMS: 100 } self.remote_window_size = DEFAULT_WINDOW_SIZE # send Setting frame before accept task. self._send_preamble() threading.Thread(target=self.send_loop).start() threading.Thread(target=self.recv_loop).start()
def test_headers_with_continuation(self): e = Encoder() header_data = e.encode( {':status': 200, 'content-type': 'foo/bar', 'content-length': '0'} ) h = HeadersFrame(1) h.data = header_data[0:int(len(header_data)/2)] c = ContinuationFrame(1) c.data = header_data[int(len(header_data)/2):] c.flags |= set(['END_HEADERS', 'END_STREAM']) sock = DummySocket() sock.buffer = BytesIO(h.serialize() + c.serialize()) c = HTTP20Connection('www.google.com') c._sock = sock r = c.request('GET', '/') assert set(c.get_response(r).headers.iter_raw()) == set([(b'content-type', b'foo/bar'), (b'content-length', b'0')])
def test_read_headers_out_of_order(self): # If header blocks aren't decoded in the same order they're received, # regardless of the stream they belong to, the decoder state will become # corrupted. e = Encoder() h1 = HeadersFrame(1) h1.data = e.encode({':status': 200, 'content-type': 'foo/bar'}) h1.flags |= set(['END_HEADERS', 'END_STREAM']) h3 = HeadersFrame(3) h3.data = e.encode({':status': 200, 'content-type': 'baz/qux'}) h3.flags |= set(['END_HEADERS', 'END_STREAM']) sock = DummySocket() sock.buffer = BytesIO(h1.serialize() + h3.serialize()) c = HTTP20Connection('www.google.com') c._sock = sock r1 = c.request('GET', '/a') r3 = c.request('GET', '/b') assert c.get_response(r3).headers == HTTPHeaderMap([('content-type', 'baz/qux')]) assert c.get_response(r1).headers == HTTPHeaderMap([('content-type', 'foo/bar')])
def test_headers_with_continuation(self): e = Encoder() header_data = e.encode({ ':status': 200, 'content-type': 'foo/bar', 'content-length': '0' }) h = HeadersFrame(1) h.data = header_data[0:int(len(header_data) / 2)] c = ContinuationFrame(1) c.data = header_data[int(len(header_data) / 2):] c.flags |= set(['END_HEADERS', 'END_STREAM']) sock = DummySocket() sock.buffer = BytesIO(h.serialize() + c.serialize()) c = HTTP20Connection('www.google.com') c._sock = sock r = c.request('GET', '/') assert set(c.get_response(r).headers.iter_raw()) == set([ (b'content-type', b'foo/bar'), (b'content-length', b'0') ])
def __init__(self, logger, ip_manager, config, ssl_sock, close_cb, retry_task_cb, idle_cb): super(Http2Worker, self).__init__(logger, ip_manager, config, ssl_sock, close_cb, retry_task_cb, idle_cb) self.network_buffer_size = 128 * 1024 # Google http/2 time out is 4 mins. self.ssl_sock.settimeout(240) self._sock = BufferedSocket(ssl_sock, self.network_buffer_size) self.next_stream_id = 1 self.streams = {} self.last_ping_time = time.time() self.last_active_time = self.ssl_sock.create_time - 1 self.continue_timeout = 0 # count ping not ACK # increase when send ping # decrease when recv ping ack # if this in not 0, don't accept request. self.ping_on_way = 0 self.accept_task = False # request_lock self.request_lock = threading.Lock() # all send frame must put to this queue # then send by send_loop # every frame put to this queue must allowed by stream window and connection window # any data frame blocked by connection window should put to self.blocked_send_frames self.send_queue = Queue.Queue() self.encoder = Encoder() self.decoder = Decoder() # keep blocked data frame in this buffer # which is allowed by stream window but blocked by connection window. # They will be sent when connection window open self.blocked_send_frames = [] # Values for the settings used on an HTTP/2 connection. # will send to remote using Setting Frame self.local_settings = { SettingsFrame.INITIAL_WINDOW_SIZE: 16 * 1024 * 1024, SettingsFrame.SETTINGS_MAX_FRAME_SIZE: 256 * 1024 } self.local_connection_initial_windows = 32 * 1024 * 1024 self.local_window_manager = FlowControlManager(self.local_connection_initial_windows) # changed by server, with SettingFrame self.remote_settings = { SettingsFrame.INITIAL_WINDOW_SIZE: DEFAULT_WINDOW_SIZE, SettingsFrame.SETTINGS_MAX_FRAME_SIZE: DEFAULT_MAX_FRAME, SettingsFrame.MAX_CONCURRENT_STREAMS: 100 } #self.remote_window_size = DEFAULT_WINDOW_SIZE self.remote_window_size = 32 * 1024 * 1024 # send Setting frame before accept task. self._send_preamble() threading.Thread(target=self.send_loop).start() threading.Thread(target=self.recv_loop).start()
class Http2Worker(HttpWorker): version = "2" def __init__(self, logger, ip_manager, config, ssl_sock, close_cb, retry_task_cb, idle_cb): super(Http2Worker, self).__init__(logger, ip_manager, config, ssl_sock, close_cb, retry_task_cb, idle_cb) self.network_buffer_size = 128 * 1024 # Google http/2 time out is 4 mins. self.ssl_sock.settimeout(240) self._sock = BufferedSocket(ssl_sock, self.network_buffer_size) self.next_stream_id = 1 self.streams = {} self.last_ping_time = time.time() self.last_active_time = self.ssl_sock.create_time - 1 self.continue_timeout = 0 # count ping not ACK # increase when send ping # decrease when recv ping ack # if this in not 0, don't accept request. self.ping_on_way = 0 self.accept_task = False # request_lock self.request_lock = threading.Lock() # all send frame must put to this queue # then send by send_loop # every frame put to this queue must allowed by stream window and connection window # any data frame blocked by connection window should put to self.blocked_send_frames self.send_queue = Queue.Queue() self.encoder = Encoder() self.decoder = Decoder() # keep blocked data frame in this buffer # which is allowed by stream window but blocked by connection window. # They will be sent when connection window open self.blocked_send_frames = [] # Values for the settings used on an HTTP/2 connection. # will send to remote using Setting Frame self.local_settings = { SettingsFrame.INITIAL_WINDOW_SIZE: 16 * 1024 * 1024, SettingsFrame.SETTINGS_MAX_FRAME_SIZE: 256 * 1024 } self.local_connection_initial_windows = 32 * 1024 * 1024 self.local_window_manager = FlowControlManager(self.local_connection_initial_windows) # changed by server, with SettingFrame self.remote_settings = { SettingsFrame.INITIAL_WINDOW_SIZE: DEFAULT_WINDOW_SIZE, SettingsFrame.SETTINGS_MAX_FRAME_SIZE: DEFAULT_MAX_FRAME, SettingsFrame.MAX_CONCURRENT_STREAMS: 100 } #self.remote_window_size = DEFAULT_WINDOW_SIZE self.remote_window_size = 32 * 1024 * 1024 # send Setting frame before accept task. self._send_preamble() threading.Thread(target=self.send_loop).start() threading.Thread(target=self.recv_loop).start() # export api def request(self, task): if not self.keep_running: # race condition self.retry_task_cb(task) return if len(self.streams) > self.config.http2_max_concurrent: self.accept_task = False task.set_state("h2_req") self.request_task(task) def encode_header(self, headers): with self.request_lock: return self.encoder.encode(headers) def request_task(self, task): with self.request_lock: # create stream to process task stream_id = self.next_stream_id # http/2 client use odd stream_id self.next_stream_id += 2 stream = Stream(self.logger, self.config, self, self.ip, stream_id, task, self._send_cb, self._close_stream_cb, self.encode_header, self.decoder, FlowControlManager(self.local_settings[SettingsFrame.INITIAL_WINDOW_SIZE]), self.remote_settings[SettingsFrame.INITIAL_WINDOW_SIZE], self.remote_settings[SettingsFrame.SETTINGS_MAX_FRAME_SIZE]) self.streams[stream_id] = stream def send_loop(self): while self.keep_running: frame = self.send_queue.get(True) if not frame: # None frame means exist break # self.logger.debug("%s Send:%s", self.ip, str(frame)) data = frame.serialize() try: self._sock.send(data, flush=False) # don't flush for small package # reduce send api call if self.send_queue._qsize(): continue # wait for payload frame time.sleep(0.01) # combine header and payload in one tcp package. if not self.send_queue._qsize(): self._sock.flush() except socket.error as e: if e.errno not in (errno.EPIPE, errno.ECONNRESET): self.logger.warn("%s http2 send fail:%r", self.ip, e) else: self.logger.exception("send error:%r", e) self.close("send fail:%r" % e) except Exception as e: self.logger.debug("http2 %s send error:%r", self.ip, e) self.close("send fail:%r" % e) def recv_loop(self): while self.keep_running: try: self._consume_single_frame() except Exception as e: self.logger.exception("recv fail:%r", e) self.close("recv fail:%r" % e) def get_rtt_rate(self): return self.rtt + len(self.streams) * 3000 def close(self, reason="conn close"): self.keep_running = False self.accept_task = False # Notify loop to exit # This function may be call by out side http2 # When gae_proxy found the appid or ip is wrong self.send_queue.put(None) for stream in self.streams.values(): if stream.task.responsed: # response have send to client # can't retry stream.close(reason=reason) else: self.retry_task_cb(stream.task) self.streams = {} super(Http2Worker, self).close(reason) def send_ping(self): p = PingFrame(0) p.opaque_data = struct.pack("!d", time.time()) self.send_queue.put(p) self.last_ping_time = time.time() self.ping_on_way += 1 def _send_preamble(self): self.send_queue.put(RawFrame(b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n')) f = SettingsFrame(0) f.settings[SettingsFrame.ENABLE_PUSH] = 0 f.settings[SettingsFrame.INITIAL_WINDOW_SIZE] = self.local_settings[SettingsFrame.INITIAL_WINDOW_SIZE] f.settings[SettingsFrame.SETTINGS_MAX_FRAME_SIZE] = self.local_settings[SettingsFrame.SETTINGS_MAX_FRAME_SIZE] self._send_cb(f) # update local connection windows size f = WindowUpdateFrame(0) f.window_increment = self.local_connection_initial_windows - DEFAULT_WINDOW_SIZE self._send_cb(f) def increase_remote_window_size(self, inc_size): # check and send blocked frames if window allow self.remote_window_size += inc_size #self.logger.debug("%s increase send win:%d result:%d", self.ip, inc_size, self.remote_window_size) while len(self.blocked_send_frames): frame = self.blocked_send_frames[0] if len(frame.data) > self.remote_window_size: return self.remote_window_size -= len(frame.data) self.send_queue.put(frame) self.blocked_send_frames.pop(0) if self.keep_running and \ self.accept_task == False and \ len(self.streams) < self.config.http2_max_concurrent and \ self.remote_window_size > 10000: self.accept_task = True self.idle_cb() def _send_cb(self, frame): # can called by stream # put to send_blocked if connection window not allow, if frame.type == DataFrame.type: if len(frame.data) > self.remote_window_size: self.blocked_send_frames.append(frame) self.accept_task = False return else: self.remote_window_size -= len(frame.data) self.send_queue.put(frame) else: self.send_queue.put(frame) def _close_stream_cb(self, stream_id, reason): # call by stream to remove from streams list # self.logger.debug("%s close stream:%d %s", self.ssl_sock.ip, stream_id, reason) try: del self.streams[stream_id] except KeyError: pass if self.keep_running and \ len(self.streams) < self.config.http2_max_concurrent and \ self.remote_window_size > 10000: self.accept_task = True self.idle_cb() self.processed_tasks += 1 def _consume_single_frame(self): try: header = self._sock.recv(9) except Exception as e: self.logger.debug("%s _consume_single_frame:%r, inactive time:%d", self.ip, e, time.time()-self.last_active_time) self.close("disconnect:%r" % e) return # Parse the header. We can use the returned memoryview directly here. frame, length = Frame.parse_frame_header(header.tobytes()) if length > FRAME_MAX_ALLOWED_LEN: self.logger.error("%s Frame size exceeded on stream %d (received: %d, max: %d)", self.ip, frame.stream_id, length, FRAME_MAX_LEN) # self._send_rst_frame(frame.stream_id, 6) # 6 = FRAME_SIZE_ERROR data = self._recv_payload(length) self.last_active_time = time.time() self._consume_frame_payload(frame, data.tobytes()) def _recv_payload(self, length): if not length: return memoryview(b'') buffer = bytearray(length) buffer_view = memoryview(buffer) index = 0 data_length = -1 # _sock.recv(length) might not read out all data if the given length # is very large. So it should be to retrieve from socket repeatedly. while length and data_length: data = self._sock.recv(length) data_length = len(data) end = index + data_length buffer_view[index:end] = data[:] length -= data_length index = end return buffer_view[:end] def _consume_frame_payload(self, frame, data): frame.parse_body(data) # self.logger.debug("%s Recv:%s", self.ip, str(frame)) # Maintain our flow control window. We do this by delegating to the # chosen WindowManager. if frame.type == DataFrame.type: size = frame.flow_controlled_length increment = self.local_window_manager._handle_frame(size) if increment < 0: self.logger.warn("increment:%d", increment) elif increment: #self.logger.debug("%s frame size:%d increase win:%d", self.ip, size, increment) w = WindowUpdateFrame(0) w.window_increment = increment self._send_cb(w) elif frame.type == PushPromiseFrame.type: self.logger.error("%s receive push frame", self.ip,) # Work out to whom this frame should go. if frame.stream_id != 0: try: stream = self.streams[frame.stream_id] stream.receive_frame(frame) except KeyError as e: if frame.type != WindowUpdateFrame.type: self.logger.exception("%s Unexpected stream identifier %d, frame.type:%s e:%r", self.ip, frame.stream_id, frame, e) else: self.receive_frame(frame) def receive_frame(self, frame): if frame.type == WindowUpdateFrame.type: self.logger.debug("WindowUpdateFrame %d", frame.window_increment) self.increase_remote_window_size(frame.window_increment) elif frame.type == PingFrame.type: if 'ACK' in frame.flags: ping_time = struct.unpack("!d", frame.opaque_data)[0] time_now = time.time() rtt = (time_now - ping_time) * 1000 if rtt < 0: self.logger.error("rtt:%f ping_time:%f now:%f", rtt, ping_time, time_now) self.rtt = rtt self.ping_on_way -= 1 #self.logger.debug("RTT:%d, on_way:%d", self.rtt, self.ping_on_way) if self.keep_running and self.ping_on_way == 0: self.accept_task = True else: # The spec requires us to reply with PING+ACK and identical data. p = PingFrame(0) p.flags.add('ACK') p.opaque_data = frame.opaque_data self._send_cb(p) # self.last_active_time = time.time() elif frame.type == SettingsFrame.type: if 'ACK' not in frame.flags: # send ACK as soon as possible f = SettingsFrame(0) f.flags.add('ACK') self._send_cb(f) # this may trigger send DataFrame blocked by remote window self._update_settings(frame) else: self.accept_task = True self.idle_cb() elif frame.type == GoAwayFrame.type: # If we get GoAway with error code zero, we are doing a graceful # shutdown and all is well. Otherwise, throw an exception. # If an error occured, try to read the error description from # code registry otherwise use the frame's additional data. error_string = frame._extra_info() time_cost = time.time() - self.last_active_time if frame.additional_data != "session_timed_out": self.logger.warn("goaway:%s, t:%d", error_string, time_cost) self.close("GoAway:%s inactive time:%d" % (error_string, time_cost)) elif frame.type == BlockedFrame.type: self.logger.warn("%s get BlockedFrame", self.ip) elif frame.type in FRAMES: # This frame isn't valid at this point. #raise ValueError("Unexpected frame %s." % frame) self.logger.error("%s Unexpected frame %s.", self.ip, frame) else: # pragma: no cover # Unexpected frames belong to extensions. Just drop it on the # floor, but log so that users know that something happened. self.logger.error("%s Received unknown frame, type %d", self.ip, frame.type) def _update_settings(self, frame): if SettingsFrame.HEADER_TABLE_SIZE in frame.settings: new_size = frame.settings[SettingsFrame.HEADER_TABLE_SIZE] self.remote_settings[SettingsFrame.HEADER_TABLE_SIZE] = new_size #self.encoder.header_table_size = new_size if SettingsFrame.INITIAL_WINDOW_SIZE in frame.settings: newsize = frame.settings[SettingsFrame.INITIAL_WINDOW_SIZE] oldsize = self.remote_settings[SettingsFrame.INITIAL_WINDOW_SIZE] delta = newsize - oldsize for stream in self.streams.values(): stream.remote_window_size += delta self.remote_settings[SettingsFrame.INITIAL_WINDOW_SIZE] = newsize if SettingsFrame.SETTINGS_MAX_FRAME_SIZE in frame.settings: new_size = frame.settings[SettingsFrame.SETTINGS_MAX_FRAME_SIZE] if not (FRAME_MAX_LEN <= new_size <= FRAME_MAX_ALLOWED_LEN): self.logger.error("%s Frame size %d is outside of allowed range", self.ip, new_size) # Tear the connection down with error code PROTOCOL_ERROR self.close("bad max frame size") #error_string = ("Advertised frame size %d is outside of range" % (new_size)) #raise ConnectionError(error_string) return self.remote_settings[SettingsFrame.SETTINGS_MAX_FRAME_SIZE] = new_size for stream in self.streams.values(): stream.max_frame_size += new_size def get_trace(self): out_list = [] out_list.append(" processed:%d" % self.processed_tasks) out_list.append(" h2.stream_num:%d" % len(self.streams)) out_list.append(" sni:%s, host:%s" % (self.ssl_sock.sni, self.ssl_sock.host)) return ",".join(out_list) def get_host(self, task_host): return task_host
class Stream(object): """ A single HTTP/2 stream. A stream is an independent, bi-directional sequence of HTTP headers and data. Each stream is identified by a single integer. From a HTTP perspective, a stream _approximately_ matches a single request-response pair. """ def __init__(self, logger, config, connection, ip, stream_id, task, send_cb, close_cb, receive_window_manager, remote_window_size, max_frame_size): self.logger = logger self.config = config self.connection = connection self.ip = ip self.stream_id = stream_id self.task = task self.state = STATE_IDLE self.get_head_time = None # There are two flow control windows: one for data we're sending, # one for data being sent to us. self.receive_window_manager = receive_window_manager self.remote_window_size = remote_window_size self.max_frame_size = max_frame_size # This is the callback handed to the stream by its parent connection. # It is called when the stream wants to send data. It expects to # receive a list of frames that will be automatically serialized. self._send_cb = send_cb # This is the callback to be called when the stream is closed. self._close_cb = close_cb # A reference to the header encoder and decoder objects belonging to # the parent connection. self._encoder = Encoder() self._decoder = Decoder() self.request_headers = HTTPHeaderMap() # Convert the body to bytes if needed. self.request_body = to_bytestring(self.task.body) # request body not send blocked by send window # the left body will send when send window opened. self.request_body_left = len(self.request_body) self.request_body_sended = False # data list before decode self.response_header_datas = [] # Set to a key-value set of the response headers once their # HEADERS..CONTINUATION frame sequence finishes. self.response_headers = None # Unconsumed response data chunks self.response_body = [] self.response_body_len = 0 threading.Thread(target=self.start_request).start() def start_request(self): """ Open the stream. Does this by encoding and sending the headers: no more calls to ``add_header`` are allowed after this method is called. The `end` flag controls whether this will be the end of the stream, or whether data will follow. """ # Strip any headers invalid in H2. #headers = h2_safe_headers(self.request_headers) self.add_header(":Method", self.task.method) self.add_header(":Scheme", "https") self.add_header(":Authority", self.task.host) self.add_header(":Path", self.task.path) default_headers = (':method', ':scheme', ':authority', ':path') #headers = h2_safe_headers(self.task.headers) for name, value in self.task.headers.items(): is_default = to_native_string(name) in default_headers self.add_header(name, value, replace=is_default) # Encode the headers. encoded_headers = self._encoder.encode(self.request_headers) # It's possible that there is a substantial amount of data here. The # data needs to go into one HEADERS frame, followed by a number of # CONTINUATION frames. For now, for ease of implementation, let's just # assume that's never going to happen (16kB of headers is lots!). # Additionally, since this is so unlikely, there's no point writing a # test for this: it's just so simple. if len(encoded_headers) > FRAME_MAX_LEN: # pragma: no cover raise ValueError("Header block too large.") header_frame = HeadersFrame(self.stream_id) header_frame.data = encoded_headers # If no data has been provided, this is the end of the stream. Either # way, due to the restriction above it's definitely the end of the # headers. header_frame.flags.add('END_HEADERS') # Send the header frame. self.task.set_state("start send header") self._send_cb(header_frame) # Transition the stream state appropriately. self.state = STATE_OPEN self.task.set_state("start send left body") self.send_left_body() self.task.set_state("end send left body") self.timeout_response() def add_header(self, name, value, replace=False): """ Adds a single HTTP header to the headers to be sent on the request. """ if not replace: self.request_headers[name] = value else: self.request_headers.replace(name, value) def send_left_body(self): while self.remote_window_size and not self.request_body_sended: send_size = min(self.remote_window_size, self.request_body_left, self.max_frame_size) f = DataFrame(self.stream_id) data_start = len(self.request_body) - self.request_body_left f.data = self.request_body[data_start:data_start + send_size] self.remote_window_size -= send_size self.request_body_left -= send_size # If the length of the data is less than MAX_CHUNK, we're probably # at the end of the file. If this is the end of the data, mark it # as END_STREAM. if self.request_body_left == 0: f.flags.add('END_STREAM') # Send the frame and decrement the flow control window. self._send_cb(f) # If no more data is to be sent on this stream, transition our state. if self.request_body_left == 0: self.request_body_sended = True self._close_local() def receive_frame(self, frame): """ Handle a frame received on this stream. called by connection. """ # self.logger.debug("stream %d recved frame %r", self.stream_id, frame) if frame.type == WindowUpdateFrame.type: self.remote_window_size += frame.window_increment self.send_left_body() elif frame.type == HeadersFrame.type: # Begin the header block for the response headers. self.response_header_datas = [frame.data] elif frame.type == PushPromiseFrame.type: self.logger.error("%s receive PushPromiseFrame:%d", self.ip, frame.stream_id) elif frame.type == ContinuationFrame.type: # Continue a header block begun with either HEADERS or PUSH_PROMISE. self.response_header_datas.append(frame.data) elif frame.type == DataFrame.type: # Append the data to the buffer. if not self.task.finished: self.task.put_data(frame.data) if 'END_STREAM' not in frame.flags: # Increase the window size. Only do this if the data frame contains # actual data. # don't do it if stream is closed. size = frame.flow_controlled_length increment = self.receive_window_manager._handle_frame(size) #if increment: # self.logger.debug("stream:%d frame size:%d increase win:%d", self.stream_id, size, increment) #content_len = int(self.request_headers.get("Content-Length")[0]) #self.logger.debug("%s get:%d s:%d", self.ip, self.response_body_len, size) if increment and not self._remote_closed: w = WindowUpdateFrame(self.stream_id) w.window_increment = increment self._send_cb(w) elif frame.type == BlockedFrame.type: # If we've been blocked we may want to fixup the window. increment = self.receive_window_manager._blocked() if increment: w = WindowUpdateFrame(self.stream_id) w.window_increment = increment self._send_cb(w) elif frame.type == RstStreamFrame.type: # Rest Frame send from server is not define in RFC inactive_time = time.time() - self.connection.last_active_time self.logger.debug( "%s Stream %d Rest by server, inactive:%d. error code:%d", self.ip, self.stream_id, inactive_time, frame.error_code) self.connection.close("RESET") elif frame.type in FRAMES: # This frame isn't valid at this point. #raise ValueError("Unexpected frame %s." % frame) self.logger.error("%s Unexpected frame %s.", self.ip, frame) else: # pragma: no cover # Unknown frames belong to extensions. Just drop it on the # floor, but log so that users know that something happened. self.logger.error("%s Received unknown frame, type %d", self.ip, frame.type) pass if 'END_HEADERS' in frame.flags: # Begin by decoding the header block. If this fails, we need to # tear down the entire connection. headers = self._decoder.decode(b''.join( self.response_header_datas)) self._handle_header_block(headers) # We've handled the headers, zero them out. self.response_header_datas = None self.get_head_time = time.time() length = self.response_headers.get("Content-Length", None) if isinstance(length, list): length = int(length[0]) if not self.task.finished: self.task.content_length = length self.task.set_state("h2_get_head") self.send_response() if 'END_STREAM' in frame.flags: #self.logger.debug("%s Closing remote side of stream:%d", self.ip, self.stream_id) time_now = time.time() time_cost = time_now - self.get_head_time if time_cost > 0 and \ isinstance(self.task.content_length, int) and \ not self.task.finished: speed = self.task.content_length / time_cost self.task.set_state("h2_finish[SP:%d]" % speed) self._close_remote() self.close("end stream") if not self.task.finished: self.connection.continue_timeout = 0 def send_response(self): if self.task.responsed: self.logger.error("http2_stream send_response but responsed.%s", self.task.url) self.close("h2 stream send_response but sended.") return self.task.responsed = True status = int(self.response_headers[b':status'][0]) strip_headers(self.response_headers) response = simple_http_client.BaseResponse( status=status, headers=self.response_headers) response.ssl_sock = self.connection.ssl_sock response.worker = self.connection response.task = self.task self.task.queue.put(response) if status in self.config.http2_status_to_close: self.connection.close("status %d" % status) def close(self, reason="close"): if not self.task.responsed: self.connection.retry_task_cb(self.task, reason) else: self.task.finish() # empty block means fail or closed. self._close_cb(self.stream_id, reason) def _handle_header_block(self, headers): """ Handles the logic for receiving a completed headers block. A headers block is an uninterrupted sequence of one HEADERS frame followed by zero or more CONTINUATION frames, and is terminated by a frame bearing the END_HEADERS flag. HTTP/2 allows receipt of up to three such blocks on a stream. The first is optional, and contains a 1XX response. The second is mandatory, and must contain a final response (200 or higher). The third is optional, and may contain 'trailers', headers that are sent after a chunk-encoded body is sent. Here we only process the simple state: no push, one header frame. """ if self.response_headers is None: self.response_headers = HTTPHeaderMap(headers) else: # Received too many headers blocks. raise ProtocolError("Too many header blocks.") return @property def _local_closed(self): return self.state in (STATE_CLOSED, STATE_HALF_CLOSED_LOCAL) @property def _remote_closed(self): return self.state in (STATE_CLOSED, STATE_HALF_CLOSED_REMOTE) @property def _local_open(self): return self.state in (STATE_OPEN, STATE_HALF_CLOSED_REMOTE) def _close_local(self): self.state = (STATE_HALF_CLOSED_LOCAL if self.state == STATE_OPEN else STATE_CLOSED) def _close_remote(self): self.state = (STATE_HALF_CLOSED_REMOTE if self.state == STATE_OPEN else STATE_CLOSED) def timeout_response(self): start_time = time.time() while time.time() - start_time < self.task.timeout: time.sleep(1) if self._remote_closed: return self.logger.warn("h2 timeout %s task_trace:%s worker_trace:%s", self.connection.ssl_sock.ip, self.task.get_trace(), self.connection.get_trace()) self.task.set_state("timeout") if self.task.responsed: self.task.finish() else: self.task.response_fail("timeout") self.connection.continue_timeout += 1 if self.connection.continue_timeout > self.connection.config.http2_max_timeout_tasks and \ time.time() - self.connection.last_active_time > 60: self.connection.close("down fail")
class Http2Worker(HttpWorker): version = "2" def __init__(self, logger, ip_manager, config, ssl_sock, close_cb, retry_task_cb, idle_cb): super(Http2Worker, self).__init__(logger, ip_manager, config, ssl_sock, close_cb, retry_task_cb, idle_cb) self.network_buffer_size = 128 * 1024 # Google http/2 time out is 4 mins. self.ssl_sock.settimeout(240) self._sock = BufferedSocket(ssl_sock, self.network_buffer_size) self.next_stream_id = 1 self.streams = {} self.last_ping_time = time.time() self.last_active_time = self.ssl_sock.create_time - 1 self.continue_timeout = 0 # count ping not ACK # increase when send ping # decrease when recv ping ack # if this in not 0, don't accept request. self.ping_on_way = 0 self.accept_task = False # request_lock self.request_lock = threading.Lock() # all send frame must put to this queue # then send by send_loop # every frame put to this queue must allowed by stream window and connection window # any data frame blocked by connection window should put to self.blocked_send_frames self.send_queue = Queue.Queue() self.encoder = Encoder() self.decoder = Decoder() # keep blocked data frame in this buffer # which is allowed by stream window but blocked by connection window. # They will be sent when connection window open self.blocked_send_frames = [] # Values for the settings used on an HTTP/2 connection. # will send to remote using Setting Frame self.local_settings = { SettingsFrame.INITIAL_WINDOW_SIZE: 16 * 1024 * 1024, SettingsFrame.SETTINGS_MAX_FRAME_SIZE: 256 * 1024 } self.local_connection_initial_windows = 32 * 1024 * 1024 self.local_window_manager = FlowControlManager( self.local_connection_initial_windows) # changed by server, with SettingFrame self.remote_settings = { SettingsFrame.INITIAL_WINDOW_SIZE: DEFAULT_WINDOW_SIZE, SettingsFrame.SETTINGS_MAX_FRAME_SIZE: DEFAULT_MAX_FRAME, SettingsFrame.MAX_CONCURRENT_STREAMS: 100 } #self.remote_window_size = DEFAULT_WINDOW_SIZE self.remote_window_size = 32 * 1024 * 1024 # send Setting frame before accept task. self._send_preamble() threading.Thread(target=self.send_loop).start() threading.Thread(target=self.recv_loop).start() # export api def request(self, task): if not self.keep_running: # race condition self.retry_task_cb(task) return if len(self.streams) > self.config.http2_max_concurrent: self.accept_task = False task.set_state("h2_req") self.request_task(task) def encode_header(self, headers): with self.request_lock: return self.encoder.encode(headers) def request_task(self, task): with self.request_lock: # create stream to process task stream_id = self.next_stream_id # http/2 client use odd stream_id self.next_stream_id += 2 stream = Stream( self.logger, self.config, self, self.ip, stream_id, task, self._send_cb, self._close_stream_cb, self.encode_header, self.decoder, FlowControlManager( self.local_settings[SettingsFrame.INITIAL_WINDOW_SIZE]), self.remote_settings[SettingsFrame.INITIAL_WINDOW_SIZE], self.remote_settings[SettingsFrame.SETTINGS_MAX_FRAME_SIZE]) self.streams[stream_id] = stream def send_loop(self): while self.keep_running: frame = self.send_queue.get(True) if not frame: # None frame means exist break # self.logger.debug("%s Send:%s", self.ip, str(frame)) data = frame.serialize() try: self._sock.send(data, flush=False) # don't flush for small package # reduce send api call if self.send_queue._qsize(): continue # wait for payload frame time.sleep(0.01) # combine header and payload in one tcp package. if not self.send_queue._qsize(): self._sock.flush() except socket.error as e: if e.errno not in (errno.EPIPE, errno.ECONNRESET): self.logger.warn("%s http2 send fail:%r", self.ip, e) else: self.logger.exception("send error:%r", e) self.close("send fail:%r" % e) except Exception as e: self.logger.debug("http2 %s send error:%r", self.ip, e) self.close("send fail:%r" % e) def recv_loop(self): while self.keep_running: try: self._consume_single_frame() except Exception as e: self.logger.exception("recv fail:%r", e) self.close("recv fail:%r" % e) def get_rtt_rate(self): return self.rtt + len(self.streams) * 3000 def close(self, reason="conn close"): self.keep_running = False self.accept_task = False # Notify loop to exit # This function may be call by out side http2 # When gae_proxy found the appid or ip is wrong self.send_queue.put(None) for stream in self.streams.values(): if stream.task.responsed: # response have send to client # can't retry stream.close(reason=reason) else: self.retry_task_cb(stream.task) self.streams = {} super(Http2Worker, self).close(reason) def send_ping(self): p = PingFrame(0) p.opaque_data = struct.pack("!d", time.time()) self.send_queue.put(p) self.last_ping_time = time.time() self.ping_on_way += 1 def _send_preamble(self): self.send_queue.put(RawFrame(b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n')) f = SettingsFrame(0) f.settings[SettingsFrame.ENABLE_PUSH] = 0 f.settings[SettingsFrame.INITIAL_WINDOW_SIZE] = self.local_settings[ SettingsFrame.INITIAL_WINDOW_SIZE] f.settings[ SettingsFrame.SETTINGS_MAX_FRAME_SIZE] = self.local_settings[ SettingsFrame.SETTINGS_MAX_FRAME_SIZE] self._send_cb(f) # update local connection windows size f = WindowUpdateFrame(0) f.window_increment = self.local_connection_initial_windows - DEFAULT_WINDOW_SIZE self._send_cb(f) def increase_remote_window_size(self, inc_size): # check and send blocked frames if window allow self.remote_window_size += inc_size #self.logger.debug("%s increase send win:%d result:%d", self.ip, inc_size, self.remote_window_size) while len(self.blocked_send_frames): frame = self.blocked_send_frames[0] if len(frame.data) > self.remote_window_size: return self.remote_window_size -= len(frame.data) self.send_queue.put(frame) self.blocked_send_frames.pop(0) if self.keep_running and \ self.accept_task == False and \ len(self.streams) < self.config.http2_max_concurrent and \ self.remote_window_size > 10000: self.accept_task = True self.idle_cb() def _send_cb(self, frame): # can called by stream # put to send_blocked if connection window not allow, if frame.type == DataFrame.type: if len(frame.data) > self.remote_window_size: self.blocked_send_frames.append(frame) self.accept_task = False return else: self.remote_window_size -= len(frame.data) self.send_queue.put(frame) else: self.send_queue.put(frame) def _close_stream_cb(self, stream_id, reason): # call by stream to remove from streams list # self.logger.debug("%s close stream:%d %s", self.ssl_sock.ip, stream_id, reason) try: del self.streams[stream_id] except KeyError: pass if self.keep_running and \ len(self.streams) < self.config.http2_max_concurrent and \ self.remote_window_size > 10000: self.accept_task = True self.idle_cb() self.processed_tasks += 1 def _consume_single_frame(self): try: header = self._sock.recv(9) except Exception as e: self.logger.debug("%s _consume_single_frame:%r, inactive time:%d", self.ip, e, time.time() - self.last_active_time) self.close("disconnect:%r" % e) return # Parse the header. We can use the returned memoryview directly here. frame, length = Frame.parse_frame_header(header.tobytes()) if length > FRAME_MAX_ALLOWED_LEN: self.logger.error( "%s Frame size exceeded on stream %d (received: %d, max: %d)", self.ip, frame.stream_id, length, FRAME_MAX_LEN) # self._send_rst_frame(frame.stream_id, 6) # 6 = FRAME_SIZE_ERROR data = self._recv_payload(length) self.last_active_time = time.time() self._consume_frame_payload(frame, data.tobytes()) def _recv_payload(self, length): if not length: return memoryview(b'') buffer = bytearray(length) buffer_view = memoryview(buffer) index = 0 data_length = -1 # _sock.recv(length) might not read out all data if the given length # is very large. So it should be to retrieve from socket repeatedly. while length and data_length: data = self._sock.recv(length) data_length = len(data) end = index + data_length buffer_view[index:end] = data[:] length -= data_length index = end return buffer_view[:end] def _consume_frame_payload(self, frame, data): frame.parse_body(data) # self.logger.debug("%s Recv:%s", self.ip, str(frame)) # Maintain our flow control window. We do this by delegating to the # chosen WindowManager. if frame.type == DataFrame.type: size = frame.flow_controlled_length increment = self.local_window_manager._handle_frame(size) if increment < 0: self.logger.warn("increment:%d", increment) elif increment: #self.logger.debug("%s frame size:%d increase win:%d", self.ip, size, increment) w = WindowUpdateFrame(0) w.window_increment = increment self._send_cb(w) elif frame.type == PushPromiseFrame.type: self.logger.error( "%s receive push frame", self.ip, ) # Work out to whom this frame should go. if frame.stream_id != 0: try: stream = self.streams[frame.stream_id] stream.receive_frame(frame) except KeyError as e: if frame.type != WindowUpdateFrame.type: self.logger.exception( "%s Unexpected stream identifier %d, frame.type:%s e:%r", self.ip, frame.stream_id, frame, e) else: self.receive_frame(frame) def receive_frame(self, frame): if frame.type == WindowUpdateFrame.type: self.logger.debug("WindowUpdateFrame %d", frame.window_increment) self.increase_remote_window_size(frame.window_increment) elif frame.type == PingFrame.type: if 'ACK' in frame.flags: ping_time = struct.unpack("!d", frame.opaque_data)[0] time_now = time.time() rtt = (time_now - ping_time) * 1000 if rtt < 0: self.logger.error("rtt:%f ping_time:%f now:%f", rtt, ping_time, time_now) self.rtt = rtt self.ping_on_way -= 1 #self.logger.debug("RTT:%d, on_way:%d", self.rtt, self.ping_on_way) if self.keep_running and self.ping_on_way == 0: self.accept_task = True else: # The spec requires us to reply with PING+ACK and identical data. p = PingFrame(0) p.flags.add('ACK') p.opaque_data = frame.opaque_data self._send_cb(p) # self.last_active_time = time.time() elif frame.type == SettingsFrame.type: if 'ACK' not in frame.flags: # send ACK as soon as possible f = SettingsFrame(0) f.flags.add('ACK') self._send_cb(f) # this may trigger send DataFrame blocked by remote window self._update_settings(frame) else: self.accept_task = True self.idle_cb() elif frame.type == GoAwayFrame.type: # If we get GoAway with error code zero, we are doing a graceful # shutdown and all is well. Otherwise, throw an exception. # If an error occured, try to read the error description from # code registry otherwise use the frame's additional data. error_string = frame._extra_info() time_cost = time.time() - self.last_active_time if frame.additional_data != "session_timed_out": self.logger.warn("goaway:%s, t:%d", error_string, time_cost) self.close("GoAway:%s inactive time:%d" % (error_string, time_cost)) elif frame.type == BlockedFrame.type: self.logger.warn("%s get BlockedFrame", self.ip) elif frame.type in FRAMES: # This frame isn't valid at this point. #raise ValueError("Unexpected frame %s." % frame) self.logger.error("%s Unexpected frame %s.", self.ip, frame) else: # pragma: no cover # Unexpected frames belong to extensions. Just drop it on the # floor, but log so that users know that something happened. self.logger.error("%s Received unknown frame, type %d", self.ip, frame.type) def _update_settings(self, frame): if SettingsFrame.HEADER_TABLE_SIZE in frame.settings: new_size = frame.settings[SettingsFrame.HEADER_TABLE_SIZE] self.remote_settings[SettingsFrame.HEADER_TABLE_SIZE] = new_size #self.encoder.header_table_size = new_size if SettingsFrame.INITIAL_WINDOW_SIZE in frame.settings: newsize = frame.settings[SettingsFrame.INITIAL_WINDOW_SIZE] oldsize = self.remote_settings[SettingsFrame.INITIAL_WINDOW_SIZE] delta = newsize - oldsize for stream in self.streams.values(): stream.remote_window_size += delta self.remote_settings[SettingsFrame.INITIAL_WINDOW_SIZE] = newsize if SettingsFrame.SETTINGS_MAX_FRAME_SIZE in frame.settings: new_size = frame.settings[SettingsFrame.SETTINGS_MAX_FRAME_SIZE] if not (FRAME_MAX_LEN <= new_size <= FRAME_MAX_ALLOWED_LEN): self.logger.error( "%s Frame size %d is outside of allowed range", self.ip, new_size) # Tear the connection down with error code PROTOCOL_ERROR self.close("bad max frame size") #error_string = ("Advertised frame size %d is outside of range" % (new_size)) #raise ConnectionError(error_string) return self.remote_settings[ SettingsFrame.SETTINGS_MAX_FRAME_SIZE] = new_size for stream in self.streams.values(): stream.max_frame_size += new_size def get_trace(self): out_list = [] out_list.append(" processed:%d" % self.processed_tasks) out_list.append(" h2.stream_num:%d" % len(self.streams)) out_list.append(" sni:%s, host:%s" % (self.ssl_sock.sni, self.ssl_sock.host)) return ",".join(out_list) def get_host(self, task_host): return task_host
def setup_method(self, method): self.frames = [] self.encoder = Encoder() self.conn = None
class TestServerPush(object): def setup_method(self, method): self.frames = [] self.encoder = Encoder() self.conn = None def add_push_frame(self, stream_id, promised_stream_id, headers, end_block=True): frame = PushPromiseFrame(stream_id) frame.promised_stream_id = promised_stream_id frame.data = self.encoder.encode(headers) if end_block: frame.flags.add('END_HEADERS') self.frames.append(frame) def add_headers_frame(self, stream_id, headers, end_block=True, end_stream=False): frame = HeadersFrame(stream_id) frame.data = self.encoder.encode(headers) if end_block: frame.flags.add('END_HEADERS') if end_stream: frame.flags.add('END_STREAM') self.frames.append(frame) def add_data_frame(self, stream_id, data, end_stream=False): frame = DataFrame(stream_id) frame.data = data if end_stream: frame.flags.add('END_STREAM') self.frames.append(frame) def request(self): self.conn = HTTP20Connection('www.google.com', enable_push=True) self.conn._sock = DummySocket() self.conn._sock.buffer = BytesIO(b''.join([frame.serialize() for frame in self.frames])) self.conn.request('GET', '/') def assert_response(self): self.response = self.conn.get_response() assert self.response.status == 200 assert dict(self.response.headers) == {b'content-type': [b'text/html']} def assert_pushes(self): self.pushes = list(self.conn.get_pushes()) assert len(self.pushes) == 1 assert self.pushes[0].method == b'GET' assert self.pushes[0].scheme == b'https' assert self.pushes[0].authority == b'www.google.com' assert self.pushes[0].path == b'/' expected_headers = {b'accept-encoding': [b'gzip']} assert dict(self.pushes[0].request_headers) == expected_headers def assert_push_response(self): push_response = self.pushes[0].get_response() assert push_response.status == 200 assert dict(push_response.headers) == {b'content-type': [b'application/javascript']} assert push_response.read() == b'bar' def test_promise_before_headers(self): self.add_push_frame(1, 2, [(':method', 'GET'), (':path', '/'), (':authority', 'www.google.com'), (':scheme', 'https'), ('accept-encoding', 'gzip')]) self.add_headers_frame(1, [(':status', '200'), ('content-type', 'text/html')]) self.add_data_frame(1, b'foo', end_stream=True) self.add_headers_frame(2, [(':status', '200'), ('content-type', 'application/javascript')]) self.add_data_frame(2, b'bar', end_stream=True) self.request() assert len(list(self.conn.get_pushes())) == 0 self.assert_response() self.assert_pushes() assert self.response.read() == b'foo' self.assert_push_response() def test_promise_after_headers(self): self.add_headers_frame(1, [(':status', '200'), ('content-type', 'text/html')]) self.add_push_frame(1, 2, [(':method', 'GET'), (':path', '/'), (':authority', 'www.google.com'), (':scheme', 'https'), ('accept-encoding', 'gzip')]) self.add_data_frame(1, b'foo', end_stream=True) self.add_headers_frame(2, [(':status', '200'), ('content-type', 'application/javascript')]) self.add_data_frame(2, b'bar', end_stream=True) self.request() assert len(list(self.conn.get_pushes())) == 0 self.assert_response() assert len(list(self.conn.get_pushes())) == 0 assert self.response.read() == b'foo' self.assert_pushes() self.assert_push_response() def test_promise_after_data(self): self.add_headers_frame(1, [(':status', '200'), ('content-type', 'text/html')]) self.add_data_frame(1, b'fo') self.add_push_frame(1, 2, [(':method', 'GET'), (':path', '/'), (':authority', 'www.google.com'), (':scheme', 'https'), ('accept-encoding', 'gzip')]) self.add_data_frame(1, b'o', end_stream=True) self.add_headers_frame(2, [(':status', '200'), ('content-type', 'application/javascript')]) self.add_data_frame(2, b'bar', end_stream=True) self.request() assert len(list(self.conn.get_pushes())) == 0 self.assert_response() assert len(list(self.conn.get_pushes())) == 0 assert self.response.read() == b'foo' self.assert_pushes() self.assert_push_response() def test_capture_all_promises(self): self.add_push_frame(1, 2, [(':method', 'GET'), (':path', '/one'), (':authority', 'www.google.com'), (':scheme', 'https'), ('accept-encoding', 'gzip')]) self.add_headers_frame(1, [(':status', '200'), ('content-type', 'text/html')]) self.add_push_frame(1, 4, [(':method', 'GET'), (':path', '/two'), (':authority', 'www.google.com'), (':scheme', 'https'), ('accept-encoding', 'gzip')]) self.add_data_frame(1, b'foo', end_stream=True) self.add_headers_frame(4, [(':status', '200'), ('content-type', 'application/javascript')]) self.add_headers_frame(2, [(':status', '200'), ('content-type', 'application/javascript')]) self.add_data_frame(4, b'two', end_stream=True) self.add_data_frame(2, b'one', end_stream=True) self.request() assert len(list(self.conn.get_pushes())) == 0 pushes = list(self.conn.get_pushes(capture_all=True)) assert len(pushes) == 2 assert pushes[0].path == b'/one' assert pushes[1].path == b'/two' assert pushes[0].get_response().read() == b'one' assert pushes[1].get_response().read() == b'two' self.assert_response() assert self.response.read() == b'foo' def test_cancel_push(self): self.add_push_frame(1, 2, [(':method', 'GET'), (':path', '/'), (':authority', 'www.google.com'), (':scheme', 'https'), ('accept-encoding', 'gzip')]) self.add_headers_frame(1, [(':status', '200'), ('content-type', 'text/html')]) self.request() self.conn.get_response() list(self.conn.get_pushes())[0].cancel() f = RstStreamFrame(2) f.error_code = 8 assert self.conn._sock.queue[-1] == f.serialize() def test_reset_pushed_streams_when_push_disabled(self): self.add_push_frame(1, 2, [(':method', 'GET'), (':path', '/'), (':authority', 'www.google.com'), (':scheme', 'https'), ('accept-encoding', 'gzip')]) self.add_headers_frame(1, [(':status', '200'), ('content-type', 'text/html')]) self.request() self.conn._enable_push = False self.conn.get_response() f = RstStreamFrame(2) f.error_code = 7 assert self.conn._sock.queue[-1] == f.serialize() def test_pushed_requests_ignore_unexpected_headers(self): headers = HTTPHeaderMap([ (':scheme', 'http'), (':method', 'get'), (':authority', 'google.com'), (':path', '/'), (':reserved', 'no'), ('no', 'no'), ]) p = HTTP20Push(headers, DummyStream(b'')) assert p.request_headers == HTTPHeaderMap([('no', 'no')])
class TestServerPush(object): def setup_method(self, method): self.frames = [] self.encoder = Encoder() self.conn = None def add_push_frame(self, stream_id, promised_stream_id, headers, end_block=True): frame = PushPromiseFrame(stream_id) frame.promised_stream_id = promised_stream_id frame.data = self.encoder.encode(headers) if end_block: frame.flags.add('END_HEADERS') self.frames.append(frame) def add_headers_frame(self, stream_id, headers, end_block=True, end_stream=False): frame = HeadersFrame(stream_id) frame.data = self.encoder.encode(headers) if end_block: frame.flags.add('END_HEADERS') if end_stream: frame.flags.add('END_STREAM') self.frames.append(frame) def add_data_frame(self, stream_id, data, end_stream=False): frame = DataFrame(stream_id) frame.data = data if end_stream: frame.flags.add('END_STREAM') self.frames.append(frame) def request(self): self.conn = HTTP20Connection('www.google.com', enable_push=True) self.conn._sock = DummySocket() self.conn._sock.buffer = BytesIO(b''.join( [frame.serialize() for frame in self.frames])) self.conn.request('GET', '/') def assert_response(self): self.response = self.conn.get_response() assert self.response.status == 200 assert dict(self.response.headers) == {b'content-type': [b'text/html']} def assert_pushes(self): self.pushes = list(self.conn.get_pushes()) assert len(self.pushes) == 1 assert self.pushes[0].method == b'GET' assert self.pushes[0].scheme == b'https' assert self.pushes[0].authority == b'www.google.com' assert self.pushes[0].path == b'/' expected_headers = {b'accept-encoding': [b'gzip']} assert dict(self.pushes[0].request_headers) == expected_headers def assert_push_response(self): push_response = self.pushes[0].get_response() assert push_response.status == 200 assert dict(push_response.headers) == { b'content-type': [b'application/javascript'] } assert push_response.read() == b'bar' def test_promise_before_headers(self): self.add_push_frame(1, 2, [(':method', 'GET'), (':path', '/'), (':authority', 'www.google.com'), (':scheme', 'https'), ('accept-encoding', 'gzip')]) self.add_headers_frame(1, [(':status', '200'), ('content-type', 'text/html')]) self.add_data_frame(1, b'foo', end_stream=True) self.add_headers_frame(2, [(':status', '200'), ('content-type', 'application/javascript')]) self.add_data_frame(2, b'bar', end_stream=True) self.request() assert len(list(self.conn.get_pushes())) == 0 self.assert_response() self.assert_pushes() assert self.response.read() == b'foo' self.assert_push_response() def test_promise_after_headers(self): self.add_headers_frame(1, [(':status', '200'), ('content-type', 'text/html')]) self.add_push_frame(1, 2, [(':method', 'GET'), (':path', '/'), (':authority', 'www.google.com'), (':scheme', 'https'), ('accept-encoding', 'gzip')]) self.add_data_frame(1, b'foo', end_stream=True) self.add_headers_frame(2, [(':status', '200'), ('content-type', 'application/javascript')]) self.add_data_frame(2, b'bar', end_stream=True) self.request() assert len(list(self.conn.get_pushes())) == 0 self.assert_response() assert len(list(self.conn.get_pushes())) == 0 assert self.response.read() == b'foo' self.assert_pushes() self.assert_push_response() def test_promise_after_data(self): self.add_headers_frame(1, [(':status', '200'), ('content-type', 'text/html')]) self.add_data_frame(1, b'fo') self.add_push_frame(1, 2, [(':method', 'GET'), (':path', '/'), (':authority', 'www.google.com'), (':scheme', 'https'), ('accept-encoding', 'gzip')]) self.add_data_frame(1, b'o', end_stream=True) self.add_headers_frame(2, [(':status', '200'), ('content-type', 'application/javascript')]) self.add_data_frame(2, b'bar', end_stream=True) self.request() assert len(list(self.conn.get_pushes())) == 0 self.assert_response() assert len(list(self.conn.get_pushes())) == 0 assert self.response.read() == b'foo' self.assert_pushes() self.assert_push_response() def test_capture_all_promises(self): self.add_push_frame(1, 2, [(':method', 'GET'), (':path', '/one'), (':authority', 'www.google.com'), (':scheme', 'https'), ('accept-encoding', 'gzip')]) self.add_headers_frame(1, [(':status', '200'), ('content-type', 'text/html')]) self.add_push_frame(1, 4, [(':method', 'GET'), (':path', '/two'), (':authority', 'www.google.com'), (':scheme', 'https'), ('accept-encoding', 'gzip')]) self.add_data_frame(1, b'foo', end_stream=True) self.add_headers_frame(4, [(':status', '200'), ('content-type', 'application/javascript')]) self.add_headers_frame(2, [(':status', '200'), ('content-type', 'application/javascript')]) self.add_data_frame(4, b'two', end_stream=True) self.add_data_frame(2, b'one', end_stream=True) self.request() assert len(list(self.conn.get_pushes())) == 0 pushes = list(self.conn.get_pushes(capture_all=True)) assert len(pushes) == 2 assert pushes[0].path == b'/one' assert pushes[1].path == b'/two' assert pushes[0].get_response().read() == b'one' assert pushes[1].get_response().read() == b'two' self.assert_response() assert self.response.read() == b'foo' def test_cancel_push(self): self.add_push_frame(1, 2, [(':method', 'GET'), (':path', '/'), (':authority', 'www.google.com'), (':scheme', 'https'), ('accept-encoding', 'gzip')]) self.add_headers_frame(1, [(':status', '200'), ('content-type', 'text/html')]) self.request() self.conn.get_response() list(self.conn.get_pushes())[0].cancel() f = RstStreamFrame(2) f.error_code = 8 assert self.conn._sock.queue[-1] == f.serialize() def test_reset_pushed_streams_when_push_disabled(self): self.add_push_frame(1, 2, [(':method', 'GET'), (':path', '/'), (':authority', 'www.google.com'), (':scheme', 'https'), ('accept-encoding', 'gzip')]) self.add_headers_frame(1, [(':status', '200'), ('content-type', 'text/html')]) self.request() self.conn._enable_push = False self.conn.get_response() f = RstStreamFrame(2) f.error_code = 7 assert self.conn._sock.queue[-1] == f.serialize() def test_pushed_requests_ignore_unexpected_headers(self): headers = HTTPHeaderMap([ (':scheme', 'http'), (':method', 'get'), (':authority', 'google.com'), (':path', '/'), (':reserved', 'no'), ('no', 'no'), ]) p = HTTP20Push(headers, DummyStream(b'')) assert p.request_headers == HTTPHeaderMap([('no', 'no')])