Esempio n. 1
0
class FrameGrabber(Thread):
    '''Watch if a new frame is coming
    '''
    def __init__(self, messages):
        super(FrameGrabber, self).__init__()
        self.messages = messages
        self.notifier = Condition()
        self.frame = None
        self.to_exit = Event()

    def shutdown(self):
        ''' Stop the thread
        '''
        self.to_exit.set()

    def run(self):
        while not self.to_exit.is_set():
            try:
                message = self.messages.get(True, 1)
                if message is None:
                    break
                frame = message.get("ar_frame", None)
                if frame is None:
                    continue
                _, frame = cv2.imencode(".jpg", frame,
                                        [cv2.IMWRITE_JPEG_QUALITY, 80])
                frame = frame.tostring()
                self.frame = frame
                self.notifier.notify()
            except Queue.Empty:
                continue
Esempio n. 2
0
    def get_data(cls, account, source_filter, limit=100, skip=0):
        """
        Gathers card information from Google Sheets
        GET https://spreadsheets.google.com/feeds/list/[spreadsheet]/[worksheet]/private/full
        """
        if not account or not account.enabled:
            raise ValueError('cannot gather information without an account')
        client = AsyncHTTPClient()

        if source_filter.spreadsheet is None:
            raise ValueError('required parameter spreadsheet missing')
        if source_filter.worksheet is None:
            raise ValueError('required parameter worksheet missing')
        uri = "https://docs.google.com/spreadsheets/d/{}/export?format=csv&gid={}".format(
            source_filter.spreadsheet, source_filter.worksheet
        )

        app_log.info(
            "Start retrieval of worksheet {}/{} for {}".format(source_filter.spreadsheet, source_filter.worksheet,
                                                               account._id))

        lock = Condition()
        oauth_client = account.get_client()
        uri, headers, body = oauth_client.add_token(uri)
        req = HTTPRequest(uri, headers=headers, body=body, streaming_callback=lambda c: cls.write(c))

        client.fetch(req, callback=lambda r: lock.notify())
        yield lock.wait(timeout=timedelta(seconds=MAXIMUM_REQ_TIME))

        app_log.info(
            "Finished retrieving worksheet for {}".format(account._id))
Esempio n. 3
0
    def get_data(cls, account, source_filter, limit=100, skip=0):
        """
        Gathers card information from Google Sheets
        GET https://spreadsheets.google.com/feeds/list/[spreadsheet]/[worksheet]/private/full
        """
        if not account or not account.enabled:
            raise ValueError('cannot gather information without an account')
        client = AsyncHTTPClient()

        if source_filter.spreadsheet is None:
            raise ValueError('required parameter spreadsheet missing')
        if source_filter.worksheet is None:
            raise ValueError('required parameter worksheet missing')
        uri = "https://docs.google.com/spreadsheets/d/{}/export?format=csv&gid={}".format(
            source_filter.spreadsheet, source_filter.worksheet)

        app_log.info("Start retrieval of worksheet {}/{} for {}".format(
            source_filter.spreadsheet, source_filter.worksheet, account._id))

        lock = Condition()
        oauth_client = account.get_client()
        uri, headers, body = oauth_client.add_token(uri)
        req = HTTPRequest(uri,
                          headers=headers,
                          body=body,
                          streaming_callback=lambda c: cls.write(c))

        client.fetch(req, callback=lambda r: lock.notify())
        yield lock.wait(timeout=timedelta(seconds=MAXIMUM_REQ_TIME))

        app_log.info("Finished retrieving worksheet for {}".format(
            account._id))
Esempio n. 4
0
class PingHandler(firenado.tornadoweb.TornadoHandler):

    def __init__(self, application, request, **kwargs):
        super(PingHandler, self).__init__(application, request, **kwargs)
        self.callback_queue = None
        self.condition = Condition()
        self.response = None
        self.corr_id = str(uuid.uuid4())
        self.in_channel = self.application.get_app_component().rabbitmq[
            'client'].channels['in']

    @gen.coroutine
    def post(self):
        self.in_channel.queue_declare(exclusive=True,
                                      callback=self.on_request_queue_declared)
        yield self.condition.wait()

        self.write(self.response)

    def on_request_queue_declared(self, response):
        logger.info('Request temporary queue declared.')
        self.callback_queue = response.method.queue
        self.in_channel.basic_consume(self.on_response, no_ack=True,
                                      queue=self.callback_queue)
        self.in_channel.basic_publish(
            exchange='',
            routing_key='ping_rpc_queue',
            properties=pika.BasicProperties(
                reply_to=self.callback_queue,
                correlation_id=self.corr_id,
            ),
            body=self.request.body)

    def on_response(self, ch, method, props, body):
        if self.corr_id == props.correlation_id:
            self.response = {
                'data': body.decode("utf-8"),
                'date': datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            }
            self.in_channel.queue_delete(queue=self.callback_queue)
            self.condition.notify()
Esempio n. 5
0
class PingHandler(firenado.tornadoweb.TornadoHandler):

    def __init__(self, application, request, **kwargs):
        super(PingHandler, self).__init__(application, request, **kwargs)
        self.callback_queue = None
        self.condition = Condition()
        self.response = None
        self.corr_id = str(uuid.uuid4())
        self.in_channel = self.application.get_app_component().rabbitmq[
            'client'].channels['in']

    @gen.coroutine
    def post(self):
        self.in_channel.queue_declare(exclusive=True,
                                      callback=self.on_request_queue_declared)
        yield self.condition.wait()
        self.write(self.response)

    def on_request_queue_declared(self, response):
        logger.info('Request temporary queue declared.')
        self.callback_queue = response.method.queue
        self.in_channel.basic_consume(self.on_response, no_ack=True,
                                      queue=self.callback_queue)
        self.in_channel.basic_publish(
            exchange='',
            routing_key='ping_rpc_queue',
            properties=pika.BasicProperties(
                reply_to=self.callback_queue,
                correlation_id=self.corr_id,
            ),
            body=self.request.body)

    def on_response(self, ch, method, props, body):
        if self.corr_id == props.correlation_id:
            self.response = {
                'data': body.decode("utf-8"),
                'date': datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            }
            self.in_channel.queue_delete(queue=self.callback_queue)
            self.condition.notify()
Esempio n. 6
0
class AwaitableState:
    def __init__(self, initial_state: State = State.REQUEST_SENT):
        self._cond = Condition()
        self._state = initial_state

    def get(self) -> State:
        return self._state

    def set(self, state: State):
        self._set(state)
        self._cond.notify()

    def _set(self, state: State):
        log.debug("State changed to: {state}".format(state=state.name))
        self._state = state

    async def wait_for_state_change(self, expires: timedelta):
        """ Returns true if state has changed, or false on timeouts """
        state_changed = await self._cond.wait(timeout=expires)
        if not state_changed:
            self._set(State.EXPIRED)
        return state_changed
Esempio n. 7
0
class InMemStream(Stream):

    def __init__(self, buf=None, auto_close=True):
        """In-Memory based stream

        :param buf: the buffer for the in memory stream
        """
        self._stream = deque()
        if buf:
            self._stream.append(buf)
        self.state = StreamState.init
        self._condition = Condition()
        self.auto_close = auto_close

        self.exception = None

    def clone(self):
        new_stream = InMemStream()
        new_stream.state = self.state
        new_stream.auto_close = self.auto_close
        new_stream._stream = deque(self._stream)
        return new_stream

    def read(self):

        def read_chunk(future):
            if self.exception:
                future.set_exception(self.exception)
                return future

            chunk = ""

            while len(self._stream) and len(chunk) < common.MAX_PAYLOAD_SIZE:
                chunk += self._stream.popleft()

            future.set_result(chunk)
            return future

        read_future = tornado.concurrent.Future()

        # We're not ready yet
        if self.state != StreamState.completed and not len(self._stream):
            wait_future = self._condition.wait()
            wait_future.add_done_callback(
                lambda f: f.exception() or read_chunk(read_future)
            )
            return read_future

        return read_chunk(read_future)

    def write(self, chunk):
        if self.exception:
            raise self.exception

        if self.state == StreamState.completed:
            raise UnexpectedError("Stream has been closed.")

        if chunk:
            self._stream.append(chunk)
            self._condition.notify()

        # This needs to return a future to match the async interface.
        r = tornado.concurrent.Future()
        r.set_result(None)
        return r

    def set_exception(self, exception):
        self.exception = exception
        self.close()

    def close(self):
        self.state = StreamState.completed
        self._condition.notify()
Esempio n. 8
0
class DeviceConnection(object):

    state_waiters = {}
    state_happened = {}

    def __init__ (self, device_server, stream, address):
        self.fw_version = 0.0
        self.recv_msg_cond = Condition()
        self.recv_msg = {}
        self.send_msg_sem = Semaphore(1)
        self.pending_request_cnt = 0
        self.device_server = device_server
        self.stream = stream
        self.address = address
        self.stream.set_nodelay(True)
        self.timeout_handler_onlinecheck = None
        self.timeout_handler_offline = None
        self.killed = False
        self.sn = ""
        self.private_key = ""
        self.node_id = 0
        self.iv = None
        self.cipher_down = None
        self.cipher_up = None

        #self.state_waiters = []
        #self.state_happened = []

        self.event_waiters = []
        self.event_happened = []

        self.ota_ing = False
        self.ota_notify_done_future = None
        self.post_ota = False
        self.online_status = True

    @gen.coroutine
    def secure_write (self, data):
        if self.cipher_down:
            cipher_text = self.cipher_down.encrypt(pad(data))
            yield self.stream.write(cipher_text)

    @gen.coroutine
    def wait_hello (self):
        try:
            self._wait_hello_future = self.stream.read_bytes(64) #read 64bytes: 32bytes SN + 32bytes signature signed with private key
            str1 = yield gen.with_timeout(timedelta(seconds=10), self._wait_hello_future,
                                         io_loop=ioloop.IOLoop.current())
            self.idle_time = 0  #reset the idle time counter

            if len(str1) != 64:
                self.stream.write("sorry\r\n")
                yield gen.sleep(0.1)
                self.kill_myself()
                gen_log.debug("receive length != 64")
                raise gen.Return(100) # length not match 64

            if re.match(r'@\d\.\d', str1[0:4]):
                #new version firmware
                self._wait_hello_future = self.stream.read_bytes(4) #read another 4bytes
                str2 = yield gen.with_timeout(timedelta(seconds=10), self._wait_hello_future, io_loop=ioloop.IOLoop.current())

                self.idle_time = 0  #reset the idle time counter

                if len(str2) != 4:
                    self.stream.write("sorry\r\n")
                    yield gen.sleep(0.1)
                    self.kill_myself()
                    gen_log.debug("receive length != 68")
                    raise gen.Return(100) # length not match 64

                str1 += str2
                self.fw_version = float(str1[1:4])
                sn = str1[4:36]
                sig = str1[36:68]
            else:
                #for version < 1.1
                sn = str1[0:32]
                sig = str1[32:64]

            gen_log.info("accepted sn: %s @fw_version %.1f" % (sn, self.fw_version))

            #query the sn from database
            node = None
            cur = self.device_server.cur
            cur.execute('select * from nodes where node_sn="%s"'%sn)
            rows = cur.fetchall()
            if len(rows) > 0:
                node = rows[0]

            if not node:
                self.stream.write("sorry\r\n")
                yield gen.sleep(0.1)
                self.kill_myself()
                gen_log.info("node sn not found")
                raise gen.Return(101) #node not found

            key = node['private_key']
            key = key.encode("ascii")

            sig0 = hmac.new(key, msg=sn, digestmod=hashlib.sha256).digest()
            gen_log.debug("sig:     "+ binascii.hexlify(sig))
            gen_log.debug("sig calc:"+ binascii.hexlify(sig0))

            if sig0 == sig:
                #send IV + AES Key
                self.sn = sn
                self.private_key = key
                self.node_id = str(node['node_id'])
                gen_log.info("valid hello packet from node %s" % self.node_id)
                #remove the junk connection of the same sn
                ioloop.IOLoop.current().add_callback(self.device_server.remove_junk_connection, self)
                #init aes
                self.iv = Random.new().read(AES.block_size)
                self.cipher_down = AES.new(key, AES.MODE_CFB, self.iv, segment_size=128)

                if self.fw_version > 1.0:
                    self.cipher_up = AES.new(key, AES.MODE_CFB, self.iv, segment_size=128)
                else:
                    #for old version
                    self.cipher_up = self.cipher_down

                cipher_text = self.iv + self.cipher_down.encrypt(pad("hello"))
                gen_log.debug("cipher text: "+ cipher_text.encode('hex'))
                self.stream.write(cipher_text)
                raise gen.Return(0)
            else:
                self.stream.write("sorry\r\n")
                yield gen.sleep(0.1)
                self.kill_myself()
                gen_log.error("signature not match: %s %s" % (sig, sig0))
                raise gen.Return(102) #sig not match
        except gen.TimeoutError:
            self.kill_myself()
            raise gen.Return(1)
        except iostream.StreamClosedError:
            self.kill_myself()
            raise gen.Return(2)

        #ioloop.IOLoop.current().add_future(self._serving_future, lambda future: future.result())

    @gen.coroutine
    def _loop_reading_input (self):
        line = ""
        piece = ""
        while not self.killed:
            msg = ""
            try:
                msg = yield self.stream.read_bytes(16)
                msg = unpad(self.cipher_up.decrypt(msg))
                line += msg

                while line.find('\r\n') > -1:
                    #reset the timeout
                    if self.timeout_handler_onlinecheck:
                        ioloop.IOLoop.current().remove_timeout(self.timeout_handler_onlinecheck)
                    self.timeout_handler_onlinecheck = ioloop.IOLoop.current().call_later(60, self._online_check)

                    if self.timeout_handler_offline:
                        ioloop.IOLoop.current().remove_timeout(self.timeout_handler_offline)
                    self.timeout_handler_offline = ioloop.IOLoop.current().call_later(70, self._callback_when_offline)

                    index = line.find('\r\n')
                    piece = line[:index+2]
                    line = line[index+2:]
                    piece = piece.strip("\r\n")

                    if piece in ['##ALIVE##']:
                        gen_log.info('Node %s alive on %s channel!' % (self.node_id, self.device_server.role))
                        continue

                    json_obj = json.loads(piece)
                    gen_log.info('Node %s recv json on %s channel' % (self.node_id, self.device_server.role))
                    gen_log.debug('%s' % str(json_obj))

                    try:
                        state = None
                        event = None
                        if json_obj['msg_type'] == 'online_status':
                            if json_obj['msg'] in ['1',1,True]:
                                self.online_status = True
                            else:
                                self.online_status = False
                            continue
                        elif json_obj['msg_type'] == 'ota_trig_ack':
                            state = ('going', 'Node has been notified...')
                            self.ota_ing = True
                            if self.ota_notify_done_future:
                                self.ota_notify_done_future.set_result(1)
                        elif json_obj['msg_type'] == 'ota_status':
                            if json_obj['msg'] == 'started':
                                state = ('going', 'Downloading the firmware...')
                            else:
                                state = ('error', 'Failed to start the downloading.')
                                self.post_ota = True
                        elif json_obj['msg_type'] == 'ota_result':
                            if json_obj['msg'] == 'success':
                                state = ('done', 'Firmware updated.')
                            else:
                                state = ('error', 'Update failed. Please reboot the node and retry.')
                            self.post_ota = True
                            self.ota_ing = False
                        elif json_obj['msg_type'] == 'event':
                            event = json_obj
                            event.pop('msg_type')
                        gen_log.debug("state: ")
                        gen_log.debug(state)
                        gen_log.debug("event: ")
                        gen_log.debug(event)
                        if state:
                            #print self.state_waiters
                            #print self.state_happened
                            if self.state_waiters and self.sn in self.state_waiters and len(self.state_waiters[self.sn]) > 0:
                                f = self.state_waiters[self.sn].pop(0)
                                f.set_result(state)
                                if len(self.state_waiters[self.sn]) == 0:
                                    del self.state_waiters[self.sn]
                            elif self.state_happened and self.sn in self.state_happened:
                                self.state_happened[self.sn].append(state)
                            else:
                                self.state_happened[self.sn] = [state]
                        elif event:
                            if len(self.event_waiters) == 0:
                                self.event_happened.append(event)
                            else:
                                for future in self.event_waiters:
                                    future.set_result(event)
                                self.event_waiters = []
                        else:
                            self.recv_msg = json_obj
                            self.recv_msg_cond.notify()
                            yield gen.moment
                    except Exception,e:
                        gen_log.warn("Node %s: %s" % (self.node_id ,str(e)))

            except iostream.StreamClosedError:
                gen_log.error("StreamClosedError when reading from node %s" % self.node_id)
                self.kill_myself()
                return
            except ValueError:
                gen_log.warn("Node %s: %s can not be decoded into json" % (self.node_id, piece))
            except Exception,e:
                gen_log.error("Node %s: %s" % (self.node_id ,str(e)))
                self.kill_myself()
                return

            yield gen.moment
Esempio n. 9
0
class DeviceConnection(object):

    state_waiters = {}
    state_happened = {}

    def __init__ (self, device_server, stream, address, conn_pool):
        self.fw_version = 0.0
        self.recv_msg_cond = Condition()
        self.recv_msg = {}
        self.send_msg_sem = Semaphore(1)
        self.pending_request_cnt = 0
        self.device_server = device_server
        self.device_server_conn_pool = conn_pool
        self.stream = stream
        self.address = address
        self.stream.set_nodelay(True)
        self.stream.set_close_callback(self.on_close)
        self.timeout_handler_onlinecheck = None
        self.timeout_handler_offline = None
        self.killed = False
        self.is_junk = False
        self.sn = ""
        self.private_key = ""
        self.node_id = ""
        self.user_id = ""
        self.iv = None
        self.cipher_down = None
        self.cipher_up = None

        self.event_waiters = []
        self.event_happened = []

        self.ota_ing = False
        self.ota_notify_done_future = None
        self.post_ota = False
        self.online_status = True

    @gen.coroutine
    def secure_write (self, data):
        if self.cipher_down:
            cipher_text = self.cipher_down.encrypt(pad(data))
            yield self.stream.write(cipher_text)

    @gen.coroutine
    def wait_hello (self):
        try:
            self._wait_hello_future = self.stream.read_bytes(64) #read 64bytes: 32bytes SN + 32bytes signature signed with private key
            str1 = yield gen.with_timeout(timedelta(seconds=10), self._wait_hello_future,
                                          io_loop=ioloop.IOLoop.current())
            self.idle_time = 0  #reset the idle time counter

            if len(str1) != 64:
                self.stream.write("sorry\r\n")
                yield gen.sleep(0.1)
                self.kill_myself()
                gen_log.debug("receive length != 64")
                raise gen.Return(100) # length not match 64

            if re.match(r'@\d\.\d', str1[0:4]):
                #new version firmware
                self._wait_hello_future = self.stream.read_bytes(4) #read another 4bytes
                str2 = yield gen.with_timeout(timedelta(seconds=10), self._wait_hello_future, io_loop=ioloop.IOLoop.current())

                self.idle_time = 0  #reset the idle time counter

                if len(str2) != 4:
                    self.stream.write("sorry\r\n")
                    yield gen.sleep(0.1)
                    self.kill_myself()
                    gen_log.debug("receive length != 68")
                    raise gen.Return(100) # length not match 64

                str1 += str2
                self.fw_version = float(str1[1:4])
                sn = str1[4:36]
                sig = str1[36:68]
            else:
                #for version < 1.1
                sn = str1[0:32]
                sig = str1[32:64]

            gen_log.info("accepted sn: %s @fw_version %.1f" % (sn, self.fw_version))

            #query the sn from database
            node = None
            cur = self.device_server.cur
            cur.execute('select * from nodes where node_sn="%s"'%sn)
            rows = cur.fetchall()
            if len(rows) > 0:
                node = rows[0]

            if not node:
                self.stream.write("sorry\r\n")
                yield gen.sleep(0.1)
                self.kill_myself()
                gen_log.error("node sn not found")
                raise gen.Return(101) #node not found

            key = node['private_key']
            key = key.encode("ascii")

            sig0 = hmac.new(key, msg=sn, digestmod=hashlib.sha256).digest()
            gen_log.debug("sig:     "+ binascii.hexlify(sig))
            gen_log.debug("sig calc:"+ binascii.hexlify(sig0))

            if sig0 == sig:
                #send IV + AES Key
                self.sn = sn
                self.private_key = key
                self.node_id = str(node['node_id'])
                self.user_id = str(node['user_id'])
                gen_log.info("valid hello packet from node %s" % self.node_id)

                # remove the junk connection of the same thing
                if self.sn in self.device_server_conn_pool:
                    gen_log.info("%s device server will remove one junk connection of same sn: %s"% (self.device_server.role, self.sn))
                    self.device_server_conn_pool[self.sn].kill_junk()

                # save into conn pool
                self.device_server_conn_pool[self.sn] = self
                gen_log.info('>>>>>>>>>>>>>>>>>>>>>')
                gen_log.info('channel: %s' % self.device_server.role)
                gen_log.info('size of conn pool: %d' % len(self.device_server_conn_pool))
                gen_log.info('<<<<<<<<<<<<<<<<<<<<<')

                #init aes
                self.iv = Random.new().read(AES.block_size)
                self.cipher_down = AES.new(key, AES.MODE_CFB, self.iv, segment_size=128)

                if self.fw_version > 1.0:
                    self.cipher_up = AES.new(key, AES.MODE_CFB, self.iv, segment_size=128)
                else:
                    #for old version
                    self.cipher_up = self.cipher_down

                cipher_text = self.iv + self.cipher_down.encrypt(pad("hello"))
                gen_log.debug("cipher text: "+ cipher_text.encode('hex'))
                self.stream.write(cipher_text)

                user_event = {"node_sn": self.sn, "event_type": "stat", "event_data": {"online": True, "at": self.device_server.role}}
                CoEventBus().broadcast('/event/users/{}'.format(self.user_id), user_event)

                raise gen.Return(0)
            else:
                self.stream.write("sorry\r\n")
                yield gen.sleep(0.1)
                self.kill_myself()
                gen_log.error("signature not match: %s %s" % (sig, sig0))
                raise gen.Return(102) #sig not match
        except gen.TimeoutError:
            self.kill_myself()
            raise gen.Return(1)
        except iostream.StreamClosedError:
            self.kill_myself()
            raise gen.Return(2)

    @gen.coroutine
    def _loop_reading_input (self):
        line = ""
        piece = ""
        while not self.killed:
            msg = ""
            try:
                msg = yield self.stream.read_bytes(16)
                msg = unpad(self.cipher_up.decrypt(msg))
                line += msg

                while line.find('\r\n') > -1:
                    #reset the timeout
                    if self.timeout_handler_onlinecheck:
                        ioloop.IOLoop.current().remove_timeout(self.timeout_handler_onlinecheck)
                    self.timeout_handler_onlinecheck = ioloop.IOLoop.current().call_later(HEARTBEAT_PERIOD_SEC, self._online_check)

                    if self.timeout_handler_offline:
                        ioloop.IOLoop.current().remove_timeout(self.timeout_handler_offline)
                    self.timeout_handler_offline = ioloop.IOLoop.current().call_later(HEARTBEAT_PERIOD_SEC + HEARTBEAT_NEGATIVE_CHECK_DELAY_SEC, \
                                                                                      self._callback_when_offline)

                    index = line.find('\r\n')
                    piece = line[:index+2]
                    line = line[index+2:]
                    piece = piece.strip("\r\n")

                    if piece in ['##ALIVE##']:
                        gen_log.debug('Node %s alive on %s channel!' % (self.node_id, self.device_server.role))
                        continue

                    json_obj = json.loads(piece)
                    gen_log.debug('Node %s recv json on %s channel' % (self.node_id, self.device_server.role))
                    gen_log.debug('%s' % str(json_obj))

                    try:
                        state = None
                        event = None
                        if json_obj['msg_type'] == 'online_status':
                            if json_obj['msg'] in ['1',1,True]:
                                self.online_status = True
                            else:
                                self.online_status = False
                            continue
                        elif json_obj['msg_type'] == 'ota_trig_ack':
                            state = ('going', 'Node has been notified...')
                            self.ota_ing = True
                            if self.ota_notify_done_future:
                                self.ota_notify_done_future.set_result(1)
                        elif json_obj['msg_type'] == 'ota_status':
                            if json_obj['msg'] == 'started':
                                state = ('going', 'Downloading the firmware...')
                            else:
                                state = ('error', 'Failed to start the downloading.')
                                self.post_ota = True
                        elif json_obj['msg_type'] == 'ota_result':
                            if json_obj['msg'] == 'success':
                                state = ('done', 'Firmware updated.')
                            else:
                                state = ('error', 'Update failed. Please reboot the node and retry.')
                            self.post_ota = True
                            self.ota_ing = False
                        elif json_obj['msg_type'] == 'event':
                            event = json_obj
                            event.pop('msg_type')
                            user_event = {"node_sn": self.sn, "event_type": "grove", "event_data": event['msg']}
                            CoEventBus().broadcast('/event/users/{}'.format(self.user_id), user_event)

                        gen_log.debug("state: ")
                        gen_log.debug(state)
                        gen_log.debug("event: ")
                        gen_log.debug(event)

                        if state:
                            #print self.state_waiters
                            #print self.state_happened
                            if self.sn in DeviceConnection.state_waiters and len(DeviceConnection.state_waiters[self.sn]) > 0:
                                f = self.state_waiters[self.sn].pop(0)
                                f.set_result(state)
                                if len(DeviceConnection.state_waiters[self.sn]) == 0:
                                    del DeviceConnection.state_waiters[self.sn]
                            elif self.sn in DeviceConnection.state_happened:
                                DeviceConnection.state_happened[self.sn].append(state)
                            else:
                                DeviceConnection.state_happened[self.sn] = [state]
                        elif event:
                            if len(self.event_waiters) > 0:
                                for future in self.event_waiters:
                                    future.set_result(event)
                                self.event_waiters = []
                        else:
                            self.recv_msg = json_obj
                            self.recv_msg_cond.notify()
                            yield gen.moment
                    except Exception, e:
                        gen_log.warn("Node %s: %s" % (self.node_id ,str(e)))

            except iostream.StreamClosedError:
                gen_log.error("StreamClosedError when reading from node %s on %s channel" % (self.node_id, self.device_server.role))
                self.kill_myself()
                return
            except ValueError:
                gen_log.warn("Node %s: %s can not be decoded into json" % (self.node_id, piece))
            except Exception,e:
                gen_log.error("Node %s: %s" % (self.node_id ,str(e)))
                self.kill_myself()
                return

            yield gen.moment
class HttpChunkedRecognizeHandler(tornado.web.RequestHandler):
    """
    Provides a HTTP POST/PUT interface supporting chunked transfer requests, similar to that provided by
    http://github.com/alumae/ruby-pocketsphinx-server.
    """
    def prepare(self):
        self.id = str(uuid.uuid4())
        self.final_hyp = ""
        self.worker_done = Condition()
        self.user_id = self.request.headers.get("device-id", "none")
        self.content_id = self.request.headers.get("content-id", "none")
        logging.info("%s: OPEN: user='******', content='%s'" %
                     (self.id, self.user_id, self.content_id))
        self.worker = None
        self.error_status = 0
        self.error_message = None
        #Waiter thread for final hypothesis:
        #self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
        try:
            self.worker = self.application.available_workers.pop()
            self.application.send_status_update()
            logging.info("%s: Using worker %s" % (self.id, self.__str__()))
            self.worker.set_client_socket(self)

            content_type = self.request.headers.get("Content-Type", None)
            if content_type:
                content_type = content_type_to_caps(content_type)
                logging.info("%s: Using content type: %s" %
                             (self.id, content_type))

            self.worker.write_message(
                json.dumps(
                    dict(id=self.id,
                         content_type=content_type,
                         user_id=self.user_id,
                         content_id=self.content_id)))
        except KeyError:
            logging.warn("%s: No worker available for client request" %
                         self.id)
            self.set_status(503)
            self.finish("No workers available")

    @tornado.gen.coroutine
    def data_received(self, chunk):
        assert self.worker is not None
        logging.debug("%s: Forwarding client message of length %d to worker" %
                      (self.id, len(chunk)))
        self.worker.write_message(chunk, binary=True)

    @tornado.gen.coroutine
    def post(self, *args, **kwargs):
        yield self.end_request(args, kwargs)

    @tornado.gen.coroutine
    def put(self, *args, **kwargs):
        yield self.end_request(args, kwargs)

    @tornado.gen.coroutine
    def end_request(self, *args, **kwargs):
        logging.info("%s: Handling the end of chunked recognize request" %
                     self.id)
        assert self.worker is not None
        self.worker.write_message("EOS", binary=False)
        logging.info("%s: Waiting for worker to finish" % self.id)
        yield self.worker_done.wait()
        if self.error_status == 0:
            logging.info("%s: Final hyp: %s" % (self.id, self.final_hyp))
            response = {
                "status": 0,
                "id": self.id,
                "hypotheses": [{
                    "utterance": self.final_hyp
                }]
            }
            self.write(response)
        else:
            logging.info("%s: Error (status=%d) processing HTTP request: %s" %
                         (self.id, self.error_status, self.error_message))
            response = {
                "status": self.error_status,
                "id": self.id,
                "message": self.error_message
            }
            self.write(response)
        self.application.num_requests_processed += 1
        self.application.send_status_update()
        self.worker.set_client_socket(None)
        self.worker.close()
        self.finish()
        logging.info("Everything done")

    @tornado.gen.coroutine
    def send_event(self, event):
        event_str = str(event)
        if len(event_str) > 100:
            event_str = event_str[:97] + "..."
        logging.info("%s: Receiving event %s from worker" %
                     (self.id, event_str))
        if event["status"] == 0 and ("result" in event):
            try:
                if len(event["result"]
                       ["hypotheses"]) > 0 and event["result"]["final"]:
                    if len(self.final_hyp) > 0:
                        self.final_hyp += " "
                    self.final_hyp += event["result"]["hypotheses"][0][
                        "transcript"]
            except:
                e = sys.exc_info()[0]
                logging.warn(
                    "Failed to extract hypothesis from recognition result:" +
                    e)
        elif event["status"] != 0:
            self.error_status = event["status"]
            self.error_message = event.get("message", "")

    @tornado.gen.coroutine
    def close(self):
        logging.info("%s: Receiving 'close' from worker" % (self.id))
        self.worker_done.notify()
Esempio n. 11
0
class DeviceConnection(object):

    state_waiters = {}
    state_happened = {}

    def __init__ (self, device_server, stream, address):
        self.recv_msg_cond = Condition()
        self.recv_msg = {}
        self.send_msg_sem = Semaphore(1)
        self.pending_request_cnt = 0
        self.device_server = device_server
        self.stream = stream
        self.address = address
        self.stream.set_nodelay(True)
        self.idle_time = 0;
        self.killed = False
        self.sn = ""
        self.private_key = ""
        self.node_id = 0
        self.name = ""
        self.iv = None
        self.cipher = None

        #self.state_waiters = []
        #self.state_happened = []

        self.event_waiters = []
        self.event_happened = []

        self.ota_ing = False
        self.ota_notify_done_future = None
        self.post_ota = False
        self.online_status = True

    @gen.coroutine
    def secure_write (self, data):
        if self.cipher:
            cipher_text = self.cipher.encrypt(pad(data))
            yield self.stream.write(cipher_text)

    @gen.coroutine
    def wait_hello (self):
        try:
            self._wait_hello_future = self.stream.read_bytes(64) #read 64bytes: 32bytes SN + 32bytes signature signed with private key
            str = yield gen.with_timeout(timedelta(seconds=10), self._wait_hello_future, io_loop=self.stream.io_loop)
            self.idle_time = 0  #reset the idle time counter

            if len(str) != 64:
                self.stream.write("sorry\r\n")
                yield gen.sleep(0.1)
                self.kill_myself()
                gen_log.debug("receive length != 64")
                raise gen.Return(100) # length not match 64

            sn = str[0:32]
            sig = str[32:64]

            gen_log.info("accepted sn: "+ sn)

            #query the sn from database
            node = None
            for n in NODES_DATABASE:
                if n['node_sn'] == sn:
                    node = n
                    break

            if not node:
                self.stream.write("sorry\r\n")
                yield gen.sleep(0.1)
                self.kill_myself()
                gen_log.info("node sn not found")
                raise gen.Return(101) #node not found

            key = node['node_key']
            key = key.encode("ascii")

            sig0 = hmac.new(key, msg=sn, digestmod=hashlib.sha256).digest()
            gen_log.debug("sig:     "+ binascii.hexlify(sig))
            gen_log.debug("sig calc:"+ binascii.hexlify(sig0))

            if sig0 == sig:
                #send IV + AES Key
                self.sn = sn
                self.private_key = key
                self.node_id = node['node_sn']
                self.name = node['name']
                gen_log.info("valid hello packet from node %s" % self.name)
                #remove the junk connection of the same sn
                self.stream.io_loop.add_callback(self.device_server.remove_junk_connection, self)
                #init aes
                self.iv = Random.new().read(AES.block_size)
                self.cipher = AES.new(key, AES.MODE_CFB, self.iv, segment_size=128)
                cipher_text = self.iv + self.cipher.encrypt(pad("hello"))
                gen_log.debug("cipher text: "+ cipher_text.encode('hex'))
                self.stream.write(cipher_text)
                raise gen.Return(0)
            else:
                self.stream.write("sorry\r\n")
                yield gen.sleep(0.1)
                self.kill_myself()
                gen_log.error("signature not match: %s %s" % (sig, sig0))
                raise gen.Return(102) #sig not match
        except gen.TimeoutError:
            self.kill_myself()
            raise gen.Return(1)
        except iostream.StreamClosedError:
            self.kill_myself()
            raise gen.Return(2)

        #self.stream.io_loop.add_future(self._serving_future, lambda future: future.result())

    @gen.coroutine
    def _loop_reading_input (self):
        line = ""
        piece = ""
        while not self.killed:
            msg = ""
            try:
                msg = yield self.stream.read_bytes(16)
                msg = unpad(self.cipher.decrypt(msg))

                self.idle_time = 0 #reset the idle time counter

                line += msg

                while line.find('\r\n') > -1:
                    index = line.find('\r\n')
                    piece = line[:index+2]
                    line = line[index+2:]
                    piece = piece.strip("\r\n")
                    json_obj = json.loads(piece)
                    gen_log.info('Node %s: recv json: %s' % (self.name, str(json_obj)))

                    try:
                        state = None
                        event = None
                        if json_obj['msg_type'] == 'online_status':
                            if json_obj['msg'] in ['1',1,True]:
                                self.online_status = True
                            else:
                                self.online_status = False
                            continue
                        elif json_obj['msg_type'] == 'ota_trig_ack':
                            state = ('going', 'Node has been notified...')
                            self.ota_ing = True
                            if self.ota_notify_done_future:
                                self.ota_notify_done_future.set_result(1)
                        elif json_obj['msg_type'] == 'ota_status':
                            if json_obj['msg'] == 'started':
                                state = ('going', 'Downloading the firmware...')
                            else:
                                state = ('error', 'Failed to start the downloading.')
                                self.post_ota = True
                        elif json_obj['msg_type'] == 'ota_result':
                            if json_obj['msg'] == 'success':
                                state = ('done', 'Firmware updated.')
                            else:
                                state = ('error', 'Update failed. Please reboot the node and retry.')
                            self.post_ota = True
                            self.ota_ing = False
                        elif json_obj['msg_type'] == 'event':
                            event = json_obj
                            event.pop('msg_type')
                        gen_log.debug(state)
                        gen_log.debug(event)
                        if state:
                            #print self.state_waiters
                            #print self.state_happened
                            if self.state_waiters and self.sn in self.state_waiters and len(self.state_waiters[self.sn]) > 0:
                                f = self.state_waiters[self.sn].pop(0)
                                f.set_result(state)
                                if len(self.state_waiters[self.sn]) == 0:
                                    del self.state_waiters[self.sn]
                            elif self.state_happened and self.sn in self.state_happened:
                                self.state_happened[self.sn].append(state)
                            else:
                                self.state_happened[self.sn] = [state]
                        elif event:
                            if len(self.event_waiters) == 0:
                                self.event_happened.append(event)
                            else:
                                for future in self.event_waiters:
                                    future.set_result(event)
                                self.event_waiters = []
                        else:
                            self.recv_msg = json_obj
                            self.recv_msg_cond.notify()
                            yield gen.moment
                    except Exception,e:
                        gen_log.warn("Node %s: %s" % (self.name ,str(e)))

            except iostream.StreamClosedError:
                gen_log.error("StreamClosedError when reading from node %s" % self.name)
                self.kill_myself()
                return
            except ValueError:
                gen_log.warn("Node %s: %s can not be decoded into json" % (self.name, piece))
            except Exception,e:
                gen_log.error("Node %s: %s" % (self.name ,str(e)))
                self.kill_myself()
                return

            yield gen.moment
Esempio n. 12
0
class DecodeRequestHandler(tornado.web.RequestHandler):
    SUPPORTED_METHOD = ('POST')

    #Called at the beginning of a request before get/post/etc
    def prepare(self):
        self.worker = None
        self.filePath = None
        self.uuid = str(uuid.uuid4())
        self.set_status(200, "Initial statut")
        self.waitResponse = Condition()
        self.waitWorker = Condition()

        if self.request.method != 'POST':
            logging.debug("Received a non-POST request")
            self.set_status(
                403, "Wrong request, server handles only POST requests")
            self.finish()
        #File Retrieval
        # TODO: Adapt input to existing controller API
        if 'wavFile' not in self.request.files.keys():
            self.set_status(
                403, "POST request must contain a 'file_to_transcript' field.")
            self.finish()
            logging.debug(
                "POST request from %s does not contain 'file_to_transcript' field."
            )
        temp_file = self.request.files['wavFile'][0]['body']
        self.temp_file = temp_file

        #Writing file
        try:
            f = open(TEMP_FILE_PATH + self.uuid + '.wav', 'wb')
        except IOError:
            logging.error("Could not write file.")
            self.set_status(
                500, "Server error: Counldn't write file on server side.")
            self.finish()
        else:
            f.write(temp_file)
            self.filePath = TEMP_FILE_PATH + self.uuid + '.wav'
            logging.debug("File correctly received from client")

    @gen.coroutine
    def post(self, *args, **kwargs):
        logging.debug("Allocating Worker to %s" % self.uuid)

        yield self.allocate_worker()
        self.worker.write_message(
            json.dumps({
                'uuid': self.uuid,
                'file': self.temp_file.encode('base64')
            }))
        yield self.waitResponse.wait()
        self.finish()

    @gen.coroutine
    def allocate_worker(self):
        while self.worker == None:
            try:
                self.worker = self.application.available_workers.pop()
            except:
                self.worker = None
                self.application.waiting_client.add(self)
                self.application.display_server_status()
                yield self.waitWorker.wait()
            else:
                self.worker.client_handler = self
                logging.debug("Worker allocated to client %s" % self.uuid)
                self.application.display_server_status()

    @gen.coroutine
    def receive_response(self, message):
        os.remove(TEMP_FILE_PATH + self.uuid + '.wav')
        self.set_status(200, "Transcription succeded")
        self.set_header("Content-Type", "application/json")
        self.set_header("Access-Control-Allow-Origin", "*")
        self.write({'transcript': message})
        self.application.num_requests_processed += 1
        self.waitResponse.notify()

    def on_finish(self):
        #CLEANUP
        pass
Esempio n. 13
0
class InMemStream(Stream):
    def __init__(self, buf=None, auto_close=True):
        """In-Memory based stream

        :param buf: the buffer for the in memory stream
        """
        self._stream = deque()
        if buf:
            self._stream.append(buf)
        self.state = StreamState.init
        self._condition = Condition()
        self.auto_close = auto_close

        self.exception = None
        self.exc_info = None

    def clone(self):
        new_stream = InMemStream()
        new_stream.state = self.state
        new_stream.auto_close = self.auto_close
        new_stream._stream = deque(self._stream)
        return new_stream

    def read(self):
        def read_chunk(future):
            if self.exception:
                if self.exc_info:
                    future.set_exc_info(self.exc_info)
                else:
                    future.set_exception(self.exception)
                return future

            chunk = b""

            while len(self._stream) and len(chunk) < common.MAX_PAYLOAD_SIZE:
                new_chunk = self._stream.popleft()
                if six.PY3 and isinstance(new_chunk, str):
                    new_chunk = new_chunk.encode('utf8')
                chunk += new_chunk

            future.set_result(chunk)
            return future

        read_future = tornado.concurrent.Future()

        # We're not ready yet
        if self.state != StreamState.completed and not len(self._stream):
            wait_future = self._condition.wait()
            tornado.ioloop.IOLoop.current().add_future(
                wait_future,
                lambda f: f.exception() or read_chunk(read_future))
            return read_future

        return read_chunk(read_future)

    def write(self, chunk):
        if self.exception:
            raise self.exception

        if self.state == StreamState.completed:
            raise UnexpectedError("Stream has been closed.")

        if chunk:
            self._stream.append(chunk)
            self._condition.notify()

        # This needs to return a future to match the async interface.
        r = tornado.concurrent.Future()
        r.set_result(None)
        return r

    def set_exception(self, exception, exc_info=None):
        self.exception = exception
        self.exc_info = exc_info
        self.close()

    def close(self):
        self.state = StreamState.completed
        self._condition.notify()
class TornadoCoroutineExecutor(Executor):
    def __init__(self,
                 core_pool_size,
                 queue,
                 reject_handler,
                 coroutine_pool_name=None):
        self._core_pool_size = core_pool_size
        self._queue = queue
        self._reject_handler = reject_handler
        self._coroutine_pool_name = coroutine_pool_name or \
            'tornado-coroutine-pool-%s' % uuid.uuid1().hex
        self._core_coroutines_condition = Condition()
        self._core_coroutines = {}
        self._core_coroutines_wait_condition = Condition()
        self._shutting_down = False
        self._shuted_down = False
        self._initialize_core_coroutines()

    def _initialize_core_coroutines(self):
        for ind in range(self._core_pool_size):
            self._core_coroutines[ind] = self._core_coroutine_run(ind)
            LOGGER.info("core coroutine: %s is intialized" %
                        self._get_coroutine_name(ind))

    def _get_coroutine_name(self, ind):
        return '%s:%d' % (self._coroutine_pool_name, ind)

    @gen.coroutine
    def _core_coroutine_run(self, ind):
        coroutine_name = self._get_coroutine_name(ind)
        while not self._shutting_down and not self._shuted_down:
            try:
                task_item = self._queue.get_nowait()
            except QueueEmpty:
                LOGGER.debug("coroutine: %s will enter into waiting pool" %
                             coroutine_name)
                if self._shutting_down or self._shuted_down:
                    break
                yield self._core_coroutines_wait_condition.wait()
                LOGGER.debug("coroutine: %s was woken up from waiting pool" %
                             coroutine_name)
                continue

            async_result = task_item.async_result
            async_result.set_time_info("consumed_from_queue_at")
            if not async_result.set_running_or_notify_cancel():
                continue
            time_info_key = "executed_completion_at"
            try:
                result = yield task_item.function(*task_item.args,
                                                  **task_item.kwargs)
                async_result.set_time_info(time_info_key).set_result(result)
            except Exception as ex:
                async_result.set_time_info(time_info_key).set_exception(ex)

        LOGGER.info("coroutine: %s is stopped" % coroutine_name)
        self._core_coroutines.pop(ind)
        if not self._core_coroutines:
            LOGGER.info("all coroutines in %s are stopped" %
                        self._coroutine_pool_name)
            self._core_coroutines_condition.notify_all()

    def submit_task(self, function, *args, **kwargs):
        async_result = AsyncResult()
        if self._shutting_down or self._shuted_down:
            async_result.set_exception(
                ShutedDownError(self._coroutine_pool_name))
            return async_result
        if not gen.is_coroutine_function(function):
            async_result.set_exception(
                RuntimeError("function must be tornado coroutine function"))
            return async_result

        is_full = False
        task_item = TaskItem(function, args, kwargs, async_result)
        try:
            self._queue.put_nowait(task_item)
            async_result.set_time_info("submitted_to_queue_at")
        except QueueFull:
            is_full = True

        if is_full:
            return self._reject_handler(self._queue, task_item)
        else:
            self._core_coroutines_wait_condition.notify()
            return async_result

    @gen.coroutine
    def shutdown(self, wait_time=None):
        if self._shutting_down or self._shuted_down:
            raise gen.Return()

        self._shutting_down = True
        self._shuted_down = False

        LOGGER.info("begin to notify all coroutines")
        self._core_coroutines_wait_condition.notify_all()
        if self._core_coroutines:
            yield self._core_coroutines_condition.wait(wait_time)

        while True:
            try:
                task_item = self._queue.get_nowait()
            except QueueEmpty:
                break
            else:
                task_item.async_result.set_exception(
                    ShutedDownError(self._coroutine_pool_name))

        self._shutting_down = False
        self._shuted_down = True
Esempio n. 15
0
class MockFitsWriterClient(object):
    """
    Wrapper class for a KATCP client to a EddFitsWriterServer
    """
    def __init__(self, address):
        """
        @brief      Construct new instance
        """
        self._address = address
        self._ioloop = IOLoop.current()
        self._stop_event = Event()
        self._is_stopped = Condition()
        self._socket = None

    def reset_connection(self):
        self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._socket.setblocking(False)
        try:
            self._socket.connect(self._address)
        except socket.error as error:
            if error.args[0] == errno.EINPROGRESS:
                pass
            else:
                raise error

    @coroutine
    def recv_nbytes(self, nbytes):
        received_bytes = 0
        data = b''
        while received_bytes < nbytes:
            if self._stop_event.is_set():
                raise StopEvent
            try:
                log.debug("Requesting {} bytes".format(nbytes -
                                                       received_bytes))
                current_data = self._socket.recv(nbytes - received_bytes)
                received_bytes += len(current_data)
                data += current_data
                log.debug("Received {} bytes ({} of {} bytes)".format(
                    len(current_data), received_bytes, nbytes))
            except socket.error as error:
                error_id = error.args[0]
                if error_id == errno.EAGAIN or error_id == errno.EWOULDBLOCK:
                    yield sleep(0.1)
                else:
                    log.exception("Unexpected error on socket recv: {}".format(
                        str(error)))
                    raise error
        raise Return(data)

    @coroutine
    def recv_loop(self):
        try:
            header, sections = yield self.recv_packet()
        except StopEvent:
            log.debug("Notifying that recv calls have stopped")
            self._is_stopped.notify()
        except Exception:
            log.exception("Failure while receiving packet")
        else:
            self._ioloop.add_callback(self.recv_loop)

    def start(self):
        self._stop_event.clear()
        self.reset_connection()
        self._ioloop.add_callback(self.recv_loop)

    @coroutine
    def stop(self, timeout=2):
        self._stop_event.set()
        try:
            success = yield self._is_stopped.wait(timeout=self._ioloop.time() +
                                                  timeout)
            if not success:
                raise TimeoutError
        except TimeoutError:
            log.error(("Could not stop the client within "
                       "the {} second limit").format(timeout))
        except Exception:
            log.exception("Fucup")

    @coroutine
    def recv_packet(self):
        log.debug("Receiving packet header")
        raw_header = yield self.recv_nbytes(C.sizeof(FWHeader))
        log.debug("Converting packet header")
        header = FWHeader.from_buffer_copy(raw_header)
        log.info("Received header: {}".format(header))
        fw_data_type = header.channel_data_type.strip().upper()
        c_data_type, np_data_type = TYPE_MAP[fw_data_type]
        sections = []
        for section in range(header.nsections):
            log.debug("Receiving section {} of {}".format(
                section + 1, header.nsections))
            raw_section_header = yield self.recv_nbytes(
                C.sizeof(FWSectionHeader))
            section_header = FWSectionHeader.from_buffer_copy(
                raw_section_header)
            log.info("Section {} header: {}".format(section, section_header))
            log.debug("Receiving section data")
            raw_bytes = yield self.recv_nbytes(
                C.sizeof(c_data_type) * section_header.nchannels)
            data = np.frombuffer(raw_bytes, dtype=np_data_type)
            log.info("Section {} data: {}".format(section, data[:10]))
            sections.append((section_header, data))
        raise Return((header, sections))
Esempio n. 16
0
class InMemStream(Stream):
    def __init__(self, buf=None, auto_close=True):
        """In-Memory based stream

        :param buf: the buffer for the in memory stream
        """
        self._stream = deque()
        if buf:
            self._stream.append(buf)
        self.state = StreamState.init
        self._condition = Condition()
        self.auto_close = auto_close

        self.exception = None

    def clone(self):
        new_stream = InMemStream()
        new_stream.state = self.state
        new_stream.auto_close = self.auto_close
        new_stream._stream = deque(self._stream)
        return new_stream

    def read(self):
        def read_chunk(future):
            if self.exception:
                future.set_exception(self.exception)
                return future

            chunk = ""

            while len(self._stream) and len(chunk) < common.MAX_PAYLOAD_SIZE:
                chunk += self._stream.popleft()

            future.set_result(chunk)
            return future

        read_future = tornado.concurrent.Future()

        # We're not ready yet
        if self.state != StreamState.completed and not len(self._stream):
            wait_future = self._condition.wait()
            wait_future.add_done_callback(
                lambda f: f.exception() or read_chunk(read_future))
            return read_future

        return read_chunk(read_future)

    def write(self, chunk):
        if self.exception:
            raise self.exception

        if self.state == StreamState.completed:
            raise StreamingError("Stream has been closed.")

        if chunk:
            self._stream.append(chunk)
            self._condition.notify()

        # This needs to return a future to match the async interface.
        r = tornado.concurrent.Future()
        r.set_result(None)
        return r

    def set_exception(self, exception):
        self.exception = exception
        self.close()

    def close(self):
        self.state = StreamState.completed
        self._condition.notify()