class NotificationQueue(Queue):
    """ A queue class that notifies waiting threads when the queue is available for
    reading."""
    def __init__(self):
        """ Constructor."""
        Queue.__init__(self)
        self.lock = Condition()

    def get(self):
        """ Get an item from the queue. This method is thread-safe.
        Method must be called on a non-empty queue.
        Return : the first item from the queue
        Raises: a EmptyQueue exeception
        """
        return Queue.get(self, False)

    def put(self, item):
        """ Add a new item to the queue. This method is thread-safe.
        Threads that are waiting for get an item are notified.
        item: new item to add to the queue.
        Return : None
        """
        self.lock.acquire()
        Queue.put(self, item)
        self.lock.notify()
        self.lock.release()

    def acquire(self):
        self.lock.acquire()

    def release(self):
        self.lock.release()

    def wait(self):
        self.lock.wait()
示例#2
0
文件: interface.py 项目: Aerobota/rce
    class _EventRef(object):
        """ Helper class which acts as a threading.Event, but which can be used
            to pass a reference when signaling the event.
        """
        def __init__(self):
            self._cond = Condition(Lock())
            self._flag = False
            self._ref = None

        def isSet(self):
            return self._flag

        def set(self, ref):
            with self._cond:
                assert self._ref is None
                self._ref = ref
                self._flag = True
                self._cond.notifyAll()

        def get(self, timeout=None):
            with self._cond:
                if not self._flag:
                    self._cond.wait(timeout)

                if not self._flag:
                    raise TimeoutExceeded('Could not get the reference.')

                return self._ref

        def clear(self, ref):
            with self._cond:
                self._ref = None
                self._flag = False
示例#3
0
class _CoreScheduleThread(Thread):
    def __init__(self,threadpool):
        self.scheduletasks = [];
        self.tasklock = Lock();
        self.condition = Condition(Lock())
        self.threadpool = threadpool
        Thread.__init__(self)

    def run(self):
        while True:
            self.condition.acquire()
            if len(self.scheduletasks) == 0:
                self.condition.wait();
            else:
                task = self.scheduletasks.pop(0)
                if dates.current_timestamps()>=task.nexttime: 
                    self.threadpool.execute(task.function,*task.args,**task.kwargs)
                    task.nexttime = dates.current_timestamps()+task.period;
                else:
                    self.condition.wait(task.nexttime-dates.current_timestamps())
                self.addtask(task)
            self.condition.release()

    
    def addtask(self,task): # copy on write
        self.tasklock.acquire()
        tasks = [ t for t in self.scheduletasks ]
        tasks.append(task)
        tasks.sort(key=lambda task:task.nexttime)
        self.scheduletasks = tasks
        self.tasklock.release()
    def wait_statement(self, wtime, N=5):
        '''

        '''

        start_time = time.time()
        if self.debug:
            time.sleep(1)

        else:
            if isinstance(wtime, str):
                wtime = float(self._get_interpolation_value(wtime))

            if wtime > N:
                c = Condition()
                c.acquire()
                wd = WaitDialog(wtime=wtime,
                                    condition=c,
                                    parent=self)
                do_later(wd.edit_traits, kind='livemodal')
                c.wait()
                c.release()
                do_later(wd.close)
            else:
                time.sleep(wtime)

        msg = 'actual wait time %0.3f s' % (time.time() - start_time)
        self.log_statement(msg)
class SimpleBlockingQueue(object):
    def __init__(self):
        self.items = deque()
        self.condition = Condition()
        self.interrupted = False

    def put(self, item):
        with self.condition:
            self.items.pushleft(item)

    def take(self, block=False):
        with self.condition:
            if block:
                while not self.items:
                    self.condition.wait()
                    if self.interrupted:
                        self.interrupted = False
                        return None
            elif not self.items:
                return None

            return self.items.popright()

    def interrupt(self):
        with self.condition:
            self.interrupted = True
            self.condition.notifyAll()
示例#6
0
class _InterruptableLock(object):
    def __init__(self, lock):
        """Creates a new lock."""
        self.__lock = lock
        self.__condition = Condition(Lock())

    def __enter__(self):
        return self.acquire()

    def __exit__(self, exc_type, exc_value, traceback):
        self.release()

    def acquire(self, blocking=True, timeout=-1):
        """Acquire a lock, blocking or non-blocking.
        Blocks until timeout, if timeout a positive float and blocking=True. A timeout
        value of -1 blocks indefinitely, unless blocking=False."""
        if not isinstance(timeout, (int, float)):
            raise TypeError("a float is required")

        if blocking:
            # blocking indefinite
            if timeout == -1:
                with self.__condition:
                    while not self.__lock.acquire(False):
                        # condition with timeout is interruptable
                        self.__condition.wait(60)
                return True

            # same as non-blocking
            elif timeout == 0:
                return self.__lock.acquire(False)

            elif timeout < 0:
                raise ValueError("timeout value must be strictly positive (or -1)")

            # blocking finite
            else:
                start = time()
                waited_time = 0
                with self.__condition:
                    while waited_time < timeout:
                        if self.__lock.acquire(False):
                            return True
                        else:
                            self.__condition.wait(timeout - waited_time)
                            waited_time = time() - start
                return False

        elif timeout != -1:
            raise ValueError("can't specify a timeout for a non-blocking call")

        else:
            # non-blocking
            return self.__lock.acquire(False)

    def release(self):
        """Release a lock."""
        self.__lock.release()
        with self.__condition:
            self.__condition.notify()
示例#7
0
class Event(object):
    def __init__(self,):
        self.__cond = Condition(Lock())
        self.__flag = False

    def isSet(self):
        return self.__flag

    is_set = isSet

    def set(self):
        self.__cond.acquire()
        try:
            self.__flag = True
            self.__cond.notify_all()
        finally:
            self.__cond.release()

    def clear(self):
        self.__cond.acquire()
        try:
            self.__flag = False
        finally:
            self.__cond.release()

    def wait(self, timeout=None):
        self.__cond.acquire()
        try:
            if not self.__flag:
                self.__cond.wait(timeout)
            return self.__flag
        finally:
            self.__cond.release()
示例#8
0
class DirtyBit:
    def __init__(self):
        self._mon = RLock()
        self._rc = Condition(self._mon)
        self._dirty = 0
        
    def clear(self):
        self._mon.acquire()
        self._dirty = 0
        #self._rc.notify() only interested in knowing when we are dirty
        self._mon.release()

    def set(self):
        self._mon.acquire()
        self._dirty = 1
        self._rc.notify()
        self._mon.release()

    def value(self):
        return self._dirty

    def wait(self):
        self._mon.acquire()
        self._rc.wait()
        self._mon.release()
示例#9
0
    def process_sequence(self, sequence):
        """
        @param sequence: the movement sequence to execute
        @return: dictionary of sensor readings
        """
        self.log.debug('process sequence')
        # parse the json message
        jseq = json.loads(sequence)
        moveseq = jseq['sequence']

        for moves in moveseq:
            self.log.debug('processing move %s' % moves)
            condition = Condition()
            thread_count = 0
            for move in moves:
                thread_count += 1
                servo_instruction = self.process_move(move)
                move_servo_thread = ServoThread(servo_instruction, self.pwm_queue, condition, thread_count, self.update_current_pulse)
                move_servo_thread.setDaemon(False)
                move_servo_thread.setName('Servo %d' % servo_instruction['channel'])
                move_servo_thread.start()
            condition.acquire()
            if thread_count > 0:
                condition.wait()
            # wait for all threads to finish before doing next loop
            condition.release()
示例#10
0
class BinaryServer(Thread):

    """
    Socket server forwarding request to internal server
    """

    def __init__(self, internal_server, hostname, port):
        Thread.__init__(self)
        self.socket_server = None
        self.hostname = hostname
        self.port = port
        self.iserver = internal_server
        self._cond = Condition()

    def start(self):
        with self._cond:
            Thread.start(self)
            self._cond.wait()

    def run(self):
        logger.warning("Listening on %s:%s", self.hostname, self.port)
        socketserver.TCPServer.allow_reuse_address = True  # get rid of address already in used warning
        self.socket_server = ThreadingTCPServer((self.hostname, self.port), UAHandler)
        # self.socket_server.daemon_threads = True # this will force a shutdown of all threads, maybe too hard
        self.socket_server.internal_server = self.iserver  # allow handler to acces server properties
        with self._cond:
            self._cond.notify_all()
        self.socket_server.serve_forever()

    def stop(self):
        logger.warning("server shutdown request")
        self.socket_server.shutdown()
示例#11
0
class Node(object):
    STATE_INIT = 0
    STATE_STARTING = 1
    STATE_RUNNING = 2
    STATE_STOPPED = 3
    STATE_PARTITIONED = 4
    STATE_FAILED = 5
    
    state_str = {STATE_INIT: "Initial",
                 STATE_STARTING: "Starting",
                 STATE_RUNNING: "Running",
                 STATE_STOPPED: "Stopped",
                 STATE_FAILED: "Failed"}
    
    def __init__(self, node_id):
        self.node_id = node_id
        self.state = Node.STATE_INIT
        
        self.cv_lock = Lock()
        self.cv = Condition(self.cv_lock)
        
    def wait_for_state(self, state):
        self.cv.acquire()
        while self.state not in (state, Node.STATE_FAILED):
            self.cv.wait()
        self.cv.release()
        
        if self.state != state:
            raise ChistributedException("Node entered failed state while waiting for state %s" % Node.state_str[state])
        
    def set_state(self, state):
        self.cv.acquire()
        self.state = state
        self.cv.notify_all()
        self.cv.release()        
class Consumer(Thread):
    """Класс потока – обработчика данных"""

    def __init__(self):
        """Конструктор класса"""

        super().__init__()
        self.condition = Condition()
        self.received_data = None

    def run(self):
        """Код, исполняемый в потоке"""

        running_sum = 0
        count = 0

        while True:
            with self.condition:
                # Ожидание доступности данных
                self.condition.wait()
                # Получение данных
                data = self.received_data

            # Обработка данных
            running_sum += data
            count += 1
            print('Got data: {}\nTotal count: {}\nAverage: {}\n'.format(
                data, count, running_sum / count))

    def send(self, data):
        """Метод отправки данных в обработчик"""
        self.received_data = data
示例#13
0
class SendBuffer(object):

    def __init__(self, max_size):
        self.size = 0
        self.frontbuffer = ''
        self.full = Condition()
        self.backbuffer = StringIO()
        self.max_size = max_size

    def __len__(self):
        return len(self.frontbuffer) + self.size

    def write(self, data):
        dlen = len(data)
        with self.full:
            while self.size + dlen > self.max_size and dlen < self.max_size:
                # if a single write is bigger than the buffer, swallow hard.
                # No amount of waitng will make it fit.
                self.full.wait()
            self.backbuffer.write(data)
            self.size += dlen

    def to_sock(self, sock):
        if not self.frontbuffer:
            with self.full:
                buf, self.backbuffer = self.backbuffer, StringIO()
                self.size = 0
                self.full.notify_all()
            self.frontbuffer = buf.getvalue()

        written = sock.send(self.frontbuffer)
        self.frontbuffer = self.frontbuffer[written:]
示例#14
0
文件: atlas.py 项目: adamsutton/PyEPG
class DataQueue ( Queue ):

  def __init__ ( self, count ):
    Queue.__init__(self)
    self._count = count
    self._cond  = Condition()

  def get ( self, block = True, timeout = None ):
    self._cond.acquire()
    if self.empty():
      if not self._count:
        raise Empty()
      elif block:
        self._cond.wait(timeout)
        if self.empty() and not self._count:
          self._cond.release()
          raise Empty()
    self._cond.release()
    return Queue.get(self, block, timeout)

  def put ( self, data ):
    self._cond.acquire()
    self._count = self._count - 1
    Queue.put(self, data)
    if self._count:
      self._cond.notify()
    else:
      self._cond.notifyAll()
    self._cond.release()

  def remain ( self ):
    return self._count + self.unfinished_tasks
示例#15
0
class MessageQueue(object):
    """The message queue used internally by Gossip."""

    def __init__(self):
        self._queue = deque()
        self._condition = Condition()

    def pop(self):
        self._condition.acquire()
        try:
            while len(self._queue) < 1:
                self._condition.wait()
            return self._queue.pop()
        finally:
            self._condition.release()

    def __len__(self):
        return len(self._queue)

    def __deepcopy__(self, memo):
        newmq = MessageQueue()
        newmq._queue = copy.deepcopy(self._queue, memo)
        return newmq

    def appendleft(self, msg):
        self._condition.acquire()
        try:
            self._queue.appendleft(msg)
            self._condition.notify()
        finally:
            self._condition.release()
示例#16
0
class PostBox(list):

    def __init__(self, *a):
        list.__init__(self, *a)
        self.lever = Condition()
        self.open = True
        self.done = False

    def wait(self):
        self.lever.acquire()
        if not self.done:
            self.lever.wait()
        self.lever.release()
        return self.result

    def wakeup(self, data):
        self.result = data
        self.lever.acquire()
        self.done = True
        self.lever.notifyAll()
        self.lever.release()

    def append(self, e):
        self.lever.acquire()
        if not self.open:
            raise BoxClosedErr
        list.append(self, e)
        self.lever.release()

    def close(self):
        self.lever.acquire()
        self.open = False
        self.lever.release()
示例#17
0
class Future(object):
    def __init__(self, func, *args, **kwargs):
        self.__done = False
        self.__result = None
        self.__cond = Condition()
        self.__func = func
        self.__args = args
        self.__kwargs = kwargs
        self.__except = None

    def __call__(self):
        with self.__cond:
            try:
                self.__result = self.__func(*self.__args, **self.__kwargs)
            except:
                self.__result = None
                self.__except = sys.exc_info()
            self.__done = True
            self.__cond.notify()

    def result(self):
        with self.__cond:
            while not self.__done:
                self.__cond.wait()
        if self.__except:
            exc = self.__except
            reraise(exc[0], exc[1], exc[2])
        result = self.__result
        return copy.deepcopy(result)
示例#18
0
class Heartbeater(object):
    interval = 5

    def __init__(self, api, jobstep_id, interval=None):
        self.api = api
        self.jobstep_id = jobstep_id
        self.cv = Condition()
        self.finished = Event()

        if interval is not None:
            self.interval = interval

    def wait(self):
        with self.cv:
            self.finished.clear()
            while not self.finished.is_set():
                data = self.api.get_jobstep(self.jobstep_id)
                if data['status']['id'] == 'finished':
                    self.finished.set()
                    break

                self.cv.wait(self.interval)

    def close(self):
        with self.cv:
            self.finished.set()
            self.cv.notifyAll()
示例#19
0
class ZEventLoopThread(object):
    def __init__(self, cb=None):
        self._event_loop = None
        self._thr = Thread(target=self._thread_func)
        self._guard = Condition(Lock())
        self._thr_callback = cb

    def start_loop(self):
        self._thr.start()
        self._guard.acquire()
        while self._event_loop is None:
            self._guard.wait()

        return self._event_loop

    def _thread_func(self):
        # Note: create event loop in the new thread
        # instead in the constructor 
        event_loop = ZEventLoop()

        if self._thr_callback:
            self._thr_callback(event_loop)

        with self._guard:
            self._event_loop = event_loop
            self._guard.notify()

        event_loop.loop()
示例#20
0
class IOManager(Thread):

    def __init__(self,so):
        Thread.__init__(self)
        self.mutexIO = Condition()
        self.listIO = []
        self.handlerIO = so.handlerIO
    
    def process(self):
        for irq in self.listIO:
            print("Ejecuto IOManager")
            self.handlerIO.handle(irq)
        self.removeIOFromListIO()
        
    def run(self):
        while(True):
            with self.mutexIO:
                self.mutexIO.wait()
                self.process()
    
    def removeIOFromListIO(self):
        for reqIO in self.listIO:
            self.listIO.remove(reqIO)
              
    def add(self,iOReq):
        self.listIO.append(iOReq)
示例#21
0
class SharedCell(object):
    """Shared data for the producer/consumer problem."""
    
    def __init__(self):
        self._data = -1
        self._writeable = True
        self._condition = Condition()

    def setData(self, data):
        """Producer's method to write to shared data."""
        self._condition.acquire()
        while not self._writeable:
            self._condition.wait()
        print "%s setting data to %d" % \
              (currentThread().getName(), data)
        self._data = data
        self._writeable = False
        self._condition.notify()
        self._condition.release()

    def getData(self):
        """Consumer's method to read from shared data."""
        self._condition.acquire()
        while self._writeable:
            self._condition.wait()
        print "%s accessing data %d" % \
              (currentThread().getName(), self._data)
        self._writeable = True
        self._condition.notify()
        self._condition.release()
        return self._data
示例#22
0
class CountLatch(object):
    """Synchronization primitive with the capacity to count and latch.
    Counting up latches, while reaching zero when counting down un-latches.

    .. note::
        The CountLatch class has been included in Sardana
        on a provisional basis. Backwards incompatible changes
        (up to and including removal of the module) may occur if
        deemed necessary by the core developers.
    """

    def __init__(self):
        self.count = 0
        self.condition = Condition()

    def count_up(self):
        """Count up."""
        self.condition.acquire()
        self.count += 1
        self.condition.release()

    def count_down(self, _=None):
        """Count down."""
        self.condition.acquire()
        self.count -= 1
        if self.count <= 0:
            self.condition.notifyAll()
        self.condition.release()

    def wait(self):
        """Wait until the counter reaches zero."""
        self.condition.acquire()
        while self.count > 0:
            self.condition.wait()
        self.condition.release()
示例#23
0
    def request(self, request_type, request, sequential=False):
        """
        """
        cond = Condition()
        # dict is used as reference type to allow result to be returned from
        # seperate thread
        response_cont = {}

        def on_success(response_type, response):
            response_cont.update({
                "success": True,
                "type": response_type,
                "response": response,
            })
            with cond:
                cond.notify()

        def on_error(exception):
            response_cont.update({
                "success": False,
                "exception": exception,
            })
            with cond:
                cond.notify()

        with cond:
            self.request_async(request_type, request,
                               on_success, on_error,
                               sequential)
            cond.wait()

        if not response_cont["success"]:
            raise response_cont["exception"]

        return response_cont["type"], response_cont["response"]
示例#24
0
class Queue(object):
    def __init__(self):
        self.lock = Lock()
        self.cond = Condition(self.lock)
        self.data = []
        self.live = True

    def put(self, values):
        self.cond.acquire()
        self.data.extend(values)
        self.cond.notify()
        self.cond.release()

    def get(self):
        if not self.cond.acquire(False):
            return []
        self.cond.wait()
        result = self.data
        self.data = []
        self.cond.release()
        return result

    def close(self):
        self.cond.acquire()
        self.live = False
        self.cond.notify()
        self.cond.release()
示例#25
0
class AsyncChannel:
	# yes, i believe this is just Queue ... i was new to python and couldn't find it

	def __init__ (self, buffer_size=1):
		self.buffer = []
		self.max_items = buffer_size
		self.lock = Lock ()
		self.not_empty = Condition (self.lock)
		self.not_full = Condition (self.lock)

	def get (self):
		self.lock.acquire ()
		while len (self.buffer) == 0:
			self.not_empty.wait ()
		val = self.buffer.pop (0)
		self.not_full.notifyAll ()
		self.lock.release ()
		return val

	def put (self, val):
		self.lock.acquire ()
		while len (self.buffer) == self.max_items:
			self.not_full.wait ()
		self.buffer.append (val)
		self.not_empty.notifyAll ()
		self.lock.release ()
示例#26
0
class MockSock:
    def __init__(self):
        self.sent = six.binary_type()
        self.send_q = deque()
        self.recv_q = deque()
        self.closed = False
        self.lock = Lock()
        self.cond = Condition(self.lock)
        self.timeout = 0.5

    def gettimeout(self):
        return self.timeout

    def close(self):
        self.lock.acquire()
        self.closed = True
        self.lock.release()

    def queue_send(self, num_bytes):
        self.lock.acquire()
        try:
            self.send_q.append(num_bytes)
            self.cond.notify_all()
        finally:
            self.lock.release()

    def send(self, message):
        self.lock.acquire()
        try:
            while not len(self.send_q):
                self.cond.wait()
            num_bytes = self.send_q.popleft()
            if num_bytes < len(message):
                self.send_q.appendleft(len(message) - num_bytes)
            sent = min(num_bytes, len(message))
            self.sent += message[:sent]
            return sent
        finally:
            self.lock.release()

    def queue_recv(self, message):
        self.lock.acquire()
        try:
            self.recv_q.append(message)
            self.cond.notify_all()
        finally:
            self.lock.release()

    def recv(self, max_len):
        self.lock.acquire()
        try:
            while not len(self.recv_q):
                self.cond.wait()
            message = self.recv_q.popleft()
            recd = min(max_len, len(message))
            if (recd < len(message)):
                self.recv_q.appendleft(message[recd:])
            return message[:recd]
        finally:
            self.lock.release()
示例#27
0
文件: workers.py 项目: gcovr/gcovr
class LockedDirectories(object):
    """
    Class that keeps a list of locked directories
    """
    def __init__(self):
        self.dirs = set()
        self.cv = Condition()

    def run_in(self, dir_):
        """
        Start running in the directory and lock it
        """
        self.cv.acquire()
        while dir_ in self.dirs:
            self.cv.wait()
        self.dirs.add(dir_)
        self.cv.release()

    def done(self, dir_):
        """
        Finished with the directory, unlock it
        """
        self.cv.acquire()
        self.dirs.remove(dir_)
        self.cv.notify_all()
        self.cv.release()
示例#28
0
class RepceJob(object):

    """class representing message status we can use
    for waiting on reply"""

    def __init__(self, cbk):
        """
        - .rid: (process-wise) unique id
        - .cbk: what we do upon receiving reply
        """
        self.rid = (os.getpid(), thread.get_ident(), time.time())
        self.cbk = cbk
        self.lever = Condition()
        self.done = False

    def __repr__(self):
        return ':'.join([str(x) for x in self.rid])

    def wait(self):
        self.lever.acquire()
        if not self.done:
            self.lever.wait()
        self.lever.release()
        return self.result

    def wakeup(self, data):
        self.result = data
        self.lever.acquire()
        self.done = True
        self.lever.notify()
        self.lever.release()
示例#29
0
class Broadcaster(object):
    def __init__(self):
        self._condition = Condition()
        self._last_id = 0
        self._queue = []
    
    @acquire_then_notify
    def send(self, item):
        print "Sending message", item
        self._last_id += 1
        self._queue.append((self._last_id, item))
    
    def _find_items(self, since_id=None):
        if not since_id:
            found = self._queue
        else:
            found = [(id, item) for (id,item) in self._queue if id > since_id]
        return found
    
    @acquire
    def recv(self, since_id=None, timeout=10):
        end_time = time() + timeout
        while time() < end_time:
            found = self._find_items(since_id)
            if found:
                break
            print "Waiting"
            self._condition.wait(timeout)
        print "found %d" % len(found)
        return found
class ResourcePool(object):
    def __init__(self, initial_value):
        super(ResourcePool, self).__init__()
        self.condition = Condition()
        self.value = initial_value

    def acquire(self, amount):
        with self.condition:
            while amount > self.value:
                self.condition.wait()
            self.value -= amount
            self.__validate()

    def release(self, amount):
        with self.condition:
            self.value += amount
            self.__validate()
            self.condition.notifyAll()

    def __validate(self):
        assert 0 <= self.value

    def __str__(self):
        return str(self.value)

    def __repr__(self):
        return "ResourcePool(%i)" % self.value

    @contextmanager
    def acquisitionOf(self, amount):
        self.acquire(amount)
        try:
            yield
        finally:
            self.release(amount)
示例#31
0
class Runner:
    def __init__(self):
        logger.info('Starting performances node')

        self.robot_name = rospy.get_param('/robot_name')
        self.running = False
        self.paused = False
        self.autopause = False
        self.pause_time = 0
        self.start_time = 0
        self.start_timestamp = 0
        self.lock = Lock()
        self.run_condition = Condition()
        self.running_performance = None
        self.unload_finished = False
        # in memory set of properties with priority over params
        self.variables = {}
        # References to event subscribing node callbacks
        self.observers = {}
        # Performances that already played as alternatives. Used to maximize different performance in single demo
        self.performances_played = {}
        self.worker = Thread(target=self.worker)
        self.worker.setDaemon(True)
        rospy.init_node('performances')
        self.services = {
            'head_pau_mux': rospy.ServiceProxy('/' + self.robot_name + '/head_pau_mux/select', MuxSelect),
            'neck_pau_mux': rospy.ServiceProxy('/' + self.robot_name + '/neck_pau_mux/select', MuxSelect)
        }
        self.topics = {
            'running_performance': rospy.Publisher('~running_performance', String, queue_size=1),
            'look_at': rospy.Publisher('/blender_api/set_face_target', Target, queue_size=1),
            'gaze_at': rospy.Publisher('/blender_api/set_gaze_target', Target, queue_size=1),
            'head_rotation': rospy.Publisher('/blender_api/set_head_rotation', Float32, queue_size=1),
            'emotion': rospy.Publisher('/blender_api/set_emotion_state', EmotionState, queue_size=3),
            'gesture': rospy.Publisher('/blender_api/set_gesture', SetGesture, queue_size=3),
            'expression': rospy.Publisher('/' + self.robot_name + '/make_face_expr', MakeFaceExpr, queue_size=3),
            'kfanimation': rospy.Publisher('/' + self.robot_name + '/play_animation', PlayAnimation, queue_size=3),
            'interaction': rospy.Publisher('/behavior_switch', String, queue_size=1),
            'bt_control': rospy.Publisher('/behavior_control', Int32, queue_size=1),
            'events': rospy.Publisher('~events', Event, queue_size=1),
            'chatbot': rospy.Publisher('/' + self.robot_name + '/speech', ChatMessage, queue_size=1),
            'speech_events': rospy.Publisher('/' + self.robot_name + '/speech_events', String, queue_size=1),
            'soma_state': rospy.Publisher("/blender_api/set_soma_state", SomaState, queue_size=2),
            'tts': {
                'en': rospy.Publisher('/' + self.robot_name + '/tts_en', String, queue_size=1),
                'zh': rospy.Publisher('/' + self.robot_name + '/tts_zh', String, queue_size=1),
                'default': rospy.Publisher('/' + self.robot_name + '/tts', String, queue_size=1),
            },
            'tts_control': rospy.Publisher('/' + self.robot_name + '/tts_control', String, queue_size=1)
        }
        self.load_properties()
        rospy.Service('~reload_properties', Trigger, self.reload_properties_callback)
        rospy.Service('~set_properties', srv.SetProperties, self.set_properties_callback)
        rospy.Service('~load', srv.Load, self.load_callback)
        rospy.Service('~load_performance', srv.LoadPerformance, self.load_performance_callback)
        rospy.Service('~unload', Trigger, self.unload_callback)
        rospy.Service('~run', srv.Run, self.run_callback)
        rospy.Service('~run_by_name', srv.RunByName, self.run_by_name_callback)
        rospy.Service('~run_full_performance', srv.RunByName, self.run_full_performance_callback)
        rospy.Service('~resume', srv.Resume, self.resume_callback)
        rospy.Service('~pause', srv.Pause, self.pause_callback)
        rospy.Service('~stop', srv.Stop, self.stop)
        rospy.Service('~current', srv.Current, self.current_callback)
        # Shared subscribers for nodes
        rospy.Subscriber('/' + self.robot_name + '/speech_events', String,
                         lambda msg: self.notify('speech_events', msg))
        rospy.Subscriber('/' + self.robot_name + '/speech', ChatMessage, self.speech_callback)
        # Shared subscribers for nodes
        rospy.Subscriber('/hand_events', String, self.hand_callback)
        Server(PerformancesConfig, self.reconfig)
        rospy.Subscriber('/face_training_event', String, self.training_callback)
        self.worker.start()
        rospy.spin()

    def reconfig(self, config, level):
        with self.lock:
            self.autopause = config.autopause

        return config

    def reload_properties_callback(self, request):
        self.load_properties()
        return TriggerResponse(success=True)

    def unload_callback(self, request):
        self.unload()
        return TriggerResponse(success=True)

    def unload(self):
        self.stop()
        with self.lock:
            if self.running_performance:
                self.running_performance = None
                self.topics['running_performance'].publish(String(json.dumps(None)))

    def set_properties_callback(self, request):
        self.set_variable(request.id, json.loads(request.properties))
        return srv.SetPropertiesResponse(success=True)

    def load_callback(self, request):
        return srv.LoadResponse(success=True, performance=json.dumps(self.load(request.id)))

    def load_performance_callback(self, request):
        self.load_performance(json.loads(request.performance))
        return srv.LoadPerformanceResponse(True)

    def run_by_name_callback(self, request):
        self.stop()
        if not self.load(request.id):
            return srv.RunByNameResponse(False)
        return srv.RunByNameResponse(self.run(0.0))

    def run_full_performance_callback(self, request):
        self.stop()
        performances = self.load_folder(request.id) or self.load(request.id)
        if not performances:
            return srv.RunByNameResponse(False)
        return srv.RunByNameResponse(self.run(0.0, unload_finished=True))

    def load_folder(self, id):
        if id.startswith('shared'):
            robot_name = 'common'
        else:
            robot_name = rospy.get_param('/robot_name')
        dir_path = os.path.join(rospack.get_path('robots_config'), robot_name, 'performances', id)
        if os.path.isdir(dir_path):
            root, dirs, files = next(os.walk(dir_path))

            files = fnmatch.filter(files, "*.yaml")
            if not files:
                # If no folder is picked one directory
                # Sub-directories are counted as sub-performances
                if not dirs:
                    return []
                if id in self.performances_played:
                    # All performances played. Pick any but last played
                    if set(self.performances_played[id]) == set(dirs):
                        dirs = self.performances_played[id][:-1]
                        self.performances_played[id] = []
                    else:
                        # Pick from not played performances
                        dirs = list(set(dirs) - set(self.performances_played[id]))
                else:
                    self.performances_played[id] = []
                # Pick random performance
                p = random.choice(dirs)
                self.performances_played[id].append(p)
                return self.load_folder(os.path.join(id, p))
            # make names in folder/file format
            return self.load(id)
        return []

    def load(self, id):
        robot_name = 'common' if id.startswith('shared') else rospy.get_param('/robot_name')
        p = os.path.join(rospack.get_path('robots_config'), robot_name, 'performances', id)

        if os.path.isdir(p):
            root, dirs, files = next(os.walk(p))
            files = natsorted(fnmatch.filter(files, "*.yaml"), key=lambda f: f.lower())
            ids = ["{}/{}".format(id, f[:-5]) for f in files]
            timelines = [self.get_timeline(i) for i in ids]
            timelines = [t for t in timelines if t]
            performance = {'id': id, 'name': os.path.basename(id), 'path': os.path.dirname(id), 'timelines': timelines,
                           'nodes': self.get_merged_timeline_nodes(timelines)}
        else:
            performance = self.get_timeline(id)

        if performance:
            self.load_performance(performance)
            return performance
        else:
            return None

    def get_timeline(self, id):
        timeline = None
        robot_name = 'common' if id.startswith('shared') else rospy.get_param('/robot_name')
        p = os.path.join(rospack.get_path('robots_config'), robot_name, 'performances', id) + '.yaml'

        if os.path.isfile(p):
            with open(p, 'r') as f:
                timeline = yaml.load(f.read())
                timeline['id'] = id
                timeline['name'] = os.path.basename(id)
                timeline['path'] = os.path.dirname(id)
                self.validate_timeline(timeline)
        return timeline

    def get_timeline_duration(self, timeline):
        duration = 0

        if 'nodes' in timeline and isinstance(timeline['nodes'], list):
            for node in timeline['nodes']:
                duration = max(duration, (node['duration'] if 'duration' in node else 0) + node['start_time'])

        return duration

    def get_merged_timeline_nodes(self, timelines):
        merged = []
        offset = 0

        for timeline in timelines:
            duration = 0
            nodes = timeline.get('nodes', [])
            nodes = copy.deepcopy(nodes)

            for node in nodes:
                duration = max(duration, node['duration'] + node['start_time'])
                node['start_time'] += offset

            merged += nodes
            offset += duration

        return merged

    def validate_performance(self, performance):
        self.validate_timeline(performance)
        if 'timelines' in performance:
            for timeline in performance['timelines']:
                self.validate_timeline(timeline)
        return performance

    def validate_timeline(self, timeline):
        if 'nodes' not in timeline or not isinstance(timeline['nodes'], list):
            timeline['nodes'] = []

        for node in timeline['nodes']:
            if 'start_time' not in node:
                node['start_time'] = 0
            if node['name'] == 'pause':
                node['duration'] = 0.1
            if 'duration' not in node:
                node['duration'] = 0

        return timeline

    def load_performance(self, performance):
        with self.lock:
            self.validate_performance(performance)
            self.running_performance = performance
            self.topics['running_performance'].publish(String(json.dumps(performance)))

    def run_callback(self, request):
        return srv.RunResponse(self.run(request.startTime))

    def run(self, start_time, unload_finished=False):
        self.stop()
        # Wait for worker to stop performance and enter waiting before proceeding
        self.run_condition.acquire()
        with self.lock:
            success = self.running_performance and len(self.running_performance) > 0
            if success:
                self.unload_finished = unload_finished
                self.running = True
                self.start_time = start_time
                self.start_timestamp = time.time()
                # notify worker thread
                self.run_condition.notify()
            self.run_condition.release()
            return success

    def resume_callback(self, request):
        success = self.resume()
        with self.lock:
            run_time = self.get_run_time()

        return srv.ResumeResponse(success, run_time)

    def resume(self):
        success = False
        with self.lock:
            if self.running and self.paused:
                run_time = self.get_run_time()
                self.paused = False
                self.start_timestamp = time.time() - run_time
                self.start_time = 0
                self.topics['events'].publish(Event('resume', run_time))
                success = True

        return success

    def stop(self, request=None):
        stop_time = 0

        with self.lock:
            if self.running:
                stop_time = self.get_run_time()
                self.running = False
                self.paused = False
                self.topics['tts_control'].publish('shutup')

        return srv.StopResponse(True, stop_time)

    def pause_callback(self, request):
        if self.pause():
            with self.lock:
                return srv.PauseResponse(True, self.get_run_time())
        else:
            return srv.PauseResponse(False, 0)

    # Pauses current
    def pause(self):
        with self.lock:
            if self.running and not self.paused:
                self.pause_time = time.time()
                self.paused = True
                self.topics['events'].publish(Event('paused', self.get_run_time()))
                return True
            else:
                return False

    # Returns current performance
    def current_callback(self, request):
        with self.lock:
            current_time = self.get_run_time()
            running = self.running and not self.paused
            return srv.CurrentResponse(performance=json.dumps(self.running_performance),
                                       current_time=current_time,
                                       running=running)

    def worker(self):
        self.run_condition.acquire()
        while True:
            with self.lock:
                self.paused = False
                self.running = False

            self.topics['events'].publish(Event('idle', 0))
            self.run_condition.wait()
            self.topics['events'].publish(Event('running', self.start_time))

            with self.lock:
                if not self.running_performance:
                    continue

            behavior = True
            offset = 0
            timelines = self.running_performance['timelines'] if 'timelines' in self.running_performance else [
                self.running_performance]

            for i, timeline in enumerate(timelines):
                # check if performance is finished without starting
                running = True
                nodes = [Node.createNode(node, self, self.start_time - offset, timeline.get('id', '')) for node in
                         timeline['nodes']]
                pid = timeline.get('id', '')
                finished = None
                pause = pid and self.get_property(os.path.dirname(pid), 'pause_behavior')
                # Pause must be either enabled or not set (by default all performances are
                # pausing behavior if its not set)

                if (pause or pause is None) and behavior:
                    # Only pause behavior if its already running. Otherwise Pause behavior have no effect
                    behavior_enabled = False

                    try:
                        behavior_enabled = rospy.get_param("/behavior_enabled")
                    except KeyError:
                        pass

                    if behavior_enabled:
                        self.topics['interaction'].publish('btree_off')
                        behavior = False

                with self.lock:
                    if not self.running:
                        break

                while running:
                    with self.lock:
                        run_time = self.get_run_time()

                        if not self.running:
                            self.topics['events'].publish(Event('finished', run_time))
                            break

                        if self.paused:
                            continue

                    running = False
                    # checks if any nodes still running
                    for k, node in enumerate(nodes):
                        running = node.run(run_time - offset) or running
                    if finished is None:
                        # true if all performance nodes are already finished
                        finished = not running

                offset += self.get_timeline_duration(timeline)

                with self.lock:
                    autopause = self.autopause and finished is False and i < len(timelines) - 1

                if autopause:
                    self.pause()

            if not behavior:
                self.topics['interaction'].publish('btree_on')

            if self.unload_finished:
                self.unload_finished = False
                self.unload()

    def get_run_time(self):
        """
        Must acquire self.lock in order to safely use this method
        :return:
        """
        run_time = 0

        if self.running:
            run_time = self.start_time
            if self.paused:
                run_time += self.pause_time - self.start_timestamp
            else:
                run_time += time.time() - self.start_timestamp

        return run_time

    # Notifies register nodes on the events from ROS.
    def notify(self, event, msg):
        if event not in self.observers.keys():
            return
        for i in xrange(len(self.observers[event]) - 1, -1, -1):
            try:
                self.observers[event][i](msg)
            except TypeError:
                # Remove dead methods
                del self.observers[event][i]

    # Registers callbacks for specific events. Uses weak reference to allow nodes cleanup after finish.
    def register(self, event, cb):
        if not event in self.observers:
            self.observers[event] = []
        m = WeakMethod(cb)
        self.observers[event].append(m)
        return m

    # Allows nodes to unsubscribe from events
    def unregister(self, event, ref):
        if event in self.observers:
            if ref in self.observers[event]:
                self.observers[event].remove(ref)

    def hand_callback(self, msg):
        self.notify('HAND', msg)
        self.notify(msg.data, msg)

    def load_properties(self):
        robot_name = rospy.get_param('/robot_name')
        robot_path = os.path.join(rospack.get_path('robots_config'), robot_name, 'performances')
        common_path = os.path.join(rospack.get_path('robots_config'), 'common', 'performances')
        for path in [common_path, robot_path]:
            for root, dirnames, filenames in os.walk(path):
                if '.properties' in filenames:
                    filename = os.path.join(root, '.properties')
                    if os.path.isfile(filename):
                        with open(filename) as f:
                            properties = yaml.load(f.read())
                            dir = os.path.relpath(root, path)
                            rospy.set_param('/' + os.path.join(self.robot_name, 'webui/performances', dir).strip(
                                "/.") + '/properties', properties)

    def get_property(self, path, name):
        param_name = os.path.join('/', self.robot_name, 'webui/performances', path, 'properties', name)
        return rospy.get_param(param_name, None)

    def set_variable(self, id, properties):
        for key, val in properties.iteritems():
            rospy.logerr("id {} key {} val {}".format(id, key, val))
            if id in self.variables:
                self.variables[id][key] = val
            else:
                self.variables[id] = {key: val}

    def get_variable(self, id, name):
        if os.path.dirname(id) in self.variables and name in self.variables[os.path.dirname(id)] \
                and self.variables[os.path.dirname(id)][name]:
            return self.variables[os.path.dirname(id)][name]
        else:
            val = None
            param_name = os.path.join('/', self.robot_name, 'webui/performances', os.path.dirname(id),
                                      'properties/variables', name)
            if rospy.has_param(param_name):
                val = rospy.get_param(param_name)
                if self.is_param(val):
                    if rospy.has_param(val):
                        return str(rospy.get_param(val))
                    if rospy.has_param("/{}{}".format(self.robot_name, val)):
                        return str(rospy.get_param("/{}{}".format(self.robot_name, val)))

                    return None
            return val

    def speech_callback(self, msg):
        self.notify('SPEECH', msg.utterance)

    @staticmethod
    def is_param(param):
        """ Checks if value is valid param.
        Has to start with slash
        """
        validator = rospy.names.global_name("param_name")
        try:
            validator(param, False)
            return True
        except rospy.names.ParameterInvalid:
            return False

    def training_callback(self, msg):
        self.notify('FACE_TRAINING', msg.data)
示例#32
0
class Sampler(object):
    """ Sampler used to play, stop and mix multiple sounds.

        .. warning:: A single sampler instance should be used at a time.

    """
    def __init__(self, sr=22050, backend='sounddevice', timeout=1):
        """
        :param int sr: samplerate used - all sounds added to the sampler will automatically be resampled if needed (- his can be a CPU consumming task, try to use sound with all identical sampling rate if possible.
        :param str backend: backend used for playing sound. Can be either 'sounddevice' or 'dummy'.

        """
        self.sr = sr
        self.sounds = []

        self.chunks = Queue(1)
        self.chunk_available = Condition()
        self.is_done = Event(
        )  # new event to prevent play to be called again before the sound is actually played
        self.timeout = timeout  # timeout value for graceful exit of the BackendStream

        if backend == 'dummy':
            from .dummy_stream import DummyStream
            self.BackendStream = DummyStream
        elif backend == 'sounddevice':
            from sounddevice import OutputStream
            self.BackendStream = OutputStream
        else:
            raise ValueError("Backend can either be 'sounddevice' or 'dummy'")

        # TODO: use a process instead?
        self.play_thread = Thread(target=self.run)
        self.play_thread.daemon = True
        self.play_thread.start()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.play_thread.join()

    def play(self, sound):
        """ Adds and plays a new Sound to the Sampler.

            :param sound: sound to play

            .. note:: If the sound is already playing, it will restart from the beginning.

        """
        # self.is_done.clear()  # hold is_done until the sound is played
        # print("self.is_done", self.is_done.isSet())
        if self.sr != sound.sr:
            raise ValueError(
                'You can only play sound with a samplerate of {} (here {}). Use the Sound.resample method for instance.',
                self.sr, sound.sr)

        if sound in self.sounds:
            self.remove(sound)

        with self.chunk_available:
            self.sounds.append(sound)
            sound.playing = True

            self.chunk_available.notify()
        # self.is_done.wait()  # wait for the sound to be entirely played

    def remove(self, sound):
        """ Remove a currently played sound. """
        with self.chunk_available:
            sound.playing = False
            self.sounds.remove(sound)

    # Play loop

    def next_chunks(self):
        """ Gets a new chunk from all played sound and mix them together. """
        with self.chunk_available:
            while True:
                playing_sounds = [s for s in self.sounds if s.playing]

                chunks = []
                for s in playing_sounds:
                    try:
                        chunks.append(next(s.chunks))
                    except StopIteration:
                        s.playing = False
                        self.sounds.remove(s)
                        self.is_done.set(
                        )  # sound was played, release is_done to end the wait in play

                if chunks:
                    break

                self.chunk_available.wait()

            return numpy.mean(chunks, axis=0)

    def run(self):
        """ Play loop, i.e. send all sound chunk by chunk to the soundcard. """
        self.running = True

        def chunks_producer():
            while self.running:
                self.chunks.put(self.next_chunks())

        t = Thread(target=chunks_producer)
        t.start()

        with self.BackendStream(samplerate=self.sr, channels=1) as stream:
            while self.running:
                try:
                    stream.write(self.chunks.get(timeout=self.timeout)
                                 )  # timeout so stream.write() thread can exit
                except Empty:
                    self.running = False  # let play_thread exit
示例#33
0
class ConnectionPool(object):
    class Connection(object):
        def __init__(self, mgr, fs_name):
            self.fs = None
            self.mgr = mgr
            self.fs_name = fs_name
            self.ops_in_progress = 0
            self.last_used = time.time()
            self.fs_id = self.get_fs_id()

        def get_fs_id(self):
            fs_map = self.mgr.get('fs_map')
            for fs in fs_map['filesystems']:
                if fs['mdsmap']['fs_name'] == self.fs_name:
                    return fs['id']
            raise VolumeException(
                -errno.ENOENT, "Volume '{0}' not found".format(self.fs_name))

        def get_fs_handle(self):
            self.last_used = time.time()
            self.ops_in_progress += 1
            return self.fs

        def put_fs_handle(self, notify):
            assert self.ops_in_progress > 0
            self.ops_in_progress -= 1
            if self.ops_in_progress == 0:
                notify()

        def del_fs_handle(self, waiter):
            if waiter:
                while self.ops_in_progress != 0:
                    waiter()
            if self.is_connection_valid():
                self.disconnect()
            else:
                self.abort()

        def is_connection_valid(self):
            fs_id = None
            try:
                fs_id = self.get_fs_id()
            except:
                # the filesystem does not exist now -- connection is not valid.
                pass
            return self.fs_id == fs_id

        def is_connection_idle(self, timeout):
            return (self.ops_in_progress == 0
                    and ((time.time() - self.last_used) >= timeout))

        def connect(self):
            assert self.ops_in_progress == 0
            log.debug("Connecting to cephfs '{0}'".format(self.fs_name))
            self.fs = cephfs.LibCephFS(rados_inst=self.mgr.rados)
            log.debug(
                "Setting user ID and group ID of CephFS mount as root...")
            self.fs.conf_set("client_mount_uid", "0")
            self.fs.conf_set("client_mount_gid", "0")
            log.debug("CephFS initializing...")
            self.fs.init()
            log.debug("CephFS mounting...")
            self.fs.mount(filesystem_name=self.fs_name.encode('utf-8'))
            log.debug("Connection to cephfs '{0}' complete".format(
                self.fs_name))
            self.mgr._ceph_register_client(self.fs.get_addrs())

        def disconnect(self):
            assert self.ops_in_progress == 0
            log.info("disconnecting from cephfs '{0}'".format(self.fs_name))
            addrs = self.fs.get_addrs()
            self.fs.shutdown()
            self.mgr._ceph_unregister_client(addrs)
            self.fs = None

        def abort(self):
            assert self.ops_in_progress == 0
            log.info("aborting connection from cephfs '{0}'".format(
                self.fs_name))
            self.fs.abort_conn()
            self.fs = None

    class RTimer(Timer):
        """
        recurring timer variant of Timer
        """
        def run(self):
            try:
                while not self.finished.is_set():
                    self.finished.wait(self.interval)
                    self.function(*self.args, **self.kwargs)
                self.finished.set()
            except Exception as e:
                log.error("ConnectionPool.RTimer: %s", e)
                raise

    # TODO: make this configurable
    TIMER_TASK_RUN_INTERVAL = 30.0  # seconds
    CONNECTION_IDLE_INTERVAL = 60.0  # seconds

    def __init__(self, mgr):
        self.mgr = mgr
        self.connections = {}
        self.lock = Lock()
        self.cond = Condition(self.lock)
        self.timer_task = ConnectionPool.RTimer(
            ConnectionPool.TIMER_TASK_RUN_INTERVAL, self.cleanup_connections)
        self.timer_task.start()

    def cleanup_connections(self):
        with self.lock:
            log.info("scanning for idle connections..")
            idle_fs = [
                fs_name for fs_name, conn in self.connections.items()
                if conn.is_connection_idle(
                    ConnectionPool.CONNECTION_IDLE_INTERVAL)
            ]
            for fs_name in idle_fs:
                log.info("cleaning up connection for '{}'".format(fs_name))
                self._del_fs_handle(fs_name)

    def get_fs_handle(self, fs_name):
        with self.lock:
            conn = None
            try:
                conn = self.connections.get(fs_name, None)
                if conn:
                    if conn.is_connection_valid():
                        return conn.get_fs_handle()
                    else:
                        # filesystem id changed beneath us (or the filesystem does not exist).
                        # this is possible if the filesystem got removed (and recreated with
                        # same name) via "ceph fs rm/new" mon command.
                        log.warning(
                            "filesystem id changed for volume '{0}', reconnecting..."
                            .format(fs_name))
                        self._del_fs_handle(fs_name)
                conn = ConnectionPool.Connection(self.mgr, fs_name)
                conn.connect()
            except cephfs.Error as e:
                # try to provide a better error string if possible
                if e.args[0] == errno.ENOENT:
                    raise VolumeException(
                        -errno.ENOENT,
                        "Volume '{0}' not found".format(fs_name))
                raise VolumeException(-e.args[0], e.args[1])
            self.connections[fs_name] = conn
            return conn.get_fs_handle()

    def put_fs_handle(self, fs_name):
        with self.lock:
            conn = self.connections.get(fs_name, None)
            if conn:
                conn.put_fs_handle(notify=lambda: self.cond.notifyAll())

    def _del_fs_handle(self, fs_name, wait=False):
        conn = self.connections.pop(fs_name, None)
        if conn:
            conn.del_fs_handle(
                waiter=None if not wait else lambda: self.cond.wait())

    def del_fs_handle(self, fs_name, wait=False):
        with self.lock:
            self._del_fs_handle(fs_name, wait)

    def del_all_handles(self):
        with self.lock:
            for fs_name in list(self.connections.keys()):
                log.info("waiting for pending ops for '{}'".format(fs_name))
                self._del_fs_handle(fs_name, wait=True)
                log.info("pending ops completed for '{}'".format(fs_name))
            # no new connections should have been initialized since its
            # guarded on shutdown.
            assert len(self.connections) == 0
示例#34
0
class SecureDocXMLRPCServer(CustomThreadingMixIn, DocXMLRPCServer):
    def __init__(self, registerInstance, server_address, keyFile=DEFAULTKEYFILE, certFile=DEFAULTCERTFILE, logRequests=True):
        """Secure Documenting XML-RPC server.
        It it very similar to DocXMLRPCServer but it uses HTTPS for transporting XML data.
        """
        DocXMLRPCServer.__init__(self, server_address, SecureDocXMLRpcRequestHandler, logRequests)
        self.logRequests = logRequests

        # stuff for doc server
        try: self.set_server_title(registerInstance.title)
        except AttributeError: self.set_server_title('default title')
        try: self.set_server_name(registerInstance.name)
        except AttributeError: self.set_server_name('default name')
        if registerInstance.__doc__: self.set_server_documentation(registerInstance.__doc__)
        else: self.set_server_documentation('default documentation')
        self.register_introspection_functions()

        # init stuff, handle different versions:
        try:
            SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self)
        except TypeError:
            # An exception is raised in Python 2.5 as the prototype of the __init__
            # method has changed and now has 3 arguments (self, allow_none, encoding)
            SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self, False, None)
        SocketServer.BaseServer.__init__(self, server_address, SecureDocXMLRpcRequestHandler)
        self.register_instance(registerInstance) # for some reason, have to register instance down here!

        # SSL socket stuff
        ctx = SSL.Context(SSL.SSLv23_METHOD)
        ctx.use_privatekey_file(keyFile)
        ctx.use_certificate_file(certFile)
        self.socket = SSL.Connection(ctx, socket.socket(self.address_family, self.socket_type))
        self.server_bind()
        self.server_activate()

        # requests count and condition, to allow for keyboard quit via CTL-C
        self.requests = 0
        self.rCondition = Condition()


    def startup(self):
        'run until quit signaled from keyboard...'
        print 'server starting; hit CTRL-C to quit...'
        while True:
            try:
                self.rCondition.acquire()
                start(self.handle_request, ()) # we do this async, because handle_request blocks!
                while not self.requests:
                    self.rCondition.wait(timeout=3.0)
                if self.requests: self.requests -= 1
                self.rCondition.release()
            except KeyboardInterrupt:
                print "quit signaled, i'm done."
                return

    def get_request(self):
        request, client_address = self.socket.accept()
        self.rCondition.acquire()
        self.requests += 1
        self.rCondition.notifyAll()
        self.rCondition.release()
        return (request, client_address)

    def listMethods(self):
        'return list of method names (strings)'
        methodNames = self.funcs.keys()
        methodNames.sort()
        return methodNames

    def methodHelp(self, methodName):
        'method help'
        if methodName in self.funcs:
            return self.funcs[methodName].__doc__
        else:
            raise Exception('method "%s" is not supported' % methodName)
示例#35
0
class GbnServer:
    def __init__(self, addr):
        # 定义seq最大值为255
        self.__BUFSIZE = 255
        self.__send_cache = []
        self.__seq = 0
        self.__pkg_sign = 255
        self.__ack_sign = 0
        self.__sock = socket(AF_INET, SOCK_DGRAM)
        self.__sock.bind(addr)
        self.__client_addr = None
        self.__window_size = 10
        self.__lock = Lock()
        # 期待的下一个数据包
        self.__recv_eseq = 1
        self.__recv_cache = []
        self.__notify_sender = Condition()
        self.__notify_recver = Condition()
        self.__notify_ack = Condition()
        t = Thread(target=self.__send_thread, daemon=True)
        t.start()
        t = Thread(target=self.__ack_thread, daemon=True)
        t.start()
        return

    def __get_seq(self):
        self.__seq += 1
        if self.__seq > self.__BUFSIZE:
            self.__seq = 0
        return self.__seq

    def __mk_pkt(self, data):
        temp = bytes([0, self.__get_seq()]) + len(data).to_bytes(4, byteorder='big') + data
        if len(temp) % 2 != 0:
            temp += bytes([0])
        return temp + (lib.checksum(temp) ^ 0xffff).to_bytes(2, 'big')

    def send(self, data):
        """
        使用GBN协议发送数据,这个方法是上层应用的调用,该方法不会发送数据,而是将数据打包放入发送缓存中,
        并提醒发送线程
        发送数据 。
        :param data: 要发送的数据
        :return: None
        """
        self.__notify_sender.acquire()
        self.__send_cache.append(self.__mk_pkt(data))
        self.__notify_sender.notify()
        self.__notify_sender.release()
        return

    def recv(self):
        """
        接收对方发来的数据,当没有数据的时候进入阻塞状态直到对方发来数据。
        :return:
        """
        self.__notify_recver.acquire()
        while len(self.__recv_cache) == 0:
            self.__notify_recver.wait()
        data = self.__recv_cache[0]
        self.__recv_cache.pop(0)
        return data

    def __send_thread(self):
        """
        真正的发送数据的方法,当实例化该类的对象的时候,该方法作为独立的守护线程在后台运行,将发送缓存中的数据(如果
        有的话)发送给接收方,如果缓存中没有数据则进入挂起状态。刚开始的时候是进入挂起状态的。
        :return: None
        """
        while True:
            self.__notify_sender.acquire()
            # 当发送缓存中没有数据的时候,线程进入挂起状态。
            while len(self.__send_cache) == 0:
                self.__notify_sender.wait()
            self.__notify_sender.release()
            while len(self.__send_cache) > 0:
                self.__lock.acquire()
                for i in range(min(self.__window_size, len(self.__send_cache))):
                    self.__sock.sendto(self.__send_cache[i], self.__client_addr)
                self.__lock.release()
                self.__notify_ack.acquire()
                self.__notify_ack.notify()
                flag = self.__notify_ack.wait(4)
                while flag and len(self.__send_cache) > 0:
                    self.__notify_ack.notify()
                    flag = self.__notify_ack.wait(10)
                self.__notify_ack.release()

    def __ack_thread(self):
        """
        当对方关闭连接以至于无法向对方发送还没有发送的数据包的时候,会抛出ConnectionResetError异常
        累积确认
        :return:无
        """
        while True:
            data, addr = self.__sock.recvfrom(1024)
            if self.__client_addr is None:
                self.__client_addr = addr
            if data[0] == 0 and data[1] == self.__recv_eseq and lib.checksum(data) == 0xffff:
                self.__recv_eseq += 1
                if self.__recv_eseq > self.__BUFSIZE:
                    self.__recv_eseq = 0
                data_len = int.from_bytes(data[2:6], 'big')
                self.__recv_cache.append(data[6:6+data_len])
                ack = bytes([255, data[1]])
                self.__sock.sendto(ack, self.__client_addr)
                self.__notify_recver.acquire()
                self.__notify_recver.notify()
                self.__notify_recver.release()
            elif len(self.__send_cache) > 0:
                seq = data[1]
                while len(self.__send_cache) > 0 and self.__send_cache[0][1] <= seq:
                    self.__lock.acquire()
                    self.__send_cache.pop(0)
                    self.__lock.release()
                    self.__notify_ack.acquire()
                    self.__notify_ack.notify()
                    self.__notify_ack.release()
示例#36
0
class CephfsConnectionPool(object):
    class Connection(object):
        def __init__(self, mgr: Module_T, fs_name: str):
            self.fs: Optional["cephfs.LibCephFS"] = None
            self.mgr = mgr
            self.fs_name = fs_name
            self.ops_in_progress = 0
            self.last_used = time.time()
            self.fs_id = self.get_fs_id()

        def get_fs_id(self) -> int:
            fs_map = self.mgr.get('fs_map')
            for fs in fs_map['filesystems']:
                if fs['mdsmap']['fs_name'] == self.fs_name:
                    return fs['id']
            raise CephfsConnectionException(
                -errno.ENOENT, "FS '{0}' not found".format(self.fs_name))

        def get_fs_handle(self) -> "cephfs.LibCephFS":
            self.last_used = time.time()
            self.ops_in_progress += 1
            return self.fs

        def put_fs_handle(self, notify: Callable) -> None:
            assert self.ops_in_progress > 0
            self.ops_in_progress -= 1
            if self.ops_in_progress == 0:
                notify()

        def del_fs_handle(self, waiter: Optional[Callable]) -> None:
            if waiter:
                while self.ops_in_progress != 0:
                    waiter()
            if self.is_connection_valid():
                self.disconnect()
            else:
                self.abort()

        def is_connection_valid(self) -> bool:
            fs_id = None
            try:
                fs_id = self.get_fs_id()
            except:
                # the filesystem does not exist now -- connection is not valid.
                pass
            logger.debug("self.fs_id={0}, fs_id={1}".format(self.fs_id, fs_id))
            return self.fs_id == fs_id

        def is_connection_idle(self, timeout: float) -> bool:
            return (self.ops_in_progress == 0
                    and ((time.time() - self.last_used) >= timeout))

        def connect(self) -> None:
            assert self.ops_in_progress == 0
            logger.debug("Connecting to cephfs '{0}'".format(self.fs_name))
            self.fs = cephfs.LibCephFS(rados_inst=self.mgr.rados)
            logger.debug(
                "Setting user ID and group ID of CephFS mount as root...")
            self.fs.conf_set("client_mount_uid", "0")
            self.fs.conf_set("client_mount_gid", "0")
            self.fs.conf_set("client_check_pool_perm", "false")
            logger.debug("CephFS initializing...")
            self.fs.init()
            logger.debug("CephFS mounting...")
            self.fs.mount(filesystem_name=self.fs_name.encode('utf-8'))
            logger.debug("Connection to cephfs '{0}' complete".format(
                self.fs_name))
            self.mgr._ceph_register_client(self.fs.get_addrs())

        def disconnect(self) -> None:
            try:
                assert self.fs
                assert self.ops_in_progress == 0
                logger.info("disconnecting from cephfs '{0}'".format(
                    self.fs_name))
                addrs = self.fs.get_addrs()
                self.fs.shutdown()
                self.mgr._ceph_unregister_client(addrs)
                self.fs = None
            except Exception as e:
                logger.debug("disconnect: ({0})".format(e))
                raise

        def abort(self) -> None:
            assert self.fs
            assert self.ops_in_progress == 0
            logger.info("aborting connection from cephfs '{0}'".format(
                self.fs_name))
            self.fs.abort_conn()
            logger.info("abort done from cephfs '{0}'".format(self.fs_name))
            self.fs = None

    # TODO: make this configurable
    TIMER_TASK_RUN_INTERVAL = 30.0  # seconds
    CONNECTION_IDLE_INTERVAL = 60.0  # seconds
    MAX_CONCURRENT_CONNECTIONS = 5  # max number of concurrent connections per volume

    def __init__(self, mgr: Module_T):
        self.mgr = mgr
        self.connections: Dict[str, List[CephfsConnectionPool.Connection]] = {}
        self.lock = Lock()
        self.cond = Condition(self.lock)
        self.timer_task = RTimer(CephfsConnectionPool.TIMER_TASK_RUN_INTERVAL,
                                 self.cleanup_connections)
        self.timer_task.start()

    def cleanup_connections(self) -> None:
        with self.lock:
            logger.info("scanning for idle connections..")
            idle_conns = []
            for fs_name, connections in self.connections.items():
                logger.debug(
                    f'fs_name ({fs_name}) connections ({connections})')
                for connection in connections:
                    if connection.is_connection_idle(
                            CephfsConnectionPool.CONNECTION_IDLE_INTERVAL):
                        idle_conns.append((fs_name, connection))
            logger.info(f'cleaning up connections: {idle_conns}')
            for idle_conn in idle_conns:
                self._del_connection(idle_conn[0], idle_conn[1])

    def get_fs_handle(self, fs_name: str) -> "cephfs.LibCephFS":
        with self.lock:
            try:
                min_shared = 0
                shared_connection = None
                connections = self.connections.setdefault(fs_name, [])
                logger.debug(
                    f'[get] volume: ({fs_name}) connection: ({connections})')
                if connections:
                    min_shared = connections[0].ops_in_progress
                    shared_connection = connections[0]
                for connection in list(connections):
                    logger.debug(
                        f'[get] connection: {connection} usage: {connection.ops_in_progress}'
                    )
                    if connection.ops_in_progress == 0:
                        if connection.is_connection_valid():
                            logger.debug(
                                f'[get] connection ({connection}) can be reused'
                            )
                            return connection.get_fs_handle()
                        else:
                            # filesystem id changed beneath us (or the filesystem does not exist).
                            # this is possible if the filesystem got removed (and recreated with
                            # same name) via "ceph fs rm/new" mon command.
                            logger.warning(
                                f'[get] filesystem id changed for volume ({fs_name}), disconnecting ({connection})'
                            )
                            # note -- this will mutate @connections too
                            self._del_connection(fs_name, connection)
                    else:
                        if connection.ops_in_progress < min_shared:
                            min_shared = connection.ops_in_progress
                            shared_connection = connection
                # when we end up here, there are no "free" connections. so either spin up a new
                # one or share it.
                if len(connections
                       ) < CephfsConnectionPool.MAX_CONCURRENT_CONNECTIONS:
                    logger.debug(
                        '[get] spawning new connection since no connection is unused and we still have room for more'
                    )
                    connection = CephfsConnectionPool.Connection(
                        self.mgr, fs_name)
                    connection.connect()
                    self.connections[fs_name].append(connection)
                    return connection.get_fs_handle()
                else:
                    assert shared_connection is not None
                    logger.debug(
                        f'[get] using shared connection ({shared_connection})')
                    return shared_connection.get_fs_handle()
            except cephfs.Error as e:
                # try to provide a better error string if possible
                if e.args[0] == errno.ENOENT:
                    raise CephfsConnectionException(
                        -errno.ENOENT, "FS '{0}' not found".format(fs_name))
                raise CephfsConnectionException(-e.args[0], e.args[1])

    def put_fs_handle(self, fs_name: str, fs_handle: cephfs.LibCephFS) -> None:
        with self.lock:
            connections = self.connections.get(fs_name, [])
            for connection in connections:
                if connection.fs == fs_handle:
                    logger.debug(
                        f'[put] connection: {connection} usage: {connection.ops_in_progress}'
                    )
                    connection.put_fs_handle(
                        notify=lambda: self.cond.notifyAll())

    def _del_connection(self,
                        fs_name: str,
                        connection: Connection,
                        wait: bool = False) -> None:
        self.connections[fs_name].remove(connection)
        connection.del_fs_handle(
            waiter=None if not wait else lambda: self.cond.wait())

    def _del_connections(self, fs_name: str, wait: bool = False) -> None:
        for connection in list(self.connections.get(fs_name, [])):
            self._del_connection(fs_name, connection, wait)

    def del_connections(self, fs_name: str, wait: bool = False) -> None:
        with self.lock:
            self._del_connections(fs_name, wait)

    def del_all_connections(self) -> None:
        with self.lock:
            for fs_name in list(self.connections.keys()):
                logger.info("waiting for pending ops for '{}'".format(fs_name))
                self._del_connections(fs_name, wait=True)
                logger.info("pending ops completed for '{}'".format(fs_name))
            # no new connections should have been initialized since its
            # guarded on shutdown.
            assert len(self.connections) == 0
示例#37
0
class ContinuousPagingSession(object):
    def __init__(self, stream_id, decoder, row_factory, connection):
        self.stream_id = stream_id
        self.decoder = decoder
        self.row_factory = row_factory
        self.connection = connection
        self._condition = Condition()
        self._stop = False
        self._page_queue = deque()

    def on_message(self, result):
        if isinstance(result, ResultMessage):
            self.on_page(result)
        elif isinstance(result, ErrorMessage):
            self.on_error(result)

    def on_page(self, result):
        with self._condition:
            self._page_queue.appendleft(
                (result.column_names, result.parsed_rows, None))
            self._stop |= result.continuous_paging_last
            self._condition.notify()

        if result.continuous_paging_last:
            self.connection.remove_continuous_paging_session(self.stream_id)

    def on_error(self, error):
        with self._condition:
            self._page_queue.appendleft((None, None, error.to_exception()))
            self._stop = True
            self._condition.notify()

        self.connection.remove_continuous_paging_session(self.stream_id)

    def results(self):
        with self._condition:
            while True:
                while not self._page_queue and not self._stop:
                    # TODO: need to timeout here somehow
                    self._condition.wait()
                while self._page_queue:
                    names, rows, err = self._page_queue.pop()
                    if err:
                        raise err
                    self._condition.release()
                    for row in self.row_factory(names, rows):
                        yield row
                    self._condition.acquire()
                if self._stop:
                    break

    def cancel(self):
        log.debug("Canceling paging session %s from %s", self.stream_id,
                  self.connection.host)
        self.connection.send_msg(
            CancelMessage(CONTINUOUS_PAGING_OP_TYPE, self.stream_id),
            self.connection.get_request_id(), self._on_cancel_response)
        with self._condition:
            self._stop = True
            self._condition.notify()

    def _on_cancel_response(self, response):
        if isinstance(response, ResultMessage):
            log.debug("Paging session %s canceled.", self.stream_id)
        else:
            log.error(
                "Failed canceling streaming session %s from %s: %s",
                self.stream_id, self.connection.host,
                response.to_exception()
                if hasattr(response, 'to_exception') else response)
        self.connection.remove_continuous_paging_session(self.stream_id)
示例#38
0
class FileWatcher(Thread):
    """Looks for new files, and queues them.
    """
    def __init__(self, filename_template, file_queue, frequency):
        """Looks for new files arriving at the given *frequency*, and queues
        them.
        """
        Thread.__init__(self)
        self.queue = file_queue
        self.template = filename_template
        self.frequency = datetime.timedelta(minutes=frequency)
        self.running = True
        self.cond = Condition()

    def terminate(self):
        """Terminate thread.
        """
        self.running = False
        self.cond.acquire()
        self.cond.notify()
        self.cond.release()
        LOG.debug("Termination request received in FileWatcher")

    def wait(self, secs):
        if self.running:
            self.cond.wait(secs)

    def run(self):
        """Run the file watcher.
        """

        filelist = set()
        sleep_time = 8

        while self.running:
            self.cond.acquire()
            if isinstance(self.template, (list, tuple)):
                new_filelist = []
                for template in self.template:
                    new_filelist += glob.glob(template)
                new_filelist = set(new_filelist)
            else:
                new_filelist = set(glob.glob(self.template))
            files_to_process = list(new_filelist - filelist)
            filelist = new_filelist

            files_dict = {}
            for fil in files_to_process:
                files_dict[fil] = os.path.getmtime(fil)

            files_to_process.sort(
                lambda x, y: cmp(files_dict[x], files_dict[y]))

            if len(files_to_process) != 0 and self.running:
                sleep_time = 8
                times = []
                for i in files_to_process:
                    LOG.debug("queueing %s..." % i)
                    self.queue.put(i)
                    times.append(os.stat(i).st_ctime)
                times.sort()

                since_creation = datetime.timedelta(seconds=time.time() -
                                                    times[-1])
                if (self.frequency > since_creation):
                    to_wait = self.frequency - since_creation

                    LOG.info("Waiting at least " + str(to_wait) +
                             " for next file")
                    sleep_time = (to_wait.seconds +
                                  to_wait.microseconds / 1000000.0)
                    self.wait(sleep_time)
                    sleep_time = 8
            elif self.running:
                LOG.info("no new file has come, waiting %s secs" %
                         str(sleep_time))
                self.wait(sleep_time)
                if sleep_time < 60:
                    sleep_time *= 2

            self.cond.release()
        LOG.info("FileWatcher terminated.")
示例#39
0
class RDMTestThread(Thread):
    """The RDMResponder tests are closely coupled to the Wrapper (yuck!). So we
     need to run this all in a separate thread. This is all a bit of a hack and
     you'll get into trouble if multiple things are running at once...
  """
    RUNNING, COMPLETED, ERROR = range(3)
    TESTS, COLLECTOR = range(2)

    def __init__(self, pid_store, logs_directory):
        super(RDMTestThread, self).__init__()
        self._pid_store = pid_store
        self._logs_directory = logs_directory
        self._terminate = False
        self._request = None
        # Guards _terminate and _request
        self._cv = Condition()
        self._wrapper = None
        self._test_state_lock = Lock()  # Guards _test_state
        self._test_state = {}

    def Stop(self):
        self._cv.acquire()
        self._terminate = True
        self._cv.notify()
        self._cv.release()

    def ScheduleTests(self, universe, uid, test_filter, broadcast_write_delay,
                      inter_test_delay, dmx_frame_rate, slot_count):
        """Schedule the tests to be run. Callable from any thread. Callbable by any
       thread.

    Returns:
      An error message, or None if the tests were scheduled.
    """
        if not self._CheckIfConnected():
            return 'Lost connection to OLAD'

        self._cv.acquire()
        if self._request is not None:
            self._cv.release()
            return 'Existing request pending'

        self._request = lambda: self._RunTests(
            universe, uid, test_filter, broadcast_write_delay,
            inter_test_delay, dmx_frame_rate, slot_count)
        self._cv.notify()
        self._cv.release()
        return None

    def ScheduleCollector(self, universe, skip_queued_messages):
        """Schedule the collector to run on a universe. Callable by any thread.

    Returns:
      An error message, or None if the collection was scheduled.
    """
        if not self._CheckIfConnected():
            return 'Lost connection to OLAD'

        self._cv.acquire()
        if self._request is not None:
            self._cv.release()
            return 'Existing request pending'

        self._request = lambda: self._RunCollector(universe,
                                                   skip_queued_messages)
        self._cv.notify()
        self._cv.release()
        return None

    def Stat(self):
        """Check the state of the tests. Callable by any thread.

    Returns:
      The status of the tests.
    """
        self._test_state_lock.acquire()
        state = dict(self._test_state)
        self._test_state_lock.release()
        return state

    def run(self):
        self._wrapper = ClientWrapper()
        self._collector = ModelCollector(self._wrapper, self._pid_store)
        while True:
            self._cv.acquire()
            if self._terminate:
                logging.info('quitting test thread')
                self._cv.release()
                return

            if self._request is not None:
                request = self._request
                self._request = None
                self._cv.release()
                request()
                continue
            # Nothing to do, go into the wait
            self._cv.wait()
            self._cv.release()

    def _UpdateStats(self, tests_completed, total_tests):
        self._test_state_lock.acquire()
        self._test_state['tests_completed'] = tests_completed
        self._test_state['total_tests'] = total_tests
        self._test_state_lock.release()

    def _RunTests(self, universe, uid, test_filter, broadcast_write_delay,
                  inter_test_delay, dmx_frame_rate, slot_count):
        self._test_state_lock.acquire()
        self._test_state = {
            'action': self.TESTS,
            'tests_completed': 0,
            'total_tests': None,
            'state': self.RUNNING,
            'duration': 0,
        }
        start_time = datetime.now()
        self._test_state_lock.release()

        runner = TestRunner.TestRunner(universe, uid, broadcast_write_delay,
                                       inter_test_delay, self._pid_store,
                                       self._wrapper)

        for test in TestRunner.GetTestClasses(TestDefinitions):
            runner.RegisterTest(test)

        dmx_sender = None
        if dmx_frame_rate > 0 and slot_count > 0:
            logging.info(
                'Starting DMXSender with slot count %d and FPS of %d' %
                (slot_count, dmx_frame_rate))
            dmx_sender = DMXSender(self._wrapper, universe, dmx_frame_rate,
                                   slot_count)

        try:
            tests, device = runner.RunTests(test_filter, False,
                                            self._UpdateStats)
        except Exception as e:
            self._test_state_lock.acquire()
            self._test_state['state'] = self.ERROR
            self._test_state['exception'] = str(e)
            self._test_state['traceback'] = traceback.format_exc()
            self._test_state_lock.release()
            return
        finally:
            if dmx_sender is not None:
                dmx_sender.Stop()

        timestamp = int(time())
        end_time = datetime.now()
        test_parameters = {
            'broadcast_write_delay': broadcast_write_delay,
            'inter_test_delay': inter_test_delay,
            'dmx_frame_rate': dmx_frame_rate,
            'dmx_slot_count': slot_count,
        }
        log_saver = TestLogger.TestLogger(self._logs_directory)
        logs_saved = True
        try:
            log_saver.SaveLog(uid, timestamp, end_time, tests, device,
                              test_parameters)
        except TestLogger.TestLoggerException:
            logs_saved = False

        self._test_state_lock.acquire()
        # We can't use total_seconds() since it requires Python 2.7
        time_delta = end_time - start_time
        self._test_state['duration'] = (time_delta.seconds +
                                        time_delta.days * 24 * 3600)
        self._test_state['state'] = self.COMPLETED
        self._test_state['tests'] = tests
        self._test_state['logs_saved'] = logs_saved
        self._test_state['timestamp'] = timestamp
        self._test_state['uid'] = uid
        self._test_state_lock.release()

    def _RunCollector(self, universe, skip_queued_messages):
        """Run the device model collector for a universe."""
        logging.info('Collecting for universe %d' % universe)
        self._test_state_lock.acquire()
        self._test_state = {
            'action': self.COLLECTOR,
            'state': self.RUNNING,
        }
        self._test_state_lock.release()

        try:
            output = self._collector.Run(universe, skip_queued_messages)
        except Exception as e:
            self._test_state_lock.acquire()
            self._test_state['state'] = self.ERROR
            self._test_state['exception'] = str(e)
            self._test_state['traceback'] = traceback.format_exc()
            self._test_state_lock.release()
            return

        self._test_state_lock.acquire()
        self._test_state['state'] = self.COMPLETED
        self._test_state['output'] = output
        self._test_state_lock.release()

    def _CheckIfConnected(self):
        """Check if the client is connected to olad.

    Returns:
      True if connected, False otherwise.
    """
        # TODO(simon): add this check, remember it needs locking.
        return True
示例#40
0
class OrderedQueueDispatcherPool:
    """
    A thread pool that dispatches messages to a list of receivers.

    The number of threads is usually smaller than the number of receivers and
    for each receiver it is guaranteed that messages arrive in the order they
    were published. No guarantees are given between different receivers.  All
    methods except #start and #stop are reentrant.

    The pool can be stopped and restarted at any time during the processing but
    these calls must be single-threaded.

    Assumptions:
     - same subscriptions for multiple receivers unlikely, hence filtering done
       per receiver thread

    .. codeauthor:: jwienke
    """
    class _Receiver:
        def __init__(self, receiver):
            self.receiver = receiver
            self.queue = Queue()
            self.processing = False
            self.processing_mutex = Lock()
            self.processing_condition = Condition()

    def _true_filter(self, receiver, message):
        return True

    def __init__(self, thread_pool_size, del_func, filter_func=None):
        """
        Construct a new pool.

        Args:
            thread_pool_size (int >= 1):
                number of threads for this pool
            del_func (callable):
                the strategy used to deliver messages of type M to receivers of
                type R. This will most likely be a simple delegate function
                mapping to a concrete method call.  Must be reentrant. callable
                with two arguments. First is the receiver of a message, second
                is the message to deliver
            filter_func (callable):
                Reentrant function used to filter messages per receiver.
                Default accepts every message. callable with two arguments.
                First is the receiver of a message, second is the message to
                filter. Must return a bool, true means to deliver the message,
                false rejects it.
        """

        self._logger = get_logger_by_class(self.__class__)

        if thread_pool_size < 1:
            raise ValueError("Thread pool size must be at least 1,"
                             "{} was given.".format(thread_pool_size))
        self._thread_pool_size = int(thread_pool_size)

        self._del_func = del_func
        if filter_func is not None:
            self._filter_func = filter_func
        else:
            self._filter_func = self._true_filter

        self._condition = Condition()
        self._receivers = []

        self._jobsAvailable = False

        self._started = False
        self._interrupted = False

        self._threadPool = []

        self._currentPosition = 0

    def __del__(self):
        self.stop()

    def register_receiver(self, receiver):
        """
        Register a new receiver at the pool.

        Multiple registrations of the same receiver are possible resulting in
        being called multiple times for the same message (but effectively this
        destroys the guarantee about ordering given above because multiple
        message queues are used for every subscription).

        Args:
            receiver:
                new receiver
        """

        with self._condition:
            self._receivers.append(self._Receiver(receiver))

        self._logger.info("Registered receiver %s", receiver)

    def unregister_receiver(self, receiver):
        """
        Unregister all registrations of one receiver.

        Args:
            receiver:
                receiver to unregister

        Returns:
            True if one or more receivers were unregistered, else False
        """

        removed = None
        with self._condition:
            kept = []
            for r in self._receivers:
                if r.receiver == receiver:
                    removed = r
                else:
                    kept.append(r)
            self._receivers = kept
        if removed:
            with removed.processing_condition:
                while removed.processing:
                    self._logger.info("Waiting for receiver %s to finish",
                                      receiver)
                    removed.processing_condition.wait()
        return not (removed is None)

    def push(self, message):
        """
        Push a new message to be dispatched to all receivers in this pool.

        Args:
            message:
                message to dispatch
        """

        with self._condition:
            for receiver in self._receivers:
                receiver.queue.put(message)
            self._jobsAvailable = True
            self._condition.notify()

        # XXX: This is disabled because it can trigger this bug for protocol
        # buffers payloads:
        # http://code.google.com/p/protobuf/issues/detail?id=454
        # See also #1331
        # self._logger.debug("Got new message to dispatch: %s", message)

    def _next_job(self, worker_num):
        """
        Return the next job to process for worker threads.

        Blocks if there is no job.

        Args:
            worker_num:
                number of the worker requesting a new job

        Returns:
            the receiver to work on
        """

        receiver = None
        with self._condition:

            got_job = False
            while not got_job:

                while (not self._jobsAvailable) and (not self._interrupted):
                    self._logger.debug("Worker %d: no jobs available, waiting",
                                       worker_num)
                    self._condition.wait()

                if (self._interrupted):
                    raise _InterruptedError("Processing was interrupted")

                # search the next job
                for _ in range(len(self._receivers)):

                    self._currentPosition = self._currentPosition + 1
                    real_pos = self._currentPosition % len(self._receivers)

                    if (not self._receivers[real_pos].processing) and \
                            (not self._receivers[real_pos].queue.empty()):

                        receiver = self._receivers[real_pos]
                        receiver.processing = True
                        got_job = True
                        break

                if not got_job:
                    self._jobsAvailable = False

            self._condition.notify()
            return receiver

    def _finished_work(self, receiver, worker_num):

        with self._condition:

            with receiver.processing_condition:
                receiver.processing = False
                receiver.processing_condition.notifyAll()
            if not receiver.queue.empty():
                self._jobsAvailable = True
                self._logger.debug(
                    "Worker %d: new jobs available, "
                    "notifying one", worker_num)
                self._condition.notify()

    def _worker(self, worker_num):
        """
        Threaded worker method.

        Args:
            worker_num:
                number of this worker thread
        """

        try:

            while True:

                receiver = self._next_job(worker_num)
                message = receiver.queue.get(True, None)
                self._logger.debug("Worker %d: got message %s for receiver %s",
                                   worker_num, message, receiver.receiver)
                if self._filter_func(receiver.receiver, message):
                    self._logger.debug(
                        "Worker %d: delivering message %s for receiver %s",
                        worker_num, message, receiver.receiver)
                    self._del_func(receiver.receiver, message)
                    self._logger.debug(
                        "Worker %d: delivery for receiver %s finished",
                        worker_num, receiver.receiver)
                self._finished_work(receiver, worker_num)

        except _InterruptedError:
            pass

    def start(self):
        """
        Start processing and return immediately.

        Raises:
            RuntimeError:
                if the pool was already started and is running
        """

        with self._condition:

            if self._started:
                raise RuntimeError("Pool already running")

            self._interrupted = False

            for i in range(self._thread_pool_size):
                worker = Thread(target=self._worker, args=[i])
                worker.setDaemon(True)
                worker.start()
                self._threadPool.append(worker)

            self._started = True

        self._logger.info("Started pool with %d threads",
                          self._thread_pool_size)

    def stop(self):
        """Block until every thread has stopped working."""

        self._logger.info("Starting to stop thread pool by wating for workers")

        with self._condition:
            self._interrupted = True
            self._condition.notifyAll()

        for worker in self._threadPool:
            self._logger.debug("Joining worker %s", worker)
            worker.join()

        self._threadPool = []

        self._started = False

        self._logger.info("Stopped thread pool")
示例#41
0
class IPAllocatorDHCP(IPAllocator):
    def __init__(self,
                 assigned_ip_blocks: Set[ip_network],
                 ip_state_map: IpDescriptorMap,
                 dhcp_store: MutableMapping[str, DHCPDescriptor],
                 gw_info: UplinkGatewayInfo,
                 retry_limit: int = 300,
                 iface: str = "dhcp0"):
        """
        Allocate IP address for SID using DHCP server.
        SID is mapped to MAC address using function defined in mac.py
        then this mac address used in DHCP request to allocate new IP
        from DHCP server.
        This IP is also cached to improve performance in case of
        reallocation for same SID in short period of time.

        Args:
            assigned_ip_blocks: set of IP blocks, populated from DHCP.
            ip_state_map: maintains state of IP allocation to UE.
            dhcp_store: maintains DHCP transaction for each active MAC address
            gw_info_map: maintains uplink GW info
            retry_limit: try DHCP request
            iface: DHCP interface.
        """
        self._ip_state_map = ip_state_map  # {state=>{ip=>ip_desc}}
        self._assigned_ip_blocks = assigned_ip_blocks
        self.dhcp_wait = Condition()
        self._dhcp_client = DHCPClient(dhcp_wait=self.dhcp_wait,
                                       dhcp_store=dhcp_store,
                                       gw_info=gw_info,
                                       iface=iface)
        self._retry_limit = retry_limit  # default wait for two minutes
        self._dhcp_client.run()

    def add_ip_block(self, ipblock: ip_network):
        logging.warning("No need to allocate block for DHCP allocator: %s",
                        ipblock)

    def remove_ip_blocks(self,
                         *ipblocks: List[ip_network],
                         _force: bool = False) -> List[ip_network]:
        logging.warning("trying to delete ipblock from DHCP allocator: %s",
                        ipblocks)
        return []

    def list_added_ip_blocks(self) -> List[ip_network]:
        return list(deepcopy(self._assigned_ip_blocks))

    def list_allocated_ips(self, ipblock: ip_network) -> List[ip_address]:
        """ List IP addresses allocated from a given IP block

        Args:
            ipblock (ipaddress.ip_network): ip network to add
            e.g. ipaddress.ip_network("10.0.0.0/24")

        Return:
            list of IP addresses (ipaddress.ip_address)

        """
        return [
            ip for ip in self._ip_state_map.list_ips(IPState.ALLOCATED)
            if ip in ipblock
        ]

    def alloc_ip_address(self, sid: str) -> IPDesc:
        """
        Assumption: one-to-one mappings between SID and IP.

        Args:
            sid (string): universal subscriber id

        Returns:
            ipaddress.ip_address: IP address allocated

        Raises:
            NoAvailableIPError: if run out of available IP addresses
        """
        mac = create_mac_from_sid(sid)
        LOG.debug("allocate IP for %s mac %s", sid, mac)

        dhcp_desc = self._dhcp_client.get_dhcp_desc(mac)
        LOG.debug("got IP from redis: %s", dhcp_desc)

        if dhcp_allocated_ip(dhcp_desc) is not True:
            dhcp_desc = self._alloc_ip_address_from_dhcp(mac)

        if dhcp_allocated_ip(dhcp_desc):
            ip_block = ip_network(dhcp_desc.subnet)
            ip_desc = IPDesc(ip=ip_address(dhcp_desc.ip),
                             state=IPState.ALLOCATED,
                             sid=sid,
                             ip_block=ip_block,
                             ip_type=IPType.DHCP)
            LOG.debug("Got IP after sending DHCP requests: %s", ip_desc)
            self._assigned_ip_blocks.add(ip_block)

            return ip_desc
        else:
            raise NoAvailableIPError("No available IP addresses From DHCP")

    def release_ip(self, ip_desc: IPDesc):
        """
        Release IP address, this involves following steps.
        1. send DHCP protocol packet to release the IP.
        2. update IP block list.
        3. update IP from ip-state.

        Args:
            ip_desc, release needs following info from IPDesc.
            sid: SID, used to get mac address.
            ip: IP assigned to this SID
            ip_block: IP block of the IP address.

        Returns: None
        """
        self._dhcp_client.release_ip_address(create_mac_from_sid(ip_desc.sid))
        # Remove the IP from free IP list, since DHCP is the
        # owner of this IP
        self._ip_state_map.remove_ip_from_state(ip_desc.ip, IPState.FREE)

        list_allocated_ips = self._ip_state_map.list_ips(IPState.ALLOCATED)
        for ipaddr in list_allocated_ips:
            if ipaddr in ip_desc.ip_block:
                # found the IP, do not remove this ip_block
                return

        ip_block_network = ip_network(ip_desc.ip_block)
        if ip_block_network in self._assigned_ip_blocks:
            self._assigned_ip_blocks.remove(ip_block_network)
        logging.debug("del: _assigned_ip_blocks %s ipblock %s",
                      self._assigned_ip_blocks, ip_desc.ip_block)

    def stop_dhcp_sniffer(self):
        self._dhcp_client.stop()

    def _alloc_ip_address_from_dhcp(self, mac: MacAddress) -> DHCPDescriptor:
        retry_count = 0
        with self.dhcp_wait:
            dhcp_desc = None
            while (retry_count < self._retry_limit
                   and dhcp_allocated_ip(dhcp_desc) is not True):

                if retry_count % DEFAULT_DHCP_REQUEST_RETRY_FREQUENCY == 0:
                    self._dhcp_client.send_dhcp_packet(mac, DHCPState.DISCOVER)
                self.dhcp_wait.wait(timeout=DEFAULT_DHCP_REQUEST_RETRY_DELAY)

                dhcp_desc = self._dhcp_client.get_dhcp_desc(mac)

                retry_count = retry_count + 1

            return dhcp_desc
示例#42
0
class TileAgent:
    ST_IDLE = 0
    ST_RUN = 1
    ST_PAUSE = 2
    ST_CLOSING = 3

    TILE_VALID = 0x00
    TILE_NOT_IN_MEM = 0x01
    TILE_NOT_IN_DISK = 0x02
    TILE_EXPIRE = 0x03
    TILE_REQ = 0x10
    TILE_REQ_FAILED = 0x20

    #properties from map_desc
    @property
    def map_id(self):
        return self.__map_desc.map_id

    @property
    def map_title(self):
        return self.__map_desc.map_title

    @property
    def level_min(self):
        return self.__map_desc.level_min

    @property
    def level_max(self):
        return self.__map_desc.level_max

    @property
    def url_template(self):
        return self.__map_desc.url_template

    @property
    def server_parts(self):
        return self.__map_desc.server_parts

    @property
    def invert_y(self):
        return self.__map_desc.invert_y

    @property
    def coord_sys(self):
        return self.__map_desc.coord_sys

    @property
    def lower_corner(self):
        return self.__map_desc.lower_corner

    @property
    def upper_corner(self):
        return self.__map_desc.upper_corner

    @property
    def expire_sec(self):
        return self.__map_desc.expire_sec

    @property
    def tile_format(self):
        return self.__map_desc.tile_format

    @property
    def state(self):
        return self.__state

    def __init__(self, map_desc, cache_dir, auto_start=False):
        self.__map_desc = map_desc.clone()

        self.__state = self.ST_IDLE

        #local cache
        self.__cache_dir = cache_dir
        self.__disk_cache = None

        #memory cache
        self.__mem_cache = MemoryCache(self.TILE_NOT_IN_MEM,
                                       is_concurrency=True)

        #download helpers
        self.__MAX_WORKS = 3
        self.__download_lock = Lock()
        self.__download_cv = Condition(self.__download_lock)
        self.__workers = {}
        #self.__workers_lock = Lock()
        #self.__workers_cv = Condition(self.__workers_lock)
        self.__req_queue = OrderedDict()
        #self.__req_lock = Lock()
        #self.__req_cv = Condition(self.__req_lock)
        self.__download_monitor = Thread(target=self.__runDownloadMonitor)

        if auto_start:
            self.start()

    def start(self):
        #create cache dir for the map
        self.__disk_cache = DBDiskCache(self.__cache_dir, self.__map_desc,
                                        conf.DB_SCHEMA)
        self.__disk_cache.start()
        #start download thread
        self.__state = self.ST_RUN
        self.__download_monitor.start()

    def close(self):
        #notify download monitor to exit
        with self.__download_cv:
            self.__state = self.ST_CLOSING
            self.__download_cv.notify()
        self.__download_monitor.join()

        #close resources
        if self.__disk_cache is not None:
            self.__disk_cache.close()

    def pause(self):
        with self.__download_cv:
            if self.__state == self.ST_RUN:
                self.__state = self.ST_PAUSE
                logging.debug("[%s] Change status from run to pause" %
                              (self.map_id, ))
                self.__download_cv.notify()

    def resume(self):
        with self.__download_cv:
            if self.__state == self.ST_PAUSE:
                self.__state = self.ST_RUN
                logging.debug("[%s] Change status from pasue to run" %
                              (self.map_id, ))
                self.__download_cv.notify()

    def isSupportedLevel(self, level):
        return self.level_min <= level and level <= self.level_max

    def getCachePath(self):
        return os.path.join(self.__cache_dir, self.map_id)

    def genTileId(self, level, x, y):
        return "%s-%d-%d-%d" % (self.map_id, level, x, y)

    @classmethod
    def flipY(cls, y, level):
        return (1 << level) - 1 - y

    def genTileUrl(self, level, x, y):
        if self.invert_y:
            y = self.flipY(y, level)

        url = self.url_template

        if self.server_parts:
            url = url.replace("{$serverpart}",
                              random.choice(self.server_parts))
        url = url.replace("{$x}", str(x))
        url = url.replace("{$y}", str(y))
        url = url.replace("{$z}", str(level))
        #logging.critical('url: ' + url)
        return url

    def __downloadTile(self, id, req):
        level, x, y, status, cb = req  #unpack the req

        tile_data = None
        try:
            url = self.genTileUrl(level, x, y)
            logging.info("[%s] DL %s" % (self.map_id, url))
            with urllib.request.urlopen(url, timeout=30) as response:
                tile_data = response.read()
            logging.info('[%s] DL %s [OK]' % (self.map_id, url))
        except Exception as ex:
            logging.warning('[%s] DL %s [FAILED][%s]' %
                            (self.map_id, url, str(ex)))

        #as failed, and not save to memory/disk
        if self.__state == self.ST_CLOSING:
            return None

        if tile_data is None:
            #save to memory
            status = self.TILE_REQ_FAILED | (status & 0x0F)
            self.__mem_cache.set(id, status)
            return None

        #get tile_img, and save to memory
        tile_img = None
        try:
            tile_img = Image.open(BytesIO(tile_data))
            self.__mem_cache.set(id, self.TILE_VALID, tile_img)
        except Exception as ex:
            logging.error("[%s] Error to open tile data: %s" %
                          (self.map_id, str(ex)))
            return None

        #save tile_data to disk
        try:
            self.__disk_cache.put(level, x, y, tile_data)
        except Exception as ex:
            logging.error("[%s] Error to save tile data: %s" %
                          (self.map_id, str(ex)))

        return tile_img

    #The therad to download
    def __runDownloadJob(self, id, req):

        #do download
        tile_img = self.__downloadTile(id, req)

        #the download is done
        #(do this before thread exit, to ensure monitor is notified)
        with self.__download_cv:
            self.__workers.pop(id, None)
            self.__download_cv.notify()
            #premature done
            if self.__state == self.ST_CLOSING:
                return

        #invoke cb. cb may be blocking, so do this AFTER removing the thread from __workers
        if tile_img is not None:
            level, x, y, status, cb = req  #unpack the req
            if cb is not None:
                tile_info = (self.map_id, level, x, y)
                try:
                    cb(tile_info)
                except Exception as ex:
                    logging.warning(
                        "[%s] Invoke cb of download tile error: %s" %
                        (self.map_id, str(ex)))

    #The thread to handle all download requests
    def __runDownloadMonitor(self):
        def no_worker():
            return len(self.__workers) == 0

        #http_handler = urllib.request.HTTPHandler()
        #opener = urllib.request.build_opener(http_handler)
        #urllib.request.install_opener(opener)

        while True:
            #wait for requests
            with self.__download_cv:
                self.__download_cv.wait()

                if self.__state == self.ST_CLOSING:
                    logging.debug(
                        "[%s] status(closing), download monitor closing" %
                        (self.map_id, ))
                    break
                elif self.__state == self.ST_PAUSE:
                    logging.debug("[%s] status(pause), continue to wait" %
                                  (self.map_id, ))
                    continue

                if len(self.__req_queue) > 0 and len(
                        self.__workers) < self.__MAX_WORKS:
                    #the req
                    id, req = self.__req_queue.popitem()  #LIFO
                    if id in self.__workers:
                        logging.warning(
                            "[%s] Opps! the req is DUP and in progress." %
                            (self.map_id, ))  #should not happen
                    else:
                        #create the job and run the worker
                        job = lambda: self.__runDownloadJob(id, req)
                        worker = Thread(name=id, target=job)
                        self.__workers[id] = worker
                        worker.start()

        #todo: interrupt urllib.request.openurl to stop download workers.
        #http_handler.close()
        with self.__download_cv:
            #self.__download_cv.wait_for(no_worker)
            self.__state = self.ST_IDLE
            logging.debug("[%s] status(idle), download monitor closed" %
                          (self.map_id, ))

    def __requestTile(self, id, req):
        #check and add to req queue
        with self.__download_cv:
            if id in self.__req_queue:
                return
            if id in self.__workers:
                return
            #add the req
            self.__req_queue[id] = req
            self.__download_cv.notify()

    def __getTileFromDisk(self, level, x, y):
        try:
            data, ts = self.__disk_cache.get(level, x, y)
            if data is not None:
                img = Image.open(BytesIO(data))
                return img, ts
        except Exception as ex:
            logging.warning("[%s] Error to read tile data: %s" %
                            (self.map_id, str(ex)))
        return None, None

    def __getTile(self, level, x, y, req_type=None, cb=None):
        #check level
        if level > self.level_max or level < self.level_min:
            raise ValueError("level is out of range")

        id = self.genTileId(level, x, y)
        img, status, ts = self.__mem_cache.get(id)  #READ FROM memory
        status_bak = status

        if status == self.TILE_VALID:
            return img

        if (status & 0xF0) == self.TILE_REQ:
            return img  # None or Expire

        if (status & 0xF0) == self.TILE_REQ_FAILED:
            if (time.time() - ts) < 60:  #todo: user to specify retry period
                return None
            status &= 0x0F  #remove req_failed status

        if status == self.TILE_NOT_IN_MEM:
            img, ts = self.__getTileFromDisk(level, x, y)  #READ FROM disk
            if img is None:
                status = self.TILE_NOT_IN_DISK
            elif ts and self.expire_sec and (time.time() -
                                             ts) > self.expire_sec:
                status = self.TILE_EXPIRE
            else:
                self.__mem_cache.set(id, self.TILE_VALID, img)
                return img

        #check status, should be 'not in disk' or 'expire'
        if status == self.TILE_NOT_IN_DISK:
            pass
        elif status == self.TILE_EXPIRE:
            logging.warning("[%s] Tile(%d,%d,%d) is expired" %
                            (self.map_id, level, x, y))
        else:
            logging.critical("[%s] Error: unexpected tile status: %d" %
                             (self.map_id, status))
            return None

        #req or not
        if not req_type:
            if status != status_bak:
                self.__mem_cache.set(id, status)
            return img
        else:
            status |= self.TILE_REQ
            self.__mem_cache.set(id, status)
            if req_type == "async":
                self.__requestTile(id, (level, x, y, status, cb))
                return img
            else:  # sync
                return self.__downloadTile(id, (level, x, y, status, None))

    def __genMagnifyFakeTile(self, level, x, y, diff=1):
        side = to_pixel(1, 1)[0]
        for i in range(1, diff + 1):
            scale = 2**i
            img = self.__getTile(level - i, int(x / scale), int(y / scale))
            if img:
                step = int(side / scale)
                px = step * (x % scale)
                py = step * (y % scale)
                img = img.crop((px, py, px + step, py + step))
                img = img.resize((side, side))  #magnify
                return img
        return None

    def __genMinifyFakeTile(self, level, x, y, diff=1):
        bg = 'lightgray'
        side = to_pixel(1, 1)[0]

        for i in range(1, diff + 1):
            scale = 2**i
            img = Image.new("RGBA", (side * scale, side * scale), bg)
            has_tile = False
            #paste tiles
            for p in range(scale):
                for q in range(scale):
                    t = self.__getTile(level + i, x * scale + p, y * scale + q)
                    if t:
                        img.paste(t, (p * side, q * side))
                        has_tile = True
            #minify
            if has_tile:
                img = img.resize((side, side))
                return img
        return None

    #gen fake from lower/higher level
    #return None if not avaliable
    def __genFakeTile(self, level, x, y):
        #gen from lower level
        level_diff = min(level - self.level_min, 3)
        img = self.__genMagnifyFakeTile(level, x, y, level_diff)
        if img:
            return img

        #gen from upper level
        level_diff = min(self.level_max - level, 1)
        img = self.__genMinifyFakeTile(level, x, y, level_diff)
        if img:
            return img

        return None

    # @cb is only for req_type == "async" to nitify the tile is done,
    # which call cb(tile_info), tile_info = (map_id, level, x, y)
    def getTile(self, level, x, y, req_type, cb=None, allow_fake=True):
        img = self.__getTile(level, x, y, req_type, cb)
        if img is not None:
            img.is_fake = False
            return img

        if allow_fake:
            img = self.__genFakeTile(level, x, y)
            if img is not None:
                img.is_fake = True
                return img

        return None
示例#43
0
class ServiceRegistrationInfo(_BaseObject):
    """
    Service Registration instances are used to register and expose services onto a DXL fabric.

    DXL Services are exposed to the DXL fabric and are invoked in a fashion similar to RESTful web services.
    Communication between an invoking client and the DXL service is one-to-one (request/response).

    Each service is identified by the "topics" it responds to. Each of these "topics" can be thought of as
    a method that is being "invoked" on the service by the remote client.

    Multiple service "instances" can be registered with the DXL fabric that respond to the same "topics". When
    this occurs (unless explicitly overridden by the client) the fabric will select the particular instance
    to route the request to (by default round-robin). Multiple service instances can be used to increase
    scalability and fault-tolerance.

    The following demonstrates registering a service that responds to a single topic with the DXL fabric:

    .. code-block:: python

        from dxlclient.callbacks import RequestCallback
        from dxlclient.message import Response
        from dxlclient.service import ServiceRegistrationInfo

        class MyRequestCallback(RequestCallback):
            def on_request(self, request):
                # Extract information from request
                print request.payload.decode()

                # Create the response message
                res = Response(request)

                # Populate the response payload
                res.payload = "pong".encode()

                # Send the response
                dxl_client.send_response(res)

        # Create service registration object
        info = ServiceRegistrationInfo(dxl_client, "/mycompany/myservice")

        # Add a topic for the service to respond to
        info.add_topic("/testservice/testrequesttopic", MyRequestCallback())

        # Register the service with the fabric (wait up to 10 seconds for registration to complete)
        dxl_client.register_service_sync(info, 10)

    **NOTE:** A service is only considered "active" if there are references to the
    :class:`ServiceRegistrationInfo` instance. If no references to the info object exist, it will be
    destructed, and the service will be automatically *unregistered* from the fabric.

    The following demonstrates a client that is invoking the service in the example above:

    .. code-block:: python

        from dxlclient.message import Request, Message

        # Create the request message
        req = Request("/testservice/testrequesttopic")

        # Populate the request payload
        req.payload = "ping".encode()

        # Send the request and wait for a response (synchronous)
        res = dxl_client.sync_request(req)

        # Extract information from the response (if an error did not occur)
        if res.message_type != Message.MESSAGE_TYPE_ERROR:
            print res.payload.decode()
    """
    def __init__(self, client, service_type):
        """
        Constructor parameters:
        
        :param client: The :class:`dxlclient.client.DxlClient` instance that will expose this service
        :param service_type: A textual name for the service. For example, "/mycompany/myservice"
        """
        super(ServiceRegistrationInfo, self).__init__()

        # if not isinstance(channels, list):
        # raise ValueError('Channels should be a list')
        # if not channels:
        #     raise InvalidServiceException('Channel list is empty')
        if not service_type:
            raise ValueError("Undefined service name")

            #The service type or name prefix
        self._service_type = service_type
        # The unique service ID
        self._service_id = UuidGenerator.generate_id_as_string()

        # The map of registered channels and their associated callbacks
        self._callbacks_by_topic = {}
        #The map of meta data associated with this service (name-value pairs)
        self._metadata = {}
        # List of destination tenants
        self._destination_tenant_guids = []

        # The Time-To-Live (TTL) of the service registration (default: 60 minutes)
        self._ttl = 60  # minutes
        # The minimum Time-To-Live (TTL) of the service registration (default: 10 minutes)
        self._ttl_lower_limit = 10

        # Internal client reference
        self._dxl_client = client

        # Registration sync object
        self._registration_sync = Condition()

        # Whether at least one registration has occurred
        self._registration_occurred = False

        # Whether at least one unregistration registration has occurred
        self._unregistration_occurred = False

        self._destroy_lock = RLock()
        self._destroyed = False

    def __del__(self):
        """destructor"""
        super(ServiceRegistrationInfo, self).__del__()
        self._destroy()

    def _destroy(self, unregister=True):
        """
        Destroys the service registration

        :param unregister: Whether to unregister the service from the fabric
        """
        with self._destroy_lock:
            if not self._destroyed:
                if unregister and self._dxl_client:
                    try:
                        self._dxl_client.unregister_service_async(self)
                    except Exception:
                        # Currently ignoring this as it can occur due to the fact that we are
                        # attempting to unregister a service that was never registered
                        pass
                self._dxl_client = None
                self._destroyed = True

    @property
    def service_type(self):
        """
        A textual name for the service. For example, "/mycompany/myservice"
        """
        return self._service_type

    @property
    def service_id(self):
        """
        A unique identifier for the service instance (automatically generated when the :class:`ServiceRegistrationInfo`
        object is constructed)
        """
        return self._service_id

    @property
    def metadata(self):
        """
        A dictionary of name-value pairs that are sent as part of the service registration. Brokers provide
        a registry service that allows for registered services and their associated meta-information to be inspected.
        The metadata is typically used to include information such as the versions for products that are
        exposing DXL services, etc.
        """
        return self._metadata

    @metadata.setter
    def metadata(self, metadata):
        self._metadata = metadata

    @property
    def ttl(self):
        """
        The interval (in minutes) at which the client will automatically re-register the service with the
        DXL fabric (defaults to 60 minutes)
        """
        return self._ttl

    @ttl.setter
    def ttl(self, ttl):
        self._ttl = ttl

    @property
    def topics(self):
        """
        Returns a tuple containing the topics that the service responds to
        """
        return tuple(self._callbacks_by_topic.keys())

    def add_topic(self, topic, callback):
        """
        Registers a topic for the service to respond to along with the :class:`dxlclient.callbacks.RequestCallback`
        that will be invoked.
        
        :param topic:  The topic for the service to respond to
        :param callback: The :class:`dxlclient.callbacks.RequestCallback` that will be invoked when a
            :class:`dxlclient.message.Request` message is received
        """
        # TODO: use dictionary get method
        try:
            callbacks = self._callbacks_by_topic[topic]
        except KeyError:
            callbacks = set()
            self._callbacks_by_topic[topic] = callbacks
        finally:
            callbacks.add(callback)

    def add_topics(self, callbacks_by_topic):
        """
        Registers a set of topics for the service to respond to along with their associated
        :class:`dxlclient.callbacks.RequestCallback` instances as a dictionary

        :param callbacks_by_topic: Dictionary containing a set of topics for the service to respond to along with
            their associated :class:`dxlclient.callbacks.RequestCallback` instances
        """
        if not isinstance(callbacks_by_topic, dict):
            raise ValueError("Channel and callback should be a dictionary")
        if not callbacks_by_topic:
            raise ValueError("Undefined channel")
        for channel, callback in callbacks_by_topic.iteritems():
            self.add_topic(channel, callback)

    @property
    def destination_tenant_guids(self):
        """
        The set of tenant identifiers that the service will be available to. Setting this value will limit
        which tenants can invoke the service.
        """
        return self._destination_tenant_guids

    @destination_tenant_guids.setter
    def destination_tenant_guids(self, tenant_guids=None):
        if tenant_guids is None:
            tenant_guids = []
        self._destination_tenant_guids = tenant_guids

    def _wait_for_registration_notification(self, wait_time, is_register):
        """
        Waits for a registration notification (register or unregister).
        
        :param waitTime:   The amount of time to wait.
        :param isRegister Whether we are waiting for a register or unregister notification.
        :return: None.
        """
        with self._registration_sync:
            if wait_time > 0:
                self._registration_sync.wait(wait_time)
            else:
                raise DxlException(
                    "Timeout waiting for service related notification")

    def _wait_for_registration(self, timeout):
        """
        Waits for the service to be registered with the broker for the first time.
        
        :param timeout: The amount of time to wait for the registration to occur.
        :return: None.
        """
        with self._registration_sync:
            end_time = int(time.time()) + timeout
            while not self._registration_occurred:
                self._wait_for_registration_notification(
                    end_time - int(time.time()), True)

    def _notify_registration_succeeded(self):
        """
        Invoked when the service has been successfully registered with a broker.
        
        :return: None.
        """
        with self._registration_sync:
            self._registration_occurred = True
            self._unregistration_occurred = False
            self._registration_sync.notify_all()

    # @synchronized
    def _wait_for_unregistration(self, timeout):
        """
        Waits for the service to be unregistered with the broker for the first time.
         
        :param timeout: The amount of time to wait for the unregistration to occur.
        :return: None.
        """
        end_time = int(time.time()) + timeout
        with self._registration_sync:
            while not self._unregistration_occurred:
                self._wait_for_registration_notification(
                    end_time - int(time.time()), False)

    #@synchronized
    def _notify_unregistration_succeeded(self):
        """
        Invoked when the service has been successfully unregistered with a broker.
         
        :return: None.
        """
        with self._registration_sync:
            self._registration_occurred = False
            self._unregistration_occurred = True
            self._registration_sync.notify_all()
示例#44
0
class SecondaryMonitor(Thread):
    """
    Monitor data from secondary port and send programs to robot
    """
    def __init__(self, host):
        Thread.__init__(self)
        self.logger = logging.getLogger("ursecmon")
        self._parser = ParserUtils()
        self._dict = {}
        self._dictLock = Lock()
        self.host = host
        self.connect()
        self._get_version()
        self._prog_queue = []
        self._prog_queue_lock = Lock()
        self._dataqueue = bytes()
        self._trystop = False  # to stop thread
        self.running = False  # True when robot is on and listening
        self._dataEvent = Condition()
        self.lastpacket_timestamp = 0

        self.start()
        try:
            self.wait()  # make sure we got some data before someone calls us
        except Exception as ex:
            self.close()
            raise ex

    def _get_version(self):
        tmp = self._s_secondary.recv(1024)
        tmpdict = self._parser.parse(tmp)
        self._dict = tmpdict
        if "VersionMessage" in self._dict:
            self._parser.version = (
                self._dict["VersionMessage"]['majorVersion'],
                self._dict["VersionMessage"]['minorVersion'])
        else:
            self._parser.version = (0, 0)

    def connect(self):
        secondary_port = 30002  # Secondary client interface on Universal Robots
        self._s_secondary = socket.create_connection(
            (self.host, secondary_port), timeout=2.0)

    def send_program(self, prog):
        """
        send program to robot in URRobot format
        If another program is send while a program is running the first program is aborded.
        """
        prog.strip()
        self.logger.debug("Enqueueing program: %s", prog)
        if not isinstance(prog, bytes):
            prog = prog.encode()

        data = Program(prog + b"\n")
        with data.condition:
            with self._prog_queue_lock:
                self._prog_queue.append(data)
            data.condition.wait(timeout=2.0)
            self.logger.debug("program sent: %s", data)

    def run(self):
        """
        check program execution status in the secondary client data packet we get from the robot
        This interface uses only data from the secondary client interface (see UR doc)
        Only the last connected client is the primary client,
        so this is not guaranted and we cannot rely on information to the primary client.
        """
        while not self._trystop:
            with self._prog_queue_lock:
                if len(self._prog_queue) > 0:
                    data = self._prog_queue.pop(0)
                    self._s_secondary.send(data.program)
                    with data.condition:
                        data.condition.notify_all()

            data = self._get_data()
            try:
                tmpdict = self._parser.parse(data)
                with self._dictLock:
                    self._dict = tmpdict
            except ParsingException as ex:
                self.logger.warning(
                    "Error parsing one packet from urrobot: %s", ex)
                continue

            if "RobotModeData" not in self._dict:
                self.logger.warning(
                    "Got a packet from robot without RobotModeData, strange ..."
                )
                continue

            self.lastpacket_timestamp = time.time()

            rmode = 0
            if self._parser.version >= (3, 0):
                rmode = 7
                if self._parser.version >= (5, 8):
                    if self._dict["RobotModeData"]["robotMode"] == rmode \
                            and self._dict["RobotModeData"]["isRealRobotEnabled"] is True \
                            and self._dict["RobotModeData"]["isEmergencyStopped"] is False \
                            and self._dict["RobotModeData"]["isProtectiveStopped"] is False \
                            and self._dict["RobotModeData"]["isRealRobotConnected"] is True \
                            and self._dict["RobotModeData"]["isRobotPowerOn"] is True:
                        self.running = True
                else:
                    if self._dict["RobotModeData"]["robotMode"] == rmode \
                            and self._dict["RobotModeData"]["isRealRobotEnabled"] is True \
                            and self._dict["RobotModeData"]["isEmergencyStopped"] is False \
                            and self._dict["RobotModeData"]["isSecurityStopped"] is False \
                            and self._dict["RobotModeData"]["isRobotConnected"] is True \
                            and self._dict["RobotModeData"]["isPowerOnRobot"] is True:
                        self.running = True
            else:
                if self.running:
                    self.logger.error("Robot not running: " +
                                      str(self._dict["RobotModeData"]))
                self.running = False
            with self._dataEvent:
                # print("X: new data")
                self._dataEvent.notifyAll()

    def _get_data(self):
        """
        returns something that looks like a packet, nothing is guaranted
        """
        while True:
            # self.logger.debug("data queue size is: {}".format(len(self._dataqueue)))
            ans = self._parser.find_first_packet(self._dataqueue[:])
            if ans:
                self._dataqueue = ans[1]
                # self.logger.debug("found packet of size {}".format(len(ans[0])))
                return ans[0]
            else:
                # self.logger.debug("Could not find packet in received data")
                try:
                    tmp = self._s_secondary.recv(1024)
                except socket.timeout as e:
                    self._s_secondary.close()
                    time.sleep(1.0)
                    self.connect()
                    raise TimeoutException(
                        "Did not receive a valid data packet from robot in {}".
                        format(e))
                self._dataqueue += tmp

    def wait(self, timeout=2.0):
        """
        wait for next data packet from robot
        """
        tstamp = self.lastpacket_timestamp
        with self._dataEvent:
            try:
                self._dataEvent.wait(timeout)
            except socket.timeout as e:
                self._s_secondary.close()
                time.sleep(1.0)
                self.connect()
                raise TimeoutException(
                    "Did not receive a valid data packet from robot in {}".
                    format(e))
            if tstamp == self.lastpacket_timestamp:
                raise TimeoutException(
                    "Did not receive a valid data packet from robot in {}".
                    format(timeout))
            if self._parser.version >= (5, 8):
                isProtectiveStopped = self._dict["RobotModeData"][
                    "isProtectiveStopped"]
            else:
                isProtectiveStopped = self._dict["RobotModeData"][
                    "isSecurityStopped"]
            if isProtectiveStopped is True:
                raise ProtectiveStopException("Protective stopped")

    def get_cartesian_info(self, wait=False):
        if wait:
            self.wait()
        with self._dictLock:
            if "CartesianInfo" in self._dict:
                return self._dict["CartesianInfo"]
            else:
                return None

    def get_all_data(self, wait=False):
        """
        return last data obtained from robot in dictionnary format
        """
        if wait:
            self.wait()
        with self._dictLock:
            return self._dict.copy()

    def get_joint_data(self, wait=False):
        if wait:
            self.wait()
        with self._dictLock:
            if "JointData" in self._dict:
                return self._dict["JointData"]
            else:
                return None

    def get_digital_out(self, nb, wait=False):
        if wait:
            self.wait()
        with self._dictLock:
            output = self._dict["MasterBoardData"]["digitalOutputBits"]
        mask = 1 << nb
        if output & mask:
            return 1
        else:
            return 0

    def get_digital_out_bits(self, wait=False):
        if wait:
            self.wait()
        with self._dictLock:
            return self._dict["MasterBoardData"]["digitalOutputBits"]

    def get_digital_in(self, nb, wait=False):
        if wait:
            self.wait()
        with self._dictLock:
            output = self._dict["MasterBoardData"]["digitalInputBits"]
        mask = 1 << nb
        if output & mask:
            return 1
        else:
            return 0

    def get_digital_in_bits(self, wait=False):
        if wait:
            self.wait()
        with self._dictLock:
            return self._dict["MasterBoardData"]["digitalInputBits"]

    def get_analog_in(self, nb, wait=False):
        if wait:
            self.wait()
        with self._dictLock:
            return self._dict["MasterBoardData"]["analogInput" + str(nb)]

    def get_analog_inputs(self, wait=False):
        if wait:
            self.wait()
        with self._dictLock:
            return self._dict["MasterBoardData"]["analogInput0"], self._dict[
                "MasterBoardData"]["analogInput1"]

    def is_program_running(self, wait=False):
        """
        return True if robot is executing a program
        Rmq: The refresh rate is only 10Hz so the information may be outdated
        """
        if wait:
            self.wait()
        with self._dictLock:
            return self._dict["RobotModeData"]["isProgramRunning"]

    def is_protective_stopped(self, wait=False):
        if wait:
            self.wait()
        with self._dictLock:
            if self._parser.version >= (5, 8):
                val = self.secmon._dict["RobotModeData"]["isProtectiveStopped"]
            else:
                val = self.secmon._dict["RobotModeData"]["isSecurityStopped"]
            return val

    def close(self):
        self._trystop = True
        self.join()
        # with self._dataEvent: #wake up any thread that may be waiting for data before we close. Should we do that?
        # self._dataEvent.notifyAll()
        if self._s_secondary:
            with self._prog_queue_lock:
                self._s_secondary.close()
示例#45
0
文件: ssdp.py 项目: xphillyx/sabnzbd
class SSDP(Thread):
    def __init__(self, host, server_name, url, description, manufacturer,
                 manufacturer_url, model, **kwargs):
        self.__host = host  # Note: this is the LAN IP address!
        self.__server_name = server_name
        self.__url = url
        self.__description = description
        self.__manufacturer = manufacturer
        self.__manufacturer_url = manufacturer_url
        self.__model = model
        self.__ssdp_broadcast_interval = kwargs.get(
            "ssdp_broadcast_interval", 15)  # optional, default 15 seconds

        self.__myhostname = socket.gethostname()
        # a steady uuid: stays the same as long as hostname and ip address stay the same:
        self.__uuid = uuid.uuid3(uuid.NAMESPACE_DNS,
                                 self.__myhostname + self.__host)

        # Create the SSDP broadcast message
        self.__mySSDPbroadcast = f"""NOTIFY * HTTP/1.1
HOST: 239.255.255.250:1900
CACHE-CONTROL: max-age=60
LOCATION: {self.__url}/description.xml
SERVER: {self.__server_name}
NT: upnp:rootdevice
USN: uuid:{self.__uuid}::upnp:rootdevice
NTS: ssdp:alive
OPT: "http://schemas.upnp.org/upnp/1/0/"; ns=01

"""
        self.__mySSDPbroadcast = self.__mySSDPbroadcast.replace(
            "\n", "\r\n").encode("utf-8")

        # Create the XML info (description.xml)
        self.__myxml = f"""<?xml version="1.0" encoding="UTF-8" ?>
<root xmlns="urn:schemas-upnp-org:device-1-0">
<specVersion>
<major>1</major>
<minor>0</minor>
</specVersion>
<URLBase>{self.__url}</URLBase>
<device>
<deviceType>urn:schemas-upnp-org:device:Basic:1</deviceType>
<friendlyName>{self.__server_name} ({self.__myhostname})</friendlyName>
<manufacturer>{self.__manufacturer}</manufacturer>
<manufacturerURL>{self.__manufacturer_url}</manufacturerURL>
<modelDescription>{self.__model} </modelDescription>
<modelName>{self.__model}</modelName>
<modelNumber> </modelNumber>
<modelDescription>{self.__description}</modelDescription>
<modelURL>{self.__manufacturer_url}</modelURL>
<UDN>uuid:{self.__uuid}</UDN>
<presentationURL>{self.__url}</presentationURL>
</device>
</root>"""

        self.__stop = False
        self.__condition = Condition(Lock())
        super().__init__()

    def stop(self):
        logging.info("Stopping SSDP")
        self.__stop = True
        with self.__condition:
            self.__condition.notify()

    def run(self):
        logging.info("Serving SSDP on %s as %s", self.__host,
                     self.__server_name)

        # the standard multicast settings for SSDP:
        MCAST_GRP = "239.255.255.250"
        MCAST_PORT = 1900
        MULTICAST_TTL = 2

        while not self.__stop:
            # Do network stuff
            # Create socket, send the broadcast, and close the socket again
            try:
                with socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
                                   socket.IPPROTO_UDP) as sock:
                    sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL,
                                    MULTICAST_TTL)
                    sock.sendto(self.__mySSDPbroadcast,
                                (MCAST_GRP, MCAST_PORT))
            except:
                # probably no network
                pass

            # Wait until awoken or timeout is up
            with self.__condition:
                self.__condition.wait(self.__ssdp_broadcast_interval)

    def serve_xml(self):
        """Returns an XML-structure based on the information being
        served by this service, returns nothing if not running"""
        if self.__stop:
            return
        return self.__myxml
示例#46
0
    def Acquire(self, timeout=2):
        #print("Locks : " + str(self.locks) + ". Trying to acquire lock")# at\n" + ''.join(traceback.format_stack()))
        with self._lock:
            if self.cond is None:
                if self.lock.acquire(False):
                    return self._GotLock()
                elif self.locks == 0:
                    print(
                        "LOCKS SHOULDVE BEEN UNLOCKED BUT CANT BE ACQUIRED. RECREATING LOCK"
                    )
                    self.lock = self.TLock()
                    if not self.lock.acquire(False):
                        raise Exception("What the f**k")
                    return self._GotLock()
            if timeout <= 0 and self.cond is None and self.lock.acquire():
                return self._GotLock()
            cond = Condition(RLock())
            start = now = time()
            end = now + timeout
            tup = (cond, currentThread().name, start)
            with cond:
                with self._lock:
                    name = currentThread().name
                    #print("WAITING FOR CONDITION %s" % name)
                    self._waitingList.append(tup)
                while now < end:
                    if (self.cond == cond
                            or self.cond is None) and self.lock.acquire(False):
                        with self._lock:
                            self.cond = None
                            #print("GOT LOCK %s %g" % (name, (now-start)))
                            return self._GotLock(tup)
                    else:
                        if tup not in self._waitingList:
                            if now < end:
                                with self._lock:
                                    name = currentThread().name
                                    #print("WAITING AGAIN FOR CONDITION %s %d" % (name, self.locks))
                                    self._waitingList.append(tup)

                    #cond.acquire()
                    self._lock.release()
                    waitTime = min(end - now, 1)
                    try:
                        cond.wait(waitTime)
                    except Exception:
                        cond.acquire()
                        cond.wait(waitTime)
                    self._lock.acquire()
                    now = time()

            with self._lock:
                if self.cond == cond or cond is None:
                    with self._lock:
                        self.cond = None
                        return self._GotLock(tup)
                elif tup in self._waitingList:
                    self._waitingList.remove(tup)
            if self.cond:
                selfcond = str(id(self.cond))
            else:
                selfcond = "None"
            if self.e:
                raise Exception()
            s0 = "%s %s %d Lock acquiring timeout (%gs) Locks remain : %d %s " % (
                currentThread().name, str(
                    id(cond)), self.id, now - start, self.locks, selfcond)
            s = s0 + "at\n" + FormatStack() + '\n'
            i = 0
            for trace in self.stackTraces:
                s = s + ("[STACKTRACE:%d]\n%s" % (i, trace))
                i += 1
            s = s + '\nCurrent execution:\n' + FormatThreadStack(self.owner)
            print(s)
            self.e = s0
            raise Exception(s0)
示例#47
0
class Q(ABC):
    def __init__(self,
                 maxsize_queue=1000,
                 numbers_threads=1,
                 thread_status=False):
        super().__init__()
        self.__numbers_threads = numbers_threads
        self.__thread_status = thread_status
        self.__q = queue.Queue(maxsize=maxsize_queue)
        self.condition = Condition(Lock())
        self.waiting = False
        self.started = False

    @abstractmethod
    def process_item(self, item) -> None:
        pass

    @abstractmethod
    def empty_queue_for_lock(self) -> None:
        pass

    @abstractmethod
    def process_status(self) -> None:
        pass

    def apply_lock(self):
        if self.waiting == False:
            with self.condition:
                self.waiting = True
                self.condition.wait()
                self.waiting = False

    def apply_unlock(self):
        if self.waiting == True:
            with self.condition:
                self.condition.notify_all()

    def __worker(self):
        while self.started == True:
            try:
                if self.__q.empty() == True:
                    self.empty_queue_for_lock()
                else:
                    item = self.__q.get()
                    self.process_item(item)
                    self.__q.task_done()
            except:
                pass

    def __worker__status(self):
        while self.started == True:
            try:
                self.process_status()
            except:
                pass

    def q_values(self):
        return self.__q

    def put_nowait_item(self, item: FrameDTO = None) -> None:
        self.__q.put_nowait(item)
        self.apply_unlock()

    def put_item(self,
                 item: FrameDTO = None,
                 block=True,
                 timeout=None) -> None:
        self.__q.put(item, block, timeout)
        self.apply_unlock()

    def run_queue(self) -> None:
        if self.started == True:
            return
        self.started = True
        self.threads = []
        for i in range(self.__numbers_threads):
            thr = Thread(target=self.__worker, daemon=True)
            thr.start()
            self.threads.append(thr)

        print("self.__thread_status", self.__thread_status)
        if self.__thread_status == True:
            thr = Thread(target=self.__worker__status, daemon=True)
            thr.start()
            self.threads.append(thr)

    def join_queue(self) -> None:
        self.__q.join()

    def stop_queue(self) -> None:
        self.started = False
示例#48
0
    class Application:
        def __init__(self):

            root = tk.Tk()
            root.title(CONFIG['win']['title'])
            root.geometry(CONFIG['win']['geometry'])
            root.resizable(*CONFIG['win']['resizable'])
            root.protocol('WM_DELETE_WINDOW', self.destructor)
            self.widgets = {}
            self.widgets['root'] = root

            # gui
            self.fonts = None
            self.themes = None
            self.init_fonts()
            self.init_themes()
            self.curr_theme = CONFIG['win']['default_theme']

            # flags
            self.stopped = True
            self.stoppable = False
            self.exit = False
            self.feedback_required = False
            self.feedback_provided = False
            self.feedback_frame = False

            # status
            self.rl_session = None
            self.env = None
            self.agent = None
            self.feedback_id = None
            self.time_secs = 0
            self.last_time_secs = None
            self.evaluation = 0
            self.attempts = 0
            self.secret = np.full(CONFIG['rl']['code_len'], None)

            # camera
            self.feedback_highlighter = FeedbackHighlighter(
                CONFIG['highlighter']['fps'], CONFIG['highlighter']['res'],
                CONFIG['highlighter']['format'],
                CONFIG['highlighter']['duration'],
                CONFIG['highlighter']['video_path'])
            self.emotion_analyzer = EmotionAnalyzer(
                CONFIG['analyzer']['docker_image_repository'],
                CONFIG['analyzer']['docker_image_tag'],
                CONFIG['analyzer']['video_path'],
                CONFIG['analyzer']['csv_path'])
            self.vcap = self.init_camera()

            # mainloop
            self.mainloop_cv = Condition()
            self.stoppable_mutex = Lock()
            self.mainloop_thread = self.mainloop()
            self.mainloop_thread.start()

        # customtk--------------------------------------------------

        def custom_label(self, master, x, y, height, width, *args, **kwargs):
            frame = tk.Frame(master, height=height, width=width)
            frame.pack_propagate(0)
            frame.place(x=x, y=y)
            label = tk.Label(frame, *args, **kwargs)
            label.pack(fill=tk.BOTH, expand=1)
            return label

        def custom_button(self, master, x, y, height, width, *args, **kwargs):
            frame = tk.Frame(master, height=height, width=width)
            frame.pack_propagate(0)
            frame.place(x=x, y=y)
            button = tk.Button(frame, *args, **kwargs)
            button.pack(fill=tk.BOTH, expand=1)
            return button

        def custom_option_menu(self, master, x, y, height, width, value,
                               values, *args, **kwargs):
            frame = tk.Frame(master, height=height, width=width)
            frame.pack_propagate(0)
            frame.place(x=x, y=y)
            option_menu = tk.OptionMenu(frame, value, values, *args, **kwargs)
            option_menu.pack(fill=tk.BOTH, expand=1)
            return option_menu

        # init------------------------------------------------------

        def init_fonts(self):
            self.fonts = []
            for font in STYLES['fonts']:
                self.fonts.append(
                    tkFont.Font(
                        name=font['font_name'],
                        family=font['font'],
                        size=font['font_size'],
                        weight=font['font_weight'],
                    ))

        def init_themes(self):
            self.themes = {}
            for theme in STYLES['themes']:
                self.themes[theme['name']] = theme['widgets']

        def init_gui(self):

            # video preview

            video_preview_frame = tk.Frame(master=self.widgets['root'],
                                           width=640,
                                           height=510)
            video_preview_frame.place(x=40, y=40)
            video_preview_title = self.custom_label(video_preview_frame, 0, 0,
                                                    30, 640)
            video_preview_content = self.custom_label(video_preview_frame, 0,
                                                      30, 480, 640)
            self.widgets['video_preview_frame'] = video_preview_frame
            self.widgets['video_preview_title'] = video_preview_title
            self.widgets['video_preview_content'] = video_preview_content

            # attempts

            attempts_frame = tk.Frame(master=self.widgets['root'],
                                      width=200,
                                      height=110)
            attempts_frame.place(x=720, y=40)
            attempts_title = self.custom_label(attempts_frame, 0, 0, 30, 200)
            attempts_content = self.custom_label(attempts_frame, 0, 30, 80,
                                                 200)
            self.widgets['attempts_frame'] = attempts_frame
            self.widgets['attempts_title'] = attempts_title
            self.widgets['attempts_content'] = attempts_content

            # timer

            timer_frame = tk.Frame(master=self.widgets['root'],
                                   width=200,
                                   height=110)
            timer_frame.place(x=960, y=40)
            timer_title = self.custom_label(timer_frame, 0, 0, 30, 200)
            timer_content = self.custom_label(timer_frame, 0, 30, 80, 200)
            self.widgets['timer_frame'] = timer_frame
            self.widgets['timer_title'] = timer_title
            self.widgets['timer_content'] = timer_content

            # code selector

            code_selector_frame = tk.Frame(master=self.widgets['root'],
                                           width=440,
                                           height=360)
            code_selector_frame.place(x=720, y=190)
            code_selector_title = self.custom_label(code_selector_frame, 0, 0,
                                                    30, 440)
            code_selector_content = self.custom_label(code_selector_frame, 34,
                                                      54, 283, 373)
            code_selector_buttons = np.empty(
                (CONFIG['rl']['code_len'], CONFIG['rl']['no_actions']),
                dtype=object)
            for step in range(CONFIG['rl']['code_len']):
                for action in range(CONFIG['rl']['no_actions']):
                    padx = (
                        0,
                        20) if action != CONFIG['rl']['no_actions'] - 1 else (
                            0, 0)
                    pady = (0, 24)
                    tmp_frame = tk.Frame(code_selector_content,
                                         height=78,
                                         width=78)
                    tmp_frame.pack_propagate(0)
                    tmp_frame.grid(row=step,
                                   column=action,
                                   padx=padx,
                                   pady=pady)
                    code_selector_buttons[step][action] = tk.Button(
                        tmp_frame, text=str(action), command=None)
                    code_selector_buttons[step][action].pack(fill=tk.BOTH,
                                                             expand=1)
            self.widgets['code_selector_frame'] = code_selector_frame
            self.widgets['code_selector_title'] = code_selector_title
            self.widgets['code_selector_content'] = code_selector_content
            self.widgets['code_selector_buttons'] = code_selector_buttons

            # feedback evaluation

            feedback_evaluation_frame = tk.Frame(master=self.widgets['root'],
                                                 width=640,
                                                 height=110)
            feedback_evaluation_frame.place(x=40, y=590)
            feedback_evaluation_title = self.custom_label(
                feedback_evaluation_frame, 0, 0, 30, 640)
            feedback_evaluation_scale = tk.Scale(
                feedback_evaluation_frame,
                from_=CONFIG['rl']['min_evaluation'],
                to=CONFIG['rl']['max_evaluation'],
                length=480,
                resolution=0.1)
            feedback_evaluation_scale.place(x=40, y=40)
            feedback_evaluation_button = self.custom_button(
                feedback_evaluation_frame, 550, 42, 50, 50)
            self.widgets[
                'feedback_evaluation_frame'] = feedback_evaluation_frame
            self.widgets[
                'feedback_evaluation_title'] = feedback_evaluation_title
            self.widgets[
                'feedback_evaluation_scale'] = feedback_evaluation_scale
            self.widgets[
                'feedback_evaluation_button'] = feedback_evaluation_button

            # feedback indicator

            feedback_indicator_frame = tk.Frame(master=self.widgets['root'],
                                                width=200,
                                                height=110)
            feedback_indicator_frame.place(x=720, y=590)
            feedback_indicator_title = self.custom_label(
                feedback_indicator_frame, 0, 0, 30, 200)
            feedback_indicator_content = self.custom_label(
                feedback_indicator_frame, 0, 30, 80, 200)
            self.widgets['feedback_indicator_frame'] = feedback_indicator_frame
            self.widgets['feedback_indicator_title'] = feedback_indicator_title
            self.widgets[
                'feedback_indicator_content'] = feedback_indicator_content

            # agent code

            code_frame = tk.Frame(master=self.widgets['root'],
                                  width=200,
                                  height=110)
            code_frame.place(x=960, y=590)
            code_title = self.custom_label(code_frame, 0, 0, 30, 200)
            code_content = self.custom_label(code_frame, 0, 30, 80, 200)
            self.widgets['code_frame'] = code_frame
            self.widgets['code_title'] = code_title
            self.widgets['code_content'] = code_content

            # control buttons

            flow_button = self.custom_button(self.widgets['root'], 40, 740, 40,
                                             200)
            reset_button = self.custom_button(self.widgets['root'], 280, 740,
                                              40, 200)
            self.widgets['flow_button'] = flow_button
            self.widgets['reset_button'] = reset_button

            # theme selector

            theme = tk.StringVar(self.widgets['root'])
            theme.set(self.curr_theme)
            theme_selector = self.custom_option_menu(
                self.widgets['root'],
                960,
                740,
                40,
                200,
                theme,
                *self.themes.keys(),
                command=self.on_theme_changed)
            self.widgets['theme'] = theme
            self.widgets['theme_selector'] = theme_selector

        def init_camera(self):
            for camera in range(3):
                vcap = cv2.VideoCapture(camera)
                if vcap is not None and vcap.isOpened():
                    return vcap
            return cv2.VideoCapture()

        def init_listeners(self):
            for step in range(CONFIG['rl']['code_len']):
                for action in range(CONFIG['rl']['no_actions']):
                    self.widgets['code_selector_buttons'][step][
                        action].configure(
                            command=self.on_code_selector_button_clicked(
                                step, action))
            self.widgets['feedback_evaluation_button'].configure(
                command=self.on_feedback_evaluation_button_clicked)
            self.widgets['reset_button'].configure(
                command=self.on_reset_button_clicked)
            self.widgets['flow_button'].configure(
                command=self.on_flow_button_clicked)

        def init_rl_session(self):
            return {
                'session_id':
                CONFIG['rl']['session_prefix'] + str(int(time.time() * 1000)),
                'config': {
                    'secret': sorted(list(self.secret)),
                    'no_pegs': self.env.action_space.n,
                    'code_len': len(self.env.secret),
                    'alpha': self.agent.alpha,
                    'gamma': self.agent.gamma,
                    'epsilon': self.agent.epsilon,
                    'beta': self.agent.beta,
                    'exploration_mode': self.agent.exploration_mode,
                    'epsilon_decay': self.agent.epsilon_decay,
                    'epsilon_low': self.agent.epsilon_low
                },
                'result': {
                    'guessed': None,
                    'optimal': None,
                    'qmatrix': None,
                    'attempts': None,
                    'time': None
                },
                'feedback': {}
            }

        # gui-------------------------------------------------------

        def apply_theme(self):

            theme = self.themes[self.curr_theme]

            # root

            self.widgets['root'].configure(
                background=theme['root']['background'])

            # video preview

            self.widgets['video_preview_frame'].configure(
                bg=theme['video_preview_frame']['background'])
            self.widgets['video_preview_title'].configure(
                text=theme['video_preview_title']['text'],
                bg=theme['video_preview_title']['background'],
                fg=theme['video_preview_title']['foreground'],
                font=tkFont.Font(name=theme['video_preview_title']['font'],
                                 exists=True))
            self.widgets['video_preview_content'].configure(
                bg=theme['video_preview_content']['background'],
                fg=theme['video_preview_content']['foreground'],
                font=tkFont.Font(name=theme['video_preview_content']['font'],
                                 exists=True))

            # attempts

            self.widgets['attempts_frame'].configure(
                bg=theme['attempts_frame']['background'])
            self.widgets['attempts_title'].configure(
                text=theme['attempts_title']['text'],
                bg=theme['attempts_title']['background'],
                fg=theme['attempts_title']['foreground'],
                font=tkFont.Font(name=theme['attempts_title']['font'],
                                 exists=True))
            self.widgets['attempts_content'].configure(
                bg=theme['attempts_content']['background'],
                fg=theme['attempts_content']['foreground'],
                font=tkFont.Font(name=theme['attempts_content']['font'],
                                 exists=True))

            # timer

            self.widgets['timer_frame'].configure(
                bg=theme['timer_frame']['background'])
            self.widgets['timer_title'].configure(
                text=theme['timer_title']['text'],
                bg=theme['timer_title']['background'],
                fg=theme['timer_title']['foreground'],
                font=tkFont.Font(name=theme['timer_title']['font'],
                                 exists=True))
            self.widgets['timer_content'].configure(
                bg=theme['timer_content']['background'],
                fg=theme['timer_content']['foreground'],
                font=tkFont.Font(name=theme['timer_content']['font'],
                                 exists=True))

            # code selector

            self.widgets['code_selector_frame'].configure(
                bg=theme['code_selector_frame']['background'])
            self.widgets['code_selector_title'].configure(
                text=theme['code_selector_title']['text'],
                bg=theme['code_selector_title']['background'],
                fg=theme['code_selector_title']['foreground'],
                font=tkFont.Font(name=theme['code_selector_title']['font'],
                                 exists=True))
            self.widgets['code_selector_content'].configure(
                bg=theme['code_selector_content']['background'])
            for step in range(CONFIG['rl']['code_len']):
                for action in range(CONFIG['rl']['no_actions']):
                    self.widgets['code_selector_buttons'][step][
                        action].configure(
                            bg=theme['code_selector_button']['background'],
                            fg=theme['code_selector_button']['foreground'],
                            activebackground=theme['code_selector_button']
                            ['background_active'],
                            activeforeground=theme['code_selector_button']
                            ['foreground_active'],
                            disabledforeground=theme['code_selector_button']
                            ['foreground_disabled'],
                            font=tkFont.Font(
                                name=theme['code_selector_button']['font'],
                                exists=True),
                            highlightthickness=0,
                            bd=0)

            # feedback evaluation

            self.widgets['feedback_evaluation_frame'].configure(
                bg=theme['feedback_evaluation_frame']['background'])
            self.widgets['feedback_evaluation_title'].configure(
                text=theme['feedback_evaluation_title']['text'],
                bg=theme['feedback_evaluation_title']['background'],
                fg=theme['feedback_evaluation_title']['foreground'],
                font=tkFont.Font(
                    name=theme['feedback_evaluation_title']['font'],
                    exists=True))
            self.widgets['feedback_evaluation_scale'].configure(
                tickinterval=1,
                orient=tk.HORIZONTAL,
                bg=theme['feedback_evaluation_scale']['background'],
                fg=theme['feedback_evaluation_scale']['foreground'],
                troughcolor=theme['feedback_evaluation_scale']['trough'],
                font=tkFont.Font(
                    name=theme['feedback_evaluation_scale']['font'],
                    exists=True),
                highlightthickness=0,
                bd=0)
            self.widgets['feedback_evaluation_button'].configure(
                text=theme['feedback_evaluation_button']['text'],
                bg=theme['feedback_evaluation_button']['background'],
                fg=theme['feedback_evaluation_button']['foreground'],
                activebackground=theme['feedback_evaluation_button']
                ['background_active'],
                activeforeground=theme['feedback_evaluation_button']
                ['foreground_active'],
                disabledforeground=theme['feedback_evaluation_button']
                ['foreground_disabled'],
                font=tkFont.Font(
                    name=theme['feedback_evaluation_button']['font'],
                    exists=True),
                highlightthickness=0,
                bd=0,
            )

            # feedback indicator

            self.widgets['feedback_indicator_frame'].configure(
                bg=theme['feedback_indicator_frame']['background'])
            self.widgets['feedback_indicator_title'].configure(
                text=theme['feedback_indicator_title']['text'],
                bg=theme['feedback_indicator_title']['background'],
                fg=theme['feedback_indicator_title']['foreground'],
                font=tkFont.Font(
                    name=theme['feedback_indicator_title']['font'],
                    exists=True))
            self.widgets['feedback_indicator_content'].configure(
                text=theme['feedback_indicator_content']['text'],
                bg=theme['feedback_indicator_content']['background'],
                fg=theme['feedback_indicator_content']['foreground'],
                font=tkFont.Font(
                    name=theme['feedback_indicator_content']['font'],
                    exists=True))

            # agent code

            self.widgets['code_frame'].configure(
                bg=theme['code_frame']['background'])
            self.widgets['code_title'].configure(
                text=theme['code_title']['text'],
                bg=theme['code_title']['background'],
                fg=theme['code_title']['foreground'],
                font=tkFont.Font(name=theme['code_title']['font'],
                                 exists=True))
            self.widgets['code_content'].configure(
                bg=theme['code_content']['background'],
                fg=theme['code_content']['foreground'],
                font=tkFont.Font(name=theme['code_content']['font'],
                                 exists=True))

            # control buttons

            self.widgets['flow_button'].configure(
                bg=theme['flow_button']['background'],
                fg=theme['flow_button']['foreground'],
                activebackground=theme['flow_button']['background_active'],
                activeforeground=theme['flow_button']['foreground_active'],
                disabledforeground=theme['flow_button']['foreground_disabled'],
                font=tkFont.Font(name=theme['flow_button']['font'],
                                 exists=True),
                highlightthickness=0,
                bd=0)
            self.widgets['reset_button'].configure(
                text=theme['reset_button']['text'],
                bg=theme['reset_button']['background'],
                fg=theme['reset_button']['foreground'],
                activebackground=theme['reset_button']['background_active'],
                activeforeground=theme['reset_button']['foreground_active'],
                disabledforeground=theme['reset_button']
                ['foreground_disabled'],
                font=tkFont.Font(name=theme['reset_button']['font'],
                                 exists=True),
                highlightthickness=0,
                bd=0)

            # theme selector

            self.widgets['theme_selector'].config(
                bg=theme['theme_selector']['background'],
                fg=theme['theme_selector']['foreground'],
                font=tkFont.Font(name=theme['theme_selector']['font'],
                                 exists=True),
                activebackground=theme['theme_selector']['background_active'],
                activeforeground=theme['theme_selector']['foreground_active'],
                highlightthickness=0,
                bd=0,
                relief=tk.FLAT,
                indicatoron=0,
                direction='above')

            self.refresh()

        def refresh(self, refresh_type='all'):
            if refresh_type == 'all':
                self.update_timer()
                self.update_attempts()
                self.update_feedback_indicator()
                self.update_code()
                self.update_feedback_evaluation_scale()
                self.update_feedback_evaluation_button()
                self.update_flow_button()
                self.update_reset_button()
                self.update_code_selector()
            elif refresh_type == 'rl':
                self.update_attempts()
                self.update_code()
                self.update_feedback_indicator()
                self.update_feedback_evaluation_scale()
                self.update_feedback_evaluation_button()
                self.update_flow_button()
                self.update_reset_button()

        def update_attempts(self):
            attempts_str = str(self.attempts).replace('', ' ')[1:-1]
            self.widgets['attempts_content'].configure(text=attempts_str)

        def update_timer(self):
            mins, secs = divmod(int(round(self.time_secs)), 60)
            time_secs_str = str(mins).zfill(2) + ':' + str(secs).zfill(2)
            time_secs_str = time_secs_str.replace('', ' ')[1:-1]
            self.widgets['timer_content'].configure(text=time_secs_str)

        def update_feedback_indicator(self):
            theme = self.themes[self.curr_theme]
            if self.feedback_required:
                fg = theme['feedback_indicator_content']['foreground_required']
            else:
                fg = theme['feedback_indicator_content'][
                    'foreground_not_required']
            self.widgets['feedback_indicator_content'].configure(fg=fg)

        def update_code(self):
            code = None
            if self.agent is not None:
                code = list(self.agent.curr_state)
            if code is None or len(code) == 0:
                theme = self.themes[self.curr_theme]
                code_str = theme['code_content']['text_empty']
            else:
                code_str = '{' + str(code)[1:-1] + '}'
            self.widgets['code_content'].configure(text=code_str)

        def update_feedback_evaluation_scale(self):
            theme = self.themes[self.curr_theme]
            if self.feedback_required:
                state = tk.NORMAL
                troughcolor = theme['feedback_evaluation_scale']['trough']
            else:
                state = tk.DISABLED
                troughcolor = theme['feedback_evaluation_scale'][
                    'trough_disabled']
            self.widgets['feedback_evaluation_scale'].configure(
                state=state, troughcolor=troughcolor)

        def update_feedback_evaluation_button(self):
            if self.feedback_required:
                state = tk.NORMAL
            else:
                state = tk.DISABLED
            self.widgets['feedback_evaluation_button'].configure(state=state)

        def update_flow_button(self):
            theme = self.themes[self.curr_theme]
            if (self.env is not None
                    and self.env.is_guessed()) or not (self.stoppable
                                                       or self.stopped):
                bg = theme['flow_button']['background_disabled']
                state = tk.DISABLED
                text = self.widgets['flow_button']['text']
            else:
                bg = theme['flow_button']['background']
                state = tk.NORMAL
                if self.stopped:
                    text = theme['flow_button']['text_start']
                else:
                    text = theme['flow_button']['text_stop']
            self.widgets['flow_button'].configure(state=state,
                                                  bg=bg,
                                                  text=text)

        def update_reset_button(self):
            theme = self.themes[self.curr_theme]
            if self.stopped:
                bg = theme['reset_button']['background']
                state = tk.NORMAL
            else:
                bg = theme['reset_button']['background_disabled']
                state = tk.DISABLED
            self.widgets['reset_button'].configure(state=state, bg=bg)

        def update_code_selector_button(self, step, action):
            theme = self.themes[self.curr_theme]
            if self.secret[step] is None:
                bg = theme['code_selector_button']['background']
                fg = theme['code_selector_button']['foreground']
                state = tk.NORMAL
            else:
                if action == self.secret[step]:
                    bg = theme['code_selector_button']['background_selected']
                    fg = theme['code_selector_button']['foreground_selected']
                    if self.stopped:
                        state = tk.NORMAL
                    else:
                        state = tk.DISABLED
                else:
                    bg = theme['code_selector_button']['background_disabled']
                    fg = theme['code_selector_button']['foreground']
                    state = tk.DISABLED
            self.widgets['code_selector_buttons'][step][action].configure(
                bg=bg, fg=fg, state=state)

        def update_code_selector(self):
            for step in range(CONFIG['rl']['code_len']):
                for action in range(CONFIG['rl']['no_actions']):
                    self.update_code_selector_button(step, action)

        def flash_code_selector_button(self,
                                       step,
                                       action,
                                       flash_bg_color,
                                       flash_count=3,
                                       delay=250):
            if flash_count > 0:
                self.widgets['code_selector_buttons'][step][action].configure(
                    background=flash_bg_color)
                self.widgets['code_selector_buttons'][step][action].after(
                    delay / 2,
                    lambda: self.update_code_selector_button(step, action))
                self.widgets['code_selector_buttons'][step][action].after(
                    delay, lambda: self.flash_code_selector_button(
                        step, action, flash_bg_color, flash_count - 1, delay))

        def flash_error_code_selector(self):
            theme = self.themes[self.curr_theme]
            for step in range(CONFIG['rl']['code_len']):
                if self.secret[step] is None:
                    for action in range(CONFIG['rl']['no_actions']):
                        self.flash_code_selector_button(
                            step, action,
                            theme['code_selector_button']['flash_error'])

        def flash_guessed_code_selector(self):
            theme = self.themes[self.curr_theme]
            for step in range(CONFIG['rl']['code_len']):
                for action in range(CONFIG['rl']['no_actions']):
                    self.flash_code_selector_button(
                        step,
                        action,
                        theme['code_selector_button']['flash_guessed'],
                        flash_count=3,
                        delay=500)

        def flash_action_code_selector(self, action):
            theme = self.themes[self.curr_theme]
            for step in range(CONFIG['rl']['code_len']):
                self.flash_code_selector_button(
                    step,
                    action,
                    theme['code_selector_button']['flash_action'],
                    flash_count=1,
                    delay=1500)

        # listeners-------------------------------------------------

        def on_code_selector_button_clicked(self, step, action):
            def on_code_selector_button_clicked_listener():
                if self.secret[step] is None:
                    self.secret[step] = action
                else:
                    self.secret[step] = None
                self.update_code_selector()

            return on_code_selector_button_clicked_listener

        def on_flow_button_clicked(self):
            self.stoppable_mutex.acquire()
            if self.stopped and (None in self.secret):
                self.flash_error_code_selector()
            else:
                with self.mainloop_cv:
                    self.stopped = not self.stopped
                    self.mainloop_cv.notifyAll()
                if not self.stopped:
                    self.update_code_selector()
                    self.timer()
                else:
                    self.last_time_secs = None
            self.update_reset_button()
            self.update_flow_button()
            self.stoppable_mutex.release()

        def on_reset_button_clicked(self):
            if self.rl_session is not None:
                self.fill_rl_session_result()
                DB.insert(self.rl_session)
                self.rl_session = None
            self.reset()
            self.refresh()

        def on_theme_changed(self, theme):
            self.curr_theme = theme
            self.apply_theme()

        def on_feedback_evaluation_button_clicked(self):
            self.evaluation = self.widgets['feedback_evaluation_scale'].get()
            self.feedback_id = CONFIG['highlighter'][
                'video_name_prefix'] + str(int(time.time() * 1000))
            self.rl_session['feedback'][self.feedback_id] = {
                'evaluation':
                self.evaluation,
                'attempt':
                list(self.agent.curr_state),
                'time':
                str(int(self.time_secs / 60)).zfill(2) + ':' +
                str(int(self.time_secs % 60)).zfill(2)
            }
            with self.mainloop_cv:
                self.feedback_required = False
                self.mainloop_cv.notifyAll()
            self.feedback_provided = True
            self.feedback_frame = True
            self.update_feedback_indicator()
            self.update_feedback_evaluation_scale()
            self.update_feedback_evaluation_button()

        # status----------------------------------------------------

        def destructor(self):
            self.exit = True
            with self.mainloop_cv:
                self.mainloop_cv.notifyAll()
            self.widgets['root'].destroy()
            self.vcap.release()
            cv2.destroyAllWindows()

        def reset(self):
            self.feedback_provided = False
            self.feedback_frame = False
            self.rl_session = None
            with self.mainloop_cv:
                self.feedback_required = False
                self.stopped = True
                self.mainloop_cv.notifyAll()
            self.time_secs = 0
            self.last_time_secs = None
            self.evaluation = 0
            self.attempts = 0
            self.secret = np.full(CONFIG['rl']['code_len'], None)
            self.env = None
            self.agent = None

        def fill_rl_session_result(self):
            qmatrix = {}
            for state in self.agent.qmatrix.keys():
                state_str = '{' + str(list(state))[1:-1] + '}'
                qmatrix[state_str] = {
                    'qvalues':
                    str(list(self.agent.qmatrix[state]['qvalues'])),
                    'td_errors':
                    str(list(self.agent.qmatrix[state]['td_errors'])),
                    'td_errors_variations':
                    str(list(self.agent.qmatrix[state]['td_errors_delta'])),
                    'visits':
                    self.agent.qmatrix[state]['visits']
                }
            time_str = str(int(self.time_secs / 60)).zfill(2) + ':' + str(
                int(self.time_secs % 60)).zfill(2)
            self.rl_session['result']['guessed'] = self.env.is_guessed()
            self.rl_session['result']['optimal'] = list(
                self.agent.get_optimal())
            self.rl_session['result']['qmatrix'] = qmatrix
            self.rl_session['result']['attempts'] = self.attempts
            self.rl_session['result']['time'] = time_str
            for feedback_id in self.rl_session['feedback'].keys():
                csv_path = CONFIG['analyzer'][
                    'csv_path'] + '/' + feedback_id + '.csv'
                if os.path.isfile(csv_path):
                    self.rl_session['feedback'][feedback_id][
                        'csv_path'] = csv_path
                else:
                    self.rl_session['feedback'][feedback_id]['csv_path'] = None

        # services--------------------------------------------------

        def webcam(self):
            success, frame = self.vcap.read()
            if success:
                video_path = self.feedback_highlighter.scroll(
                    cv2.flip(frame, 1), self.feedback_frame, self.feedback_id)
                self.feedback_frame = False
                if video_path is not None:
                    Thread(target=lambda: self.emotion_analyzer.analyze(
                        os.path.basename(video_path))).start()
                rgba_frame = cv2.cvtColor(cv2.flip(frame, 1),
                                          cv2.COLOR_BGR2RGBA)
                img = Image.fromarray(rgba_frame)
                imgtk = ImageTk.PhotoImage(image=img)
                self.widgets['video_preview_content'].imgtk = imgtk
                self.widgets['video_preview_content'].configure(image=imgtk,
                                                                text='')
            else:
                self.vcap.release()
                self.vcap = self.init_camera()
                theme = self.themes[self.curr_theme]
                self.widgets['video_preview_content'].configure(
                    text=theme['video_preview_content']['text_error'],
                    image='')
            self.widgets['video_preview_content'].after(
                1000 / CONFIG['vcap']['fps'], self.webcam)

        def timer(self):
            if not self.stopped:
                self.last_time_secs = self.last_time_secs or time.time()
                now_time_secs = time.time()
                self.time_secs = (self.time_secs + now_time_secs -
                                  self.last_time_secs) % 3600
                self.last_time_secs = now_time_secs
                self.update_timer()
                self.widgets['timer_content'].after(1000, self.timer)

        def mainloop(self):
            def mainloop_thread():

                while not self.exit:

                    # Verifica se l'applicazione è stata fermata o se è richiesto un feedback
                    # in caso affermativo rimane in attesa passiva
                    with self.mainloop_cv:
                        while self.stopped or self.feedback_required:
                            # In caso di uscita salva i dati se presenti
                            if self.exit:
                                if self.rl_session is not None:
                                    self.fill_rl_session_result()
                                    DB.insert(self.rl_session)
                                    self.rl_session = None
                                return
                            self.mainloop_cv.wait()

                    # Inializzazione sessione RL
                    if self.rl_session is None:
                        self.env = gym.make(CONFIG['rl']['gym'],
                                            no_pegs=CONFIG['rl']['no_actions'],
                                            secret=self.secret,
                                            random_seed=np.random.randint(
                                                np.iinfo(np.int32).max))
                        self.agent = Agent(self.env)
                        self.rl_session = self.init_rl_session()

                    else:

                        # Ignora le eccezioni sul mainloop (grafica) quando si esce
                        # dall'applicazione senza che lo step RL sia terminato
                        try:

                            # Disabilita il pulsante di STOP
                            self.stoppable_mutex.acquire()
                            if self.stopped:
                                continue
                            self.stoppable = False
                            self.update_flow_button()
                            self.stoppable_mutex.release()

                            # Se è stato fornito un feedback aggiorna la matrice Q
                            if self.feedback_provided:
                                self.agent.update_qmatrix(self.evaluation)
                                self.feedback_provided = False
                                self.agent.curr_state = self.env.reset()
                                print(self.agent.qmatrix_to_str())

                            # Altrimenti scegli un'azione da eseguire
                            elif not self.feedback_required:
                                action = self.agent.get_action()
                                self.feedback_required = self.agent.take_action(
                                    action)

                                # Se l'azione è terminale incrementa gli attempts
                                if self.feedback_required:
                                    self.attempts += 1

                                    # Se il multiset finale è corretto interrompi e salva i dati
                                    if self.env.is_guessed():
                                        self.agent.update_qmatrix(
                                            CONFIG['rl']['max_evaluation'])
                                        self.fill_rl_session_result()
                                        DB.insert(self.rl_session)
                                        self.rl_session = None
                                        self.stopped = True
                                        self.feedback_required = False
                                        self.flash_guessed_code_selector()

                                    # Altrimenti flash azione
                                    else:
                                        self.flash_action_code_selector(action)

                                # Altrimenti flash azione
                                else:
                                    self.flash_action_code_selector(action)

                            # Riabilita il pulsante di STOP
                            time.sleep(0.5)
                            self.stoppable = True
                            self.refresh('rl')
                            #time.sleep(CONFIG['rl']['epoch_delay'])

                        except:
                            raise

            self.init_gui()
            self.init_listeners()
            self.apply_theme()
            self.webcam()

            return Thread(target=mainloop_thread)
                # Notify that the count was increment
                event_count_condition.notify_all()

    # Register the callback with the client
    client.add_event_callback(EVENT_TOPIC, MyEventCallback())

    #
    # Send events
    #

    # Record the start time
    start = time.time()

    # Loop and send the events
    for event_id in range(TOTAL_EVENTS):
        # Create the event
        event = Event(EVENT_TOPIC)
        # Set the payload
        event.payload = str(event_id).encode()
        # Send the event
        client.send_event(event)

    # Wait until all events have been received
    print("Waiting for events to be received...")
    with event_count_condition:
        while event_count[0] < TOTAL_EVENTS:
            event_count_condition.wait()

    # Print the elapsed time
    print("Elapsed time (ms): " + str((time.time() - start) * 1000))
示例#50
0
class RPC:
    def __init__(self, infile, outfile, onRequest, onNotification, onError):
        self.infile = infile
        self.outfile = outfile
        self.onRequest = onRequest
        self.onNotification = onNotification
        self.onError = onError
        self.mid = 0
        self.queue = {}
        self.cv = Condition()
        self.result = None

    def incMid(self) -> int:
        mid = self.mid
        self.mid += 1
        return mid

    def message(self, contentDict: Dict[str, Any]) -> None:
        content = json.dumps(contentDict)
        message = ("Content-Length: {}\r\n\r\n"
                   "{}".format(len(content), content))
        logger.debug(' => ' + content)
        self.outfile.write(message)
        self.outfile.flush()

    def call(self, method: str, params: Dict[str, Any], cb=None):
        """
        @param cb: func. Callback to handle result. If None, turn to sync call.
        """
        mid = self.incMid()
        if cb is not None:
            self.queue[mid] = cb

        contentDict = {
            "jsonrpc": "2.0",
            "method": method,
            "params": params,
            "id": mid,
        }  # type: Dict[str, Any]
        self.message(contentDict)

        if cb is not None:
            return

        with self.cv:
            while self.result is None:
                self.cv.wait()
            result = self.result
            self.result = None
            return result

    def notify(self, method: str, params: Dict[str, Any]) -> None:
        contentDict = {
            "jsonrpc": "2.0",
            "method": method,
            "params": params,
        }  # type: Dict[str, Any]
        self.message(contentDict)

    def serve(self):
        contentLength = 0
        while not self.infile.closed:
            line = self.infile.readline().strip()
            if line:
                header, value = line.split(":")
                if header == "Content-Length":
                    contentLength = int(value)
            else:
                content = self.infile.read(contentLength)
                logger.debug(' <= ' + content)
                self.handle(json.loads(content))

    def handle(self, message: Dict[str, Any]):
        if "error" in message:  # error
            if "id" in message:
                mid = message["id"]
                del self.queue[mid]
            try:
                self.onError(message["error"])
            except:
                logger.exception("Exception in RPC.onError.")
        elif "result" in message:  # result
            mid = message['id']
            result = message["result"]
            if mid in self.queue:  # async call
                try:
                    self.queue[mid](result)
                except:
                    logger.exception("Exception in RPC request callback.")
                del self.queue[mid]
            else:  # sync call
                with self.cv:
                    self.result = result
                    self.cv.notify()
        elif "method" in message:  # request/notification
            if "id" in message:  # request
                try:
                    self.onRequest(message)
                except:
                    logger.exception("Exception in RPC.onRequest")
            else:
                try:
                    self.onNotification(message)
                except:
                    logger.exception("Exception in RPC.onNotification")
        else:
            logger.error('Unexpected')
示例#51
0
class PLCObject(object):
    def __init__(self, WorkingDir, argv, statuschange, evaluator, pyruntimevars):
        self.workingdir = WorkingDir  # must exits already
        self.tmpdir = os.path.join(WorkingDir, 'tmp')
        if os.path.exists(self.tmpdir):
            shutil.rmtree(self.tmpdir)
        os.mkdir(self.tmpdir)
        # FIXME : is argv of any use nowadays ?
        self.argv = [WorkingDir] + argv  # force argv[0] to be "path" to exec...
        self.statuschange = statuschange
        self.evaluator = evaluator
        self.pyruntimevars = pyruntimevars
        self.PLCStatus = PlcStatus.Empty
        self.PLClibraryHandle = None
        self.PLClibraryLock = Lock()
        # Creates fake C funcs proxies
        self._InitPLCStubCalls()
        self._loading_error = None
        self.python_runtime_vars = None
        self.TraceThread = None
        self.TraceLock = Lock()
        self.Traces = []
        self.DebugToken = 0

        self._init_blobs()

    # First task of worker -> no @RunInMain
    def AutoLoad(self, autostart):
        # Get the last transfered PLC
        try:
            self.CurrentPLCFilename = open(
                self._GetMD5FileName(),
                "r").read().strip() + lib_ext
            self.PLCStatus = PlcStatus.Stopped
            if autostart:
                if self.LoadPLC():
                    self.StartPLC()
                    return
        except Exception:
            self.PLCStatus = PlcStatus.Empty
            self.CurrentPLCFilename = None

        self.StatusChange()

    def StatusChange(self):
        if self.statuschange is not None:
            for callee in self.statuschange:
                callee(self.PLCStatus)

    def LogMessage(self, *args):
        if len(args) == 2:
            level, msg = args
        else:
            level = LogLevelsDefault
            msg, = args
        PLCprint(msg)
        if self._LogMessage is not None:
            return self._LogMessage(level, msg, len(msg))
        return None

    @RunInMain
    def ResetLogCount(self):
        if self._ResetLogCount is not None:
            self._ResetLogCount()

    # used internaly
    def GetLogCount(self, level):
        if self._GetLogCount is not None:
            return int(self._GetLogCount(level))
        elif self._loading_error is not None and level == 0:
            return 1

    @RunInMain
    def GetLogMessage(self, level, msgid):
        tick = ctypes.c_uint32()
        tv_sec = ctypes.c_uint32()
        tv_nsec = ctypes.c_uint32()
        if self._GetLogMessage is not None:
            maxsz = len(self._log_read_buffer)-1
            sz = self._GetLogMessage(level, msgid,
                                     self._log_read_buffer, maxsz,
                                     ctypes.byref(tick),
                                     ctypes.byref(tv_sec),
                                     ctypes.byref(tv_nsec))
            if sz and sz <= maxsz:
                self._log_read_buffer[sz] = '\x00'
                return self._log_read_buffer.value, tick.value, tv_sec.value, tv_nsec.value
        elif self._loading_error is not None and level == 0:
            return self._loading_error, 0, 0, 0
        return None

    def _GetMD5FileName(self):
        return os.path.join(self.workingdir, "lasttransferedPLC.md5")

    def _GetLibFileName(self):
        return os.path.join(self.workingdir, self.CurrentPLCFilename)

    def _LoadPLC(self):
        """
        Load PLC library
        Declare all functions, arguments and return values
        """
        md5 = open(self._GetMD5FileName(), "r").read()
        self.PLClibraryLock.acquire()
        try:
            self._PLClibraryHandle = dlopen(self._GetLibFileName())
            self.PLClibraryHandle = ctypes.CDLL(self.CurrentPLCFilename, handle=self._PLClibraryHandle)

            self.PLC_ID = ctypes.c_char_p.in_dll(self.PLClibraryHandle, "PLC_ID")
            if len(md5) == 32:
                self.PLC_ID.value = md5

            self._startPLC = self.PLClibraryHandle.startPLC
            self._startPLC.restype = ctypes.c_int
            self._startPLC.argtypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_char_p)]

            self._stopPLC_real = self.PLClibraryHandle.stopPLC
            self._stopPLC_real.restype = None

            self._PythonIterator = getattr(self.PLClibraryHandle, "PythonIterator", None)
            if self._PythonIterator is not None:
                self._PythonIterator.restype = ctypes.c_char_p
                self._PythonIterator.argtypes = [ctypes.c_char_p, ctypes.POINTER(ctypes.c_void_p)]

                self._stopPLC = self._stopPLC_real
            else:
                # If python confnode is not enabled, we reuse _PythonIterator
                # as a call that block pythonthread until StopPLC
                self.PlcStopping = Event()

                def PythonIterator(res, blkid):
                    self.PlcStopping.clear()
                    self.PlcStopping.wait()
                    return None
                self._PythonIterator = PythonIterator

                def __StopPLC():
                    self._stopPLC_real()
                    self.PlcStopping.set()
                self._stopPLC = __StopPLC

            self._ResetDebugVariables = self.PLClibraryHandle.ResetDebugVariables
            self._ResetDebugVariables.restype = None

            self._RegisterDebugVariable = self.PLClibraryHandle.RegisterDebugVariable
            self._RegisterDebugVariable.restype = None
            self._RegisterDebugVariable.argtypes = [ctypes.c_int, ctypes.c_void_p]

            self._FreeDebugData = self.PLClibraryHandle.FreeDebugData
            self._FreeDebugData.restype = None

            self._GetDebugData = self.PLClibraryHandle.GetDebugData
            self._GetDebugData.restype = ctypes.c_int
            self._GetDebugData.argtypes = [ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(ctypes.c_void_p)]

            self._suspendDebug = self.PLClibraryHandle.suspendDebug
            self._suspendDebug.restype = ctypes.c_int
            self._suspendDebug.argtypes = [ctypes.c_int]

            self._resumeDebug = self.PLClibraryHandle.resumeDebug
            self._resumeDebug.restype = None

            self._ResetLogCount = self.PLClibraryHandle.ResetLogCount
            self._ResetLogCount.restype = None

            self._GetLogCount = self.PLClibraryHandle.GetLogCount
            self._GetLogCount.restype = ctypes.c_uint32
            self._GetLogCount.argtypes = [ctypes.c_uint8]

            self._LogMessage = self.PLClibraryHandle.LogMessage
            self._LogMessage.restype = ctypes.c_int
            self._LogMessage.argtypes = [ctypes.c_uint8, ctypes.c_char_p, ctypes.c_uint32]

            self._log_read_buffer = ctypes.create_string_buffer(1 << 14)  # 16K
            self._GetLogMessage = self.PLClibraryHandle.GetLogMessage
            self._GetLogMessage.restype = ctypes.c_uint32
            self._GetLogMessage.argtypes = [ctypes.c_uint8, ctypes.c_uint32, ctypes.c_char_p, ctypes.c_uint32, ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(ctypes.c_uint32)]

            self._loading_error = None

        except Exception:
            self._loading_error = traceback.format_exc()
            PLCprint(self._loading_error)
            return False
        finally:
            self.PLClibraryLock.release()

        return True

    @RunInMain
    def LoadPLC(self):
        res = self._LoadPLC()
        if res:
            self.PythonRuntimeInit()
        else:
            self._FreePLC()

        return res

    @RunInMain
    def UnLoadPLC(self):
        self.PythonRuntimeCleanup()
        self._FreePLC()

    def _InitPLCStubCalls(self):
        """
        create dummy C func proxies
        """
        self._startPLC = lambda x, y: None
        self._stopPLC = lambda: None
        self._ResetDebugVariables = lambda: None
        self._RegisterDebugVariable = lambda x, y: None
        self._IterDebugData = lambda x, y: None
        self._FreeDebugData = lambda: None
        self._GetDebugData = lambda: -1
        self._suspendDebug = lambda x: -1
        self._resumeDebug = lambda: None
        self._PythonIterator = lambda: ""
        self._GetLogCount = None
        self._LogMessage = None
        self._GetLogMessage = None
        self._PLClibraryHandle = None
        self.PLClibraryHandle = None

    def _FreePLC(self):
        """
        Unload PLC library.
        This is also called by __init__ to create dummy C func proxies
        """
        self.PLClibraryLock.acquire()
        try:
            # Unload library explicitely
            if getattr(self, "_PLClibraryHandle", None) is not None:
                dlclose(self._PLClibraryHandle)

            # Forget all refs to library
            self._InitPLCStubCalls()

        finally:
            self.PLClibraryLock.release()

        return False

    def PythonRuntimeCall(self, methodname, use_evaluator=True, reverse_order=False):
        """
        Calls init, start, stop or cleanup method provided by
        runtime python files, loaded when new PLC uploaded
        """
        methods = self.python_runtime_vars.get("_runtime_%s" % methodname, [])
        if reverse_order:
            methods = reversed(methods)
        for method in methods:
            if use_evaluator:
                _res, exp = self.evaluator(method)
            else:
                _res, exp = default_evaluator(method)
            if exp is not None:
                self.LogMessage(0, '\n'.join(traceback.format_exception(*exp)))

    # used internaly
    def PythonRuntimeInit(self):
        MethodNames = ["init", "start", "stop", "cleanup"]
        self.python_runtime_vars = globals().copy()
        self.python_runtime_vars.update(self.pyruntimevars)
        parent = self

        class PLCSafeGlobals(object):
            def __getattr__(self, name):
                try:
                    t = parent.python_runtime_vars["_"+name+"_ctype"]
                except KeyError:
                    raise KeyError("Try to get unknown shared global variable : %s" % name)
                v = t()
                parent.python_runtime_vars["_PySafeGetPLCGlob_"+name](ctypes.byref(v))
                return parent.python_runtime_vars["_"+name+"_unpack"](v)

            def __setattr__(self, name, value):
                try:
                    t = parent.python_runtime_vars["_"+name+"_ctype"]
                except KeyError:
                    raise KeyError("Try to set unknown shared global variable : %s" % name)
                v = parent.python_runtime_vars["_"+name+"_pack"](t, value)
                parent.python_runtime_vars["_PySafeSetPLCGlob_"+name](ctypes.byref(v))

        self.python_runtime_vars.update({
            "PLCGlobals":     PLCSafeGlobals(),
            "WorkingDir":     self.workingdir,
            "PLCObject":      self,
            "PLCBinary":      self.PLClibraryHandle,
            "PLCGlobalsDesc": []})

        for methodname in MethodNames:
            self.python_runtime_vars["_runtime_%s" % methodname] = []

        try:
            filenames = os.listdir(self.workingdir)
            filenames.sort()
            for filename in filenames:
                name, ext = os.path.splitext(filename)
                if name.upper().startswith("RUNTIME") and ext.upper() == ".PY":
                    execfile(os.path.join(self.workingdir, filename), self.python_runtime_vars)
                    for methodname in MethodNames:
                        method = self.python_runtime_vars.get("_%s_%s" % (name, methodname), None)
                        if method is not None:
                            self.python_runtime_vars["_runtime_%s" % methodname].append(method)
        except Exception:
            self.LogMessage(0, traceback.format_exc())
            raise

        self.PythonRuntimeCall("init", use_evaluator=False)

        self.PythonThreadCondLock = Lock()
        self.PythonThreadCond = Condition(self.PythonThreadCondLock)
        self.PythonThreadCmd = "Wait"
        self.PythonThread = Thread(target=self.PythonThreadProc, name="PLCPythonThread")
        self.PythonThread.start()

    # used internaly
    def PythonRuntimeCleanup(self):
        if self.python_runtime_vars is not None:
            self.PythonThreadCommand("Finish")
            self.PythonThread.join()
            self.PythonRuntimeCall("cleanup", use_evaluator=False, reverse_order=True)

        self.python_runtime_vars = None

    def PythonThreadLoop(self):
        res, cmd, blkid = "None", "None", ctypes.c_void_p()
        compile_cache = {}
        while True:
            cmd = self._PythonIterator(res, blkid)
            FBID = blkid.value
            if cmd is None:
                break
            try:
                self.python_runtime_vars["FBID"] = FBID
                ccmd, AST = compile_cache.get(FBID, (None, None))
                if ccmd is None or ccmd != cmd:
                    AST = compile(cmd, '<plc>', 'eval')
                    compile_cache[FBID] = (cmd, AST)
                result, exp = self.evaluator(eval, AST, self.python_runtime_vars)
                if exp is not None:
                    res = "#EXCEPTION : "+str(exp[1])
                    self.LogMessage(1, ('PyEval@0x%x(Code="%s") Exception "%s"') % (
                        FBID, cmd, '\n'.join(traceback.format_exception(*exp))))
                else:
                    res = str(result)
                self.python_runtime_vars["FBID"] = None
            except Exception as e:
                res = "#EXCEPTION : "+str(e)
                self.LogMessage(1, ('PyEval@0x%x(Code="%s") Exception "%s"') % (FBID, cmd, str(e)))

    def PythonThreadProc(self):
        while True:
            self.PythonThreadCondLock.acquire()
            cmd = self.PythonThreadCmd
            while cmd == "Wait":
                self.PythonThreadCond.wait()
                cmd = self.PythonThreadCmd
                self.PythonThreadCmd = "Wait"
            self.PythonThreadCondLock.release()

            if cmd == "Activate":
                self.PythonRuntimeCall("start")
                self.PythonThreadLoop()
                self.PythonRuntimeCall("stop", reverse_order=True)
            else:  # "Finish"
                break

    def PythonThreadCommand(self, cmd):
        self.PythonThreadCondLock.acquire()
        self.PythonThreadCmd = cmd
        self.PythonThreadCond.notify()
        self.PythonThreadCondLock.release()

    def _fail(msg):
        self.LogMessage(0, msg)
        self.PLCStatus = PlcStatus.Broken
        self.StatusChange()

    def PreStartPLC(self):
        """ 
        Here goes actions to be taken just before PLC starts, 
        with all libraries and python object already created.
        For example : restore saved proprietary parameters
        """
        pass

    @RunInMain
    def StartPLC(self):

        if self.PLClibraryHandle is None:
            if not self.LoadPLC():
                self._fail(_("Problem starting PLC : can't load PLC"))

        self.PreStartPLC()

        if self.CurrentPLCFilename is not None and self.PLCStatus == PlcStatus.Stopped:
            c_argv = ctypes.c_char_p * len(self.argv)
            res = self._startPLC(len(self.argv), c_argv(*self.argv))
            if res == 0:
                self.PLCStatus = PlcStatus.Started
                self.StatusChange()
                self.PythonThreadCommand("Activate")
                self.LogMessage("PLC started")
            else:
                self._fail(_("Problem starting PLC : error %d" % res))

    @RunInMain
    def StopPLC(self):
        if self.PLCStatus == PlcStatus.Started:
            self.LogMessage("PLC stopped")
            self._stopPLC()
            self.PLCStatus = PlcStatus.Stopped
            self.StatusChange()
            if self.TraceThread is not None:
                self.TraceThread.join()
                self.TraceThread = None
            return True
        return False

    def GetPLCstatus(self):
        try:
            return self._GetPLCstatus()
        except EOFError:
            return (PlcStatus.Disconnected, None)

    @RunInMain
    def _GetPLCstatus(self):
        return self.PLCStatus, map(self.GetLogCount, xrange(LogLevelsCount))

    @RunInMain
    def GetPLCID(self):
        return getPSKID(partial(self.LogMessage, 0))

    def _init_blobs(self):
        self.blobs = {}
        if os.path.exists(self.tmpdir):
            shutil.rmtree(self.tmpdir)
        os.mkdir(self.tmpdir)

    @RunInMain
    def SeedBlob(self, seed):
        blob = (mkstemp(dir=self.tmpdir) + (hashlib.new('md5'),))
        _fd, _path, md5sum = blob
        md5sum.update(seed)
        newBlobID = md5sum.digest()
        self.blobs[newBlobID] = blob
        return newBlobID

    @RunInMain
    def AppendChunkToBlob(self, data, blobID):
        blob = self.blobs.pop(blobID, None)

        if blob is None:
            return None

        fd, _path, md5sum = blob
        md5sum.update(data)
        newBlobID = md5sum.digest()
        os.write(fd, data)
        self.blobs[newBlobID] = blob
        return newBlobID

    @RunInMain
    def PurgeBlobs(self):
        for fd, _path, _md5sum in self.blobs.values():
            os.close(fd)
        self._init_blobs()

    def _BlobAsFile(self, blobID, newpath):
        blob = self.blobs.pop(blobID, None)

        if blob is None:
            raise Exception(_("Missing data to create file: {}").format(newpath))

        fd, path, _md5sum = blob
        fobj = os.fdopen(fd)
        fobj.flush()
        os.fsync(fd)
        fobj.close()
        shutil.move(path, newpath)

    def _extra_files_log_path(self):
        return os.path.join(self.workingdir, "extra_files.txt")

    def RepairPLC(self):
        self.PurgePLC()
        MainWorker.quit()

    @RunInMain
    def PurgePLC(self):

        extra_files_log = self._extra_files_log_path()

        old_PLC_filename = os.path.join(self.workingdir, self.CurrentPLCFilename) \
            if self.CurrentPLCFilename is not None \
            else None

        try:
            allfiles = open(extra_files_log, "rt").readlines()
            allfiles.extend([extra_files_log, old_PLC_filename, self._GetMD5FileName()])
        except Exception:
            self.LogMessage("No files to purge")
            allfiles = []

        for filename in allfiles:
            if filename:
                filename = filename.strip()
                try:
                    os.remove(os.path.join(self.workingdir, filename))
                except Exception:
                    self.LogMessage("Couldn't purge " + filename)

        self.PLCStatus = PlcStatus.Empty

        # TODO: PLCObject restart

    @RunInMain
    def NewPLC(self, md5sum, plc_object, extrafiles):
        if self.PLCStatus in [PlcStatus.Stopped, PlcStatus.Empty, PlcStatus.Broken]:
            NewFileName = md5sum + lib_ext
            extra_files_log = self._extra_files_log_path()

            new_PLC_filename = os.path.join(self.workingdir, NewFileName)

            self.UnLoadPLC()

            self.PurgePLC()

            self.LogMessage("NewPLC (%s)" % md5sum)

            try:
                # Create new PLC file
                self._BlobAsFile(plc_object, new_PLC_filename)

                # Then write the files
                log = open(extra_files_log, "w")
                for fname, blobID in extrafiles:
                    fpath = os.path.join(self.workingdir, fname)
                    self._BlobAsFile(blobID, fpath)
                    log.write(fname+'\n')

                # Store new PLC filename based on md5 key
                with open(self._GetMD5FileName(), "w") as f:
                    f.write(md5sum)
                    f.flush()
                    os.fsync(f.fileno())

                # Store new PLC filename
                self.CurrentPLCFilename = NewFileName
            except Exception:
                self.PLCStatus = PlcStatus.Broken
                self.StatusChange()
                PLCprint(traceback.format_exc())
                return False

            if self.LoadPLC():
                self.PLCStatus = PlcStatus.Stopped
            else:
                self.PLCStatus = PlcStatus.Broken
            self.StatusChange()

            return self.PLCStatus == PlcStatus.Stopped
        return False

    def MatchMD5(self, MD5):
        try:
            last_md5 = open(self._GetMD5FileName(), "r").read()
            return last_md5 == MD5
        except Exception:
            pass
        return False

    @RunInMain
    def SetTraceVariablesList(self, idxs):
        """
        Call ctype imported function to append
        these indexes to registred variables in PLC debugger
        """
        self.DebugToken += 1
        if idxs:
            # suspend but dont disable
            if self._suspendDebug(False) == 0:
                # keep a copy of requested idx
                self._ResetDebugVariables()
                for idx, iectype, force in idxs:
                    if force is not None:
                        c_type, _unpack_func, pack_func = \
                            TypeTranslator.get(iectype,
                                               (None, None, None))
                        force = ctypes.byref(pack_func(c_type, force))
                    self._RegisterDebugVariable(idx, force)
                self._TracesSwap()
                self._resumeDebug()
                return self.DebugToken
        else:
            self._suspendDebug(True)
        return None

    def _TracesSwap(self):
        self.LastSwapTrace = time()
        if self.TraceThread is None and self.PLCStatus == PlcStatus.Started:
            self.TraceThread = Thread(target=self.TraceThreadProc, name="PLCTrace")
            self.TraceThread.start()
        self.TraceLock.acquire()
        Traces = self.Traces
        self.Traces = []
        self.TraceLock.release()
        return Traces

    @RunInMain
    def GetTraceVariables(self, DebugToken):
        if DebugToken is not None and DebugToken == self.DebugToken:
            return self.PLCStatus, self._TracesSwap()
        return PlcStatus.Broken, []

    def TraceThreadProc(self):
        """
        Return a list of traces, corresponding to the list of required idx
        """
        self._resumeDebug()  # Re-enable debugger
        while self.PLCStatus == PlcStatus.Started:
            tick = ctypes.c_uint32()
            size = ctypes.c_uint32()
            buff = ctypes.c_void_p()
            TraceBuffer = None

            self.PLClibraryLock.acquire()

            res = self._GetDebugData(ctypes.byref(tick),
                                     ctypes.byref(size),
                                     ctypes.byref(buff))
            if res == 0:
                if size.value:
                    TraceBuffer = ctypes.string_at(buff.value, size.value)
                self._FreeDebugData()

            self.PLClibraryLock.release()

            # leave thread if GetDebugData isn't happy.
            if res != 0:
                break

            if TraceBuffer is not None:
                self.TraceLock.acquire()
                lT = len(self.Traces)
                if lT != 0 and lT * len(self.Traces[0]) > 1024 * 1024:
                    self.Traces.pop(0)
                self.Traces.append((tick.value, TraceBuffer))
                self.TraceLock.release()

            # TraceProc stops here if Traces not polled for 3 seconds
            traces_age = time() - self.LastSwapTrace
            if traces_age > 3:
                self.TraceLock.acquire()
                self.Traces = []
                self.TraceLock.release()
                self._suspendDebug(True)  # Disable debugger
                break

        self.TraceThread = None

    def RemoteExec(self, script, *kwargs):
        try:
            exec(script, kwargs)
        except Exception:
            _e_type, e_value, e_traceback = sys.exc_info()
            line_no = traceback.tb_lineno(get_last_traceback(e_traceback))
            return (-1, "RemoteExec script failed!\n\nLine %d: %s\n\t%s" %
                    (line_no, e_value, script.splitlines()[line_no - 1]))
        return (0, kwargs.get("returnVal", None))
示例#52
0
class Api(vnokcoin_spot_usd.OkcoinApi):
    """OKCOIN的API实现"""

    #----------------------------------------------------------------------
    def __init__(self, gateway):
        """Constructor"""
        super(Api, self).__init__()

        self.gateway = gateway  # gateway对象
        self.gatewayName = gateway.gatewayName  # gateway对象名称

        self.cbDict = {}
        self.tickDict = {}
        self.orderDict = {}

        self.lastOrderID = ''
        self.orderCondition = Condition()
        self.trade_password = False

        #self.initCallback()
        self.tradeFlag = False
        self.tradeFlag_2 = True
        self.logger = rwLoggerFunction()
        #self.strategyName=''

    #----------------------------------------------------------------------
    def qryInstruments(self):
        """查询合约"""
        params = {'accountId': self.accountId}
        self.getInstruments(params)

    # ----------------------------------------------------------------------
    def qryGenerateCnyContract(self):
        l = self.generateCnyContract()
        for contract in l:
            contract.gatewayName = self.gatewayName
            self.gateway.onContract(contract)

    # ----------------------------------------------------------------------
    def onMessage(self, ws, evt):
        """信息推送"""
        data = json.loads(evt)
        channel = data[RESPONSE_CHANNEL]
        callback = self.cbDict[channel]
        callback(data)

    #----------------------------------------------------------------------
    def onError(self, evt):
        """错误推送"""
        error = VtErrorData()
        error.gatewayName = self.gatewayName
        error.errorMsg = str(evt)
        self.gateway.onError(error)

    #----------------------------------------------------------------------
    def onClose(self, ws):
        """接口断开"""
        self.gateway.connected = True
        self.writeLog(u'服务器连接断开')

    #----------------------------------------------------------------------
    def onOpen(self, ws):
        pass

    #----------------------------------------------------------------------
    def writeLog(self, content):
        """快速记录日志"""
        log = VtLogData()
        log.gatewayName = self.gatewayName
        log.logContent = content
        self.gateway.onLog(log)

    #----------------------------------------------------------------------
    def initCallback(self):
        """初始化回调函数"""
        pass

    #----------------------------------------------------------------------
    def generateSpecificContract(self, contract, symbol):
        """生成合约"""
        new = copy(contract)
        new.symbol = symbol
        new.vtSymbol = symbol
        #new.vtSymbol = EXCHANGE_NAME + CONNECTION_MARK + symbol
        new.name = symbol
        return new

    #----------------------------------------------------------------------
    def generateCnyContract(self):
        """生成CNY合约信息"""
        contractList = []

        contract = VtContractData()
        contract.exchange = EXCHANGE_NAME
        contract.productClass = PRODUCT_SPOT
        contract.size = 1
        contract.priceTick = 0.01
        #contract.strategyName = self.strategyName
        contractList.append(
            self.generateSpecificContract(contract, BTC_USD_SPOT))

        return contractList

    # ----------------------------------------------------------------------
    # def getTickerInfo(self,symbol):
    #     """查询行情数据"""
    #     paramsDict = {"ticker_depth": 'ticker', "symbol": symbol}
    #     self.sendRequest(paramsDict, self.onTicker, API_FLAG_TICKER)

    # ----------------------------------------------------------------------
    def onTicker(self, data):
        """"""
        if 'ticker' not in data:
            return
        ticker = data['ticker']
        symbol = BTC_USD_SPOT
        #vtSymbol= EXCHANGE_NAME+'_'+symbol
        vtSymbol = symbol
        if vtSymbol not in self.tickDict:
            tick = VtTickData()
            tick.symbol = symbol
            tick.vtSymbol = vtSymbol
            tick.gatewayName = self.gatewayName
            tick.exchange = EXCHANGE_NAME
            self.tickDict[vtSymbol] = tick
        else:
            tick = self.tickDict[vtSymbol]

        tick.highPrice = float(ticker['high'])
        tick.lowPrice = float(ticker['low'])
        tick.lastPrice = float(ticker['last'])
        tick.volume = float(ticker['vol'])
        tick.date, tick.time = generateDateTime(data['date'])

        newtick = copy(tick)
        self.gateway.onTick(newtick)

    # ----------------------------------------------------------------------
    # def getDepthInfo(self,symbol):
    #     """查询行情数据"""
    #     paramsDict = {"ticker_depth": 'depth', "symbol": symbol}
    #     self.sendRequest(paramsDict, self.onDepth, API_FLAG_TICKER)
    # ----------------------------------------------------------------------

    def onDepth(self, data):
        """"""
        # if 'asks' not in data:
        #     return
        #
        # symbol = channelSymbolMap[data['symbol']]
        #
        # if symbol not in self.tickDict:
        #     tick = VtTickData()
        #     tick.symbol = symbol
        #     tick.vtSymbol = symbol
        #     tick.gatewayName = self.gatewayName
        #     self.tickDict[symbol] = tick
        # else:
        #     tick = self.tickDict[symbol]
        #
        # tick.bidPrice1, tick.bidVolume1 = data['bids'][0]
        # tick.bidPrice2, tick.bidVolume2 = data['bids'][1]
        # tick.bidPrice3, tick.bidVolume3 = data['bids'][2]
        # tick.bidPrice4, tick.bidVolume4 = data['bids'][3]
        # tick.bidPrice5, tick.bidVolume5 = data['bids'][4]
        #
        # tick.askPrice1, tick.askVolume1 = data['asks'][0]
        # tick.askPrice2, tick.askVolume2 = data['asks'][1]
        # tick.askPrice3, tick.askVolume3 = data['asks'][2]
        # tick.askPrice4, tick.askVolume4 = data['asks'][3]
        # tick.askPrice5, tick.askVolume5 = data['asks'][4]
        pass
        #newtick = copy(tick)
        #self.gateway.onTick(newtick)

    # ----------------------------------------------------------------------
    def getOrders(self):
        """查询正在进行的委托订单"""
        # timestamp = long(time.time())
        # paramsDict = {"access_key": self.apiKey, "secret_key": self.secretKey,
        #               "created": timestamp, "coin_type": 1, "method": 'get_orders'}
        # sign = signature(paramsDict)
        # del paramsDict["secret_key"]
        # paramsDict['sign'] = sign
        # self.sendRequest(paramsDict, self.onGetOrders)
        pass

    # ----------------------------------------------------------------------
    def onGetOrders(self, data):
        """回调函数"""
        # if len(data) == 0:
        #     return
        # for d in data:
        #     order = VtOrderData()
        #     order.gatewayName = self.gatewayName
        #
        #     order.symbol = BTC_USD_SPOT
        #     order.exchange = EXCHANGE_OKCOIN
        #     order.vtSymbol = '.'.join([order.symbol, order.exchange])
        #     order.orderID = str(d['id'])
        #     order.direction, priceType = priceTypeMap[str(d['type'])]
        #     order.offset = orderTypeMap[str(d['type'])]
        #     #order.status = orderStatusMap[str(d['status'])]
        #
        #     order.price = d['order_price']
        #     order.totalVolume = d['order_amount']
        #     order.tradeVolume = d['processed_amount']
        #     order.orderTime = generateDateTimeStamp(d['order_time'])
        #
        #     order.vtOrderID = '.'.join([self.gatewayName, order.orderID])
        #
        #     self.gateway.onOrder(order)
        #
        #     self.orderDict[order.orderID] = order
        pass
        #self.writeLog(u'委托信息查询完成')

    # ----------------------------------------------------------------------
    def getTrades(self):
        #print "okcoin getTrades start"
        self.tradeFlag = False

        ORDER_INFO_RESOURCE = "/api/v1/order_info.do"
        params = {
            'api_key': self.apiKey,
            'symbol': 'btc_usd',
            'order_id': self.lastOrderID
        }
        params['sign'] = buildMySign(params, self.secretKey)
        self.sendRequest(params, self.onGetTrade, ORDER_INFO_RESOURCE)
        #print "okcoin getTrades end"

    #----------------------------------------------------------------------
    def onGetTrade(self, result):
        #print "okcoin onGetTrade start"
        """回调函数"""
        if 'orders' not in result:
            print 'Trade Data Error'
            return

        ordersData = result['orders']
        if len(ordersData) == 0:
            print 'no Trade Data '
            return
        for data in ordersData:

            order = VtOrderData()
            order.gatewayName = self.gatewayName

            order.symbol = BTC_USD_SPOT
            order.exchange = EXCHANGE_OKCOIN
            order.vtSymbol = '.'.join([order.symbol, order.exchange])

            order.orderID = str(data['order_id'])
            order.direction, priceType = priceTypeMap[str(data['type'])]
            order.offset = tradeTypeMap[str(data['type'])]
            # order.status = orderStatusMap[str(d['status'])]

            order.price = data['price']
            order.totalVolume = data['amount']
            order.tradeVolume = data['deal_amount']
            order.status = tradeStatusMap[str(data['status'])]
            #order.orderTime = generateDateTimeStamp(d['order_time'])

            order.vtOrderID = '.'.join([self.gatewayName, order.orderID])

            self.gateway.onOrder(order)

            #self.orderDict[order.orderID] = order
            if 'status' in data and tradeStatusMap[str(
                    data['status'])] == TRADER_STATUS_DEAL:
                trade = VtTradeData()
                trade.gatewayName = self.gatewayName

                trade.symbol = BTC_USD_SPOT
                trade.exchange = EXCHANGE_OKCOIN
                trade.vtSymbol = '.'.join([trade.symbol, trade.exchange])
                trade.orderID = str(data['order_id'])
                trade.tradeID = str(data['order_id'])
                trade.direction, priceType = priceTypeMap[str(data['type'])]
                trade.offset = tradeTypeMap[str(data['type'])]

                trade.price = data['avg_price']
                trade.volume = float(data['deal_amount'])
                trade.status = tradeStatusMap[str(data['status'])]

                trade.vtOrderID = '.'.join([self.gatewayName, trade.orderID])
                trade.vtTradeID = '.'.join([self.gatewayName, trade.tradeID])

                self.gateway.onTrade(trade)

                #self.orderDict[trade.orderID] = trade
        #print "okcoin onGetTrade end"
        self.writeLog(u'成交信息查询完成')

    # ----------------------------------------------------------------------
    def spotUserInfo(self):
        """查询现货账户"""
        USERINFO_RESOURCE = "/api/v1/userinfo.do"
        params = {}
        params['api_key'] = self.apiKey
        params['sign'] = buildMySign(params, self.secretKey)
        self.sendRequest(params, self.onSpotUserInfo, USERINFO_RESOURCE)

    # ----------------------------------------------------------------------
    def onSpotUserInfo(self, data):
        """回调函数"""
        # 持仓信息
        for symbol in ['btc', 'usd']:
            pos = VtPositionData()
            pos.gatewayName = self.gatewayName

            pos.symbol = symbol
            pos.vtSymbol = symbol
            pos.vtPositionName = symbol + "_" + self.gatewayName
            pos.direction = DIRECTION_NET

            funds = data['info']['funds']
            pos.frozen = float(funds['freezed']['%s' % symbol])
            pos.position = float(funds['free']['%s' % symbol])

            self.gateway.onPosition(pos)

        account = VtAccountData()

        account.gatewayName = self.gatewayName
        account.accountID = self.gatewayName
        account.vtAccountID = account.accountID
        account.balance = float(funds['asset']['total'])
        self.gateway.onAccount(account)

    # ----------------------------------------------------------------------

    def sendOrder(self, params):
        #print "okcoin sendOrder start"
        self.lastOrderID = ''
        """发送委托"""

        if params.direction == DIRECTION_LONG:
            direction = 'buy'
        else:
            direction = 'sell'

        TRADE_RESOURCE = "/api/v1/trade.do"
        paramsDict = {
            'api_key': self.apiKey,
            'symbol': 'btc_usd',
            'type': direction
        }
        if params.orderStyle == 1:
            self.tradeFlag_2 = False
        if params.price:
            paramsDict['price'] = params.price
        if params.volume:
            paramsDict['amount'] = params.volume
        paramsDict['sign'] = buildMySign(paramsDict, self.secretKey)
        #print "okcoin sendOrder start_2"
        self.sendRequest(paramsDict, self.onSendOrder, TRADE_RESOURCE)
        #print "okcoin sendOrder start_3"
        # 等待发单回调推送委托号信息
        self.orderCondition.acquire()
        self.orderCondition.wait()
        self.orderCondition.release()
        vtOrderID = '.'.join([self.gatewayName, self.lastOrderID])
        #print "okcoin sendOrder end"
        return vtOrderID

    # ----------------------------------------------------------------------
    def onSendOrder(self, data):
        #print "okcoin onSendOrder start"
        if 'result' in data and data['result']:
            self.lastOrderID = str(data['order_id'])
            self.tradeFlag = True
            #self.logger.setInfoLog('onSend_okcoin:' + 'ID:' + self.lastOrderID)
            print(u'okcoin  onSendOrder Sucess:', self.lastOrderID)
        else:
            print(u'OKCOIN下单失败,请查询账户资金额度')

        print "okcoin onSendOrder start_2"
        # 收到委托号后,通知发送委托的线程返回委托号
        self.orderCondition.acquire()
        self.orderCondition.notify()
        self.orderCondition.release()
        #print "okcoin onSendOrder end"

    # ----------------------------------------------------------------------
    def cancelOrder(self, params):
        """发送撤單"""

        CANCEL_ORDER_RESOURCE = "/api/v1/cancel_order.do"
        params = {
            'api_key': self.apiKey,
            'symbol': 'btc_cny',
            'order_id': params.orderID
        }
        params['sign'] = buildMySign(params, self.__secretkey)
        return self.sendRequest(params, self.onCancelOrder,
                                CANCEL_ORDER_RESOURCE)

    # ----------------------------------------------------------------------
    # 从策略例调用的接口
    def getTrades_huotou(self, orderID):
        """查询最近的成交订单"""

        ORDER_INFO_RESOURCE = "/api/v1/order_info.do"
        params = {
            'api_key': self.apiKey,
            'symbol': 'btc_usd',
            'order_id': orderID
        }
        params['sign'] = buildMySign(params, self.secretKey)
        result = httpPost(OKCOIN_HOST, ORDER_INFO_RESOURCE, params)

        if result['result'] and len(result['orders']) > 0:
            orderStatus = self.onGetTrade(result)
            return orderStatus
        else:
            return False

    # ----------------------------------------------------------------------

    def onCancelOrder(self, data):
        # if data['result'] == 'success':
        #     print "撤单完成"
        pass

    # ----------------------------------------------------------------------

    def generateUsdContract(self):
        """生成USD合约信息"""
        pass

    # ----------------------------------------------------------------------

    def onSpotTrade(self, data):
        """委托回报"""
        pass

    # ----------------------------------------------------------------------

    def onSpotCancelOrder(self, data):
        """撤单回报"""
        pass

    # ----------------------------------------------------------------------

    def spotSendOrder(self, req):
        """发单"""
        pass

    # ----------------------------------------------------------------------

    def spotCancel(self, req):
        """撤单"""
        pass

    # ----------------------------------------------------------------------

    def onSpotSubTrades(self, data):
        pass

    # ----------------------------------------------------------------------

    def onSpotOrderInfo(self, data):
        pass
示例#53
0
class MypyEventHandler(BaseThread):
    def __init__(self, dep_graph, queueing_handler, file_cache, compact,
                 num_workers):
        # type: (ModuleGraph, MypyQueueingHandler, MypyFileCache, bool, int) -> None
        self.dep_graph = dep_graph
        self.worker_pool = []  # type: List[MypyWorker]
        self.task_pool = []  # type: List[MypyTask]
        self.task_cond = Condition()
        self.queueing_handler = queueing_handler
        self.file_cache = file_cache
        self.compact = compact
        self.num_workers = num_workers
        super(MypyEventHandler, self).__init__()

    def on_deleted(self, event):
        # type: (FileSystemEvent) -> None
        pass

    def on_created(self, event):
        # type: (FileSystemEvent) -> None
        pass

    def _add_task(self, task, index=None):
        # type: (MypyTask, Optional[int]) -> None
        if task in self.task_pool:
            return
        if index is None:
            self.task_pool.append(task)
        else:
            self.task_pool.insert(index, task)

    def _find_modified_module(self, src_path):
        # type: (str) -> Optional[Module]
        modified_module = None
        for module_ in self.dep_graph.listModules():
            if module_.filename == src_path:
                modified_module = module_
                break
        return modified_module

    def _find_dependencies(self, root_module):
        # type: (Module) -> Set[str]
        modified_modules = {root_module}
        dependencies_to_check = set()

        def check_module(mod):
            # type: (Module) -> None
            # If it's an __init__ file then we might have changed the interface to
            # this module, so add it to the set of modified modules. This will cause
            # us to iterate all modules again.
            if mod.modname.endswith('__init__'):
                modified_modules.add(mod)
            else:
                dependencies_to_check.add(os.path.abspath(mod.filename))

        while True:
            modified_modules_size = len(modified_modules)

            for module_ in self.dep_graph.listModules():
                for import_name in module_.imports:
                    modified_modnames = {
                        mod.modname
                        for mod in modified_modules
                    }

                    if import_name in modified_modnames:
                        check_module(module_)
                        break

                    elif '{}.__init__'.format(
                            import_name) in modified_modnames:
                        check_module(module_)
                        break

            # Run until we haven't added any new modules to the set of modified modules.
            # We do this to catch the following cases:
            # (1) an init exports another init's methods (uncommon)
            # (2) we have a.py (modified) <- __init__.py <- b.py and we happen to scan b.py first (more common)
            if modified_modules_size == len(modified_modules):
                break

        dependencies_to_check.update({
            os.path.abspath(module_.filename)
            for module_ in modified_modules
        })
        return dependencies_to_check

    def _ensure_workers(self):
        # type: () -> None
        while len(self.worker_pool) < self.num_workers:
            worker = MypyWorker(self.task_pool, self.task_cond,
                                self.file_cache, self.compact)
            self.worker_pool.append(worker)
            worker.start()

    def _disable_workers(self):
        # type: () -> None
        # Prevent workers from consuming any tasks in the queue until we're ready.
        for worker in self.worker_pool:
            worker.run_tasks = False

    def _enable_workers(self):
        # type: () -> None
        # Mark all workers to start running tasks and interrupt any
        # workers with tasks that will need to be re-run.
        for worker in self.worker_pool:
            worker.run_tasks = True
            if worker.current_task is None:
                continue
            if worker.current_task in self.task_pool:
                worker.current_task.interrupt()
        self.task_cond.notify_all()

    def _wait_until_tasks_completed(self):
        # type: () -> None
        start_size = len(self.task_pool)
        while len(self.task_pool
                  ) > 0 and not self.queueing_handler.has_new_events:
            curr_size = len(self.task_pool)
            total_completed_tasks = start_size - curr_size
            sys.stdout.write('  {}/{}\r'.format(total_completed_tasks,
                                                start_size))
            sys.stdout.flush()
            self.task_cond.wait()

        if self.queueing_handler.has_new_events:
            print('Detected new changes, interrupting...')
        else:
            # Even though all tasks have been pulled from the task_pool,
            # they might not have been completed, so we have to wait until all
            # workers have cleared their current_task field.
            all_clear = False
            while not all_clear:
                all_clear = True
                for worker in self.worker_pool:
                    if worker.current_task is not None:
                        all_clear = False
                        break
                if not all_clear:
                    self.task_cond.wait()

    def on_modified(self, event):
        # type: (FileSystemEvent) -> None
        print_divider('TYPECHECKING', newline_before=True)

        self.task_cond.acquire()

        modified_module = self._find_modified_module(event.src_path)
        if modified_module is None:
            print('Unable to find module for modified file {}'.format(
                event.src_path))
            dependencies_to_check = {event.src_path}
        else:
            dependencies_to_check = self._find_dependencies(modified_module)

            # Add the modified file first so it's the first one to be checked.
            self._add_task(MypyTask(os.path.abspath(modified_module.filename)),
                           index=0)

        for filename in dependencies_to_check:
            self._add_task(MypyTask(filename))

        self._ensure_workers()
        self._enable_workers()
        self._wait_until_tasks_completed()
        self._disable_workers()

        print_divider('DONE')
        self.task_cond.release()

    def run(self):
        # type: () -> None
        while True:
            self.queueing_handler.next_event(self)
示例#54
0
class TradeAgent(object):
    def __init__(self, td_front, broker_id, user_id, password):
        self.request_id = 1
        self.broker_id = broker_id
        self.user_id = user_id
        self.password = password
        self.callback = None

        self.qry_cond = Condition()
        self.rsp_qry_queue = Queue()
        self.cond = Condition()
        self.data_queue = Queue()
        self.qry_results = {}

        self.instrument_dict = {}
        self.position_dict = None
        self.trading_account = None
        self.trade_dict = None
        self.order_dict = None

        self.session_id = None
        self.front_id = None
        self.max_order_ref = None

        self.api = MyTraderApi(broker_id, user_id, password,
                               self.rsp_qry_queue, self.data_queue)
        thread = Thread(target=self.init, args=(td_front, ))
        thread.start()
        thread_rsp_qry = Thread(target=self.process_rsp_qry)
        thread_rsp_qry.start()
        thread_data = Thread(target=self.process_data)
        thread_data.start()

    def init(self, td_front):
        self.api.RegisterFront(td_front)
        self.api.SubscribePrivateTopic(2)
        self.api.SubscribePublicTopic(2)
        self.api.Init()
        self.api.Join()

    def process_rsp_qry(self):
        while 1:
            msg = self.rsp_qry_queue.get()
            if not isinstance(msg, tuple):
                continue
            request_id, result, is_last = msg
            self.qry_cond.acquire()
            if request_id not in self.qry_results:
                self.qry_results[request_id] = {
                    'results': [],
                    'status': False,
                }
            if result:
                self.qry_results[request_id]['results'].append(result)
            if is_last:
                self.qry_results[request_id]['status'] = True
                self.qry_cond.notifyAll()
            self.qry_cond.release()

    def process_data(self):
        while 1:
            data = self.data_queue.get()
            if isinstance(
                    data, ApiStruct.Order
            ) and data.ExchangeID and data.TraderID and data.OrderLocalID:
                self.order_dict[(data.ExchangeID, data.TraderID,
                                 data.OrderLocalID)] = data
            elif isinstance(
                    data,
                    ApiStruct.Trade) and data.ExchangeID and data.TradeID:
                self.trade_dict[(data.ExchangeID, data.TradeID)] = data
            if self.callback:
                self.callback(data)

    def ready(self, timeout=3):
        request_time = time.time()
        while not self.api.ready and time.time() < request_time + timeout:
            time.sleep(1)
        if self.api.ready and self.api.user_login:
            self.front_id = self.api.user_login.FrontID
            self.session_id = self.api.user_login.SessionID
            self.max_order_ref = int(self.api.user_login.MaxOrderRef)
            print 'trade agent initial succeed'
            return True
        else:
            print 'trade agent inital failed'
            return False

    def load_instruments(self, instrument_ids):
        for instrument_id in instrument_ids:
            self.query_instrument(instrument_id)
            time.sleep(1.1)

    def _get_results(self, timeout, request_id):
        self.qry_cond.acquire()
        self.qry_results[request_id] = {
            'results': [],
            'status': False,
        }
        request_time = time.time()
        while not self.qry_results[request_id][
                'status'] and request_time + timeout > time.time():
            self.qry_cond.wait(0.1)
        self.qry_cond.release()
        ret = self.qry_results.pop(request_id, {})
        if not ret.get('status'):
            return None
        return ret.get('results')

    def query_settlement_info(self, timeout=3):
        req = ApiStruct.QrySettlementInfo(BrokerID=self.broker_id,
                                          InvestorID=self.user_id)
        self.request_id += 1
        request_id = self.request_id
        ret = self.api.ReqQrySettlementInfo(req, request_id)
        if ret != 0:
            print 'query settlement info failed', ret
            return None
        results = self._get_results(timeout, request_id)
        return results

    def query_settlement_info_confirm(self, timeout=3):
        req = ApiStruct.QrySettlementInfoConfirm(BrokerID=self.broker_id,
                                                 InvestorID=self.user_id)
        self.request_id += 1
        request_id = self.request_id
        ret = self.api.ReqQrySettlementInfoConfirm(req, request_id)
        if ret != 0:
            print 'query settlement info confirm failed', ret
            return None
        results = self._get_results(timeout, request_id)
        if results:
            print results[0]
            print 'settlement info already confirmed.'
        return results

    def query_order(self, timeout=3):
        req = ApiStruct.QryOrder(BrokerID=self.broker_id,
                                 InvestorID=self.user_id)
        self.request_id += 1
        request_id = self.request_id
        ret = self.api.ReqQryOrder(req, request_id)
        if ret != 0:
            print 'query order failed', ret
            return None
        results = self._get_results(timeout, request_id)
        if results is None:
            return None
        self.order_dict = {(result.ExchangeID, result.TraderID,
                            result.OrderLocalID): result
                           for result in results}
        return self.order_dict

    def query_trade(self, timeout=3):
        req = ApiStruct.QryTrade(BrokerID=self.broker_id,
                                 InvestorID=self.user_id)
        self.request_id += 1
        request_id = self.request_id
        ret = self.api.ReqQryTrade(req, request_id)
        if ret != 0:
            print 'query trade failed', ret
            return None
        results = self._get_results(timeout, request_id)
        if results is None:
            return None
        self.trade_dict = {(result.ExchangeID, result.TradeID): result
                           for result in results}
        return self.trade_dict

    def query_position(self, instrument_id=None, timeout=3):
        req = ApiStruct.QryInvestorPosition(BrokerID=self.broker_id,
                                            InvestorID=self.user_id)
        if instrument_id:
            req.InstrumentID = instrument_id
        self.request_id += 1
        request_id = self.request_id
        ret = self.api.ReqQryInvestorPosition(req, request_id)
        if ret != 0:
            print 'query position failed', ret
            return None
        results = self._get_results(timeout, request_id)
        if results is None:
            return None
        self.position_dict = {}
        for result in results:
            if result.InstrumentID is None or result.PosiDirection is None or result.Position is None:
                continue
            if (result.InstrumentID,
                    result.PosiDirection) not in self.position_dict:
                self.position_dict[(result.InstrumentID,
                                    result.PosiDirection)] = []
            self.position_dict[(result.InstrumentID,
                                result.PosiDirection)].append(result)
        return self.position_dict

    def query_position_detail(self, timeout=3):
        req = ApiStruct.QryInvestorPositionDetail(BrokerID=self.broker_id,
                                                  InvestorID=self.user_id)
        self.request_id += 1
        request_id = self.request_id
        ret = self.api.ReqQryInvestorPositionDetail(req, request_id)
        if ret != 0:
            print 'query position detail failed', ret
            return None
        results = self._get_results(timeout, request_id)
        return results

    def query_trading_account(self, timeout=3):
        req = ApiStruct.QryTradingAccount(BrokerID=self.broker_id,
                                          InvestorID=self.user_id)
        self.request_id += 1
        request_id = self.request_id
        ret = self.api.ReqQryTradingAccount(req, request_id)
        if ret != 0:
            print 'query trading account failed', ret
            return None
        results = self._get_results(timeout, request_id)
        if not results:
            return None
        self.trading_account = results[0]
        return self.trading_account

    def query_instrument(self, instrument_id, timeout=3):
        req = ApiStruct.QryInstrument(InstrumentID=instrument_id)
        self.request_id += 1
        request_id = self.request_id
        ret = self.api.ReqQryInstrument(req, request_id)
        if ret != 0:
            print 'query instrument {} failed'.format(instrument_id), ret
            return None
        results = self._get_results(timeout, request_id)
        if not results:
            return None
        self.instrument_dict[instrument_id] = results[0]
        print results[0]
        return results[0]

    def query_depth_market_data(self, instrument_id, timeout=3):
        req = ApiStruct.QryDepthMarketData(InstrumentID=instrument_id)
        self.request_id += 1
        request_id = self.request_id
        ret = self.api.ReqQryDepthMarketData(req, request_id)
        if ret != 0:
            print 'query {} depth market data failed'.format(
                instrument_id), ret
            return None
        results = self._get_results(timeout, request_id)
        if not results:
            return None
        return results[0]

    def settlement_info_confirm(self, timeout=3):
        req = ApiStruct.SettlementInfoConfirm(BrokerID=self.broker_id,
                                              InvestorID=self.user_id)
        self.request_id += 1
        request_id = self.request_id
        ret = self.api.ReqSettlementInfoConfirm(req, request_id)
        if ret != 0:
            print 'settlement info confirm failed', ret
            return False
        results = self._get_results(timeout, request_id)
        if results is None:
            return False
        print results[0]
        print 'settlement info confirm succeed.'
        return True

    def _make_order(self, instrument_id, direction, volume):
        self.max_order_ref += 1
        return ApiStruct.InputOrder(
            BrokerID=self.broker_id,
            InvestorID=self.user_id,
            InstrumentID=instrument_id,
            OrderRef=str(self.max_order_ref),
            Direction=direction,
            CombOffsetFlag=ApiStruct.OF_Open,
            CombHedgeFlag=ApiStruct.HF_Speculation,
            VolumeTotalOriginal=volume,
            ContingentCondition=ApiStruct.CC_Immediately,
            VolumeCondition=ApiStruct.VC_AV,
            MinVolume=1,
            ForceCloseReason=ApiStruct.FCC_NotForceClose,
            IsAutoSuspend=0,
            UserForceClose=0)

    def _set_price(self, order, direction, limit_price, time_condition,
                   market_data):
        price = None
        if limit_price == 0:
            if market_data:
                if direction == ApiStruct.D_Buy:
                    # 取对手价
                    price = market_data.AskPrice1
                else:
                    # 取对手价
                    price = market_data.BidPrice1
                time_condition = ApiStruct.TC_IOC
        else:
            price = limit_price

        if price is None:
            return False

        order.LimitPrice = price
        order.TimeCondition = time_condition
        return True

    def buy(self,
            instrument_id,
            volume,
            limit_price=0,
            time_condition=ApiStruct.TC_GFD,
            market_data=None):
        direction = ApiStruct.D_Buy
        order = self._make_order(instrument_id, direction, volume)
        order.OrderPriceType = ApiStruct.OPT_LimitPrice
        if not self._set_price(order, direction, limit_price, time_condition,
                               market_data):
            return None

        self.request_id += 1
        self.api.ReqOrderInsert(order, self.request_id)
        return order

    def sell(self,
             instrument_id,
             volume,
             limit_price=0,
             time_condition=ApiStruct.TC_GFD,
             market_data=None):
        direction = ApiStruct.D_Sell
        order = self._make_order(instrument_id, direction, volume)
        order.CombOffsetFlag = ApiStruct.OF_Close
        order.OrderPriceType = ApiStruct.OPT_LimitPrice
        if not self._set_price(order, direction, limit_price, time_condition,
                               market_data):
            return None

        self.request_id += 1
        self.api.ReqOrderInsert(order, self.request_id)
        return order

    def sell_today(self,
                   instrument_id,
                   volume,
                   limit_price=0,
                   time_condition=ApiStruct.TC_GFD,
                   market_data=None):
        direction = ApiStruct.D_Sell
        order = self._make_order(instrument_id, direction, volume)
        order.CombOffsetFlag = ApiStruct.OF_CloseToday
        order.OrderPriceType = ApiStruct.OPT_LimitPrice
        if not self._set_price(order, direction, limit_price, time_condition,
                               market_data):
            return None

        self.request_id += 1
        self.api.ReqOrderInsert(order, self.request_id)
        return order

    def short(self,
              instrument_id,
              volume,
              limit_price=0,
              time_condition=ApiStruct.TC_GFD,
              market_data=None):
        direction = ApiStruct.D_Sell
        order = self._make_order(instrument_id, direction, volume)
        order.OrderPriceType = ApiStruct.OPT_LimitPrice
        if not self._set_price(order, direction, limit_price, time_condition,
                               market_data):
            return None

        self.request_id += 1
        self.api.ReqOrderInsert(order, self.request_id)
        return order

    def cover(self,
              instrument_id,
              volume,
              limit_price=0,
              time_condition=ApiStruct.TC_GFD,
              market_data=None):
        direction = ApiStruct.D_Buy
        order = self._make_order(instrument_id, direction, volume)
        order.CombOffsetFlag = ApiStruct.OF_Close
        order.OrderPriceType = ApiStruct.OPT_LimitPrice
        if not self._set_price(order, direction, limit_price, time_condition,
                               market_data):
            return None

        self.request_id += 1
        self.api.ReqOrderInsert(order, self.request_id)
        return order

    def cover_today(self,
                    instrument_id,
                    volume,
                    limit_price=0,
                    time_condition=ApiStruct.TC_GFD,
                    market_data=None):
        direction = ApiStruct.D_Buy
        order = self._make_order(instrument_id, direction, volume)
        order.CombOffsetFlag = ApiStruct.OF_CloseToday
        order.OrderPriceType = ApiStruct.OPT_LimitPrice
        if not self._set_price(order, direction, limit_price, time_condition,
                               market_data):
            return None

        self.request_id += 1
        self.api.ReqOrderInsert(order, self.request_id)
        return order

    def open(self,
             instrument_id,
             posi_direction,
             volume,
             limit_price=0,
             time_condition=ApiStruct.TC_GFD,
             market_data=None):
        if posi_direction == ApiStruct.PD_Long:
            return self.buy(instrument_id, volume, limit_price, time_condition,
                            market_data)
        elif posi_direction == ApiStruct.PD_Short:
            return self.short(instrument_id, volume, limit_price,
                              time_condition, market_data)

    def close(self,
              instrument_id,
              posi_direction,
              volume,
              limit_price=0,
              time_condition=ApiStruct.TC_GFD,
              market_data=None):
        if posi_direction == ApiStruct.PD_Long:
            return self.sell(instrument_id, volume, limit_price,
                             time_condition, market_data)
        elif posi_direction == ApiStruct.PD_Short:
            return self.cover(instrument_id, volume, limit_price,
                              time_condition, market_data)

    def close_today(self,
                    instrument_id,
                    posi_direction,
                    volume,
                    limit_price=0,
                    time_condition=ApiStruct.TC_GFD,
                    market_data=None):
        if posi_direction == ApiStruct.PD_Long:
            return self.sell_today(instrument_id, volume, limit_price,
                                   time_condition, market_data)
        elif posi_direction == ApiStruct.PD_Short:
            return self.cover_today(instrument_id, volume, limit_price,
                                    time_condition, market_data)

    def refer(self, instrument_id, exchange_id, order_sys_id):
        order_action = ApiStruct.InputOrderAction(
            BrokerID=self.broker_id,
            InvestorID=self.user_id,
            InstrumentID=instrument_id,
            ActionFlag=ApiStruct.AF_Delete,
            ExchangeID=exchange_id,
            OrderSysID=order_sys_id)
        self.request_id += 1
        return self.api.ReqOrderAction(order_action, self.request_id) == 0

    def refer_local(self, instrument_id, order_ref):
        order_action = ApiStruct.InputOrderAction(
            BrokerID=self.broker_id,
            InvestorID=self.user_id,
            InstrumentID=instrument_id,
            ActionFlag=ApiStruct.AF_Delete,
            FrontID=self.front_id,
            SessionID=self.session_id,
            OrderRef=order_ref)
        self.request_id += 1
        return self.api.ReqOrderAction(order_action, self.request_id) == 0
示例#55
0
class PooledDB:
    """Pool for DB-API 2 connections.

	After you have created the connection pool, you can use
	connection() to get pooled, steady DB-API 2 connections.

	"""

    version = __version__

    def __init__(self,
                 creator,
                 mincached=0,
                 maxcached=0,
                 maxshared=0,
                 maxconnections=0,
                 blocking=False,
                 maxusage=None,
                 setsession=None,
                 failures=None,
                 *args,
                 **kwargs):
        """Set up the DB-API 2 connection pool.

		creator: either an arbitrary function returning new DB-API 2
			connection objects or a DB-API 2 compliant database module
		mincached: initial number of idle connections in the pool
			(0 means no connections are made at startup)
		maxcached: maximum number of idle connections in the pool
			(0 or None means unlimited pool size)
		maxshared: maximum number of shared connections
			(0 or None means all connections are dedicated)
			When this maximum number is reached, connections are
			shared if they have been requested as shareable.
		maxconnections: maximum number of connections generally allowed
			(0 or None means an arbitrary number of connections)
		blocking: determines behavior when exceeding the maximum
			(if this is set to true, block and wait until the number of
			connections decreases, otherwise an error will be reported)
		maxusage: maximum number of reuses of a single connection
			(0 or None means unlimited reuse)
			When this maximum usage number of the connection is reached,
			the connection is automatically reset (closed and reopened).
		setsession: optional list of SQL commands that may serve to prepare
			the session, e.g. ["set datestyle to ...", "set time zone ..."]
		failures: an optional exception class or a tuple of exception classes
			for which the connection failover mechanism shall be applied,
			if the default (OperationalError, InternalError) is not adequate
		args, kwargs: the parameters that shall be passed to the creator
			function or the connection constructor of the DB-API 2 module

		"""
        try:
            threadsafety = creator.threadsafety
        except AttributeError:
            try:
                if not callable(creator.connect):
                    raise AttributeError
            except AttributeError:
                threadsafety = 2
            else:
                threadsafety = 0
        if not threadsafety:
            raise NotSupportedError("Database module is not thread-safe.")
        self._creator = creator
        self._args, self._kwargs = args, kwargs
        self._maxusage = maxusage
        self._setsession = setsession
        self._failures = failures
        if mincached is None:
            mincached = 0
        if maxcached is None:
            maxcached = 0
        if maxconnections is None:
            maxconnections = 0
        if maxcached:
            if maxcached < mincached:
                maxcached = mincached
            self._maxcached = maxcached
        else:
            self._maxcached = 0
        if threadsafety > 1 and maxshared:
            self._maxshared = maxshared
            self._shared_cache = []  # the cache for shared connections
        else:
            self._maxshared = 0
        if maxconnections:
            if maxconnections < maxcached:
                maxconnections = maxcached
            if maxconnections < maxshared:
                maxconnections = maxshared
            self._maxconnections = maxconnections
        else:
            self._maxconnections = 0
        self._idle_cache = []  # the actual pool of idle connections
        self._condition = Condition()
        if not blocking:

            def wait():
                raise TooManyConnections

            self._condition.wait = wait
        self._connections = 0
        # Establish an initial number of idle database connections:
        idle = [self.dedicated_connection() for i in range(mincached)]
        while idle:
            idle.pop().close()

    def steady_connection(self):
        """Get a steady, unpooled DB-API 2 connection."""
        return connect(self._creator, self._maxusage, self._setsession,
                       self._failures, True, *self._args, **self._kwargs)

    def connection(self, shareable=True):
        """"Get a steady, cached DB-API 2 connection from the pool.

		If shareable is set and the underlying DB-API 2 allows it,
		then the connection may be shared with other threads.

		"""
        if shareable and self._maxshared:
            self._condition.acquire()
            try:
                while (not self._shared_cache and self._maxconnections
                       and self._connections >= self._maxconnections):
                    self._condition.wait()
                if len(self._shared_cache) < self._maxshared:
                    # shared cache is not full, get a dedicated connection
                    try:  # first try to get it from the idle cache
                        con = self._idle_cache.pop(0)
                    except IndexError:  # else get a fresh connection
                        con = self.steady_connection()
                    con = SharedDBConnection(con)
                    self._connections += 1
                else:  # shared cache full or no more connections allowed
                    self._shared_cache.sort()  # least shared connection first
                    con = self._shared_cache.pop(0)  # get it
                    con.share()  # increase share of this connection
                # put the connection (back) into the shared cache
                self._shared_cache.append(con)
                self._condition.notify()
            finally:
                self._condition.release()
            con = PooledSharedDBConnection(self, con)
        else:  # try to get a dedicated connection
            self._condition.acquire()
            try:
                while (self._maxconnections
                       and self._connections >= self._maxconnections):
                    self._condition.wait()
                # connection limit not reached, get a dedicated connection
                try:  # first try to get it from the idle cache
                    con = self._idle_cache.pop(0)
                except IndexError:  # else get a fresh connection
                    con = self.steady_connection()
                con = PooledDedicatedDBConnection(self, con)
                self._connections += 1
            finally:
                self._condition.release()
        return con

    def dedicated_connection(self):
        """Alias for connection(shareable=False)."""
        return self.connection(False)

    def unshare(self, con):
        """Decrease the share of a connection in the shared cache."""
        self._condition.acquire()
        try:
            con.unshare()
            shared = con.shared
            if not shared:  # connection is idle,
                try:  # so try to remove it
                    self._shared_cache.remove(con)  # from shared cache
                except ValueError:
                    pass  # pool has already been closed
        finally:
            self._condition.release()
        if not shared:  # connection has become idle,
            self.cache(con.con)  # so add it to the idle cache

    def cache(self, con):
        """Put a dedicated connection back into the idle cache."""
        self._condition.acquire()
        try:
            if not self._maxcached or len(self._idle_cache) < self._maxcached:
                # the idle cache is not full, so put it there, but
                try:  # before returning the connection back to the pool,
                    con.rollback()  # perform a rollback
                    # in order to prevent uncommited actions from being
                    # unintentionally commited by some other thread
                except Exception:
                    # if an error occurs (no transaction, not supported)
                    pass  # then it will be silently ignored
                self._idle_cache.append(con)  # append it to the idle cache
            else:  # if the idle cache is already full,
                con.close()  # then close the connection
            self._connections -= 1
            self._condition.notify()
        finally:
            self._condition.release()

    def close(self):
        """Close all connections in the pool."""
        self._condition.acquire()
        try:
            while self._idle_cache:  # close all idle connections
                con = self._idle_cache.pop(0)
                try:
                    con.close()
                except Exception:
                    pass
            if self._maxshared:  # close all shared connections
                while self._shared_cache:
                    con = self._shared_cache.pop(0).con
                    try:
                        con.close()
                    except Exception:
                        pass
                    self._connections -= 1
            self._condition.notifyAll()
        finally:
            self._condition.release()

    def __del__(self):
        """Delete the pool."""
        try:
            self.close()
        except Exception:
            pass
示例#56
0
    test = 1
    # 2. queue import Queue
    # deque 是线程安全的
    detail_url_queue = Queue(maxsize=1000)
    # 线程同步: 保证调用同一变量的先后执行
    from threading import Lock, RLock, Condition

    lock = Lock()  #
    lock.acquire()
    lock.release()
    lock = RLock()  # 可重入的锁, 可连续调用acquire
    lock.acquire()
    lock.release()
    cond = Condition()
    cond.acquire()
    cond.wait()
    cond.notify()
    cond.release()

    # 进程间通信
    from multiprocessing import Process, Queue, Pool, Manager, Pipe

    # Queue
    # 要使用multiprocessing, 或是manager的
    # multiprocessing中的queue不能用于pool进程池
    # pool中的进程间通信需要使用manager中的queue
    queue = Queue()
    manager = Manager()
    m_queue = manager.Queue(10)
    # pipe pipe的性能高于queue pipe只能适用于两个进程
    receive_pipe, send_pipe = Pipe()
示例#57
0
class Dispatcher(InstrumentedThread):
    def __init__(self, timeout=10):
        super().__init__(name='Dispatcher')
        self._timeout = timeout
        self._msg_type_handlers = {}
        self._in_queue = queue.PriorityQueue()
        self._send_message = {}
        self._send_last_message = {}
        self._message_information = {}
        self._condition = Condition()
        self._dispatch_timers = {}
        self._priority = {}

    def _get_dispatch_timer(self, tag):
        if tag not in self._dispatch_timers:
            self._dispatch_timers[tag] = COLLECTOR.timer(
                'dispatch_execution_time', tags={"handler": tag},
                instance=self)
        return self._dispatch_timers[tag]

    def add_send_message(self, connection, send_message):
        """Adds a send_message function to the Dispatcher's
        dictionary of functions indexed by connection.

        Args:
            connection (str): A locally unique identifier
                provided by the receiver of messages.
            send_message (fn): The method that should be called
                by the dispatcher to respond to messages which
                arrive via connection.
        """
        self._send_message[connection] = send_message
        LOGGER.debug("Added send_message function "
                     "for connection %s", connection)

    def add_send_last_message(self, connection, send_last_message):
        """Adds a send_last_message function to the Dispatcher's
        dictionary of functions indexed by connection.

        Args:
            connection (str): A locally unique identifier
                provided by the receiver of messages.
            send_last_message (fn): The method that should be called
                by the dispatcher to respond to messages which
                arrive via connection, when the connection should be closed
                after the message has been sent.
        """
        self._send_last_message[connection] = send_last_message
        LOGGER.debug("Added send_last_message function "
                     "for connection %s", connection)

    def remove_send_message(self, connection):
        """Removes a send_message function previously registered
        with the Dispatcher.

        Args:
            connection (str): A locally unique identifier provided
                by the receiver of messages.
        """
        if connection in self._send_message:
            del self._send_message[connection]
            LOGGER.debug("Removed send_message function "
                         "for connection %s", connection)
        else:
            LOGGER.debug("Attempted to remove send_message "
                         "function for connection %s, but no "
                         "send_message function was registered",
                         connection)

    def remove_send_last_message(self, connection):
        """Removes a send_last_message function previously registered
        with the Dispatcher.

        Args:
            connection (str): A locally unique identifier provided
                by the receiver of messages.
        """
        if connection in self._send_last_message:
            del self._send_last_message[connection]
            LOGGER.debug("Removed send_last_message function "
                         "for connection %s", connection)
        else:
            LOGGER.debug("Attempted to remove send_last_message "
                         "function for connection %s, but no "
                         "send_last_message function was registered",
                         connection)

    def dispatch(self, connection, message, connection_id):
        if message.message_type in self._msg_type_handlers:
            priority = self._priority.get(message.message_type, Priority.LOW)
            message_id = _gen_message_id()
            self._message_information[message_id] = (
                connection,
                connection_id,
                message,
                _ManagerCollection(
                    self._msg_type_handlers[message.message_type])
            )
            self._in_queue.put_nowait((priority, message_id))

            queue_size = self._in_queue.qsize()
            if queue_size > 10:
                LOGGER.debug("Dispatch incoming queue size: %s", queue_size)
        else:
            LOGGER.info("received a message of type %s "
                        "from %s but have no handler for that type",
                        get_enum_name(message.message_type),
                        connection_id)

    def add_handler(self, message_type, handler, executor, priority=None):
        if not isinstance(handler, Handler):
            raise TypeError("%s is not a Handler subclass" % handler)
        if message_type not in self._msg_type_handlers:
            self._msg_type_handlers[message_type] = [
                _HandlerManager(executor, handler)]
        else:
            self._msg_type_handlers[message_type].append(
                _HandlerManager(executor, handler))

        if priority is not None:
            self._priority[message_type] = priority

    def set_message_priority(self, message_type, priority):
        self._priority[message_type] = priority

    def _process(self, message_id):
        _, connection_id, \
            message, collection = self._message_information[message_id]

        try:
            handler_manager = next(collection)
        except IndexError:
            # IndexError is raised if done with handlers
            del self._message_information[message_id]
            return

        timer_tag = type(handler_manager.handler).__name__
        timer_ctx = self._get_dispatch_timer(timer_tag).time()

        def do_next(result):
            timer_ctx.stop()
            try:
                self._determine_next(message_id, result)
            except Exception:  # pylint: disable=broad-except
                LOGGER.exception(
                    "Unhandled exception while determining next")

        handler_manager.execute(connection_id, message.content, do_next)

    def _determine_next(self, message_id, result):
        if result is None:
            LOGGER.debug('Ignoring None handler result, likely due to an '
                         'unhandled error while executing the handler')
            return

        if result.status == HandlerStatus.DROP:
            del self._message_information[message_id]

        elif result.status == HandlerStatus.PASS:
            self._process(message_id)

        elif result.status == HandlerStatus.RETURN_AND_PASS:
            connection, connection_id, \
                original_message, _ = self._message_information[message_id]

            if result.message_out and result.message_type:
                message = validator_pb2.Message(
                    content=result.message_out.SerializeToString(),
                    correlation_id=original_message.correlation_id,
                    message_type=result.message_type)
                try:
                    self._send_message[connection](msg=message,
                                                   connection_id=connection_id)
                except KeyError:
                    LOGGER.info("Can't send message %s back to "
                                "%s because connection %s not in dispatcher",
                                get_enum_name(message.message_type),
                                connection_id,
                                connection)
                self._process(message_id)
            else:
                LOGGER.error("HandlerResult with status of RETURN_AND_PASS "
                             "is missing message_out or message_type")

        elif result.status == HandlerStatus.RETURN:
            connection, connection_id,  \
                original_message, _ = self._message_information[message_id]

            del self._message_information[message_id]

            if result.message_out and result.message_type:
                message = validator_pb2.Message(
                    content=result.message_out.SerializeToString(),
                    correlation_id=original_message.correlation_id,
                    message_type=result.message_type)
                try:
                    self._send_message[connection](msg=message,
                                                   connection_id=connection_id)
                except KeyError:
                    LOGGER.info("Can't send message %s back to "
                                "%s because connection %s not in dispatcher",
                                get_enum_name(message.message_type),
                                connection_id,
                                connection)
            else:
                LOGGER.error("HandlerResult with status of RETURN "
                             "is missing message_out or message_type")

        elif result.status == HandlerStatus.RETURN_AND_CLOSE:
            connection, connection_id,  \
                original_message, _ = self._message_information[message_id]

            del self._message_information[message_id]

            if result.message_out and result.message_type:
                message = validator_pb2.Message(
                    content=result.message_out.SerializeToString(),
                    correlation_id=original_message.correlation_id,
                    message_type=result.message_type)
                try:
                    LOGGER.warning(
                        "Sending hang-up in reply to %s to connection %s",
                        get_enum_name(original_message.message_type),
                        connection_id)
                    self._send_last_message[connection](
                        msg=message,
                        connection_id=connection_id)
                except KeyError:
                    LOGGER.info("Can't send last message %s back to "
                                "%s because connection %s not in dispatcher",
                                get_enum_name(message.message_type),
                                connection_id,
                                connection)
            else:
                LOGGER.error("HandlerResult with status of RETURN_AND_CLOSE "
                             "is missing message_out or message_type")
        with self._condition:
            if not self._message_information:
                self._condition.notify()

    def run(self):
        while True:
            try:
                _, msg_id = self._in_queue.get()
                if msg_id == -1:
                    break
                self._process(msg_id)
            except Exception:  # pylint: disable=broad-except
                LOGGER.exception("Unhandled exception while dispatching")

    def stop(self):
        self._in_queue.put_nowait((Priority.HIGH, -1))

    def block_until_complete(self):
        """Blocks until no more messages are in flight,
        useful for unit tests.
        """
        with self._condition:
            if self._message_information:
                self._condition.wait()
示例#58
0
class QueueMessageHandler(MessageHandler, Thread):
    def __init__(self, previous_handler):
        Thread.__init__(self)
        MessageHandler.__init__(self, previous_handler)
        self.daemon = True
        self.queue = []
        self.c = Condition()
        self.alive = True
        self.start()

    def handle_message(self, msg):
        with self.c:
            should_notify = len(self.queue) == 0
            self.queue.append(msg)
            if len(self.queue) > self.queue_length:
                del self.queue[0:len(self.queue) - self.queue_length]
            if should_notify:
                self.c.notify()

    def transition(self):
        if self.throttle_rate == 0 and self.queue_length == 0:
            self.finish()
            return MessageHandler(self)
        elif self.queue_length == 0:
            self.finish()
            return ThrottleMessageHandler(self)
        else:
            with self.c:
                if len(self.queue) > self.queue_length:
                    del self.queue[0:len(self.queue) - self.queue_length]
                self.c.notify()
            return self

    def finish(self):
        """ If throttle was set to 0, this pushes all buffered messages """
        # Notify the thread to finish
        with self.c:
            self.alive = False
            self.c.notify()

        self.join()

    def run(self):
        while self.alive:
            with self.c:
                while self.alive and (self.time_remaining() > 0
                                      or len(self.queue) == 0):
                    if len(self.queue) == 0:
                        self.c.wait()
                    else:
                        self.c.wait(self.time_remaining())
                if self.alive and self.time_remaining() == 0 and len(
                        self.queue) > 0:
                    try:
                        MessageHandler.handle_message(self, self.queue[0])
                    except:
                        traceback.print_exc(file=sys.stderr)
                    del self.queue[0]
        while self.time_remaining() == 0 and len(self.queue) > 0:
            try:
                MessageHandler.handle_message(self, self.queue[0])
            except:
                traceback.print_exc(file=sys.stderr)
            del self.queue[0]
示例#59
0
class System:
    def __init__(self, config):
        self.tasks = deque()

        self.config = config

        self.scheduler = Scheduler(self)
        self._messages = deque()
        self.message_manager = MessageManager(self)
        self.events_manager = SystemEventsManager(self)
        self._a_context = None
        self._a_ref = None
        self._result = None
        self._trigger = Condition()

        self._is_running = False
        self._active_schedulers = 0

    @property
    def is_running(self):
        return self._is_running

    def start(self):
        self._is_running = True

        self.events_manager.start()
        self.scheduler.start()

        with self._trigger:
            self._trigger.wait()

    def stop(self):
        self._is_running = False

        with self.events_manager.trigger:
            self.events_manager.trigger.notify_all()

        self.events_manager.join()

        with self.scheduler.trigger:
            self.scheduler.trigger.notify_all()

        self.scheduler.join()

    def spawn(self, actor_class, *args, **kwargs):
        return self.scheduler.spawn_actor(
            SpawnRequest(
                parent_ref=ActorRef(0, 0),
                actor_class=actor_class,
                args=args,
                kwargs=kwargs
            )
        )

    def send(self, target_ref, message):
        self.message_manager.send(
            Message(
                sender_ref=ActorRef(0, 0),
                target_ref=target_ref,
                message=message
            )
        )

    def ask(self, target_ref, message, timeout=None):
        context = uuid.uuid4().int

        self._a_ref = target_ref
        self._a_context = context
        self._result = None

        self.message_manager.send(
            Message(
                sender_ref=ActorRef(0, 0),
                target_ref=target_ref,
                message=message,
                context=context
            )
        )
        with self._trigger:
            self._trigger.wait(timeout)

        result = self._result
        self._result = None

        return result
示例#60
0
class MainThreadController(object):
    def __init__(self):
        if not is_main_thread():
            raise Exception(
                "A controller can only be created from the main thread")
        self._main_wake_up = Condition()
        self._main_wake_up.acquire()
        self._obj = None
        self._caller_wake_up = None
        self._args = None
        self._kwargs = None
        self._return = None
        self._quit = False

    def invoke(self, obj, *args, **kwargs):
        if is_main_thread():
            return obj.__call__(*args, **kwargs)
        released = False
        self._main_wake_up.acquire()
        try:
            self._caller_wake_up = Condition()
            with self._caller_wake_up:
                self._obj = obj
                self._args = args
                self._kwargs = kwargs
                # tell the main thread to wake up
                self._main_wake_up.notify_all()
                self._main_wake_up.release()
                released = True
                self._caller_wake_up.wait()
                # tell the main thread that we received the result:
                self._caller_wake_up.notify_all()
                ret = self._return
                return ret
        finally:
            self._obj = None
            self._args = None
            self._kwargs = None
            self._caller_wake_up = None
            self._return = None
            if not released:
                self._main_wake_up.release()

    def quit(self):
        self._main_wake_up.acquire()
        try:
            self._quit = True
            self._main_wake_up.notify_all()
        finally:
            self._main_wake_up.release()

    def run(self):
        if not is_main_thread():
            raise Exception("run can only be called from the main thread!")
        from . import signals

        def signal_handler(signal, frame):
            self._quit = True

        signals.add_sigint_handler(signal_handler)
        while True:
            try:
                self._main_wake_up.wait(1.0)
            except KeyboardInterrupt:
                self._quit = True
            if self._quit:
                return
            elif self._caller_wake_up is None:
                # we timed out
                continue
            with self._caller_wake_up:
                self._return = self._obj.__call__(*self._args, **self._kwargs)
                self._caller_wake_up.notify_all()
                # wait for the calling thread to confirm it received the result
                while not self._caller_wake_up.wait(1.0):
                    if self._quit:
                        return