Example #1
0
def manage(worker, n_threads = 10, catalogs = download.catalogs):
    'Manage a bunch of worker threads, and generate their results.'
    # Download in random order
    args = []
    for catalog in catalogs:
        for dataset in download.datasets(catalog):
            args.append((catalog, dataset))
    random.shuffle(args)

    read_queue = Queue()
    for a in args:
        read_queue.put(a)

    write_queue = Queue()
    threads = []
    for i in range(n_threads):
        threads.append(Thread(target = worker, args = (read_queue,write_queue)))
        threads[-1].start()

    while not (read_queue.empty() and write_queue.empty() and set(t.is_alive() for t in threads) == {False}):
        try:
            x = write_queue.get_nowait()
        except Empty:
            pass
        else:
            yield x
Example #2
0
    def generer_graph(self):
        """Génère le graph des sorties.

        Le graph est un dictionnaire comprenant en clé un tuple
        (origine, destination) et en valeur la liste des sorties
        nécessaires pour s'y rendre.

        L'algorithme Dijkstra est utilisé.

        """
        graph = {}
        aretes = {}
        sorties = {}

        # On remplit le chemin avec toutes les liaisons
        for salle in self.salles.values():
            origine = salle.mnemonic
            aretes[origine] = []
            for sortie in salle.sorties:
                destination = sortie.salle_dest.mnemonic
                aretes[origine].append(destination)
                sorties[origine, destination] = sortie.nom

        # Population des chemins dans le graph
        for origine in range(1, len(self.salles) + 1):
            origine = str(origine)
            for destination in range(1, len(self.salles) + 1):
                destination = str(destination)
                if origine == destination:
                    continue

                frontier = Queue()
                frontier.put(origine)
                origines = {origine: None}
                while not frontier.empty():
                    actuel = frontier.get()
                    if actuel == destination:
                        break

                    for fils in aretes[actuel]:
                        if fils not in origines:
                            frontier.put(fils)
                            origines[fils] = actuel


                # Recherche la liste des sorties
                parent = Queue()
                parent.put(destination)
                chemin = []
                while not parent.empty():
                    actuel = parent.get()
                    precedent = origines[actuel]
                    sortie = sorties[precedent, actuel]
                    chemin.insert(0, sortie)
                    if precedent != origine:
                        parent.put(precedent)

                graph[origine, destination] = chemin

        self.graph = graph
Example #3
0
def main(filename):
    username = input('User: '******': script started')

    task_queue = Queue()

    new_task = Task(task_queue)
    new_task.read(filename)

    while not task_queue.empty():
        host_queue = Queue()
        log_queue = Queue()
        task_block = task_queue.get()
        for host_ip in task_block['hosts']:
            host_queue.put(host_ip)
        config_list = task_block['tasks']

        while not host_queue.empty():

            for x in range(50):
                worker = ExecuteTask(host_queue, config_list, log_queue, username, password, mode=task_block['mode'])
                worker.daemon = True
                worker.start()

        host_queue.join()

        while not log_queue.empty():
            st = log_queue.get()
            print(st)

    print(time_stamp(), ': script stopped')
class FakeConsumer(BrightsideConsumer):
    """The fake consumer is a test double for a consumer wrapping messaging middleware.
        To use it, just add BrighsideMessage(s) to the queue and the call receive to pop
        then off the stack. Purge, will clean the queue
    """
    def __init__(self):
        self._queue = Queue()
        self._acknowledged_message = None

    def acknowledge(self, message):
        self._acknowledged_message = message

    def has_acknowledged(self, message):
        return (self._acknowledged_message is not None) and (self._acknowledged_message.id == message.id)

    @property
    def queue(self):
        return self._queue

    def purge(self):
        while not self._queue.empty():
            self._queue.get(block=False)

        assert self._queue.empty()

    def receive(self, timeout: int):
        return self._queue.get(block=True,timeout=timeout)
 def zigzagLevelOrder(self, root):
     if root == None:
         return []
     else:
         q = Queue()
         q.put(root)
         q.put("#")
         levelOrderTraversal = []
         level = []
         levelNo = 0
         while q.empty() == False:
             node = q.get()
             if node == "#":
                 if q.empty() == False:
                     q.put("#")
                 if levelNo == 0 or levelNo % 2 == 0:    
                     levelOrderTraversal.append(level)
                 else:
                     levelOrderTraversal.append(level[::-1])
                 level = []
                 levelNo += 1
             else:
                 level.append(node.val)
                 if node.left:
                     q.put(node.left)
                 if node.right:
                     q.put(node.right)
                     
         return levelOrderTraversal
Example #6
0
def check_if_pingable(ip_list):
    """
    Check what IP addresses from the list are reachable
    :param ip_list: list of IP addresses to ping
    :return: list of reachable and unreachable hosts
    """

    queue_ip = Queue()
    queue_reachable = Queue()
    queue_unreachable = Queue()

    reachable = []
    unreachable = []

    for ip in ip_list:
        queue_ip.put(ip)

    while not queue_ip.empty():
        for x in range(50):
            worker = Ping(queue_ip, queue_reachable, queue_unreachable)
            worker.daemon = True
            worker.start()

    queue_ip.join()

    while not queue_reachable.empty():
        reachable.append(queue_reachable.get())

    while not queue_unreachable.empty():
        unreachable.append(queue_unreachable.get())

    return reachable, unreachable
 def levelOrderBottom(self, root):
     if root == None:
         return []
     else:
         q = Queue()
         q.put(root)
         q.put("#")
         levelOrderTraversal = []
         level = []
         stack = []
         
         while q.empty() == False:
             node = q.get()
             if node == "#":
                 if q.empty() == False:
                     q.put("#")
                 stack.append(level)
                 level = []
             else:
                 level.append(node.val)
                 if node.left:
                     q.put(node.left)
                 if node.right:
                     q.put(node.right)
                     
         while stack:
             levelOrderTraversal.append(stack.pop())
             
         return levelOrderTraversal
Example #8
0
class GridMap():
    """
    Map class 
	"""
    def __init__(self, width, height, size):
        self.width = width
        self.height = height
        self.cellSize = size
        self.__map = np.zeros((self.width,self.height))
        self.__drawQueue = Queue()
        
        self.linear = Linear()

        
    def update(self, x, y, theta):
        self.__map[x][y] = 1
        self.__drawQueue.put((x,y))
        
        #theta = round(theta, 2)
        #self.linear.update(x, y, theta)
        
    def getPoints(self):
        points = []
        while not self.__drawQueue.empty():
            points.append(self.__drawQueue.get())
        
        return points
        
    def drawPoints(self, drawFunc):
        """
        draw the points by callback function drawFunc
        """
        if self.__drawQueue.empty():
            return 
        drawFunc(self.__drawQueue.get())
Example #9
0
class InsnPool:
    def __init__(self, proc, max_threads = None):
        self.numThreads = 0
        if max_threads is None:
            max_threads = multiprocessing.cpu_count() * 5
        self.max_threads = max_threads
        self.queue = Queue()
        self.proc = proc

        # A lock used to wake up the polling thread if asynchronous
        self.lock = threading.Semaphore()

    def query(self, insn):
        self.queue.put(insn)

    def signal(self):
        self.numThreads -= 1
        if self.has_cycled():
            # Wake up the polling thread if batch is done processing
            self.lock.release()

    def has_cycled(self):
        return self.numThreads == 0

    def has_finished(self):
        return self.has_cycled() and self.queue.empty()

    def poll_all(self, blocking = True, callback = None):
        if not blocking and callback is None:
            raise ValueError("If called in a non blocking way you must provide a callback function.")

        def poll_all_impl():
            while not self.has_finished():  # Wait for all threads to process
                self.poll()
                self.lock.acquire()  # Pauses the current thread and waits till batch is processed
            if callback:
                callback()

        if blocking:
            poll_all_impl()
        else:
            thread = threading.Thread(daemon = True, target = poll_all_impl)
            thread.start()

    def poll(self):
        # Batch processing, we only start a new batch when the old
        # one has finished. It is only possible to jump to a location once.
        if not self.has_cycled(): return

        locations = set()
        while self.numThreads < self.max_threads and not self.queue.empty():
            insn = self.queue.get_nowait()
            if insn.pc in locations:
                continue
            else:
                locations.add(insn.pc)

            insn.start()
            self.numThreads += 1
Example #10
0
class SLIPEncoderDecoder:
	""" This class acts as an intermediary to a socket, it sends data packets in
	    SLIP format and processes incoming (in SLIP) to normal datagrams
	"""
	def __init__(self, socket):
		self.socket = socket
		self.inPacketBuf = Queue()
		self.workingRetrievalBuffer = None

	def sendPacket(self, packet):
		tempSLIPBuffer = bytearray()  
		tempSLIPBuffer.append(SLIP_END)  
		for i in packet:  
			if i == SLIP_END:  
				tempSLIPBuffer.append(SLIP_ESC)  
				tempSLIPBuffer.append(SLIP_ESC_END)  
			elif i == SLIP_ESC:  
				tempSLIPBuffer.append(SLIP_ESC)  
				tempSLIPBuffer.append(SLIP_ESC_ESC)  
			else:  
				tempSLIPBuffer.append(i)  
		tempSLIPBuffer.append(SLIP_END)
		self.socket.send(bytes(tempSLIPBuffer))

	def getPacket(self):
		if self.inPacketBuf.empty():
			self.retrieveData()
		return self.inPacketBuf.get()

	def dataWaiting(self):
		return not self.inPacketBuf.empty()

	def retrieveData(self):
		workingBuf = self.workingRetrievalBuffer
		while self.inPacketBuf.empty():
			newData = self.socket.recv(SOCKET_BUF_SIZE)
			newData = iter(newData)
			for i in newData:
				if i == SLIP_END:
					if workingBuf is None:
						workingBuf = bytearray()
					else:
						self.inPacketBuf.put(bytes(workingBuf))
						workingBuf = None
					
				elif i == SLIP_ESC:
					i = newData.__next__()
					if i == SLIP_ESC_END:
						workingBuf.append(SLIP_ESC)
					elif i == SLIP_ESC_ESC:
						workingBuf.append(SLIP_ESC)
					else:
						raise(SLIPException("Unexpected byte %x following ESCAPE character"%i))
				else:
					workingBuf.append(i)
		self.workingRetrievalBuffer = workingBuf
Example #11
0
class Display:
    def __init__(self, rows, columns, ports=None):
        self.matrix = []
        self.ports = ports
        self.rows = rows
        self.columns = columns
        for column in range(0, self.columns):
            column = []
            for row in range(0, self.rows):
                column.append(0)
            self.matrix.append(column)
        self.buffer = Queue()
        self.separator = [0 for _ in range(0, self.rows)]

    def write(self, text):
        for char in text:
            if not self.buffer.empty():
                self.buffer.put(self.separator)
            character = Character(char)
            matrix = character.matrix
            columns = len(matrix[0])
            rows = len(matrix)
            for column_index in range(0, columns):
                column = []
                for row_index in range(rows, 0, -1):
                    column.append(matrix[row_index - 1][column_index])
                self.buffer.put(column)
        for _ in range(0, self.columns):
            self.buffer.put(self.separator)

    def print_current_buffer(self):
        while not self.buffer.empty():
            column = self.buffer.get()
            for item in column:
                print('{}'.format(item), end=' ')
                time.sleep(0.05)
            print()

    def show(self):
        if self.ports:
            port_adapter.setup(self.ports)
        while not self.buffer.empty():
            column = self.buffer.get()
            for pos in range(1, self.columns):
                self.matrix[pos - 1] = self.matrix[pos]
            self.matrix[self.columns - 1] = column
            if self.ports:
                for column_index in range(0, self.columns):
                    for row_index in range(0, self.rows):
                        item = self.matrix[column_index][row_index]
                        port = self.ports[column_index][row_index]
                        port_adapter.output(port, item)
            for item in column:
                print('{}'.format(item), end=' ')
            print()
            time.sleep(0.2)
Example #12
0
class Pipe:
    """
    In-memory file-like object that may be read and written by co-operating
    threads.
    Note: no readline() etc. is provided - only read and write.
    Note: read() will block waiting for more to be read, until close() is
    called.
    """
    def __init__(self):
        self.buffer = Queue()
        self.closed = False

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def write(self, s):
        for ch in s:
            self.buffer.put(ch)

    def pending(self):
        return not self.buffer.empty()

    def read(self, size=None):
        ret = ""
        count = 0
        while size is None or len(ret) < size:
            if self.buffer.empty() and self.closed:
                break
            if count > 50:
                raise Exception("Timed out waiting for read that never ended.")
            count += 1
            ret += self.buffer.get(timeout=0.1)
        return ret

    def readline(self):
        ret = ""
        count = 0
        c = "x"
        while c not in ("\n", ""):
            if self.buffer.empty() and self.closed:
                break
            if count > 50:
                raise Exception("Timed out waiting for read that never ended.")
            count += 1
            c = self.buffer.get(timeout=0.1)
            ret += c
        return ret

    def flush(self):
        pass

    def close(self):
        self.closed = True
Example #13
0
File: uart.py Project: flaviut/d16i
class D16Uart():

    def __init__(self):
        self.tx_fifo = Queue(maxsize=8)
        self.rx_fifo = Queue(maxsize=8)
        self.tx_thread = Thread(target=self.tx_worker, daemon=True)
        self.tx_thread.start()
        self.rx_thread = Thread(target=self.rx_worker, daemon=True)
        self.rx_thread.start()
        self.rx_overrun = 0
        self.rx_lock = Lock()

    def read_uart_data(self) -> int:
        c = self.rx_fifo.get(block=True)
        return int(c)

    def end(self):
        while not self.tx_fifo.empty():
            pass

    def write_uart_data(self, data: int):
        self.tx_fifo.put(data)

    def read_uart_status(self) -> int:
        with self.rx_lock:
            tx_free = bool(self.tx_fifo.not_full)
            tx_empty = bool(self.tx_fifo.empty())
            rx_ready = bool(self.rx_fifo.not_empty)
            rx_overrun = bool(self.rx_overrun)
            self.rx_overrun = 0

        return tx_free | tx_empty << 1 | rx_ready << 2 | rx_overrun << 3

    def read_uart_baud(self) -> int:
        return 0

    def write_uart_baud(self, data: int):
        pass

    def tx_worker(self):
        while True:
            c = self.tx_fifo.get(block=True)
            print(chr(c), end='')
            self.tx_fifo.task_done()

    def rx_worker(self):
        while True:
            try:
                getch = _find_getch()
                c = getch()
                self.rx_fifo.put(ord(c[0]), block=False)
            except Queue.QueueFull:
                with self.rx_lock:
                    self.rx_overrun = True
	def test_dequeue(self):
		q = Queue()
		q.enqueue('a')
		last = q.dequeue()
		self.assertEqual('a', last.data)
		self.assertTrue(q.empty())
		q.enqueue('a')
		q.enqueue('b')
		q.dequeue()
		self.assertEqual('b', q.first.data)
		self.assertFalse(q.empty())
Example #15
0
def handleProcessOutputBytewise(proc, fhOut, fhErr):
	q = Queue()
	errQ = Queue()
	
	t = Thread(target=enqueue_output, args=(proc.stdout, q))
	t.daemon = True # thread dies with the program
	t.start()
	
	tErr = Thread(target=enqueue_output, args=(proc.stderr, errQ))
	tErr.daemon = True # thread dies with the program
	tErr.start()
	
	writeEveryXSecs = 1.0 / 10
	lastWrite = time.time()
	
	chars = []
	errChars = []
	
	while 1:
		try:
			char = q.get(timeout=.1)
		except Empty:
			pass
		else: # got line
			chars.append(char)
		
		if not errQ.empty():
			try:
				while 1:
					errChar = errQ.get_nowait()
					errChars.append(errChar)
			except Empty:
				pass
		
		if time.time() - lastWrite > writeEveryXSecs:
			if chars:
				fhOut(''.join(chars))
				del chars[:]
				
			if errChars:
				fhErr(''.join(errChars))
				del errChars[:]
			
			QtGui.QApplication.instance().processEvents()
			
			# Process terminated?
			exitCode = proc.poll()
			
			if (exitCode is not None) and (not t.is_alive()) and (not tErr.is_alive()) and q.empty() and errQ.empty():
				return exitCode
    def get_points(self, street_data):
        """Street data has a following structure:
        [("Street name", [1, 4, 5, 9]), ...]
        """

        to_process = Queue()
        processed = Queue()

        # Fill queue with data
        for street_name, houses in street_data:
            for house in houses:
                to_process.put((street_name, house))

        while True:
            if to_process.empty():
                break
            # It's better to use threading later
            house_data = to_process.get()
            if house_data is None:
                break
            street_name, house_number = house_data

            # Lookup in cache
            coord = self._cache.get_coordinate(street_name, house_number)
            if coord is not None:
                point = Point(longitude=coord[0],
                              latitude=coord[1],
                              description="")
                msg = "Cached value found for {} - {}: {}"
                logger.debug(msg.format(street_name, house_number, point))
                processed.put(point)
            else:
                search_str = " ".join(["Минск", street_name, house_number])

                try:
                    location = self._locator.geocode(search_str)
                    point = Point(longitude=location.longitude,
                                  latitude=location.latitude,
                                  description="")
                    logger.debug("Putting {} in cache".format(search_str))
                    self._cache.put_coordinate(street_name, house_number,
                                               point.longitude, point.latitude)
                    processed.put(point)
                except Exception:
                    err_msg = "While processing {} error occured"
                    logger.error(err_msg.format(search_str), exc_info=True)
        while not processed.empty():
            yield processed.get()
Example #17
0
    def track(self):
        queue = Queue()
        thread = Thread(target=self._update_status, args=(queue,))
        thread.start()

        widgets = ['Processing...', AnimatedMarker()]
        progress_indicator = ProgressBar(widgets=widgets, maxval=UnknownLength)
        progress_indicator.start()

        content = {}
        for indicator_count in itertools.count():
            if not queue.empty():
                content = queue.get()
                if isinstance(content, Exception):
                    raise content
                widgets[0] = self._get_message(content)
            progress_indicator.update(indicator_count)
            if content.get('processed'):
                break
            sleep(0.1)
        progress_indicator.finish()

        self.__content = content

        return content
Example #18
0
class ScraperThread(QThread):

    result_signal = pyqtSignal(dict)

    def __init__(self, parent=None):
        super(ScraperThread, self).__init__(parent)
        self._queue = Queue()
        self._stop = False

    def run(self):
        self._stop = False
        while not self._queue.empty() and not self._stop:
            processed_url = self._queue.get()
            result = requests.get(processed_url)
            self.result_signal.emit({'headers': result.headers})

    def clear_queue(self):
        self._queue = Queue()

    @property
    def queue(self):
        return self._queue

    @queue.setter
    def queue(self, urls):
        for url in urls:
            self._queue.put(url.strip())

    @property
    def stop(self):
        return self._stop

    @stop.setter
    def stop(self, stop):
        self._stop = stop
Example #19
0
def astar(maze, start, end):
    cost = lambda distance, current: distance + ((end[0] - current[0]) ** 2 \
                                     + (end[1] - current[1]) ** 2) ** 0.5

    visited = set()
    will_visit = set()

    queue = Queue()
    queue.put((start, list()))
    will_visit.add(start)

    while not queue.empty():
        current, path = queue.get()
        path.append(current)
        visited.add(current)

        if current == end:
            return path, len(path), len(visited)

        options = moves(maze, current)
        options.sort(key=partial(cost, len(path)))
        for potential in options:
            if potential in will_visit:
                continue

            will_visit.add(potential)
            queue.put((potential, path[:]))
Example #20
0
    def diagram(self):
        from graphviz import Digraph
        from queue import Queue
        diagram=Digraph(comment='The Trie')

        i=0

        diagram.attr('node', shape='circle')
        diagram.node(str(i), self.root.getValue())

        q=Queue()
        q.put((self.root, i))
        
        while not q.empty():

            node, parent_index=q.get()

            for child in node.getChildren():
                i+=1
                if child.getEnding():
                    diagram.attr('node', shape='doublecircle')
                    diagram.node(str(i), child.getValue())
                    diagram.attr('node', shape='circle')
                else:
                    diagram.node(str(i), child.getValue())
                diagram.edge(str(parent_index), str(i))
                q.put((child, i))

        o=open('trie_dot.gv', 'w')
        o.write(diagram.source)
        o.close()
        diagram.render('trie_dot.gv', view=True)
        'trie_dot.gv.pdf'
Example #21
0
 def ac3(self, var_index):
     from queue import Queue
     
     # Put arcs (X, Y) into queue
     # X = self.var_list[var_index]
     # for each unassigned variable Y push arc (X, Y) to queue
     queue = Queue()
     for i in range(len(self.assignment)):
         item_list = self.assignment[i]
         if not item_list:
             # here's variable Y = self.var_list[i], create arc: (Y, X)
             arc = Arc(i, var_index)
             queue.put(arc)
     
     # get each arch from queue and propagate
     while not queue.empty():
         arc = queue.get()
         # if for the arc it is possible to reduce domain of X
         if self.revise_ac3(arc):
             # if no value in domain of X, we reached a failure!
             if len(self.domain_list[arc.x]) < 1:
                 return False
             # Because domain of X has been updated propagate this to neighbors in future \
             # by pushing those arcs into queue
             for i in range(len(self.assignment)):
                 item_list = self.assignment[i]
                 if i != arc.x and not item_list:
                     arc = Arc(i, arc.x)
                     queue.put(arc)
     return True        
Example #22
0
def level_print(x):
    """
    Prints out binary tree nodes level by level.

    A tree is a linked data structure that can be regarded as a graph. Level by level
     tree traversal is performed using bread-first search-like algorithm. Level is
     printed out as soon as the frotnier had been discovered.

    Complexity: O(n)
    :param Node x: Starting node
    :return None: Prints to stdout
    """
    from queue import Queue
    Q = Queue()
    Q.put(x)
    nodes_on_curr_level = 1
    nodes_on_next_level = 0
    while not Q.empty():
        node = Q.get()
        print(str(node.key), end=' ')
        nodes_on_curr_level -= 1
        if node.left is not None:
            Q.put(node.left)
            nodes_on_next_level += 1
        if node.right is not None:
            Q.put(node.right)
            nodes_on_next_level += 1
        if nodes_on_curr_level == 0:
            nodes_on_curr_level = nodes_on_next_level
            nodes_on_next_level = 0
            print()  # Done with the level, starting next line
Example #23
0
class BlockingInProcessChannel(InProcessChannel):

    def __init__(self, *args, **kwds):
        # type: (object, object) -> object
        super(BlockingInProcessChannel, self).__init__(*args, **kwds)
        self._in_queue = Queue()

    def call_handlers(self, msg):
        self._in_queue.put(msg)

    def get_msg(self, block=True, timeout=None):
        """ Gets a message if there is one that is ready. """
        if timeout is None:
            # Queue.get(timeout=None) has stupid uninteruptible
            # behavior, so wait for a week instead
            timeout = 604800
        return self._in_queue.get(block, timeout)

    def get_msgs(self):
        """ Get all messages that are currently ready. """
        msgs = []
        while True:
            try:
                msgs.append(self.get_msg(block=False))
            except Empty:
                break
        return msgs

    def msg_ready(self):
        """ Is there a message that has been received? """
        return not self._in_queue.empty()
Example #24
0
 def estAPorte(self, entite, tuile):
    """VERIFIE SI UNE TUILE EST A PORTE D'UNE ENTITE"""
    self.deselect()
    tuileAnalyser = Queue()
    tuileAnalyser.put(entite.parent)
    i = 0
    closed = list()
    closed.append(entite.parent)
    trouve = False
    while not(tuileAnalyser.empty()) and i < entite.porte*(2*(entite.porte + 1)) and not(trouve):
       n = tuileAnalyser.get()
       territoireVoisin, nbVoisin = self.getVoisinComptage(n)
       i += nbVoisin
       for iVoisin in territoireVoisin:
          if iVoisin == tuile:
             trouve = True
             break
          if not(iVoisin in closed):
             tuileAnalyser.put(iVoisin)
             self.selectedTuile.append(iVoisin)
             closed.append(iVoisin)
             #i += 1
          else:
             i-=1
    print(trouve)
    return trouve
Example #25
0
def solve():
    
    (node, edge) = read(int)
    graph = [[] for _ in range(node)]
    cost = [-1 for _ in range(node)]
    for i in range(edge):
        (fro, to) = read(int)
        fro, to = fro-1, to-1
        if to not in graph[fro]:
            graph[fro].append(to)
            graph[to].append(fro)

    start = read(int)[0] - 1
    cost[start] = 0
    q = Queue()
    q.put(start)
    while not q.empty():
        select = q.get()
        for i in graph[select]:
            if cost[i] == -1:
                cost[i] = cost[select] + 6
                q.put(i)

    for i in range(node):
        if start == i:
            continue
        end = ' ' if i < node -1 else '\n'
        print(cost[i], end=end)
Example #26
0
class TTS(object):

    def __init__(self):
        self.clients = []
        self.voice_choices = []
        self.queue = Queue()
        if 'win32com' not in globals():
            return
        Thread(target=self._background).start()

    def _background(self):
        pythoncom.CoInitialize()
        self.tts = win32com.client.Dispatch("SAPI.SpVoice")
        self.voices = self.tts.GetVoices()
        self.voices = [self.voices.Item(i) for i in range(self.voices.Count)]
        self.voice_choices = [dict(desc=v.GetDescription(), id=i) for i, v in enumerate(self.voices)]
        self.tts.Rate = -5
        self.event_sink = win32com.client.WithEvents(self.tts, TTSEventSink)
        self.event_sink.setTTS(self)
        while True:
            self._speak(self.queue.get(True))

    def _speak(self, text):
        self._speaking = True
        self.tts.Skip("Sentence", INT32_MAX)
        self.tts.Speak(text, SVSFlagsAsync)
        self._pump()

    def speak(self, text):
        while True:
            try:
                self.queue.get(False)
            except Empty:
                break
        self.queue.put(text)

    def get_voice_choices(self):
        return self.voice_choices

    def set_voice(self, voice_id):
        self.tts.Voice = self.voices[voice_id]

    def handle_event(self, event, *args):
        msg = dict(type=event)
        if event == 'end':
            self._speaking = False
        elif event == 'word':
            msg.update(dict(char_pos=args[0], length=args[1]))
        msg = json.dumps(msg)
        for c in self.clients:
            c.write_message(msg)

    def _pump(self):
        skipped = False
        while self._speaking:
            if not skipped and not self.queue.empty():
                self.tts.Skip("Sentence", INT32_MAX)
                skipped = True
            pythoncom.PumpWaitingMessages()
            time.sleep(0.05)
	def findFirstTri(vert, lastMFace, mirrorMesh, searchData) :
		#
		#	Itterates through all faces in the mirror mesh and tests for intersection, first intersecting adjacent face connected to the initial search face is returned
		#
		
		faceQueue = Queue()
		#Tag keep track on what face we have tested / in queue. False for not tested
		taggedFaces = []
		
		#Start testing from the initial face!
		lastMFace.tag = True
		faceQueue.put_nowait(lastMFace)
		taggedFaces.append(lastMFace)
		
		while not faceQueue.empty() :
			face = faceQueue.get_nowait()
			mDat = MirrorMesh.triIntersection(vert.co, face)
			if mDat is not None and mDat._intersected :
				searchData[vert.index].setMirror(mDat)
				break #we found an intersecting tri
			#Queue connected faces
			MirrorMesh.queueConnectedFaces(face, mirrorMesh, faceQueue, taggedFaces)
		
		for f in taggedFaces :
			f.tag = False
Example #28
0
    def classify(self, data: list, labels: list):
        if len(self._t) == 0:
            raise ValueError("You must train it")

        qu = Queue()
        qu.put((self._t, 0))
        while not qu.empty():
            node, deep = qu.get()
            name = self._getName(node)
            childTree = node.get(name)
            index = self._findData(name, labels)
            if index != -1:
                # k =  str(data[index] )
                k = data[index]
                value = childTree.get(k)
                if isinstance(value, dict):
                    qu.put((value, deep + 1))
                elif value != None:
                    return value
                else:
                    return list(childTree.values())[0]
            else:
                for value in childTree.values():
                    if isinstance(value, dict):
                        qu.put((value, deep + 1))
                    else:
                        return value
def BFS(Vs,Ve,outMaze):
    Q = Queue()
    Q.put(Vs)
    visitied = outMaze
    Vs.hadVisited = True
    visitied[0][0] = Vs
    distination = (Node(1,0,0),Node(0,1,0),Node(-1,0,0),Node(0,-1,0))
    count = 1
    while Q.empty() != True:
        Vn = Q.get()
        for i in range(0,4):
            Vw = Node.NodeAdd(Vn,distination[i])
            if  isEndNote(Vw,Ve):
                Vw.hadVisited = True
                Vw.count = count
                visitied[Vw.x][Vw.y] = Vw
                print("The way is:")
                print("(%d,%d)"%(Ve.x,Ve.y))
                showWayNums(Vw,Vs,visitied)

                return True
            if  isValid(Vw):
                Vw.isWall = visitied[Vw.x][Vw.y].isWall
                if Vw.isWall != True:
                    if  visitied[Vw.x][Vw.y].hadVisited != True:
                        Q.put(Vw)
                        Vw.hadVisited = True
                        Vw.count = count
                        visitied[Vw.x][Vw.y] = Vw
                        count += 1

    else:
        print("WTF! No result!")
        return False
Example #30
0
 def selectTerritoireEntite(self, tuile):
    """SELECTIONNE LES TERRITOIRES Q'UNE ENTITE PEUT PARCOURIR ET AFFICHE A LA FENETRE LA ZONE DE SELECTION"""
    self.deselect()
    entite = tuile.getEntite()
    if entite[0].canMoove():
       tuileAnalyser = Queue()
       tuileAnalyser.put(tuile)
       i = 0
       closed = list()
       closed.append(tuile)
       while not(tuileAnalyser.empty()) and i < entite[0].pa*(2*(entite[0].pa + 1)):
          n = tuileAnalyser.get()
          territoireVoisin, nbVoisin = self.getVoisinComptage(n)
          i += nbVoisin
          for iVoisin in territoireVoisin:
             if not(iVoisin in closed):
                tuileAnalyser.put(iVoisin)
                self.selectedTuile.append(iVoisin)
                closed.append(iVoisin)
                #i += 1
             else:
                i-=1
       self.selectionType = "Entite"
       self.selectTuile(self.selectedTuile, "case selection entite.gif")
       self.selectionType = "Entite"
    def person_following(self, run_py_eyes, cvQueue: Queue):

        # Frame is considered to be 600x600 (after resize)
        # Below are variables to set what we consider center and in-range
        radiusInRangeLowerBound, radiusInRangeUpperBound = 80, 120
        centerRightBound, centerLeftBound = 400, 200
        radiusTooCloseLowerLimit = 250

        # Creating a window for later use
        cv2.namedWindow('result')
        cv2.resizeWindow('result', 600, 600)

        # Variables to 'smarten' the following procedure (see usage below)
        objectSeenOnce = False  # Object has never been seen before
        leftOrRightLastSent = None  # Keep track of whether we sent left or right last

        # TODO delete this block when done
        start = time.time()
        num_frames = 0

        # PyEyes Setup... Note that I've done some performance tinkering with pygame. Instead of redrawing the entire
        # frame on each iteration, I only turn the previously drawn pupils of the last frame white (to match the
        # background) and draw the new pupils. I also enable some performance enhancements and disable some unneeded
        # functionality. This kept out frame rate at a reliable level.
        if run_py_eyes:
            screen = pygame.display.set_mode((1024, 600), DOUBLEBUF)
            screen.set_alpha(None)  # Not needed, so set it to this for performance improvement
            surface = pygame.display.get_surface()
            # Draw the eyeballs (without pupils) and white background that we'll use for the rest of the process
            screen.fill(self.RGB_WHITE)  # Fill PyGame screen (white background)
            pygame.draw.circle(surface, self.RGB_BLACK, (256, 300), 255, 15)
            pygame.draw.circle(surface, self.RGB_BLACK, (768, 300), 255, 15)
            pygame.display.flip()
            rects = []

        while True:
            # Reset to default pupil coordinates and width (in case no object is found on this iteration)
            leftx, lefty, width = 256, 350, 0

            # Grab frame - break if we don't get it (some unknown error occurred)
            frame = self.vs.read()
            if frame is None:
                print("ERROR - frame read a NONE")
                break

            # TODO delete this block when done (if you want)
            end = time.time()
            seconds = end - start
            num_frames += 1
            fps = 0 if (seconds == 0) else num_frames / seconds

            # Resize the frame, blur it, and convert it to the HSV color space
            frame = imutils.resize(frame, width=600)
            blurred = cv2.GaussianBlur(frame, (5, 5), 0)
            hsv = cv2.cvtColor(blurred, cv2.COLOR_BGR2HSV)

            # construct a mask for the desired color, then perform a series of dilations and erosions to
            # remove any small blobs left in the mask
            mask = cv2.inRange(hsv, self.greenLower, self.greenUpper)
            mask = cv2.erode(mask, None, iterations=2)  # TODO: these were 3 or 5 before (more small blob removal)
            mask = cv2.dilate(mask, None, iterations=2)

            # find contours in the mask and initialize the current (x, y) center of the ball
            cnts = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cnts = imutils.grab_contours(cnts)

            commandString = None

            # Only proceed if at least one contour was found
            # If nothing is found, then look around OR send the STOP command to halt movement (depends on situation)
            if len(cnts) == 0:
                # If we haven't seen the object before, then we'll stay halted until we see one. If we HAVE seen the
                # object before, then we'll move in the direction (left or right) that we did most recently
                if not objectSeenOnce:
                    self.send_serial_command(Direction.STOP, b'h')
                    commandString = "STOP"
                else:  # Object has been seen before
                    if leftOrRightLastSent is not None:
                        if leftOrRightLastSent == Direction.RIGHT:
                            self.send_serial_command(Direction.RIGHT, b'r')
                            commandString = "SEARCHING: GO RIGHT"
                        elif leftOrRightLastSent == Direction.LEFT:
                            self.send_serial_command(Direction.LEFT, b'l')
                            commandString = "SEARCHING: GO LEFT"
                    else:  # variable hasn't been set yet (seems unlikely), but default to left
                        self.send_serial_command(Direction.LEFT, b'l')
                        commandString = "DEFAULT SEARCHING: GO LEFT"

            elif len(cnts) > 0:  # Else if we are seeing some object...

                # Find the largest contour in the mask and use it to compute the minimum enclosing circle and centroid
                c = max(cnts, key=cv2.contourArea)
                ((x, y), radius) = cv2.minEnclosingCircle(c)
                M = cv2.moments(c)
                center = (int(M["m10"] / M["m00"]), int(M["m01"] / M["m00"]))

                filteredPtsRadius = [radius]

                # Only consider it to a valid object if it's big enough - else it could be some other random thing
                if filteredPtsRadius[0] <= 25:
                    # TODO this is the same code as the block above - I should extract these out to a function
                    # If we haven't seen the object before, then we'll stay halted until we see one
                    # If we HAVE seen the object before, then we'll move in the direction (left or right) that we did
                    # most recently
                    if not objectSeenOnce:
                        self.send_serial_command(Direction.STOP, b'h');
                        commandString = "STOP";
                    else:  # Object has been seen before
                        if leftOrRightLastSent is not None:
                            if leftOrRightLastSent == Direction.RIGHT:
                                self.send_serial_command(Direction.RIGHT, b'r');
                                commandString = "SEARCHING: GO RIGHT"
                            elif leftOrRightLastSent == Direction.LEFT:
                                self.send_serial_command(Direction.LEFT, b'l');
                                commandString = "SEARCHING: GO LEFT"
                        else:  # variable hasn't been set yet (seems unlikely), but default to left
                            self.send_serial_command(Direction.LEFT, b'l');
                            commandString = "DEFAULT SEARCHING: GO LEFT"

                else:  # This object isn't super small ... we should proceed with the tracking

                    # Set objectSeenOnce to True if isn't already
                    if not objectSeenOnce:
                        objectSeenOnce = True

                    #  draw the circle on the frame TODO consider removing this eventually - could speed things up (barely)
                    cv2.circle(frame, (int(x), int(y)), int(filteredPtsRadius[0]), (0, 255, 255), 2)
                    cv2.circle(frame, center, 5, (0, 0, 255), -1)

                    filteredPtsX = [center[0]]
                    filteredPtsY = [center[1]]

                    # Update PyGame Values
                    if run_py_eyes:
                        lefty = int(filteredPtsY[0] + 100)
                        leftx = int(abs(filteredPtsX[0] - 600) / 2 + 106)
                        width = int(filteredPtsRadius[0])

                    # Check radius and center of the blob to determine robot action
                    # What actions should take priority?
                    # 1. Moving Backward (only if it's super close)
                    # 2. Moving Left/Right
                    # 3. Moving Forward/Backward
                    # Why? Because if we're too close any turn would be too extreme. We need to take care of that first

                    if filteredPtsRadius[0] > radiusTooCloseLowerLimit:
                        commandString = "MOVE BACKWARD - TOO CLOSE TO TURN"
                        self.send_serial_command(Direction.BACKWARD, b'b')
                    elif filteredPtsX[0] > centerRightBound:
                        commandString = "GO RIGHT"
                        self.send_serial_command(Direction.RIGHT, b'r')
                        if leftOrRightLastSent != Direction.RIGHT:
                            leftOrRightLastSent = Direction.RIGHT
                    elif filteredPtsX[0] < centerLeftBound:
                        commandString = "GO LEFT"
                        self.send_serial_command(Direction.LEFT, b'l')
                        if leftOrRightLastSent != Direction.LEFT:
                            leftOrRightLastSent = Direction.LEFT
                    elif filteredPtsRadius[0] < radiusInRangeLowerBound:
                        commandString = "MOVE FORWARD"
                        self.send_serial_command(Direction.FORWARD, b'f')
                    elif filteredPtsRadius[0] > radiusInRangeUpperBound:
                        commandString = "MOVE BACKWARD"
                        self.send_serial_command(Direction.BACKWARD, b'b')
                    elif radiusInRangeLowerBound < filteredPtsRadius[0] < radiusInRangeUpperBound:
                        commandString = "STOP MOVING - IN RANGE"
                        self.send_serial_command(Direction.STOP, b'h')

                    # Put text on the camera image to display on the screen
                    cv2.putText(frame, 'center coordinate: (' + str(filteredPtsX[0]) + ',' + str(filteredPtsY[0]) + ')',
                                (10, 60), self.font, 0.5, (200, 255, 155), 1, cv2.LINE_AA)
                    cv2.putText(frame, 'filtered radius: (' + str(filteredPtsRadius[0]) + ')', (10, 90), self.font, 0.5,
                                (200, 255, 155), 1, cv2.LINE_AA)

            # The below steps are run regardless of whether we see a valid object or not ...

            # Show FPS (TODO delete this later)
            cv2.putText(frame, commandString, (10, 30), self.font, 0.5, (200, 255, 155), 1, cv2.LINE_AA)
            cv2.putText(frame, 'FPS: (' + str(fps) + ')', (10, 120), self.font, 0.5,
                        (200, 255, 155), 1, cv2.LINE_AA)

            # show webcam video with object drawings on the screen
            cv2.imshow("result", frame)

            if run_py_eyes:
                # If our coordinates are out of bounds for the eyes, then cap them at their correct bounds
                if leftx < 106: leftx = 106
                if leftx > 406: leftx = 406
                if lefty < 150: lefty = 150
                if lefty > 450: lefty = 450
                # rightx = leftx + 400 + 112 - width
                # if rightx < 568:
                #     rightx = 568
                # if rightx > 968:
                #     rightx = 968

                # Note that the eyes could be made to get a little crossed eyed (close together) when you get very
                # close to the robot. It's not hard to do, but I didn't include it here (that's why the above lines
                # are commented out).

                # Draw left pupil
                rects.append(pygame.draw.circle(surface, self.RGB_BLACK, (leftx, lefty), self.EYEBALL_RADIUS, 0))
                # Draw right pupil
                rects.append(
                    pygame.draw.circle(surface, self.RGB_BLACK, (leftx + 500 + 12, lefty), self.EYEBALL_RADIUS, 0))
                # Update the display so our changes show up
                pygame.display.update(rects)
                # Save the left and right pupil circles so that we can remove them on the next iteration (instead of
                # clearing the whole display and redrawing it all (which is expensive))
                rects = [pygame.draw.circle(surface, self.RGB_WHITE, (leftx, lefty), self.EYEBALL_RADIUS, 0),
                         pygame.draw.circle(surface, self.RGB_WHITE, (leftx + 500 + 12, lefty), self.EYEBALL_RADIUS, 0)]

            # Close application on 'q' key press, or if the queue is not empty (there's some command to respond to).
            key = cv2.waitKey(1) & 0xFF
            if (key == ord("q")) or (not cvQueue.empty()):
                self.send_serial_command(Direction.STOP, b'h')
                # We've been requested to leave ...
                # Don't destroy everything - just destroy cv2 windows ... webcam still runs
                cv2.destroyAllWindows()
                if run_py_eyes:
                    pygame.display.quit()
                    pygame.quit()
                break
class BST:
    def __init__(self):
        self.root = None
        self.queuy = Queue()

    def insert_node_rescusrion(self, root, node):
        if root == None:
            root = node
        elif node.data < root.data:
            root.left = self.insert_node_rescusrion(root.left, node)
        else:
            root.right = self.insert_node_rescusrion(root.right, node)

        return root

    def insert_node(self, node):
        if self.root == None:
            self.root = node
        else:
            temp = self.root
            while temp != None:
                if node.data <= temp.data and temp.left == None:
                    temp.left = node
                    break
                elif node.data < temp.data:
                    temp = temp.left
                elif node.data > temp.data and temp.right == None:
                    temp.right = node
                    break
                else:
                    temp = temp.right

    def insert(self, data):
        """ insert node """
        node = Node()
        node.data = data
        self.insert_node(node)

    def insert_recursion(self, data):
        """ insert node using recusrion"""
        node = Node()
        node.data = data
        self.root = self.insert_node_rescusrion(self.root, node)

    def max(self, root):
        """ find max element in the tree"""
        if root == None:
            raise "no nodes in the tree"
        while root.right != None:
            root = root.right
        return root

    def find_max(self):
        max = self.max(self.root)
        return max.data

    def min(self, root):
        if root == None:
            raise "no nodes in the tree"
        while root.right != None:
            root = root.right
        return root

    def find_min(self):
        min = self.min(self.root)
        return min.data

    def max_recursion(self, root):
        if root == None:
            raise "no nodes in the tree"
        elif root.right == None:
            return root.data
        else:
            return self.max_recursion(root.right)

    def find_max_recursion(self):
        """ find max element in the tree using recursion"""
        max = self.max_recursion(self.root)
        return max

    def recusrive_search(self, root, data):
        if root == None:
            return False, root
        elif root.data == data:
            return True, root
        elif root.data < data:
            return self.recusrive_search(root.right, data)
        else:
            return self.recusrive_search(root.left, data)

    def search(self, data):
        """ search for element in the tree (return bool)"""
        result = self.recusrive_search(self.root, data)
        return result[0]

    def recursive_height(self, node):
        if node == None:
            return -1
        else:
            return max(self.height(node.left), self.height(node.right)) + 1

    def height(self, node):
        if node == None:
            return -1
        counter_right = 0
        counter_left = 0
        temp1 = node
        temp2 = node
        while temp1.right != None:
            temp1 = temp1.right
            counter_right += 1

        while temp2.left != None:
            temp2 = temp2.left
            counter_left += 1
        return max(counter_left, counter_right)

    def depth(self, node):
        if self.root == node:
            return 0
        temp = self.root
        counter = 0
        while temp != node:
            if temp.data < node.data:
                temp = temp.right
                counter += 1
            elif temp.data > node.data:
                temp = temp.left
                counter += 1
        return counter

    def find_height(self, data):
        """find heigh of a node,(hegiht is number of edges in the longest path
        from a node to a leaf node)
        """
        find_node = self.recusrive_search(self.root, data)[1]
        height = self.recursive_height(find_node)
        if height == -1:
            return "node doesnot exist"
        return height

    def tree_height(self):
        """ find the heigt of the tree"""
        height = self.height(self.root, self.root.data)
        return height

    def find_depth(self, data):
        """ depth of a node is the number of edges from the root to the node"""
        find_node = self.recusrive_search(self.root, data)[1]
        depth = self.depth(find_node)
        return depth

    def _delete(self, root, data):
        if root == None:
            return root  # note found
        elif root.data > data:
            root.left = self._delete(root.left, data)
        elif root.data < data:
            root.right = self._delete(root.right, data)
        else:
            if root.left == None:  # only right child or no childern(leaf node)
                temp = root
                root = root.right
                del temp
            elif root.right == None:  # only left child
                temp = root
                root = root.left
                del temp
            else:
                min = self.min(root)
                root.data = min.data
                self._delete(root.right, min.data)
        return root

    def delete_node(self, data):
        self.root = self._delete(self.root, data)

    def display_bft(self):  # breadth first traversal
        root = self.root

        if root == None:
            print("Tree is empty")
            return

        self.queuy.put(root)
        while not self.queuy.empty():
            temp = self.queuy.get()
            print(temp.data)
            if temp.left != None:
                self.queuy.put(temp.left)
            if temp.right != None:
                self.queuy.put(temp.right)

    def _inorder(self, root):
        if root == None:
            return
        self._inorder(root.left)
        print(root.data)
        self._inorder(root.right)

    def _preorder(self, root):
        if root == None:
            return
        print(root.data)
        self._preorder(root.left)
        self._preorder(root.right)

    def _postorder(self, root):
        if root == None:
            return
        self._postorder(root.left)
        self._postorder(root.right)
        print(root.data)

    def dfs_perorder(self):  # depth first search traversal <root><left><right>
        self._preorder(self.root)

    def dfs_inorder(self):  # depth first search traversal <left><root><right>
        self._inorder(self.root)

    def dfs_postorder(
            self):  # depth first search traversal <left><right><root>
        self._postorder(self.root)
Example #33
0
class EventGenerator(object):
    def __init__(self, args=None):
        '''
        This object will allow you to generate and control eventgen.  It should be handed the parse_args object
        from __main__ and will hand the argument object to the config parser of eventgen5.  This will provide the
        bridge to using the old code with the newer style.  As things get moved from the config parser, this should
        start to control all of the configuration items that are global, and the config object should only handle the
        localized .conf entries.
        :param args: __main__ parse_args() object.
        '''
        self.stopping = False
        self.force_stop = False
        self.started = False
        self.completed = False
        self.config = None
        self.args = args

        self._setup_loggers(args=args)
        # attach to the logging queue
        self.logger.info("Logging Setup Complete.")

        self._generator_queue_size = getattr(self.args, 'generator_queue_size',
                                             500)
        if self._generator_queue_size < 0:
            self._generator_queue_size = 0
        self.logger.info("Set generator queue size:{}".format(
            self._generator_queue_size))

        if self.args and 'configfile' in self.args and self.args.configfile:
            self._load_config(self.args.configfile, args=args)

    def _load_config(self, configfile, **kwargs):
        '''
        This method will use a configfile and set self.confg as a processeded config object,
        kwargs will need to match eventgenconfig.py
        :param configfile:
        :return:
        '''
        # TODO: The old eventgen had strange cli args. We should probably update the module args to match this usage.
        new_args = {}
        if "args" in kwargs:
            args = kwargs["args"]
            outputer = [
                key for key in ["keepoutput", "devnull", "modinput"]
                if getattr(args, key)
            ]
            if len(outputer) > 0:
                new_args["override_outputter"] = outputer[0]
            if getattr(args, "count"):
                new_args["override_count"] = args.count
            if getattr(args, "interval"):
                new_args["override_interval"] = args.interval
            if getattr(args, "backfill"):
                new_args["override_backfill"] = args.backfill
            if getattr(args, "end"):
                new_args["override_end"] = args.end
            if getattr(args, "multiprocess"):
                new_args["threading"] = "process"
            if getattr(args, "generators"):
                new_args["override_generators"] = args.generators
            if getattr(args, "disableOutputQueue"):
                new_args["override_outputqueue"] = args.disableOutputQueue
            if getattr(args, "profiler"):
                new_args["profiler"] = args.profiler
            if getattr(args, "sample"):
                new_args["sample"] = args.sample
            if getattr(args, "verbosity"):
                new_args["verbosity"] = args.verbosity
        self.config = Config(configfile, **new_args)
        self.config.parse()
        self.args.multiprocess = True if self.config.threading == "process" else self.args.multiprocess
        self._reload_plugins()
        if "args" in kwargs and getattr(kwargs["args"], "generators"):
            generator_worker_count = kwargs["args"].generators
        else:
            generator_worker_count = self.config.generatorWorkers

        # TODO: Probably should destroy pools better so processes are cleaned.
        if self.args.multiprocess:
            self.kill_processes()
        self._setup_pools(generator_worker_count)

    def _reload_plugins(self):
        # Initialize plugins
        # Plugins must be loaded before objects that do work, otherwise threads and processes generated will not have
        # the modules loaded in active memory.
        try:
            self.config.outputPlugins = {}
            plugins = self._initializePlugins(
                os.path.join(FILE_PATH, 'lib', 'plugins', 'output'),
                self.config.outputPlugins, 'output')
            self.config.validOutputModes.extend(plugins)
            self._initializePlugins(
                os.path.join(FILE_PATH, 'lib', 'plugins', 'generator'),
                self.config.plugins, 'generator')
            plugins = self._initializePlugins(
                os.path.join(FILE_PATH, 'lib', 'plugins', 'rater'),
                self.config.plugins, 'rater')
            self.config._complexSettings['rater'] = plugins
        except Exception as e:
            self.logger.exception(str(e))

    def _load_custom_plugins(self, PluginNotLoadedException):
        plugintype = PluginNotLoadedException.type
        plugin = PluginNotLoadedException.name
        bindir = PluginNotLoadedException.bindir
        plugindir = PluginNotLoadedException.plugindir
        pluginsdict = self.config.plugins if plugintype in (
            'generator', 'rater') else self.config.outputPlugins
        # APPPERF-263: be picky when loading from an app bindir (only load name)
        self._initializePlugins(bindir, pluginsdict, plugintype, name=plugin)

        # APPPERF-263: be greedy when scanning plugin dir (eat all the pys)
        self._initializePlugins(plugindir, pluginsdict, plugintype)

    def _setup_pools(self, generator_worker_count):
        '''
        This method is an internal method called on init to generate pools needed for processing.

        :return:
        '''
        # Load the things that actually do the work.
        self._create_generator_pool()
        self._create_timer_threadpool()
        self._create_output_threadpool()
        self._create_generator_workers(generator_worker_count)

    def _create_timer_threadpool(self, threadcount=100):
        '''
        Timer threadpool is used to contain the timer object for each sample.  A timer will stay active
        until the end condition is met for the sample.  If there is no end condition, the timer will exist forever.
        :param threadcount: is how many active timers we want to allow inside of eventgen.  Default 100.  If someone
                            has over 100 samples, additional samples won't run until the first ones end.
        :return:
        '''
        self.sampleQueue = Queue(maxsize=0)
        num_threads = threadcount
        for i in range(num_threads):
            worker = Thread(target=self._worker_do_work,
                            args=(
                                self.sampleQueue,
                                self.loggingQueue,
                            ),
                            name="TimeThread{0}".format(i))
            worker.setDaemon(True)
            worker.start()

    def _create_output_threadpool(self, threadcount=1):
        '''
        the output thread pool is used for output plugins that need to control file locking, or only have 1 set thread
        to send all the data out of. This FIFO queue just helps make sure there are file collisions or write collisions.
        There's only 1 active thread for this queue, if you're ever considering upping this, don't.  Just shut off the
        outputQueue and let each generator directly output it's data.
        :param threadcount: is how many active output threads we want to allow inside of eventgen.  Default 1
        :return:
        '''
        # TODO: Make this take the config param and figure out what we want to do with this.
        if getattr(self, "manager", None):
            self.outputQueue = self.manager.Queue(maxsize=500)
        else:
            self.outputQueue = Queue(maxsize=500)
        num_threads = threadcount
        for i in range(num_threads):
            worker = Thread(target=self._worker_do_work,
                            args=(
                                self.outputQueue,
                                self.loggingQueue,
                            ),
                            name="OutputThread{0}".format(i))
            worker.setDaemon(True)
            worker.start()

    def _create_generator_pool(self, workercount=20):
        '''
        The generator pool has two main options, it can run in multiprocessing or in threading.  We check the argument
        from configuration, and then build the appropriate queue type.  Each time a timer runs for a sample, if the
        timer says it's time to generate, it will create a new generator plugin object, and place it in this queue.
        :param workercount: is how many active workers we want to allow inside of eventgen.  Default 10.  If someone
                            has over 10 generators working, additional samples won't run until the first ones end.
        :return:
        '''
        if self.args.multiprocess:
            self.manager = multiprocessing.Manager()
            if self.config.disableLoggingQueue:
                self.loggingQueue = None
            else:
                # TODO crash caused by logging Thread https://github.com/splunk/eventgen/issues/217
                self.loggingQueue = self.manager.Queue()
                self.logging_thread = Thread(target=self.logger_thread,
                                             args=(self.loggingQueue, ),
                                             name="LoggerThread")
                self.logging_thread.start()
            # since we're now in multiprocess, we need to use better queues.
            self.workerQueue = multiprocessing.JoinableQueue(
                maxsize=self._generator_queue_size)
            self.genconfig = self.manager.dict()
            self.genconfig["stopping"] = False
        else:
            self.workerQueue = Queue(maxsize=self._generator_queue_size)
            worker_threads = workercount
            if hasattr(self.config,
                       'outputCounter') and self.config.outputCounter:
                self.output_counters = []
                for i in range(workercount):
                    self.output_counters.append(OutputCounter())
                for i in range(worker_threads):
                    worker = Thread(target=self._generator_do_work,
                                    args=(self.workerQueue, self.loggingQueue,
                                          self.output_counters[i]))
                    worker.setDaemon(True)
                    worker.start()
            else:
                for i in range(worker_threads):
                    worker = Thread(target=self._generator_do_work,
                                    args=(self.workerQueue, self.loggingQueue,
                                          None))
                    worker.setDaemon(True)
                    worker.start()

    def _create_generator_workers(self, workercount=20):
        if self.args.multiprocess:
            import multiprocessing
            self.workerPool = []
            for worker in range(workercount):
                # builds a list of tuples to use the map function
                disable_logging = True if self.args and self.args.disable_logging else False
                process = multiprocessing.Process(
                    target=self._proc_worker_do_work,
                    args=(self.workerQueue, self.loggingQueue, self.genconfig,
                          disable_logging))
                self.workerPool.append(process)
                process.start()
        else:
            pass

    def _setup_loggers(self, args=None):
        if args and args.disable_logging:
            logger.handlers = []
            logger.addHandler(logging.NullHandler())
        self.logger = logger
        self.loggingQueue = None
        if args and args.verbosity:
            self.logger.setLevel(args.verbosity)
        # Set the default log level to ERROR when directly called Generator in tests
        if args.verbosity is None:
            self.logger.setLevel(logging.ERROR)

    def _worker_do_work(self, work_queue, logging_queue):
        while not self.stopping:
            try:
                item = work_queue.get(timeout=10)
                startTime = time.time()
                item.run()
                totalTime = time.time() - startTime
                if totalTime > self.config.interval and self.config.end != 1:
                    self.logger.warning(
                        "work took longer than current interval, queue/threading throughput limitation"
                    )
                work_queue.task_done()
            except Empty:
                pass
            except Exception as e:
                self.logger.exception(str(e))
                raise e

    def _generator_do_work(self,
                           work_queue,
                           logging_queue,
                           output_counter=None):
        while not self.stopping:
            try:
                item = work_queue.get(timeout=10)
                startTime = time.time()
                item.run(output_counter=output_counter)
                totalTime = time.time() - startTime
                if totalTime > self.config.interval and item._sample.end != 1:
                    self.logger.warning(
                        "work took longer than current interval, queue/threading throughput limitation"
                    )
                work_queue.task_done()
            except Empty:
                pass
            except Exception as e:
                if self.force_stop:
                    break
                self.logger.exception(str(e))
                raise e

    @staticmethod
    def _proc_worker_do_work(work_queue, logging_queue, config,
                             disable_logging):
        genconfig = config
        stopping = genconfig['stopping']
        root = logging.getLogger()
        root.setLevel(logging.DEBUG)
        if logging_queue is not None:
            # TODO https://github.com/splunk/eventgen/issues/217
            qh = logging.handlers.QueueHandler(logging_queue)
            root.addHandler(qh)
        else:
            if disable_logging:
                root.addHandler(logging.NullHandler())
            else:
                root.addHandler(logging.StreamHandler())
        while not stopping:
            try:
                root.info("Checking for work")
                item = work_queue.get(timeout=10)
                item.logger = root
                item._out.updateConfig(item.config)
                item.run()
                work_queue.task_done()
                stopping = genconfig['stopping']
                item.logger.debug(
                    "Current Worker Stopping: {0}".format(stopping))
            except Empty:
                stopping = genconfig['stopping']
            except Exception as e:
                root.exception(e)
                raise e
        else:
            root.info("Stopping Process")
            sys.exit(0)

    def logger_thread(self, loggingQueue):
        while not self.stopping:
            try:
                record = loggingQueue.get(timeout=10)
                logger.handle(record)
                loggingQueue.task_done()
            except Empty:
                pass
            except Exception as e:
                if self.force_stop:
                    break
                self.logger.exception(str(e))
                raise e

    def _initializePlugins(self, dirname, plugins, plugintype, name=None):
        """Load a python module dynamically and add to internal dictionary of plugins (only accessed by getPlugin)"""
        ret = []
        syspathset = set(sys.path)

        dirname = os.path.abspath(dirname)
        self.logger.debug("looking for plugin(s) in {}".format(dirname))
        if not os.path.isdir(dirname):
            self.logger.debug(
                "directory {} does not exist ... moving on".format(dirname))
            return ret

        # Include all plugin directories in sys.path for includes
        if dirname not in sys.path:
            syspathset.add(dirname)
            sys.path = list(syspathset)

        # Loop through all files in passed dirname looking for plugins
        for filename in os.listdir(dirname):
            filename = dirname + os.sep + filename

            # If the file exists
            if os.path.isfile(filename):
                # Split file into a base name plus extension
                basename = os.path.basename(filename)
                base, extension = os.path.splitext(basename)

                # If we're a python file and we don't start with _
                # if extension == ".py" and not basename.startswith("_"):
                # APPPERF-263: If name param is supplied, only attempt to load
                # {name}.py from {app}/bin directory
                if extension == ".py" and (
                    (name is None and not basename.startswith("_"))
                        or base == name):
                    self.logger.debug("Searching for plugin in file '%s'" %
                                      filename)
                    try:
                        # Import the module
                        # module = imp.load_source(base, filename)

                        mod_name, mod_path, mod_desc = imp.find_module(
                            base, [dirname])
                        # TODO: Probably need to adjust module.load() to be added later so this can be pickled.
                        module = imp.load_module(base, mod_name, mod_path,
                                                 mod_desc)
                        plugin = module.load()

                        # spec = importlib.util.spec_from_file_location(base, filename)
                        # plugin = importlib.util.module_from_spec(spec)
                        # spec.loader.exec_module(plugin)

                        # set plugin to something like output.file or generator.default
                        pluginname = plugintype + '.' + base
                        plugins[pluginname] = plugin

                        # Return is used to determine valid configs, so only return the base name of the plugin
                        ret.append(base)

                        self.logger.debug("Loading module '%s' from '%s'" %
                                          (pluginname, basename))

                        # 12/3/13 If we haven't loaded a plugin right or we haven't initialized all the variables
                        # in the plugin, we will get an exception and the plan is to not handle it
                        if 'validSettings' in dir(plugin):
                            self.config._validSettings.extend(
                                plugin.validSettings)
                        if 'defaultableSettings' in dir(plugin):
                            self.config._defaultableSettings.extend(
                                plugin.defaultableSettings)
                        if 'intSettings' in dir(plugin):
                            self.config._intSettings.extend(plugin.intSettings)
                        if 'floatSettings' in dir(plugin):
                            self.config._floatSettings.extend(
                                plugin.floatSettings)
                        if 'boolSettings' in dir(plugin):
                            self.config._boolSettings.extend(
                                plugin.boolSettings)
                        if 'jsonSettings' in dir(plugin):
                            self.config._jsonSettings.extend(
                                plugin.jsonSettings)
                        if 'complexSettings' in dir(plugin):
                            self.config._complexSettings.update(
                                plugin.complexSettings)
                    except ValueError:
                        self.logger.error(
                            "Error loading plugin '%s' of type '%s'" %
                            (base, plugintype))
                    except ImportError as ie:
                        self.logger.warning(
                            "Could not load plugin: %s, skipping" % base)
                        self.logger.exception(ie)
                    except Exception as e:
                        self.logger.exception(str(e))
                        raise e
        return ret

    def start(self, join_after_start=True):
        self.stopping = False
        self.started = True
        self.config.stopping = False
        self.completed = False
        if len(self.config.samples) <= 0:
            self.logger.info("No samples found.  Exiting.")
        for s in self.config.samples:
            if s.interval > 0 or s.mode == 'replay' or s.end != "0":
                self.logger.info(
                    "Creating timer object for sample '%s' in app '%s'" %
                    (s.name, s.app))
                # This is where the timer is finally sent to a queue to be processed.  Needs to move to this object.
                try:
                    t = Timer(1.0,
                              sample=s,
                              config=self.config,
                              genqueue=self.workerQueue,
                              outputqueue=self.outputQueue,
                              loggingqueue=self.loggingQueue)
                except PluginNotLoaded as pnl:
                    self._load_custom_plugins(pnl)
                    t = Timer(1.0,
                              sample=s,
                              config=self.config,
                              genqueue=self.workerQueue,
                              outputqueue=self.outputQueue,
                              loggingqueue=self.loggingQueue)
                except Exception as e:
                    raise e
                self.sampleQueue.put(t)
        if join_after_start:
            self.logger.info(
                "All timers started, joining queue until it's empty.")
            self.join_process()

    def join_process(self):
        '''
        This method will attach the current object to the queues existing for generation and will call stop after all
        generation is complete.  If the queue never finishes, this will lock the main process to the child indefinitely.
        :return:
        '''
        try:
            while not self.sampleQueue.empty(
            ) or self.sampleQueue.unfinished_tasks > 0 or not self.workerQueue.empty(
            ):
                time.sleep(5)
            self.logger.info(
                "All timers have finished, signalling workers to exit.")
            self.stop()
        except Exception as e:
            self.logger.exception(str(e))
            raise e

    def stop(self, force_stop=False):
        # empty the sample queue:
        self.config.stopping = True
        self.stopping = True
        self.force_stop = force_stop

        self.logger.info(
            "All timers exited, joining generation queue until it's empty.")
        if force_stop:
            self.logger.info(
                "Forcibly stopping Eventgen: Deleting workerQueue.")
            del self.workerQueue
            self._create_generator_pool()
        self.workerQueue.join()
        # if we're in multiprocess, make sure we don't add more generators after the timers stopped.
        if self.args.multiprocess:
            if force_stop:
                self.kill_processes()
            else:
                self.genconfig["stopping"] = True
                for worker in self.workerPool:
                    count = 0
                    # We wait for a minute until terminating the worker
                    while worker.exitcode is None and count != 20:
                        if count == 30:
                            self.logger.info("Terminating worker {0}".format(
                                worker._name))
                            worker.terminate()
                            count = 0
                            break
                        self.logger.info(
                            "Worker {0} still working, waiting for it to finish."
                            .format(worker._name))
                        time.sleep(2)
                        count += 1

        self.logger.info(
            "All generators working/exited, joining output queue until it's empty."
        )
        if not self.args.multiprocess and not force_stop:
            self.outputQueue.join()
        self.logger.info(
            "All items fully processed. Cleaning up internal processes.")
        self.started = False
        self.stopping = False

    def reload_conf(self, configfile):
        '''
        This method will allow a user to supply a new .conf file for generation and reload the sample files.
        :param configfile:
        :return:
        '''
        self._load_config(configfile=configfile)
        self.logger.debug("Config File Loading Complete.")

    def check_running(self):
        '''
        :return: if eventgen is running, return True else False
        '''
        if hasattr(self, "outputQueue") and hasattr(
                self, "sampleQueue") and hasattr(self, "workerQueue"):
            # If all queues are not empty, eventgen is running.
            # If all queues are empty and all tasks are finished, eventgen is not running.
            # If all queues are empty and there is an unfinished task, eventgen is running.
            if not self.args.multiprocess:
                if self.outputQueue.empty() and self.sampleQueue.empty() and self.workerQueue.empty() \
                        and self.sampleQueue.unfinished_tasks <= 0 \
                        and self.outputQueue.unfinished_tasks <= 0 \
                        and self.workerQueue.unfinished_tasks <= 0:
                    self.logger.info(
                        "Queues are all empty and there are no pending tasks")
                    return self.started
                else:
                    return True
            else:
                if self.outputQueue.empty() and self.sampleQueue.empty() and self.workerQueue.empty() \
                        and self.sampleQueue.unfinished_tasks <= 0:
                    self.logger.info(
                        "Queues are all empty and there are no pending tasks")
                    return self.started
                else:
                    return True
        return False

    def check_done(self):
        '''

        :return: if eventgen jobs are finished, return True else False
        '''
        return self.sampleQueue.empty(
        ) and self.sampleQueue.unfinished_tasks <= 0 and self.workerQueue.empty(
        )

    def kill_processes(self):
        try:
            if self.args.multiprocess:
                for worker in self.workerPool:
                    try:
                        os.kill(int(worker.pid), signal.SIGKILL)
                    except:
                        continue
                del self.outputQueue
                self.manager.shutdown()
        except:
            pass
Example #34
0
class ZeroConfClient:

    # The discovery protocol name for Ultimaker printers.
    ZERO_CONF_NAME = u"_ultimaker._tcp.local."

    # Signals emitted when new services were discovered or removed on the network.
    addedNetworkCluster = Signal()
    removedNetworkCluster = Signal()

    def __init__(self) -> None:
        self._zero_conf = None  # type: Optional[Zeroconf]
        self._zero_conf_browser = None  # type: Optional[ServiceBrowser]
        self._service_changed_request_queue = None  # type: Optional[Queue]
        self._service_changed_request_event = None  # type: Optional[Event]
        self._service_changed_request_thread = None  # type: Optional[Thread]

    ## The ZeroConf service changed requests are handled in a separate thread so we don't block the UI.
    #  We can also re-schedule the requests when they fail to get detailed service info.
    #  Any new or re-reschedule requests will be appended to the request queue and the thread will process them.
    def start(self) -> None:
        self._service_changed_request_queue = Queue()
        self._service_changed_request_event = Event()
        try:
            self._zero_conf = Zeroconf()
        # CURA-6855 catch WinErrors
        except OSError:
            Logger.logException("e", "Failed to create zeroconf instance.")
            return

        self._service_changed_request_thread = Thread(
            target=self._handleOnServiceChangedRequests,
            daemon=True,
            name="ZeroConfServiceChangedThread")
        self._service_changed_request_thread.start()
        self._zero_conf_browser = ServiceBrowser(self._zero_conf,
                                                 self.ZERO_CONF_NAME,
                                                 [self._queueService])

    # Cleanup ZeroConf resources.
    def stop(self) -> None:
        if self._zero_conf is not None:
            self._zero_conf.close()
            self._zero_conf = None
        if self._zero_conf_browser is not None:
            self._zero_conf_browser.cancel()
            self._zero_conf_browser = None

    ## Handles a change is discovered network services.
    def _queueService(self, zeroconf: Zeroconf, service_type, name: str,
                      state_change: ServiceStateChange) -> None:
        item = (zeroconf, service_type, name, state_change)
        if not self._service_changed_request_queue or not self._service_changed_request_event:
            return
        self._service_changed_request_queue.put(item)
        self._service_changed_request_event.set()

    ## Callback for when a ZeroConf service has changes.
    def _handleOnServiceChangedRequests(self) -> None:
        if not self._service_changed_request_queue or not self._service_changed_request_event:
            return

        while True:
            # Wait for the event to be set
            self._service_changed_request_event.wait(timeout=5.0)

            # Stop if the application is shutting down
            if CuraApplication.getInstance().isShuttingDown():
                return

            self._service_changed_request_event.clear()

            # Handle all pending requests
            reschedule_requests = [
            ]  # A list of requests that have failed so later they will get re-scheduled
            while not self._service_changed_request_queue.empty():
                request = self._service_changed_request_queue.get()
                zeroconf, service_type, name, state_change = request
                try:
                    result = self._onServiceChanged(zeroconf, service_type,
                                                    name, state_change)
                    if not result:
                        reschedule_requests.append(request)
                except Exception:
                    Logger.logException(
                        "e",
                        "Failed to get service info for [%s] [%s], the request will be rescheduled",
                        service_type, name)
                    reschedule_requests.append(request)

            # Re-schedule the failed requests if any
            if reschedule_requests:
                for request in reschedule_requests:
                    self._service_changed_request_queue.put(request)

    ##  Handler for zeroConf detection.
    #   Return True or False indicating if the process succeeded.
    #   Note that this function can take over 3 seconds to complete. Be careful calling it from the main thread.
    def _onServiceChanged(self, zero_conf: Zeroconf, service_type: str,
                          name: str, state_change: ServiceStateChange) -> bool:
        if state_change == ServiceStateChange.Added:
            return self._onServiceAdded(zero_conf, service_type, name)
        elif state_change == ServiceStateChange.Removed:
            return self._onServiceRemoved(name)
        return True

    ## Handler for when a ZeroConf service was added.
    def _onServiceAdded(self, zero_conf: Zeroconf, service_type: str,
                        name: str) -> bool:
        # First try getting info from zero-conf cache
        info = ServiceInfo(service_type, name, properties={})
        for record in zero_conf.cache.entries_with_name(name.lower()):
            info.update_record(zero_conf, time(), record)

        for record in zero_conf.cache.entries_with_name(info.server):
            info.update_record(zero_conf, time(), record)
            if info.address:
                break

        # Request more data if info is not complete
        if not info.address:
            info = zero_conf.get_service_info(service_type, name)

        if info and info.address:
            type_of_device = info.properties.get(b"type", None)
            if type_of_device:
                if type_of_device == b"printer":
                    address = '.'.join(map(str, info.address))
                    self.addedNetworkCluster.emit(str(name), address,
                                                  info.properties)
                else:
                    Logger.log(
                        "w",
                        "The type of the found device is '%s', not 'printer'."
                        % type_of_device)
        else:
            Logger.log("w", "Could not get information about %s" % name)
            return False

        return True

    ## Handler for when a ZeroConf service was removed.
    def _onServiceRemoved(self, name: str) -> bool:
        Logger.log("d", "ZeroConf service removed: %s" % name)
        self.removedNetworkCluster.emit(str(name))
        return True
Example #35
0
class CtaEngine(BaseEngine):
    """"""

    engine_type = EngineType.LIVE  # live trading engine

    setting_filename = "cta_strategy_setting.json"
    data_filename = "cta_strategy_data.json"

    def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
        """"""
        super(CtaEngine, self).__init__(main_engine, event_engine, APP_NAME)

        self.strategy_setting = {}  # strategy_name: dict
        self.strategy_data = {}  # strategy_name: dict

        self.classes = {}  # class_name: stategy_class
        self.strategies = {}  # strategy_name: strategy

        self.symbol_strategy_map = defaultdict(
            list)  # vt_symbol: strategy list
        self.orderid_strategy_map = {}  # vt_orderid: strategy
        self.strategy_orderid_map = defaultdict(
            set)  # strategy_name: orderid list

        self.stop_order_count = 0  # for generating stop_orderid
        self.stop_orders = {}  # stop_orderid: stop_order

        self.init_thread = None
        self.init_queue = Queue()

        self.rq_client = None
        self.rq_symbols = set()

        self.vt_tradeids = set()  # for filtering duplicate trade

        self.offset_converter = OffsetConverter(self.main_engine)

    def init_engine(self):
        """
        """
        self.init_rqdata()
        self.load_strategy_class()
        self.load_strategy_setting()
        self.load_strategy_data()
        self.register_event()
        self.write_log("CTA策略引擎初始化成功")

    def close(self):
        """"""
        self.stop_all_strategies()

    def register_event(self):
        """"""
        self.event_engine.register(EVENT_TICK, self.process_tick_event)
        self.event_engine.register(EVENT_ORDER, self.process_order_event)
        self.event_engine.register(EVENT_TRADE, self.process_trade_event)
        self.event_engine.register(EVENT_POSITION, self.process_position_event)

    def init_rqdata(self):
        """
        Init RQData client.
        """
        result = rqdata_client.init()
        if result:
            self.write_log("RQData数据接口初始化成功")

    def query_bar_from_rq(self, symbol: str, exchange: Exchange,
                          interval: Interval, start: datetime, end: datetime):
        """
        Query bar data from RQData.
        """
        req = HistoryRequest(symbol=symbol,
                             exchange=exchange,
                             interval=interval,
                             start=start,
                             end=end)
        data = rqdata_client.query_history(req)
        return data

    def process_tick_event(self, event: Event):
        """"""
        tick = event.data

        strategies = self.symbol_strategy_map[tick.vt_symbol]
        if not strategies:
            return

        self.check_stop_order(tick)

        for strategy in strategies:
            if strategy.inited:
                self.call_strategy_func(strategy, strategy.on_tick, tick)

    def process_order_event(self, event: Event):
        """"""
        order = event.data

        self.offset_converter.update_order(order)

        strategy = self.orderid_strategy_map.get(order.vt_orderid, None)
        if not strategy:
            return

        # Remove vt_orderid if order is no longer active.
        vt_orderids = self.strategy_orderid_map[strategy.strategy_name]
        if order.vt_orderid in vt_orderids and not order.is_active():
            vt_orderids.remove(order.vt_orderid)

        # For server stop order, call strategy on_stop_order function
        if order.type == OrderType.STOP:
            so = StopOrder(
                vt_symbol=order.vt_symbol,
                direction=order.direction,
                offset=order.offset,
                price=order.price,
                volume=order.volume,
                stop_orderid=order.vt_orderid,
                strategy_name=strategy.strategy_name,
                status=STOP_STATUS_MAP[order.status],
                vt_orderids=[order.vt_orderid],
            )
            self.call_strategy_func(strategy, strategy.on_stop_order, so)

        # Call strategy on_order function
        self.call_strategy_func(strategy, strategy.on_order, order)

    def process_trade_event(self, event: Event):
        """"""
        trade = event.data

        # Filter duplicate trade push
        if trade.vt_tradeid in self.vt_tradeids:
            return
        self.vt_tradeids.add(trade.vt_tradeid)

        self.offset_converter.update_trade(trade)

        strategy = self.orderid_strategy_map.get(trade.vt_orderid, None)
        if not strategy:
            return

        # Update strategy pos before calling on_trade method
        if trade.direction == Direction.LONG:
            strategy.pos += trade.volume
        else:
            strategy.pos -= trade.volume

        self.call_strategy_func(strategy, strategy.on_trade, trade)

        # Sync strategy variables to data file
        self.sync_strategy_data(strategy)

        # Update GUI
        self.put_strategy_event(strategy)

    def process_position_event(self, event: Event):
        """"""
        position = event.data

        self.offset_converter.update_position(position)

    def check_stop_order(self, tick: TickData):
        """"""
        for stop_order in list(self.stop_orders.values()):
            if stop_order.vt_symbol != tick.vt_symbol:
                continue

            long_triggered = (stop_order.direction == Direction.LONG
                              and tick.last_price >= stop_order.price)
            short_triggered = (stop_order.direction == Direction.SHORT
                               and tick.last_price <= stop_order.price)

            if long_triggered or short_triggered:
                strategy = self.strategies[stop_order.strategy_name]

                # To get excuted immediately after stop order is
                # triggered, use limit price if available, otherwise
                # use ask_price_5 or bid_price_5
                if stop_order.direction == Direction.LONG:
                    if tick.limit_up:
                        price = tick.limit_up
                    else:
                        price = tick.ask_price_5
                else:
                    if tick.limit_down:
                        price = tick.limit_down
                    else:
                        price = tick.bid_price_5

                contract = self.main_engine.get_contract(stop_order.vt_symbol)

                vt_orderids = self.send_limit_order(strategy, contract,
                                                    stop_order.direction,
                                                    stop_order.offset, price,
                                                    stop_order.volume,
                                                    stop_order.lock)

                # Update stop order status if placed successfully
                if vt_orderids:
                    # Remove from relation map.
                    self.stop_orders.pop(stop_order.stop_orderid)

                    strategy_vt_orderids = self.strategy_orderid_map[
                        strategy.strategy_name]
                    if stop_order.stop_orderid in strategy_vt_orderids:
                        strategy_vt_orderids.remove(stop_order.stop_orderid)

                    # Change stop order status to cancelled and update to strategy.
                    stop_order.status = StopOrderStatus.TRIGGERED
                    stop_order.vt_orderids = vt_orderids

                    self.call_strategy_func(strategy, strategy.on_stop_order,
                                            stop_order)
                    self.put_stop_order_event(stop_order)

    def send_server_order(self, strategy: CtaTemplate, contract: ContractData,
                          direction: Direction, offset: Offset, price: float,
                          volume: float, type: OrderType, lock: bool):
        """
        Send a new order to server.
        """
        # Create request and send order.
        original_req = OrderRequest(
            symbol=contract.symbol,
            exchange=contract.exchange,
            direction=direction,
            offset=offset,
            type=type,
            price=price,
            volume=volume,
        )

        # Convert with offset converter
        req_list = self.offset_converter.convert_order_request(
            original_req, lock)

        # Send Orders
        vt_orderids = []

        for req in req_list:
            vt_orderid = self.main_engine.send_order(req,
                                                     contract.gateway_name)
            vt_orderids.append(vt_orderid)

            self.offset_converter.update_order_request(req, vt_orderid)

            # Save relationship between orderid and strategy.
            self.orderid_strategy_map[vt_orderid] = strategy
            self.strategy_orderid_map[strategy.strategy_name].add(vt_orderid)

        return vt_orderids

    def send_limit_order(self, strategy: CtaTemplate, contract: ContractData,
                         direction: Direction, offset: Offset, price: float,
                         volume: float, lock: bool):
        """
        Send a limit order to server.
        """
        return self.send_server_order(strategy, contract, direction, offset,
                                      price, volume, OrderType.LIMIT, lock)

    def send_server_stop_order(self, strategy: CtaTemplate,
                               contract: ContractData, direction: Direction,
                               offset: Offset, price: float, volume: float,
                               lock: bool):
        """
        Send a stop order to server.
        
        Should only be used if stop order supported 
        on the trading server.
        """
        return self.send_server_order(strategy, contract, direction, offset,
                                      price, volume, OrderType.STOP, lock)

    def send_local_stop_order(self, strategy: CtaTemplate,
                              direction: Direction, offset: Offset,
                              price: float, volume: float, lock: bool):
        """
        Create a new local stop order.
        """
        self.stop_order_count += 1
        stop_orderid = f"{STOPORDER_PREFIX}.{self.stop_order_count}"

        stop_order = StopOrder(vt_symbol=strategy.vt_symbol,
                               direction=direction,
                               offset=offset,
                               price=price,
                               volume=volume,
                               stop_orderid=stop_orderid,
                               strategy_name=strategy.strategy_name,
                               lock=lock)

        self.stop_orders[stop_orderid] = stop_order

        vt_orderids = self.strategy_orderid_map[strategy.strategy_name]
        vt_orderids.add(stop_orderid)

        self.call_strategy_func(strategy, strategy.on_stop_order, stop_order)
        self.put_stop_order_event(stop_order)

        return stop_orderid

    def cancel_server_order(self, strategy: CtaTemplate, vt_orderid: str):
        """
        Cancel existing order by vt_orderid.
        """
        order = self.main_engine.get_order(vt_orderid)
        if not order:
            self.write_log(f"撤单失败,找不到委托{vt_orderid}", strategy)
            return

        req = order.create_cancel_request()
        self.main_engine.cancel_order(req, order.gateway_name)

    def cancel_local_stop_order(self, strategy: CtaTemplate,
                                stop_orderid: str):
        """
        Cancel a local stop order.
        """
        stop_order = self.stop_orders.get(stop_orderid, None)
        if not stop_order:
            return
        strategy = self.strategies[stop_order.strategy_name]

        # Remove from relation map.
        self.stop_orders.pop(stop_orderid)

        vt_orderids = self.strategy_orderid_map[strategy.strategy_name]
        if stop_orderid in vt_orderids:
            vt_orderids.remove(stop_orderid)

        # Change stop order status to cancelled and update to strategy.
        stop_order.status = StopOrderStatus.CANCELLED

        self.call_strategy_func(strategy, strategy.on_stop_order, stop_order)
        self.put_stop_order_event(stop_order)

    def send_order(self, strategy: CtaTemplate, direction: Direction,
                   offset: Offset, price: float, volume: float, stop: bool,
                   lock: bool):
        """
        """
        contract = self.main_engine.get_contract(strategy.vt_symbol)
        if not contract:
            self.write_log(f"委托失败,找不到合约:{strategy.vt_symbol}", strategy)
            return ""

        # Round order price and volume to nearest incremental value
        price = round_to(price, contract.pricetick)
        volume = round_to(volume, contract.min_volume)

        if stop:
            if contract.stop_supported:
                return self.send_server_stop_order(strategy, contract,
                                                   direction, offset, price,
                                                   volume, lock)
            else:
                return self.send_local_stop_order(strategy, direction, offset,
                                                  price, volume, lock)
        else:
            return self.send_limit_order(strategy, contract, direction, offset,
                                         price, volume, lock)

    def cancel_order(self, strategy: CtaTemplate, vt_orderid: str):
        """
        """
        if vt_orderid.startswith(STOPORDER_PREFIX):
            self.cancel_local_stop_order(strategy, vt_orderid)
        else:
            self.cancel_server_order(strategy, vt_orderid)

    def cancel_all(self, strategy: CtaTemplate):
        """
        Cancel all active orders of a strategy.
        """
        vt_orderids = self.strategy_orderid_map[strategy.strategy_name]
        if not vt_orderids:
            return

        for vt_orderid in copy(vt_orderids):
            self.cancel_order(strategy, vt_orderid)

    def get_engine_type(self):
        """"""
        return self.engine_type

    def load_bar(self, vt_symbol: str, days: int, interval: Interval,
                 callback: Callable[[BarData], None]):
        """"""
        symbol, exchange = extract_vt_symbol(vt_symbol)
        end = datetime.now()
        start = end - timedelta(days)

        # Query bars from RQData by default, if not found, load from database.
        bars = self.query_bar_from_rq(symbol, exchange, interval, start, end)
        if not bars:
            bars = database_manager.load_bar_data(
                symbol=symbol,
                exchange=exchange,
                interval=interval,
                start=start,
                end=end,
            )

        for bar in bars:
            callback(bar)

    def load_tick(self, vt_symbol: str, days: int,
                  callback: Callable[[TickData], None]):
        """"""
        symbol, exchange = extract_vt_symbol(vt_symbol)
        end = datetime.now()
        start = end - timedelta(days)

        ticks = database_manager.load_tick_data(
            symbol=symbol,
            exchange=exchange,
            start=start,
            end=end,
        )

        for tick in ticks:
            callback(tick)

    def call_strategy_func(self,
                           strategy: CtaTemplate,
                           func: Callable,
                           params: Any = None):
        """
        Call function of a strategy and catch any exception raised.
        """
        try:
            if params:
                func(params)
            else:
                func()
        except Exception:
            strategy.trading = False
            strategy.inited = False

            msg = f"触发异常已停止\n{traceback.format_exc()}"
            self.write_log(msg, strategy)

    def add_strategy(self, class_name: str, strategy_name: str, vt_symbol: str,
                     setting: dict):
        """
        Add a new strategy.
        """
        if strategy_name in self.strategies:
            self.write_log(f"创建策略失败,存在重名{strategy_name}")
            return

        strategy_class = self.classes.get(class_name, None)
        if not strategy_class:
            self.write_log(f"创建策略失败,找不到策略类{class_name}")
            return

        strategy = strategy_class(self, strategy_name, vt_symbol, setting)
        self.strategies[strategy_name] = strategy

        # Add vt_symbol to strategy map.
        strategies = self.symbol_strategy_map[vt_symbol]
        strategies.append(strategy)

        # Update to setting file.
        self.update_strategy_setting(strategy_name, setting)

        self.put_strategy_event(strategy)

    def init_strategy(self, strategy_name: str):
        """
        Init a strategy.
        """
        self.init_queue.put(strategy_name)

        if not self.init_thread:
            self.init_thread = Thread(target=self._init_strategy)
            self.init_thread.start()

    def _init_strategy(self):
        """
        Init strategies in queue.
        """
        while not self.init_queue.empty():
            strategy_name = self.init_queue.get()
            strategy = self.strategies[strategy_name]

            if strategy.inited:
                self.write_log(f"{strategy_name}已经完成初始化,禁止重复操作")
                continue

            self.write_log(f"{strategy_name}开始执行初始化")

            # Call on_init function of strategy
            self.call_strategy_func(strategy, strategy.on_init)

            # Restore strategy data(variables)
            data = self.strategy_data.get(strategy_name, None)
            if data:
                for name in strategy.variables:
                    value = data.get(name, None)
                    if value:
                        setattr(strategy, name, value)

            # Subscribe market data
            contract = self.main_engine.get_contract(strategy.vt_symbol)
            if contract:
                req = SubscribeRequest(symbol=contract.symbol,
                                       exchange=contract.exchange)
                self.main_engine.subscribe(req, contract.gateway_name)
            else:
                self.write_log(f"行情订阅失败,找不到合约{strategy.vt_symbol}", strategy)

            # Put event to update init completed status.
            strategy.inited = True
            self.put_strategy_event(strategy)
            self.write_log(f"{strategy_name}初始化完成")

        self.init_thread = None

    def start_strategy(self, strategy_name: str):
        """
        Start a strategy.
        """
        strategy = self.strategies[strategy_name]
        if not strategy.inited:
            self.write_log(f"策略{strategy.strategy_name}启动失败,请先初始化")
            return

        if strategy.trading:
            self.write_log(f"{strategy_name}已经启动,请勿重复操作")
            return

        self.call_strategy_func(strategy, strategy.on_start)
        strategy.trading = True

        self.put_strategy_event(strategy)

    def stop_strategy(self, strategy_name: str):
        """
        Stop a strategy.
        """
        strategy = self.strategies[strategy_name]
        if not strategy.trading:
            return

        # Call on_stop function of the strategy
        self.call_strategy_func(strategy, strategy.on_stop)

        # Change trading status of strategy to False
        strategy.trading = False

        # Cancel all orders of the strategy
        self.cancel_all(strategy)

        # Sync strategy variables to data file
        self.sync_strategy_data(strategy)

        # Update GUI
        self.put_strategy_event(strategy)

    def edit_strategy(self, strategy_name: str, setting: dict):
        """
        Edit parameters of a strategy.
        """
        strategy = self.strategies[strategy_name]
        strategy.update_setting(setting)

        self.update_strategy_setting(strategy_name, setting)
        self.put_strategy_event(strategy)

    def remove_strategy(self, strategy_name: str):
        """
        Remove a strategy.
        """
        strategy = self.strategies[strategy_name]
        if strategy.trading:
            self.write_log(f"策略{strategy.strategy_name}移除失败,请先停止")
            return

        # Remove setting
        self.remove_strategy_setting(strategy_name)

        # Remove from symbol strategy map
        strategies = self.symbol_strategy_map[strategy.vt_symbol]
        strategies.remove(strategy)

        # Remove from active orderid map
        if strategy_name in self.strategy_orderid_map:
            vt_orderids = self.strategy_orderid_map.pop(strategy_name)

            # Remove vt_orderid strategy map
            for vt_orderid in vt_orderids:
                if vt_orderid in self.orderid_strategy_map:
                    self.orderid_strategy_map.pop(vt_orderid)

        # Remove from strategies
        self.strategies.pop(strategy_name)

        return True

    def load_strategy_class(self):
        """
        Load strategy class from source code.
        """
        path1 = Path(__file__).parent.joinpath("strategies")
        self.load_strategy_class_from_folder(
            path1, "vnpy.app.cta_strategy.strategies")

        path2 = Path.cwd().joinpath("strategies")
        self.load_strategy_class_from_folder(path2, "strategies")

    def load_strategy_class_from_folder(self,
                                        path: Path,
                                        module_name: str = ""):
        """
        Load strategy class from certain folder.
        """
        for dirpath, dirnames, filenames in os.walk(str(path)):
            for filename in filenames:
                if filename.endswith(".py"):
                    strategy_module_name = ".".join(
                        [module_name, filename.replace(".py", "")])
                    self.load_strategy_class_from_module(strategy_module_name)

    def load_strategy_class_from_module(self, module_name: str):
        """
        Load strategy class from module file.
        """
        try:
            module = importlib.import_module(module_name)

            for name in dir(module):
                value = getattr(module, name)
                if (isinstance(value, type) and issubclass(value, CtaTemplate)
                        and value is not CtaTemplate):
                    self.classes[value.__name__] = value
        except:  # noqa
            msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}"
            self.write_log(msg)

    def load_strategy_data(self):
        """
        Load strategy data from json file.
        """
        self.strategy_data = load_json(self.data_filename)

    def sync_strategy_data(self, strategy: CtaTemplate):
        """
        Sync strategy data into json file.
        """
        data = strategy.get_variables()
        data.pop("inited"
                 )  # Strategy status (inited, trading) should not be synced.
        data.pop("trading")

        self.strategy_data[strategy.strategy_name] = data
        save_json(self.data_filename, self.strategy_data)

    def get_all_strategy_class_names(self):
        """
        Return names of strategy classes loaded.
        """
        return list(self.classes.keys())

    def get_strategy_class_parameters(self, class_name: str):
        """
        Get default parameters of a strategy class.
        """
        strategy_class = self.classes[class_name]

        parameters = {}
        for name in strategy_class.parameters:
            parameters[name] = getattr(strategy_class, name)

        return parameters

    def get_strategy_parameters(self, strategy_name):
        """
        Get parameters of a strategy.
        """
        strategy = self.strategies[strategy_name]
        return strategy.get_parameters()

    def init_all_strategies(self):
        """
        """
        for strategy_name in self.strategies.keys():
            self.init_strategy(strategy_name)

    def start_all_strategies(self):
        """
        """
        for strategy_name in self.strategies.keys():
            self.start_strategy(strategy_name)

    def stop_all_strategies(self):
        """
        """
        for strategy_name in self.strategies.keys():
            self.stop_strategy(strategy_name)

    def load_strategy_setting(self):
        """
        Load setting file.
        """
        self.strategy_setting = load_json(self.setting_filename)

        for strategy_name, strategy_config in self.strategy_setting.items():
            self.add_strategy(strategy_config["class_name"], strategy_name,
                              strategy_config["vt_symbol"],
                              strategy_config["setting"])

    def update_strategy_setting(self, strategy_name: str, setting: dict):
        """
        Update setting file.
        """
        strategy = self.strategies[strategy_name]

        self.strategy_setting[strategy_name] = {
            "class_name": strategy.__class__.__name__,
            "vt_symbol": strategy.vt_symbol,
            "setting": setting,
        }
        save_json(self.setting_filename, self.strategy_setting)

    def remove_strategy_setting(self, strategy_name: str):
        """
        Update setting file.
        """
        if strategy_name not in self.strategy_setting:
            return

        self.strategy_setting.pop(strategy_name)
        save_json(self.setting_filename, self.strategy_setting)

    def put_stop_order_event(self, stop_order: StopOrder):
        """
        Put an event to update stop order status.
        """
        event = Event(EVENT_CTA_STOPORDER, stop_order)
        self.event_engine.put(event)

    def put_strategy_event(self, strategy: CtaTemplate):
        """
        Put an event to update strategy status.
        """
        data = strategy.get_data()
        event = Event(EVENT_CTA_STRATEGY, data)
        self.event_engine.put(event)

    def write_log(self, msg: str, strategy: CtaTemplate = None):
        """
        Create cta engine log event.
        """
        if strategy:
            msg = f"{strategy.strategy_name}: {msg}"

        log = LogData(msg=msg, gateway_name="CtaStrategy")
        event = Event(type=EVENT_CTA_LOG, data=log)
        self.event_engine.put(event)

    def send_email(self, msg: str, strategy: CtaTemplate = None):
        """
        Send email to default receiver.
        """
        if strategy:
            subject = f"{strategy.strategy_name}"
        else:
            subject = "CTA策略引擎"

        self.main_engine.send_email(subject, msg)
Example #36
0
class Brute:
    def __init__(self, args):
        self.urls = args.urls
        self.wordlists = args.wordlists
        self.data = args.data
        self.method = args.method
        self.thread_num = args.thread_num
        self.queue = Queue()
        f = open(self.wordlists, 'r')
        for i in f.readlines():
            self.queue.put(re.sub(r"FUZZ", i, args.data))

    def to_do(self):
        if self.method == "Post" and self.data == None:
            thread_count = int(self.thread_num)
            for i in range(thread_count):
                t = threading.Thread(target=self.fuzz)
                t.start()
                t.join()
        if self.data != None:
            thread_count = int(self.thread_num)
            for i in range(thread_count):
                t = threading.Thread(target=self.fuzz)
                t.start()
                t.join()

    def PostWebScan(self):
        gc.collect()
        while not self.queue.empty():
            urls = self.queue.get()
            resp = requests.post(urls, headers=headers)
            try:
                if resp.status_code == 200:
                    sys.stdout.write('\r' + colorama.Fore.GREEN +
                                     '[+]\t200\t{}\t\t{}\n'.format(
                                         resp.headers['content-length'], urls))
                elif resp.status_code == 403:
                    sys.stdout.write('\r' + colorama.Fore.CYAN +
                                     '[!]\t403\t{}\t\t{}\n'.format(
                                         resp.headers['content-length'], urls))
                elif resp.status_code == 302:
                    sys.stdout.write('\r' + colorama.Fore.BLUE +
                                     '[+]\t302\t{}\t\t{}\n'.format(
                                         resp.headers['content-length'], urls))
                elif resp.status_code == 301:
                    sys.stdout.write('\r' + colorama.Fore.BLUE +
                                     '[+]\t301\t{}\t\t{}\n'.format(
                                         resp.headers['content-length'], urls))
                elif resp.status_code == 405:
                    sys.stdout.write('\r' + colorama.Fore.CYAN +
                                     '[!]\t405\t{}\t\t{}\n'.format(
                                         resp.headers['content-length'], urls))
                elif resp.status_code == 400:
                    sys.stdout.write('\r' + colorama.Fore.CYAN +
                                     '[-]\t400\t{}\t\t{}\n'.format(
                                         resp.headers['content-length'], urls))
                elif resp.status_code == 500:
                    sys.stdout.write('\r' + colorama.Fore.RED +
                                     '[-]\t500\t{}\t\t{}\n'.format(
                                         resp.headers['content-length'], urls))
                elif resp.status_code == 404:
                    sys.stdout.write('\r' + colorama.Fore.RED +
                                     '[-]\t404\t{}\t\t{}\n'.format(
                                         resp.headers['content-length'], urls))
            except EOFError as e:
                sys.exit(1)

    def fuzz(self):
        gc.collect()
        while not self.queue.empty():
            datas = self.queue.get()
            resp = requests.post(self.urls, headers=headers, data=datas)
            try:
                if resp.status_code == 200:
                    sys.stdout.write(
                        '\r' + colorama.Fore.GREEN +
                        '[+]\t200\t{}\t\t{}\n'.format(
                            resp.headers['content-length'], datas))
                elif resp.status_code == 403:
                    sys.stdout.write(
                        '\r' + colorama.Fore.CYAN +
                        '[!]\t403\t{}\t\t{}\n'.format(
                            resp.headers['content-length'], datas))
                elif resp.status_code == 302:
                    sys.stdout.write(
                        '\r' + colorama.Fore.BLUE +
                        '[+]\t302\t{}\t\t{}\n'.format(
                            resp.headers['content-length'], datas))
                elif resp.status_code == 301:
                    sys.stdout.write(
                        '\r' + colorama.Fore.BLUE +
                        '[+]\t301\t{}\t\t{}\n'.format(
                            resp.headers['content-length'], datas))
                elif resp.status_code == 405:
                    sys.stdout.write(
                        '\r' + colorama.Fore.CYAN +
                        '[!]\t405\t{}\t\t{}\n'.format(
                            resp.headers['content-length'], datas))
                elif resp.status_code == 400:
                    sys.stdout.write(
                        '\r' + colorama.Fore.CYAN +
                        '[-]\t400\t{}\t\t{}\n'.format(
                            resp.headers['content-length'], datas))
                elif resp.status_code == 500:
                    sys.stdout.write(
                        '\r' + colorama.Fore.RED +
                        '[-]\t500\t{}\t\t{}\n'.format(
                            resp.headers['content-length'], datas))
                elif resp.status_code == 404:
                    sys.stdout.write(
                        '\r' + colorama.Fore.RED +
                        '[-]\t404\t{}\t\t{}\n'.format(
                            resp.headers['content-length'], datas))
            except EOFError as e:
                sys.exit(1)
class TBGatewayService:
    def __init__(self, config_file=None):
        self.__lock = RLock()
        if config_file is None:
            config_file = path.dirname(path.dirname(path.abspath(__file__))) + '/config/tb_gateway.yaml'.replace('/', path.sep)
        with open(config_file) as general_config:
            config = safe_load(general_config)
        self._config_dir = path.dirname(path.abspath(config_file)) + path.sep
        logging.config.fileConfig(self._config_dir + "logs.conf")
        global log
        log = logging.getLogger('service')
        log.info("Gateway starting...")
        self.available_connectors = {}
        self.__connector_incoming_messages = {}
        self.__connected_devices = {}
        self.__saved_devices = {}
        self.__events = []
        self.name = ''.join(choice(ascii_lowercase) for _ in range(64))
        self.__rpc_requests_in_progress = {}
        self.__connected_devices_file = "connected_devices.json"
        self.tb_client = TBClient(config["thingsboard"])
        self.tb_client.connect()
        self.subscribe_to_required_topics()
        self.counter = 0
        global main_handler
        self.main_handler = main_handler
        self.remote_handler = TBLoggerHandler(self)
        self.main_handler.setTarget(self.remote_handler)
        self._default_connectors = {
            "mqtt": "MqttConnector",
            "modbus": "ModbusConnector",
            "opcua": "OpcUaConnector",
            "ble": "BLEConnector",
            "request": "RequestConnector",
            "can": "CanConnector"
        }
        self._implemented_connectors = {}
        self._event_storage_types = {
            "memory": MemoryEventStorage,
            "file": FileEventStorage,
        }
        self.__gateway_rpc_methods = {
            "ping": self.__rpc_ping,
            "stats": self.__form_statistics,
            "devices": self.__rpc_devices,
        }
        self.__sheduled_rpc_calls = []
        self.__self_rpc_sheduled_methods_functions = {
            "restart": {"function": execv, "arguments": (executable, [executable.split(pathsep)[-1]] + argv)},
            "reboot": {"function": system, "arguments": ("reboot 0",)},
            }
        self._event_storage = self._event_storage_types[config["storage"]["type"]](config["storage"])
        self.connectors_configs = {}
        self._load_connectors(config)
        self._connect_with_connectors()
        self.__remote_configurator = None
        self.__request_config_after_connect = False
        if config["thingsboard"].get("remoteConfiguration"):
            try:
                self.__remote_configurator = RemoteConfigurator(self, config)
            except Exception as e:
                log.exception(e)
        if self.__remote_configurator is not None:
            self.__remote_configurator.send_current_configuration()
        self.__load_persistent_devices()
        self._published_events = Queue(-1)
        self._send_thread = Thread(target=self.__read_data_from_storage, daemon=True,
                                   name="Send data to Thingsboard Thread")
        self._send_thread.start()
        log.info("Gateway started.")

        try:
            gateway_statistic_send = 0
            while True:
                cur_time = time()*1000
                if self.__sheduled_rpc_calls:
                    for rpc_call_index in range(len(self.__sheduled_rpc_calls)):
                        rpc_call = self.__sheduled_rpc_calls[rpc_call_index]
                        if cur_time > rpc_call[0]:
                            rpc_call = self.__sheduled_rpc_calls.pop(rpc_call_index)
                            result = None
                            try:
                                result = rpc_call[1]["function"](*rpc_call[1]["arguments"])
                            except Exception as e:
                                log.exception(e)
                            if result == 256:
                                log.warning("Error on RPC command: 256. Permission denied.")
                if self.__rpc_requests_in_progress and self.tb_client.is_connected():
                    for rpc_in_progress, data in self.__rpc_requests_in_progress.items():
                        if cur_time >= data[1]:
                            data[2](rpc_in_progress)
                            self.cancel_rpc_request(rpc_in_progress)
                            self.__rpc_requests_in_progress[rpc_in_progress] = "del"
                    new_rpc_request_in_progress = {key: value for key, value in self.__rpc_requests_in_progress.items() if value != 'del'}
                    self.__rpc_requests_in_progress = new_rpc_request_in_progress
                else:
                    try:
                        sleep(.1)
                    except Exception as e:
                        log.exception(e)
                        break
                if not self.__request_config_after_connect and \
                        self.tb_client.is_connected() and not self.tb_client.client.get_subscriptions_in_progress():
                    self.__request_config_after_connect = True
                    self.__check_shared_attributes()

                if cur_time - gateway_statistic_send > 5000.0 and self.tb_client.is_connected():
                    summary_messages = self.__form_statistics()
                    # with self.__lock:
                    self.tb_client.client.send_telemetry(summary_messages)
                    gateway_statistic_send = time()*1000
                    # self.__check_shared_attributes()
        except KeyboardInterrupt:
            log.info("Stopping...")
            self.__close_connectors()
            log.info("The gateway has been stopped.")
            self.tb_client.stop()
        except Exception as e:
            log.exception(e)
            self.__close_connectors()
            log.info("The gateway has been stopped.")
            self.tb_client.stop()

    def __close_connectors(self):
        for current_connector in self.available_connectors:
            try:
                self.available_connectors[current_connector].close()
                log.debug("Connector %s closed connection.", current_connector)
            except Exception as e:
                log.exception(e)

    def __stop_gateway(self):
        pass

    def _attributes_parse(self, content, *args):
        try:
            log.debug("Received data: %s", content)
            log.debug(args)
            if content is not None:
                shared_attributes = content.get("shared")
                client_attributes = content.get("client")
                new_configuration = shared_attributes.get("configuration") if shared_attributes is not None and shared_attributes.get("configuration") is not None else content.get("configuration")
                if new_configuration is not None and self.__remote_configurator is not None:
                    try:
                        confirmed = self.__remote_configurator.process_configuration(new_configuration)
                        # if confirmed:
                            # self._send_thread = Thread(target=self.__read_data_from_storage, daemon=True,
                            #                            name="Send data to Thingsboard Thread")
                            # self._send_thread.start()
                        self.__remote_configurator.send_current_configuration()
                    except Exception as e:
                        log.exception(e)
                remote_logging_level = shared_attributes.get('RemoteLoggingLevel') if shared_attributes is not None else content.get("RemoteLoggingLevel")
                if remote_logging_level == 'NONE':
                    self.remote_handler.deactivate()
                    log.info('Remote logging has being deactivated.')
                elif remote_logging_level is not None:
                    if self.remote_handler.current_log_level != remote_logging_level or not self.remote_handler.activated:
                        self.main_handler.setLevel(remote_logging_level)
                        self.remote_handler.activate(remote_logging_level)
                        log.info('Remote logging has being updated. Current logging level is: %s ', remote_logging_level)
                if shared_attributes is not None:
                    log.debug("Shared attributes received (%s).", ", ".join([attr for attr in shared_attributes.keys()]))
                if client_attributes is not None:
                    log.debug("Client attributes received (%s).", ", ".join([attr for attr in client_attributes.keys()]))
        except Exception as e:
            log.exception(e)

    def get_config_path(self):
        return self._config_dir

    def subscribe_to_required_topics(self):
        self.tb_client.client.gw_set_server_side_rpc_request_handler(self._rpc_request_handler)
        self.tb_client.client.set_server_side_rpc_request_handler(self._rpc_request_handler)
        self.tb_client.client.subscribe_to_all_attributes(self._attribute_update_callback)
        self.tb_client.client.gw_subscribe_to_all_attributes(self._attribute_update_callback)

    def __check_shared_attributes(self):
        self.tb_client.client.request_attributes(callback=self._attributes_parse)

    def _load_connectors(self, main_config):
        self.connectors_configs = {}
        if not main_config.get("connectors"):
            raise Exception("Configuration for connectors not found, check your config file.")
        for connector in main_config['connectors']:
            try:
                connector_class = TBUtility.check_and_import(connector["type"], self._default_connectors.get(connector["type"], connector.get("class")))
                self._implemented_connectors[connector["type"]] = connector_class
                with open(self._config_dir + connector['configuration'], 'r') as conf_file:
                    connector_conf = load(conf_file)
                    if not self.connectors_configs.get(connector['type']):
                        self.connectors_configs[connector['type']] = []
                    connector_conf["name"] = connector["name"]
                    self.connectors_configs[connector['type']].append({"name": connector["name"], "config": {connector['configuration']: connector_conf}})
            except Exception as e:
                log.error("Error on loading connector:")
                log.exception(e)

    def _connect_with_connectors(self):
        for connector_type in self.connectors_configs:
            for connector_config in self.connectors_configs[connector_type]:
                for config in connector_config["config"]:
                    connector = None
                    try:
                        connector = self._implemented_connectors[connector_type](self, connector_config["config"][config],
                                                                                 connector_type)
                        connector.setName(connector_config["name"])
                        self.available_connectors[connector.get_name()] = connector
                        connector.open()
                    except Exception as e:
                        log.exception(e)
                        if connector is not None:
                            connector.close()

    def send_to_storage(self, connector_name, data):
        if not connector_name == self.name:
            if not TBUtility.validate_converted_data(data):
                log.error("Data from %s connector is invalid.", connector_name)
                return None
            if data["deviceName"] not in self.get_devices():
                self.add_device(data["deviceName"],
                                {"connector": self.available_connectors[connector_name]}, wait_for_publish=True, device_type=data["deviceType"])
            if not self.__connector_incoming_messages.get(connector_name):
                self.__connector_incoming_messages[connector_name] = 0
            else:
                self.__connector_incoming_messages[connector_name] += 1

        telemetry = {}
        telemetry_with_ts = []
        for item in data["telemetry"]:
            if item.get("ts") is None:
                telemetry = {**telemetry, **item}
            else:
                telemetry_with_ts.append({"ts": item["ts"], "values": {**item["values"]}})
        if telemetry_with_ts:
            data["telemetry"] = telemetry_with_ts
        else:
            data["telemetry"] = {"ts": int(time() * 1000), "values": telemetry}

        json_data = dumps(data)
        save_result = self._event_storage.put(json_data)
        if not save_result:
            log.error('Data from the device "%s" cannot be saved, connector name is %s.',
                      data["deviceName"],
                      connector_name)

    def check_size(self, size, devices_data_in_event_pack):
        if size >= 48000:
            self.__send_data(devices_data_in_event_pack)
            size = 0
        return size

    def __read_data_from_storage(self):
        devices_data_in_event_pack = {}
        log.debug("Send data Thread has been started successfully.")
        while True:
            try:
                if self.tb_client.is_connected():
                    size = getsizeof(devices_data_in_event_pack)
                    events = []
                    if self.__remote_configurator is None or not self.__remote_configurator.in_process:
                        events = self._event_storage.get_event_pack()
                    if events:
                        for event in events:
                            self.counter += 1
                            try:
                                current_event = loads(event)
                            except Exception as e:
                                log.exception(e)
                                continue
                            if not devices_data_in_event_pack.get(current_event["deviceName"]):
                                devices_data_in_event_pack[current_event["deviceName"]] = {"telemetry": [],
                                                                                           "attributes": {}}
                            if current_event.get("telemetry"):
                                if isinstance(current_event["telemetry"], list):
                                    for item in current_event["telemetry"]:
                                        size += getsizeof(item)
                                        size = self.check_size(size, devices_data_in_event_pack)
                                        devices_data_in_event_pack[current_event["deviceName"]]["telemetry"].append(item)
                                else:
                                    if not self.tb_client.is_connected():
                                        break
                                    size += getsizeof(current_event["telemetry"])
                                    size = self.check_size(size, devices_data_in_event_pack)
                                    devices_data_in_event_pack[current_event["deviceName"]]["telemetry"].append(current_event["telemetry"])
                            if current_event.get("attributes"):
                                if isinstance(current_event["attributes"], list):
                                    for item in current_event["attributes"]:
                                        if not self.tb_client.is_connected():
                                            break
                                        size += getsizeof(item)
                                        size = self.check_size(size, devices_data_in_event_pack)
                                        devices_data_in_event_pack[current_event["deviceName"]]["attributes"].update(item.items())
                                else:
                                    if not self.tb_client.is_connected():
                                        break
                                    size += getsizeof(current_event["attributes"].items())
                                    size = self.check_size(size, devices_data_in_event_pack)
                                    devices_data_in_event_pack[current_event["deviceName"]]["attributes"].update(
                                        current_event["attributes"].items())
                        if devices_data_in_event_pack:
                            if not self.tb_client.is_connected():
                                break
                            self.__send_data(devices_data_in_event_pack)
                        if self.tb_client.is_connected() and (self.__remote_configurator is None or not self.__remote_configurator.in_process):
                            success = True
                            while not self._published_events.empty():
                                if (self.__remote_configurator is not None and self.__remote_configurator.in_process) or not self.tb_client.is_connected() or self._published_events.empty():
                                    success = False
                                    break
                                event = self._published_events.get(False, 10)
                                try:
                                    if self.tb_client.is_connected() and (self.__remote_configurator is None or not self.__remote_configurator.in_process):
                                        success = event.get() == event.TB_ERR_SUCCESS
                                    else:
                                        break
                                except Exception as e:
                                    log.exception(e)
                                    success = False
                            if success:
                                self._event_storage.event_pack_processing_done()
                                del devices_data_in_event_pack
                                devices_data_in_event_pack = {}
                        else:
                            continue
                    else:
                        sleep(.01)
                else:
                    sleep(.1)
            except Exception as e:
                log.exception(e)
                sleep(1)

    def __send_data(self, devices_data_in_event_pack):
        try:
            for device in devices_data_in_event_pack:
                if devices_data_in_event_pack[device].get("attributes"):
                    if device == self.name:
                        self._published_events.put(self.tb_client.client.send_attributes(devices_data_in_event_pack[device]["attributes"]))
                    else:
                        self._published_events.put(self.tb_client.client.gw_send_attributes(device, devices_data_in_event_pack[device]["attributes"]))
                if devices_data_in_event_pack[device].get("telemetry"):
                    if device == self.name:
                        self._published_events.put(self.tb_client.client.send_telemetry(devices_data_in_event_pack[device]["telemetry"]))
                    else:
                        self._published_events.put(self.tb_client.client.gw_send_telemetry(device, devices_data_in_event_pack[device]["telemetry"]))
                devices_data_in_event_pack[device] = {"telemetry": [], "attributes": {}}
        except Exception as e:
            log.exception(e)

    def _rpc_request_handler(self, request_id, content):
        try:
            device = content.get("device")
            if device is not None:
                connector_name = self.get_devices()[device].get("connector")
                if connector_name is not None:
                    connector_name.server_side_rpc_handler(content)
                else:
                    log.error("Received RPC request but connector for the device %s not found. Request data: \n %s",
                              content["device"],
                              dumps(content))
            else:
                try:
                    method_split = content["method"].split('_')
                    module = None
                    if len(method_split) > 0:
                        module = method_split[0]
                    if module is not None:
                        result = None
                        if self.connectors_configs.get(module):
                            log.debug("Connector \"%s\" for RPC request \"%s\" found", module, content["method"])
                            for connector_name in self.available_connectors:
                                if self.available_connectors[connector_name]._connector_type == module:
                                    log.debug("Sending command RPC %s to connector %s", content["method"], connector_name)
                                    result = self.available_connectors[connector_name].server_side_rpc_handler(content)
                        elif module == 'gateway':
                            result = self.__rpc_gateway_processing(request_id, content)
                        else:
                            log.error("Connector \"%s\" not found", module)
                            result = {"error": "%s - connector not found in available connectors." % module, "code": 404}
                        if result is None:
                            self.send_rpc_reply(None, request_id, success_sent=False)
                        else:
                            self.send_rpc_reply(None, request_id, dumps(result))
                except Exception as e:
                    self.send_rpc_reply(None, request_id, "{\"error\":\"%s\", \"code\": 500}" % str(e))
                    log.exception(e)
        except Exception as e:
            log.exception(e)

    def __rpc_gateway_processing(self, request_id, content):
        log.info("Received RPC request to the gateway, id: %s, method: %s", str(request_id), content["method"])
        arguments = content.get('params')
        method_to_call = content["method"].replace("gateway_", "")
        result = None
        if isinstance(arguments, list):
            result = self.__gateway_rpc_methods[method_to_call](*arguments)
        elif method_to_call in self.__self_rpc_sheduled_methods_functions:
            seconds_to_restart = arguments*1000 if arguments else 0
            self.__sheduled_rpc_calls.append([time()*1000 + seconds_to_restart, self.__self_rpc_sheduled_methods_functions[method_to_call]])
            log.info("Gateway %s sheduled in %i seconds", method_to_call, seconds_to_restart/1000)
            result = {"success": True}
        elif arguments is not None:
            result = self.__gateway_rpc_methods[method_to_call]()
        else:
            result = self.__gateway_rpc_methods[method_to_call]()
        return result

    def __rpc_ping(self, *args):
        return {"code": 200, "resp": "pong"}

    def __rpc_devices(self, *args):
        data_to_send = {}
        for device in self.__connected_devices:
            if self.__connected_devices[device]["connector"] is not None:
                data_to_send[device] = self.__connected_devices[device]["connector"].get_name()
        return {"code": 200, "resp": data_to_send}

    def rpc_with_reply_processing(self, topic, content):
        req_id = self.__rpc_requests_in_progress[topic][0]["data"]["id"]
        device = self.__rpc_requests_in_progress[topic][0]["device"]
        self.send_rpc_reply(device, req_id, content)
        self.cancel_rpc_request(topic)

    def send_rpc_reply(self, device=None, req_id=None, content=None, success_sent=None, wait_for_publish=None):
        try:
            rpc_response = {"success": False}
            if success_sent is not None:
                if success_sent:
                    rpc_response["success"] = True
            if device is not None and success_sent is not None:
                self.tb_client.client.gw_send_rpc_reply(device, req_id, dumps(rpc_response))
            elif device is not None and req_id is not None and content is not None:
                self.tb_client.client.gw_send_rpc_reply(device, req_id, content)
            elif device is None and success_sent is not None:
                self.tb_client.client.send_rpc_reply(req_id, dumps(rpc_response), quality_of_service=1, wait_for_publish=wait_for_publish)
            elif device is None and content is not None:
                self.tb_client.client.send_rpc_reply(req_id, content, quality_of_service=1, wait_for_publish=wait_for_publish)
        except Exception as e:
            log.exception(e)

    def register_rpc_request_timeout(self, content, timeout, topic, cancel_method):
        self.__rpc_requests_in_progress[topic] = (content, timeout, cancel_method)

    def cancel_rpc_request(self, rpc_request):
        content = self.__rpc_requests_in_progress[rpc_request][0]
        self.send_rpc_reply(device=content["device"], req_id=content["data"]["id"], success_sent=False)

    def _attribute_update_callback(self, content, *args):
        log.debug("Attribute request received with content: \"%s\"", content)
        log.debug(args)
        if content.get('device') is not None:
            try:
                self.__connected_devices[content["device"]]["connector"].on_attributes_update(content)
            except Exception as e:
                log.exception(e)
        else:
            self._attributes_parse(content)

    def __form_statistics(self):
        summary_messages = {"eventsProduced": 0, "eventsSent": 0}
        telemetry = {}
        for connector in self.available_connectors:
            connector_camel_case = connector.lower().replace(' ', '')
            telemetry[(connector_camel_case + ' EventsProduced').replace(' ', '')] = \
                self.available_connectors[connector].statistics['MessagesReceived']
            self.available_connectors[connector].statistics['MessagesReceived'] = 0
            telemetry[(connector_camel_case + ' EventsSent').replace(' ', '')] = \
                self.available_connectors[connector].statistics['MessagesSent']
            self.available_connectors[connector].statistics['MessagesSent'] = 0
            summary_messages['eventsProduced'] += telemetry[
                str(connector_camel_case + ' EventsProduced').replace(' ', '')]
            summary_messages['eventsSent'] += telemetry[
                str(connector_camel_case + ' EventsSent').replace(' ', '')]
            summary_messages.update(**telemetry)
        return summary_messages

    def add_device(self, device_name, content, wait_for_publish=False, device_type=None):
        if device_name not in self.__saved_devices:
            self.__connected_devices[device_name] = content
            self.__saved_devices[device_name] = content
            device_type = device_type if device_type is not None else 'default'
            if wait_for_publish:
                self.tb_client.client.gw_connect_device(device_name, device_type).wait_for_publish()
            else:
                self.tb_client.client.gw_connect_device(device_name, device_type)
            self.__save_persistent_devices()

    def update_device(self, device_name, event, content):
        if event == 'connector' and self.__connected_devices[device_name].get(event) != content:
            self.__save_persistent_devices()
        self.__connected_devices[device_name][event] = content

    def del_device(self, device_name):
        del self.__connected_devices[device_name]
        self.tb_client.client.gw_disconnect_device(device_name)
        self.__save_persistent_devices()

    def get_devices(self):
        return self.__connected_devices

    def __load_persistent_devices(self):
        devices = {}
        if self.__connected_devices_file in listdir(self._config_dir) and \
                path.getsize(self._config_dir + self.__connected_devices_file) > 0:
            try:
                with open(self._config_dir + self.__connected_devices_file) as devices_file:
                    devices = load(devices_file)
            except Exception as e:
                log.exception(e)
        else:
            connected_devices_file = open(self._config_dir + self.__connected_devices_file, 'w')
            connected_devices_file.close()

        if devices is not None:
            log.debug("Loaded devices:\n %s", devices)
            for device_name in devices:
                try:
                    if self.available_connectors.get(devices[device_name]):
                        self.__connected_devices[device_name] = {
                            "connector": self.available_connectors[devices[device_name]]}
                    else:
                        log.warning("Device %s connector not found, maybe it had been disabled.", device_name)
                except Exception as e:
                    log.exception(e)
                    continue
        else:
            log.debug("No device found in connected device file.")
            self.__connected_devices = {} if self.__connected_devices is None else self.__connected_devices

    def __save_persistent_devices(self):
        with open(self._config_dir + self.__connected_devices_file, 'w') as config_file:
            try:
                data_to_save = {}
                for device in self.__connected_devices:
                    if self.__connected_devices[device]["connector"] is not None:
                        data_to_save[device] = self.__connected_devices[device]["connector"].get_name()
                config_file.write(dumps(data_to_save, indent=2, sort_keys=True))
            except Exception as e:
                log.exception(e)
        log.debug("Saved connected devices.")
Example #38
0
class TestFunction(unittest.TestCase, Common):
    def setUp(self):
        # opened files
        self.files = []
        self.q = Queue()
        self.rm('logdogs.log')
        self.rm('a.log')
        self.rm('b.log')
        self.rm('logs')

    def tearDown(self):
        self.assertTrue(self.q.empty())
        for f in self.files:
            f.close()

    def handler(self, file, lines):
        # print(lines)
        self.q.put(lines)

    def open(self, path):
        # w: create an empty file or truncate if it exists
        f = open(path, 'w')
        self.files.append(f)
        return f

    def write(self, f, s):
        # it's not easy to write str to file without bufferring in both py2 and py3
        f.write(s)
        f.flush()

    def test_1_file(self):
        """
        the simplest case
        """
        DOGS = {
            'test': {
                'paths': ['a.log'],
                'includes': ['wrong'],
                'excludes': ['long'],
                'handler': self.handler
            }
        }
        f = self.open('a.log')
        logdogs = LogDogs(DOGS)

        self.write(f, 'hello world\n')
        logdogs.process()
        self.assertTrue(self.q.empty())

        self.write(f, 'something wrong\nwhats wrong\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(),
                         ['something wrong\n', 'whats wrong\n'])

        self.write(f, 'a long wrong answer\n')
        logdogs.process()
        self.assertTrue(self.q.empty())

    def test_ignore_old_logs(self):
        """
        old logs in the log file are ignored
        """
        DOGS = {
            'test': {
                'paths': ['a.log'],
                'includes': ['wrong'],
                'handler': self.handler
            }
        }
        f = self.open('a.log')
        self.write(f, 'you are on a wrong way\n')
        logdogs = LogDogs(DOGS)

        self.write(f, 'hello world\n')
        logdogs.process()
        self.assertTrue(self.q.empty())

        self.write(f, 'something wrong\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(), ['something wrong\n'])

    def test_rotate(self):
        """
        log file is moved and a new one is created
        """
        DOGS = {
            'test': {
                'paths': ['a.log'],
                'includes': ['wrong'],
                'handler': self.handler
            }
        }
        f = self.open('a.log')
        logdogs = LogDogs(DOGS)

        self.write(f, 'something wrong\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(), ['something wrong\n'])

        shutil.move('a.log', 'b.log')
        self.write(f, 'whats wrong\n')
        # nginx will still log to old log file before reopen signal is handled
        # in some cases the last few logs are missing

        # write to new file immediately
        f = self.open('a.log')
        self.write(f, 'all is wrong\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(), ['whats wrong\n'])
        self.assertEqual(self.q.get_nowait(), ['all is wrong\n'])

    def test_2_files(self):
        """
        a dog can watch more than 1 files
        """
        DOGS = {
            'test': {
                'paths': ['a.log', 'b.log'],
                'includes': ['wrong'],
                'handler': self.handler
            }
        }
        f1 = self.open('a.log')
        f2 = self.open('b.log')
        logdogs = LogDogs(DOGS)

        self.write(f1, 'something wrong\n')
        self.write(f2, 'whats wrong\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(), ['something wrong\n'])
        self.assertEqual(self.q.get_nowait(), ['whats wrong\n'])

    def test_overlap(self):
        """
        2 dogs can watch the same file
        """
        DOGS = {
            'test1': {
                'paths': ['a.log'],
                'includes': ['error', 'wrong'],
                'handler': self.handler
            },
            'test2': {
                'paths': ['a.log'],
                'includes': ['warning', 'wrong'],
                'handler': self.handler
            }
        }
        f = self.open('a.log')
        logdogs = LogDogs(DOGS)

        self.write(f, 'an error\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(), ['an error\n'])

        self.write(f, 'a warning\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(), ['a warning\n'])

        self.assertTrue(self.q.empty())

        self.write(f, 'something wrong\n')
        logdogs.process()
        # received 2 times
        self.assertEqual(self.q.get_nowait(), ['something wrong\n'])
        self.assertEqual(self.q.get_nowait(), ['something wrong\n'])

    def test_not_exists(self):
        """
        log file is not required to exist before watch
        """
        DOGS = {
            'test': {
                'paths': ['a.log'],
                'includes': ['wrong'],
                'handler': self.handler
            }
        }
        logdogs = LogDogs(DOGS)

        # create file after watch
        f = self.open('a.log')

        self.write(f, 'something wrong\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(), ['something wrong\n'])

    def test_2_not_exists(self):
        """
        bug: the same dog watch the same path(.) twice
        """
        DOGS = {
            'test': {
                'paths': ['a.log', 'b.log'],
                'includes': ['wrong'],
                'handler': self.handler
            }
        }
        logdogs = LogDogs(DOGS)

        # create file after watch
        f = self.open('a.log')
        self.write(f, 'something wrong\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(), ['something wrong\n'])
        # q must be empty now

    def test_glob(self):
        """
        glob pattern can be used in path
        """
        DOGS = {
            'test': {
                'paths': ['logs/*.log'],
                'includes': ['wrong'],
                'handler': self.handler
            }
        }
        os.makedirs('logs')
        f = self.open('logs/a.log')
        logdogs = LogDogs(DOGS)

        self.write(f, 'something wrong\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(), ['something wrong\n'])

        f = self.open('logs/b.log')

        self.write(f, 'whats wrong\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(), ['whats wrong\n'])

    def test_glob_recursively(self):
        """
        ** can be used in glob pattern
        """
        DOGS = {
            'test': {
                'paths': ['logs/**/*.log'],
                'includes': ['wrong'],
                'handler': self.handler
            }
        }

        os.makedirs('logs/b')
        os.makedirs('logs/c')
        fa = self.open('logs/a.log')
        fb = self.open('logs/b/b.log')
        logdogs = LogDogs(DOGS)

        self.write(fa, 'something wrong\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(), ['something wrong\n'])

        self.write(fb, 'you are wrong\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(), ['you are wrong\n'])

        # create a new log file
        fc = self.open('logs/c/c.log')
        self.write(fc, 'Am I wrong?\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(), ['Am I wrong?\n'])

        # create a new sub-directory
        os.makedirs('logs/d')
        fd = self.open('logs/d/d.log')
        self.write(fd, 'wrong! wrong! wrong!\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(), ['wrong! wrong! wrong!\n'])

    def test_half_line(self):
        """
        bug: half line will be read if the log is being written at the same time
        """
        DOGS = {
            'test': {
                'paths': ['a.log'],
                'includes': ['wrong'],
                'handler': self.handler
            }
        }
        # create an empty file or truncate if it exists
        f = self.open('a.log')
        logdogs = LogDogs(DOGS)

        self.write(f, 'something w')
        logdogs.process()
        self.assertTrue(self.q.empty())

        self.write(f, 'rong\n')
        logdogs.process()
        self.assertEqual(self.q.get_nowait(), ['something wrong\n'])
Example #39
0
    def fit(X0,
            W0,
            X1,
            W1,
            max_depth=2,
            min_samples_leaf=10,
            allowed_features=None,
            clip=3,
            quantizer=32):
        """ Train decision tree

        Inputs
        ------
        X0, X1 : ndarray
            Samples
        W0, W1 : ndarray or None
            Sample weights
        max_depth : scalar int
            Maximal depth of decision tree
        min_samples_leaf : scalar int/float
            TODO
        allowed_features : callable
            TODO

        Outputs
        -------
        tree : waldboost.training.DTree
            Initialized instance of decision tree. tree.apply, tree.predict and
            others can be called.
        """
        shape = X0.shape[1:]

        X = np.concatenate([as_features(X0), as_features(X1)])
        Y = np.array([0] * X0.shape[0] + [1] * X1.shape[0])
        W = np.concatenate([W0, W1])

        n_samples = W.size
        _W = W.copy()
        _W[Y == 0] /= _W[Y == 0].sum() * 2
        _W[Y == 1] /= _W[Y == 1].sum() * 2

        node_id = count()  # Counter of node ids
        sample_queue = Queue(
        )  # Queue holding sample sets that need to be processed
        sample_queue.put(
            (np.arange(n_samples), 0, next(node_id))
        )  # Initial set includes all samples, depth is 0 and id is 0 (from the counter)
        nodes = dict()  # Intermediate representation of tree
        while not sample_queue.empty():
            self_samples, depth, self_index = sample_queue.get()
            # Create either leaf or split node
            create_leaf = (depth == max_depth) or (self_samples.size <
                                                   min_samples_leaf)
            if create_leaf:
                nodes[self_index] = {
                    "samples": self_samples,
                    "feature": -1,
                    "threshold": -1,
                    "left": -1,
                    "right": -1
                }
                # logging.info(f"Leaf node {nodes[self_index]}")
            else:
                _X = X
                if allowed_features is not None:
                    ftrs = allowed_features[depth]
                    _X = X[:, ftrs]
                feature, threshold, _ = _find_split(_X[self_samples],
                                                    Y[self_samples],
                                                    _W[self_samples])
                if allowed_features is not None:
                    ftrs = allowed_features[depth]
                    feature = ftrs[feature]
                # TODO: check for feature/metric feasibility
                # split data
                bin = X[self_samples, feature] <= threshold
                # schedule l/r nodes
                left_index = next(node_id)
                sample_queue.put((self_samples[bin], depth + 1, left_index))
                right_index = next(node_id)
                sample_queue.put((self_samples[~bin], depth + 1, right_index))
                # create node
                nodes[self_index] = {
                    "samples": self_samples,
                    "feature": feature,
                    "threshold": threshold,
                    "left": left_index,
                    "right": right_index
                }

        # Compose final structude
        n_nodes = len(nodes)
        feature = [None] * n_nodes
        threshold = np.empty(n_nodes)
        left = np.empty(n_nodes, "i")
        right = np.empty(n_nodes, "i")
        pred = np.empty(n_nodes, "f")

        for node_idx, node_data in nodes.items():
            f = node_data["feature"]
            feature[node_idx] = np.unravel_index(f, shape) if f >= 0 else None
            threshold[node_idx] = node_data["threshold"]
            left[node_idx] = node_data["left"]
            right[node_idx] = node_data["right"]
            idx = node_data["samples"]
            y, w = Y[idx], W[idx]
            w0 = w[y == 0].sum() + 1e-3
            w1 = w[y == 1].sum() + 1e-3
            pred[node_idx] = np.log(w1 / w0) / 2

        if clip is not None:
            pred = np.clip(pred, -clip, clip)

        if quantizer is not None:
            pred = np.round(quantizer * pred) / quantizer

        # Return initialized waldboost.training.DTree instance
        return BaseDTree(feature, threshold, left, right, pred)
Example #40
0
class BiliBiliTop(object):
    def __init__(self):
        # 链接mongodb,获取操作集合的对象
        self.client = pymongo.MongoClient(host='mongodb://127.0.0.1', port=27017)
        self.collection = self.client['bilibili']['top']
        # 创建队列
        self.queue = Queue()
        # 线程锁
        self.lock = threading.RLock()
        # 线程数
        self.thread_num = 8
        # 线程列表
        self.thread_list = []
        # 保存信息的数量
        self.count = 0
        self.type1_dict = {
            '1': '全站榜',
            '2': '原创榜',
            '3': '新人榜',
        }
        self.rid1_dict = {
            '0': '全站',
            '1': '动画',
            '3': '音乐',
            '4': '游戏',
            '5': '娱乐',
            '36': '科技',
            '119': '鬼畜',
            '129': '舞蹈',
            '155': '时尚',
            '160': '生活',
            '168': '国创相关',
            '181': '影视',
            '188': '数码',
        }
        self.type2_dict = {
            '1': '新番榜',
            '4': '新番榜',
            '2': '影视榜',
            '3': '影视榜',
            '5': '影视榜',
        }
        self.rid2_dict = {
            '1': '番剧',
            '2': '电影',
            '3': '纪录片',
            '4': '国产动画',
            '5': '电视剧',
        }

    def save_item(self, item):
        # 若存在,则修改。若不存在,则插入
        filter = {
            'type_': item.type_,
            'rid': item.rid,
            'video_url': item.video_url,
        }
        if self.collection.find_one(filter):
            self.collection.update_one(filter, {'$set': item.__dict__})
        else:
            self.collection.insert_one(item.__dict__)
            self.count += 1
            print('已保存第{}条记录:{}'.format(self.count, item.__dict__))

    def get_rank(self):
        with self.lock:
            while not self.queue.empty():
                headers = {
                    'User-Agent': requests.get('http://127.0.0.1:8888/headers').text,
                    'Referer': 'https://www.bilibili.com/'
                }
                # 从队列中取出url信息
                url_info = self.queue.get()
                # 发送请求
                response = requests.get(url=url_info['url'], headers=headers)
                # 加载json格式数据
                data = json.loads(response.text)
                data_list = jsonpath.jsonpath(data, '$..list')
                data_list = data_list[0] if data_list else ''
                for index, item in enumerate(data_list):
                    # 创建item对象
                    bilibili = BibiliItem()
                    bilibili.type_ = url_info['type_']
                    bilibili.rid = url_info['rid']
                    bilibili.rank_url = url_info['url']
                    # 下标则为排名
                    bilibili.rank = index + 1
                    bilibili.title = item['title']
                    bilibili.coins = item['coins'] if 'coins' in item else '0'
                    bilibili.video_review = item['video_review'] if 'video_review' in item else '0'
                    if 'play' in item:
                        bilibili.play = item['play']
                    else:
                        bilibili.play = item['stat']['view']
                    if 'author' in item:
                        bilibili.author_name = item['author']
                        bilibili.author_url = 'https://space.bilibili.com/{}'.format(item['mid'])
                    else:
                        bilibili.author_name = '大会员抢先看'
                        bilibili.author_url = ''
                    if 'aid' in item:
                        bilibili.video_url = 'https://www.bilibili.com/video/av{}/'.format(item['aid'])
                    else:
                        bilibili.video_url = item['url']
                    self.save_item(bilibili)
            self.thread_list.remove(threading.current_thread())
            self.queue.task_done()

    def get_urls_info(self):
        # 获取当前星期几
        # day = datetime.datetime.now().weekday() + 1
        day = 3
        for type_key in self.type1_dict.keys():
            for rid_key in self.rid1_dict.keys():
                # https://api.bilibili.com/x/web-interface/ranking?rid=1&day=3&type=1
                # type_: 全站榜 原创榜 新人榜
                # rid:全站 动画 国创相关 音乐 舞蹈 游戏 科技 数码 生活 鬼畜 时尚 娱乐 影视
                # day:未知
                base_url = 'https://api.bilibili.com/x/web-interface/ranking?rid={}&day={}&type={}'
                url_dict = {
                    'type_': self.type1_dict[type_key],
                    'rid': self.rid1_dict[rid_key],
                    'url': base_url.format(rid_key, day, type_key),
                }
                self.queue.put_nowait(url_dict)
        for rid_key in self.rid2_dict.keys():
            # https://api.bilibili.com/pgc/web/rank/list?day={}&season_type={1,2,3,4,5}
            # type_:番剧 国产动画 纪录片 电影 电视剧
            # day:星期x
            base_url = 'https://api.bilibili.com/pgc/web/rank/list?day={}&season_type={}'
            url_dict = {
                'type_': self.type2_dict[rid_key],
                'rid': self.rid2_dict[rid_key],
                'url': base_url.format(day, rid_key),
            }
            self.queue.put_nowait(url_dict)

    def run(self):
        start_time = time.time()
        # 获取所有url
        self.get_urls_info()
        print('已获取{}个url'.format(self.queue.qsize()))
        # 创建若干线程
        for i in range(self.thread_num):
            thread = threading.Thread(target=self.get_rank)
            thread.start()
            self.thread_list.append(thread)
        # 创建阻塞
        for thread in self.thread_list:
            thread.join()
        while True:
            # 若子线程均执行完毕
            if len(self.thread_list) == 0:
                self.client.close()
                print('耗时:{:.2f}s,共保存{}条记录'.format(time.time() - start_time, self.count))
                break

    def __del__(self):
        self.client.close()
Example #41
0
class mainflow():
    def __init__(self, keyword):
        self.url = 'https://search.51job.com/list/000000,000000,0000,00,9,99,{},2,{}.html'
        self.logurl = 'https://login.51job.com/login.php?lang=c'
        self.session = requests.session()
        self.url_queue = Queue() #初始化url序列
        self.resp_queue = Queue() #初始化响应对象序列
        self.data_queue = Queue() #初始化数据序列
        self.keyword = preparam().encodeurl(preparam().encodeurl(keyword)) # 前程无忧把关键词encode了两次,所以这里做两次encode
        self.paramstr = preparam().str2dic(file='utils/paramdata.txt')
        self.logdata = preparam().str2dic(file='utils/loginfo.txt')
        self.HTTPErrorURL = []
        pass

    #登录51job
    def login(self, header=None):
        if header is not None:
            self.session.headers.update(header)
        r = self.session.post(self.logurl, data=self.logdata)
        if r.status_code != 200:
            raise Exception(print('The Login process has Failed , status code : %s' % str(r.status_code)))
        return None

    def put_url(self):
        r = self.session.get(self.url.format(self.keyword, 1), params=self.paramstr)
        pagecnt = int(etree.HTML(r.content).xpath("//div[@class='rt']/text()")[2].replace(' / ', '')) #获取页面数
        self.url_queue = Queue(pagecnt) #定义序列长度
        self.resp_queue = Queue(pagecnt)
        self.data_queue = Queue(pagecnt)
        #生成url并输入到url序列中
        for i in range(1, pagecnt+1):
            self.url_queue.put(self.url.format(self.keyword, i))
        return None

    def startThreading(self):
        rthread = []
        for i in range(5):
            getresp = getResp_thread(self.session, self.url_queue, self.resp_queue, self.paramstr)
            getresp.start()
            rthread.append(getresp)

        pthread = []
        for i in range(10):
            parsedata = parseData_thread(self.url_queue, self.resp_queue, self.data_queue)
            parsedata.start()
            pthread.append(parsedata)

        for thread in rthread:
            thread.join()

        for thread in pthread:
            thread.join()

        return None

    def writedata(self, outfile):
        while True:
            #条件成立即代表所有数据已经全部写入
            if self.url_queue.empty() and self.resp_queue.empty() and self.data_queue.empty():
                return
            data = self.data_queue.get(block=True)
            data.to_csv(outfile, mode='a', header=False, index=False, encoding='utf-8', sep='~')

    def run(self):
        self.login(header=preparam().user_agent) #登录
        self.put_url() #生成url并放入序列中
        self.startThreading() #启动线程
        self.writedata('data/jobinfo_%s.txt' % time.strftime('%Y%m%d', time.localtime(time.time()))) #写入到本地txt
        return None
Example #42
0
class Server:
    def __init__(self, port):
        try:
            self.port = port
            self.sock = socket.socket()
            self.sock.bind(('', self.port))
            self.sock.listen(5)  # will end connection after 5 bad attempts
            self.accepted = []
            self.threads = Queue()
            print("Listening to port:", self.port)
        except Exception as e:
            raise e

    def run(self):
        listening_thread = threading.Thread(target=self.__listen, daemon=True)
        ui_thread = threading.Thread(target=self.__ui, daemon=True)

        self.threads.put(listening_thread)
        self.threads.put(ui_thread)

        listening_thread.start()
        ui_thread.start()

        self.threads.join()

    def kill_threads(self):  # kills the jobs in order to quit
        while not self.threads.empty():
            self.threads.get()
            self.threads.task_done()

    def __listen(self):
        while True:
            accept = self.sock.accept()  # a list of tuples (conn, address)
            print()
            print("Received connection from: " + str(accept[1][0]) + ":" +
                  str(accept[1][1]))
            print("->", end='')
            self.accepted.append(accept)

    def send(
        self, conn, data
    ):  # the way this works could pose issues like getting stuck with infinite loops
        conn.send(data)
        response = conn.recv(1024)
        while response[-2:] != b"->":
            response += conn.recv(1024)
        return response

    def update_clients(self):
        clients = "Current connections\n"
        for i, accept in enumerate(self.accepted):
            try:
                self.send(accept[0], b' ')
            except:
                del self.accepted[i]
            clients += str(i) + ": " + str(accept[1][0]) + ":" + str(
                accept[1][1]) + '\n'
        return clients

    def save_files(self, files):
        files = files.split(b'FILE_END')[:-1]
        for file in files:
            file = file.split(b'NAME')
            try:
                with open("received_files/" + file[0].decode("utf-8"),
                          "wb") as f:
                    f.write(file[1])
            except FileNotFoundError:
                os.mkdir("received_files")
                with open("received_files/" + file[0].decode("utf-8"),
                          "wb") as f:
                    f.write(file[1])
            except Exception as e:
                print(e)
                print("Error receiving", file)

    def parse_files(self, files):
        parsed = b"FILES"
        for file in files:
            try:
                with open("send/" + file, "rb") as f:
                    parsed += file.encode() + b"NAME" + f.read() + b"FILE_END"
            except FileNotFoundError:
                print("File", file, "doesn't exist!")
            except Exception as e:
                print(e)
                print("Unable to send", file)
        parsed += b"CONN_END"
        return parsed

    def send_kill_signal(self):
        self.update_clients()
        for conn, address in self.accepted:
            conn.send(b"exit")

    def communicate(self, conn):
        try:
            print("Type a command")
            print("->", end='')
            while True:
                command = input().strip()
                if command.lower() == "exit":
                    conn.send("exit".encode())
                    conn.close()
                    return
                if len(command) > 0:
                    command_args = shlex.split(command)
                    if command_args[0] == "send":
                        files = command_args[1:]
                        response = self.send(
                            conn, self.parse_files(files)).decode("utf-8")
                        print(response, end='')
                    else:
                        response = self.send(conn, command.encode())
                        try:
                            if response[:len(b"FILES")] == b"FILES":
                                self.save_files(response[len(b"FILES"):-1 *
                                                         len(b'->')])
                                print("->", end='')
                            else:
                                print(response.decode("utf-8"), end='')
                        except UnicodeDecodeError:
                            lines = response.split(b'\n')
                            for line in lines:
                                if line == lines[-1]:
                                    print(line, end='')
                                else:
                                    print(line)
        except Exception as e:
            print(e)
            print("Failed to send command!")
            self.update_clients()

    def __ui(self):
        print("Reverse shell by dbc201")
        print("Type help for further info")
        while True:
            connect = input("->").strip()
            if connect.lower() == "exit":
                self.send_kill_signal()
                self.kill_threads()
                self.sock.close()
                return
            elif connect.lower() == "help":
                print("Type list to view active connections")
                print(
                    "------------------------------------------------------------------"
                )
                print("Type the number of the client to connect")
                print(
                    "------------------------------------------------------------------"
                )
                print(
                    "Type receive and name of the files to receive files from client"
                )
                print("Files must be in the current directory")
                print(
                    "------------------------------------------------------------------"
                )
                print(
                    "Type send to send files, you must create a folder called send"
                )
                print("and put the files there")
                print(
                    "------------------------------------------------------------------"
                )
                print("Type exit to exit")
            elif connect.lower() == "list":
                print(self.update_clients())
            else:
                try:
                    conn = self.accepted[int(connect)][0]
                    self.communicate(conn)
                except:
                    print("Failed to connect to client!")
Example #43
0
class MTurkAgent(Agent):
    """Base class for an MTurkAgent that can act in a ParlAI world"""

    # MTurkAgent Possible Statuses
    ASSIGNMENT_NOT_DONE = 'NotDone'
    ASSIGNMENT_DONE = 'Submitted'
    ASSIGNMENT_APPROVED = 'Approved'
    ASSIGNMENT_REJECTED = 'Rejected'

    def __init__(self, opt, manager, hit_id, assignment_id, worker_id):
        super().__init__(opt)

        self.conversation_id = None
        self.manager = manager
        self.id = None
        self.state = AssignState()
        self.assignment_id = assignment_id
        self.hit_id = hit_id
        self.worker_id = worker_id
        self.some_agent_disconnected = False
        self.hit_is_expired = False
        self.hit_is_abandoned = False # state from Amazon MTurk system
        self.hit_is_returned = False # state from Amazon MTurk system
        self.hit_is_complete = False # state from Amazon MTurk system
        self.disconnected = False
        self.task_group_id = manager.task_group_id
        self.message_request_time = None
        self.recieved_packets = {}

        self.msg_queue = Queue()

    def get_connection_id(self):
        """Returns an appropriate connection_id for this agent"""
        return "{}_{}".format(self.worker_id, self.assignment_id)

    def log_reconnect(self):
        """Log a reconnect of this agent """
        shared_utils.print_and_log(
            logging.DEBUG,
            'Agent ({})_({}) reconnected to {} with status {}'.format(
                self.worker_id, self.assignment_id,
                self.conversation_id, self.state.status
            )
        )

    def get_inactive_command_data(self):
        """Get appropriate inactive command data to respond to a reconnect"""
        text, command = self.state.get_inactive_command_text()
        return {
            'text': command,
            'inactive_text': text,
            'conversation_id': self.conversation_id,
            'agent_id': self.worker_id,
        }

    def wait_for_status(self, desired_status):
        """Suspend a thread until a particular assignment state changes
        to the desired state
        """
        while True:
            if self.state.status == desired_status:
                break
            time.sleep(shared_utils.THREAD_SHORT_SLEEP)

    def is_in_task(self):
        """Use conversation_id to determine if an agent is in a task"""
        if self.conversation_id:
            return 't_' in self.conversation_id
        return False

    def observe(self, msg):
        """Send an agent a message through the mturk manager"""
        self.manager.send_message(self.worker_id, self.assignment_id, msg)

    def put_data(self, id, data):
        """Put data into the message queue if it hasn't already been seen"""
        if id not in self.recieved_packets:
            self.recieved_packets[id] = True
            self.msg_queue.put(data)

    def get_new_act_message(self):
        """Get a new act message if one exists, return None otherwise"""
        # Check if Turker sends a message
        if not self.msg_queue.empty():
            msg = self.msg_queue.get()
            if msg['id'] == self.id:
                return msg

        # See if any agent has disconnected
        if self.disconnected or self.some_agent_disconnected:
            msg = {
                'id': self.id,
                'text': MTURK_DISCONNECT_MESSAGE,
                'episode_done': True
            }
            return msg

        # Check if the current turker already returned the HIT
        if self.hit_is_returned:
            msg = {
                'id': self.id,
                'text': RETURN_MESSAGE,
                'episode_done': True
            }
            return msg

        # There are no messages to be sent
        return None

    def prepare_timeout(self):
        """Log a timeout event, tell mturk manager it occurred, return message
        to return for the act call
        """
        shared_utils.print_and_log(
            logging.INFO,
            '{} timed out before sending.'.format(self.id)
        )
        self.manager.handle_turker_timeout(
            self.worker_id,
            self.assignment_id
        )
        msg = {
            'id': self.id,
            'text': TIMEOUT_MESSAGE,
            'episode_done': True
        }
        return msg

    def request_message(self):
        if not (self.disconnected or self.some_agent_disconnected or
                self.hit_is_expired):
            self.manager.send_command(
                self.worker_id,
                self.assignment_id,
                {'text': data_model.COMMAND_SEND_MESSAGE}
            )

    def act(self, timeout=None, blocking=True):
        """Sends a message to other agents in the world. If blocking, this
        will wait for the message to come in so it can be sent. Otherwise
        it will return None.
        """
        if not blocking:
            # if this is the first act since last sent message start timing
            if self.message_request_time is None:
                self.request_message()
                self.message_request_time = time.time()

            # If checking timeouts
            if timeout:
                # If time is exceeded, timeout
                if time.time() - self.message_request_time > timeout:
                    return self.prepare_timeout()

            # Get a new message, if it's not None reset the timeout
            msg = self.get_new_act_message()
            if msg is not None and self.message_request_time is not None:
                self.message_request_time = None
            return msg
        else:
            self.request_message()

            # Timeout in seconds, after which the HIT will be expired automatically
            if timeout:
                start_time = time.time()

            # Wait for agent's new message
            while True:
                msg = self.get_new_act_message()
                if msg is not None:
                    return msg

                # Check if the Turker waited too long to respond
                if timeout:
                    current_time = time.time()
                    if (current_time - start_time) > timeout:
                        return self.prepare_timeout()
                time.sleep(shared_utils.THREAD_SHORT_SLEEP)

    def change_conversation(self, conversation_id, agent_id, change_callback):
        """Handle changing a conversation for an agent, takes a callback for
        when the command is acknowledged
        """
        self.id = agent_id
        self.conversation_id = conversation_id
        data = {
            'text': data_model.COMMAND_CHANGE_CONVERSATION,
            'conversation_id': conversation_id,
            'agent_id': agent_id
        }
        self.manager.send_command(
            self.worker_id,
            self.assignment_id,
            data,
            ack_func=change_callback
        )

    def episode_done(self):
        """Return whether or not this agent believes the conversation to
        be done"""
        if self.manager.get_agent_work_status(self.assignment_id) == \
                self.ASSIGNMENT_NOT_DONE:
            return False
        else:
            return True

    def _print_not_available_for(self, item):
        shared_utils.print_and_log(
            logging.WARN,
            'Conversation ID: {}, Agent ID: {} - HIT '
            'is abandoned and thus not available for '
            '{}.'.format(self.conversation_id, self.id, item),
            should_print=True
        )

    def approve_work(self):
        """Approving work after it has been submitted"""
        if self.hit_is_abandoned:
            self._print_not_available_for('review')
        else:
            if self.manager.get_agent_work_status(self.assignment_id) == \
                    self.ASSIGNMENT_DONE:
                self.manager.approve_work(assignment_id=self.assignment_id)
                shared_utils.print_and_log(
                    logging.INFO,
                    'Conversation ID: {}, Agent ID: {} - HIT is '
                    'approved.'.format(self.conversation_id, self.id)
                )
            else:
                shared_utils.print_and_log(
                    logging.WARN,
                    'Cannot approve HIT. Turker hasn\'t completed the HIT yet.'
                )

    def reject_work(self, reason='unspecified'):
        """Reject work after it has been submitted"""
        if self.hit_is_abandoned:
            self._print_not_available_for('review')
        else:
            if self.manager.get_agent_work_status(self.assignment_id) == \
                    self.ASSIGNMENT_DONE:
                self.manager.reject_work(self.assignment_id, reason)
                shared_utils.print_and_log(
                    logging.INFO,
                    'Conversation ID: {}, Agent ID: {} - HIT is '
                    'rejected.'.format(self.conversation_id, self.id)
                )
            else:
                shared_utils.print_and_log(
                    logging.WARN,
                    'Cannot reject HIT. Turker hasn\'t completed the HIT yet.'
                )

    def block_worker(self, reason='unspecified'):
        """Block a worker from our tasks"""
        self.manager.block_worker(worker_id=self.worker_id, reason=reason)
        shared_utils.print_and_log(
            logging.WARN,
            'Blocked worker ID: {}. Reason: {}'.format(self.worker_id, reason),
            should_print=True
        )

    def pay_bonus(self, bonus_amount, reason='unspecified'):
        """Pays the given agent the given bonus"""
        if self.hit_is_abandoned:
            self._print_not_available_for('bonus')
        else:
            if self.manager.get_agent_work_status(self.assignment_id) in \
                    (self.ASSIGNMENT_DONE, self.ASSIGNMENT_APPROVED):
                unique_request_token = str(uuid.uuid4())
                self.manager.pay_bonus(
                    worker_id=self.worker_id,
                    bonus_amount=bonus_amount,
                    assignment_id=self.assignment_id,
                    reason=reason,
                    unique_request_token=unique_request_token
                )
            else:
                shared_utils.print_and_log(
                    logging.WARN,
                    'Cannot pay bonus for HIT. Reason: Turker '
                    'hasn\'t completed the HIT yet.'
                )

    def email_worker(self, subject, message_text):
        """Sends an email to a worker, returns true on a successful send"""
        response = self.manager.email_worker(
            worker_id=self.worker_id,
            subject=subject,
            message_text=message_text
        )
        if 'success' in response:
            shared_utils.print_and_log(
                logging.INFO,
                'Email sent to worker ID: {}: Subject: {}: Text: {}'.format(
                    self.worker_id,
                    subject,
                    message_text
                )
            )
            return True
        elif 'failure' in response:
            shared_utils.print_and_log(
                logging.WARN,
                "Unable to send email to worker ID: {}. Error: {}".format(
                    self.worker_id,
                    response['failure']
                )
            )
            return False

    def set_hit_is_abandoned(self):
        """Update local state to abandoned and expire the hit through MTurk"""
        if not self.hit_is_abandoned:
            self.hit_is_abandoned = True
            self.manager.force_expire_hit(self.worker_id, self.assignment_id)

    def wait_for_hit_completion(self, timeout=None):
        """Waits for a hit to be marked as complete"""
        # Timeout in seconds, after which the HIT will be expired automatically
        if timeout:
            if timeout < 0:
                # Negative timeout is for testing
                self.manager.free_workers([self])
                return True
            start_time = time.time()
        iters = (shared_utils.THREAD_MTURK_POLLING_SLEEP /
                 shared_utils.THREAD_SHORT_SLEEP)
        i = 0
        while not self.hit_is_complete and i < iters:
            time.sleep(shared_utils.THREAD_SHORT_SLEEP)
            i += 1
        while not self.hit_is_complete and \
                self.manager.get_agent_work_status(self.assignment_id) != \
                self.ASSIGNMENT_DONE:
            # Check if the Turker already returned/disconnected
            if self.hit_is_returned or self.disconnected:
                self.manager.free_workers([self])
                return False
            if timeout:
                current_time = time.time()
                if (current_time - start_time) > timeout:
                    shared_utils.print_and_log(
                        logging.INFO,
                        "Timeout waiting for ({})_({}) to complete {}.".format(
                            self.worker_id,
                            self.assignment_id,
                            self.conversation_id
                        )
                    )
                    self.set_hit_is_abandoned()
                    self.manager.free_workers([self])
                    return False
            shared_utils.print_and_log(
                logging.DEBUG,
                'Waiting for ({})_({}) to complete {}...'.format(
                    self.worker_id, self.assignment_id, self.conversation_id
                )
            )
            i = 0
            while not self.hit_is_complete and i < iters:
                time.sleep(shared_utils.THREAD_SHORT_SLEEP)
                i += 1

        shared_utils.print_and_log(
            logging.INFO,
            'Conversation ID: {}, Agent ID: {} - HIT is done.'.format(
                self.conversation_id, self.id
            )
        )
        self.manager.free_workers([self])
        return True

    def reduce_state(self):
        """Cleans up resources related to maintaining complete state"""
        self.msg_queue = None
        self.state.clear_messages()
        self.recieved_packets = None

    def shutdown(self, timeout=None, direct_submit=False):
        """Shuts down a hit when it is completed"""
        # Timeout in seconds, after which the HIT will be expired automatically
        command_to_send = data_model.COMMAND_SHOW_DONE_BUTTON
        if direct_submit:
            command_to_send = data_model.COMMAND_SUBMIT_HIT
        if not (self.hit_is_abandoned or self.hit_is_returned or \
                self.disconnected or self.hit_is_expired):
            self.manager.mark_workers_done([self])
            self.manager.send_command(
                self.worker_id,
                self.assignment_id,
                {'text': command_to_send},
            )
            return self.wait_for_hit_completion(timeout=timeout)
Example #44
0
class ClThread(threading.Thread):
    """
    这是一个随意新建线程的生产者消费者模型'
    其实有个队列, 队列中保存的是 QA_Task 对象 , callback 很重要,指定任务的时候可以绑定 函数执行
    QA_Engine 继承这个类。

    自带一个Queue
    有 self.put/ self.put_nowait/ self.get/ self.get_nowait 4个关于queue的方法

    如果你重写了run方法:
    则你需要自行处理queue中的事情/简单的做你自己的逻辑

    """

    def __init__(self, queue=None, name=None, daemon=False):
        threading.Thread.__init__(self)
        self.queue = Queue() if queue is None else queue
        self.thread_stop = False
        self.__flag = threading.Event()  # 用于暂停线程的标识
        self.__flag.set()  # 设置为True
        self.__running = threading.Event()  # 用于停止线程的标识
        self.__running.set()  # 将running设置为True
        self.name = util_random_with_topic(
            topic='czsc',
            lens=3
        ) if name is None else name
        self.idle = False
        self.daemon = daemon

    def __repr__(self):
        return '<QA_Thread: {}  id={} ident {}>'.format(
            self.name,
            id(self),
            self.ident
        )

    def run(self):
        while self.__running.isSet():
            self.__flag.wait()
            while not self.thread_stop:
                '这是一个阻塞的队列,避免出现消息的遗漏'
                try:
                    if self.queue.empty() is False:
                        _task = self.queue.get()  # 接收消息
                        # print(_task.worker, self.name)
                        assert isinstance(_task, Cl_Task)
                        if _task.worker != None:

                            _task.do()

                            self.queue.task_done()  # 完成一个任务
                        else:
                            pass
                    else:
                        self.idle = True

                        # Mac book下风扇狂转,如果sleep cpu 占用率回下降
                        # time.sleep(0.01)
                except Exception as e:
                    if isinstance(e, ValueError):
                        pass
                    else:
                        raise e

    def pause(self):
        self.__flag.clear()

    def resume(self):
        self.__flag.set()  # 设置为True, 让线程停止阻塞

    def stop(self):
        # self.__flag.set()       # 将线程从暂停状态恢复, 如何已经暂停的话
        self.__running.clear()
        self.thread_stop = True  # 设置为False

    def __start(self):
        self.queue.start()

    def put(self, task):
        self.queue.put(task)

    def put_nowait(self, task):
        self.queue.put_nowait(task)

    def get(self):
        return self.queue.get()

    def get_nowait(self):
        return self.queue.get_nowait()

    def qsize(self):
        return self.queue.qsize()
Example #45
0
class LambdaTrader(object):
    def __init__(self, logger):
        self.Logger = logger
        self.CurrentPositions = Queue()
        self.CurrentBalance = Queue()
        self.SubmittedOrders = Queue()
        self.Messages = []
        self.PendingOrders = Queue()
        db = boto3.resource('dynamodb', region_name='us-east-1')
        self.__Securities = db.Table('Securities')
        self.__Orders = db.Table('Orders')
        self.FixClient = gain.FixClient.Create(self.Logger, 'config.ini',
                                               False)
        self.FixClient.addOrderListener(self.OrderNotificationReceived)
        self.FixClient.addAccountInquiryListener(self.AccountInquiryReceived)

    def AccountInquiryReceived(self, event):
        if event.AccountInquiry == gain.AccountInquiry.CollateralInquiry:
            self.Logger.info('CollInquiryID: %s Account: %s' %
                             (event.CollInquiryID, event.Account))
            self.Logger.info('Balance: %s Currency: %s' %
                             (event.Balance, event.Currency))
            self.CurrentBalance.put(
                (event.CollInquiryID, event.Balance, event.Currency))
            self.CurrentBalance.task_done()
        if event.AccountInquiry == gain.AccountInquiry.RequestForPositions:
            self.Logger.info('PosReqID: %s Account: %s' %
                             (event.PosReqID, event.Account))
            self.Logger.info('Quantity: %s Amount: %s' %
                             (event.LongQty - event.ShortQty, event.PosAmt))
            self.CurrentPositions.put(
                (event.PosReqID, event.LongQty - event.ShortQty))
            self.CurrentPositions.task_done()

    def OrderNotificationReceived(self, event):
        self.Logger.info('OrderId: %s Status: %s Side: %s' %
                         (event.ClientOrderId, event.Status, event.Side))
        self.Logger.info('Symbol: %s AvgPx: %s Quantity: %s' %
                         (event.Symbol, event.AvgPx, event.Quantity))
        self.Logger.info('order notification received')
        if event.Status == gain.OrderStatus.Filled or event.Status == gain.OrderStatus.Rejected:
            self.SubmittedOrders.put(
                (event.ClientOrderId, event.Status, event.AvgPx))
            self.SubmittedOrders.task_done()

    def SendOrder(self, side, quantity, symbol, maturity, newOrderId,
                  transactionTime):

        self.Logger.info('Submitting Validated order %s %s %s %s' %
                         (side, quantity, symbol, maturity))
        if side.upper() == gain.OrderSide.Buy.upper():
            order = gain.BuyFutureMarketOrder(symbol, maturity, quantity)
        elif side.upper() == gain.OrderSide.Sell.upper():
            order = gain.SellFutureMarketOrder(symbol, maturity, quantity)
        trade = self.FixClient.send(order)
        orderId, status, price = self.SubmittedOrders.get(True, 5)
        while trade.OrderId != orderId:
            self.Logger.error(
                'requests do not match orderId: %s, trade.OrderId: %s' %
                (orderId, trade.OrderId))
            self.SubmittedOrders.put((orderId, status, price))
            self.SubmittedOrders.task_done()
            orderId, status, price = self.SubmittedOrders.get(True, 5)
        self.Logger.info(
            'Confirmed orderId %s. Status: %s. Price: %s. Symbol: %s' %
            (orderId, status, price, symbol))
        self.UpdateStatus(
            'Confirmed newOrderId: %s. ClientOrderId: %s. Status: %s. Side: %s. Qty: %s. Symbol: %s. '
            'Maturity: %s. Price: %s' % (newOrderId, orderId, status, side,
                                         quantity, symbol, maturity, price),
            newOrderId, transactionTime, orderId, status)

    def UpdateStatus(self, text, newOrderId, transactionTime, clientOrderId,
                     status):
        try:
            response = self.__Orders.update_item(
                Key={
                    'NewOrderId': newOrderId['S'],
                    'TransactionTime': transactionTime['S'],
                },
                UpdateExpression="set #s = :s, ClientOrderId = :c",
                ConditionExpression="#s = :p and NewOrderId = :n",
                ExpressionAttributeNames={'#s': 'Status'},
                ExpressionAttributeValues={
                    ':s': status,
                    ':c': clientOrderId,
                    ':n': newOrderId['S'],
                    ':p': 'PENDING'
                },
                ReturnValues="UPDATED_NEW")
            text += '. %s' % response['Attributes']

        except ClientError as e:
            self.Logger.error(e.response['Error']['Message'])
            text += '%s. %s' % ('', e.response['Error']['Message'])
        except Exception as e:
            self.Logger.error(e)
            text += '%s. %s' % ('', e)
        else:
            text += ". UpdateItem succeeded."
            self.Logger.info(json.dumps(response, indent=4,
                                        cls=DecimalEncoder))

        self.Logger.info('To Send Email: %s', text)
        self.Messages.append(text)

    def SendReport(self, text):
        try:
            self.Logger.info('Send Email: %s', text)

            def hash_smtp_pass_from_secret_key(key):
                message = "SendRawEmail"
                version = '\x02'
                h = hmac.new(key, message, digestmod=hashlib.sha256)
                return base64.b64encode("{0}{1}".format(version, h.digest()))

            msg = MIMEMultipart('alternative')
            msg['Subject'] = 'Lambda FIX Trader report'
            msg['From'] = os.environ['email_address']
            msg['To'] = os.environ['email_address']
            mime_text = MIMEText(text, 'html')
            msg.attach(mime_text)

            server = smtplib.SMTP('email-smtp.us-east-1.amazonaws.com',
                                  587,
                                  timeout=10)
            server.set_debuglevel(10)
            server.starttls()
            server.ehlo()
            server.login(
                os.environ['aws_access_key_id'],
                hash_smtp_pass_from_secret_key(
                    os.environ['aws_secret_access_key']))
            server.sendmail(os.environ['email_address'],
                            os.environ['email_address'], msg.as_string())
            res = server.quit()
            self.Logger.info(res)
        except Exception as e:
            self.Logger.error(e)

    def Run(self):
        self.FixClient.start()
        if not self.PendingOrders.empty():
            self.validate()

            report = reduce(
                lambda x, y: x + y,
                map(lambda x, y: '<br><b>%s</b>. %s\n' % (x + 1, y),
                    range(len(self.Messages)), self.Messages))
            self.SendReport(report)
        self.FixClient.stop()

    def validate_order(self, order, security):
        try:
            side = str(order['Details']['M']['Side']['S'])
            ordType = str(order['Details']['M']['OrdType']['S'])
            riskFactor = float(security['Risk']['RiskFactor'])
            margin = int(security['Risk']['Margin']['Amount'])
            marginCcy = str(security['Risk']['Margin']['Currency'])
            colReqId = self.FixClient.collateralInquiry()
            receiveColReqId, balance, ccy = self.CurrentBalance.get(True, 5)
            while colReqId != receiveColReqId:
                self.Logger.error(
                    'requests do not match colReqId: %s, receiveColReqId: %s' %
                    (colReqId, receiveColReqId))
                self.CurrentBalance.put((receiveColReqId, balance, ccy))
                self.CurrentBalance.task_done()
                receiveColReqId, balance, ccy = self.CurrentBalance.get(
                    True, 5)
            if marginCcy != ccy:
                raise Exception(
                    'Margin Currency does not match Balance Currency for %s' %
                    security['Symbol'])
            if balance * riskFactor < margin:
                raise Exception(
                    'Margin exceeded for %s. Balance: %s, RF: %s, Margin: %s' %
                    (security['Symbol'], balance, riskFactor, margin))
        except Exception as e:
            self.Logger.error(e)
            self.UpdateStatus(
                'Error validate_order NewOrderId: %s. %s' %
                (order['NewOrderId'], e), order['NewOrderId'],
                order['TransactionTime'], 0, 'INVALID')
            return False, None
        else:
            if ordType.upper() != gain.OrderType.Market.upper():
                supported = 'Only MARKET Orders are supported'
                self.Logger.error(supported)
                self.UpdateStatus(
                    'Error validate_order NewOrderId: %s. %s' %
                    (order['NewOrderId'], supported), order['NewOrderId'],
                    order['TransactionTime'], 0, 'INVALID')
                return False, None
            if side.upper() == gain.OrderSide.Buy.upper() or side.upper(
            ) == gain.OrderSide.Sell.upper():
                return True, side
            else:
                error = 'Unknown side received. Side: %s' % side
                self.Logger.error(error)
                self.UpdateStatus(
                    'Error validate_order NewOrderId: %s. %s' %
                    (order['NewOrderId'], error), order['NewOrderId'],
                    order['TransactionTime'], 0, 'INVALID')
                return False, None

    def validate_quantity(self, order, security):
        try:
            quantity = int(order['Details']['M']['Quantity']['N'])
            side = order['Details']['M']['Side']['S']
            maxPosition = security['Risk']['MaxPosition']
            reqId = self.FixClient.requestForPositions()
            receiveReqId, position = self.CurrentPositions.get(True, 5)
            while reqId != receiveReqId:
                self.Logger.error(
                    'requests do not match reqId: %s, receivedId: %s' %
                    (reqId, receiveReqId))
                self.CurrentPositions.put((receiveReqId, position))
                self.CurrentPositions.task_done()
                receiveReqId, position = self.CurrentPositions.get(True, 5)

            if side.upper() == gain.OrderSide.Buy.upper(
            ) and maxPosition < position + quantity:
                raise Exception('MaxPosition exceeded for %s' %
                                security['Symbol'])
            if side.upper() == gain.OrderSide.Sell.upper(
            ) and maxPosition < abs(position - quantity):
                raise Exception('MaxPosition exceeded for %s' %
                                security['Symbol'])
        except Empty:
            error = 'No reply to requestForPositions'
            self.Logger.error(error)
            self.UpdateStatus(
                'Error validate_quantity NewOrderId: %s. %s' %
                (order['NewOrderId'], error), order['NewOrderId'],
                order['TransactionTime'], 0, 'INVALID')
            return 0
        except Exception as e:
            self.Logger.error(e)
            self.UpdateStatus(
                'Error validate_quantity NewOrderId: %s. %s' %
                (order['NewOrderId'], e), order['NewOrderId'],
                order['TransactionTime'], 0, 'INVALID')
            return 0
        else:
            return quantity

    def validate_maturity(self, order):
        try:
            maturity = order['Details']['M']['Maturity']['S']
            year = int(maturity[:4])
            month = int(maturity[-2:])
            date = datetime.date(year, month, 1)
            expiry = self.get_expiry_date(date)
            if expiry < datetime.date.today():
                raise Exception('%s maturity date has expired' % expiry)

        except Exception as e:
            self.Logger.error(e)
            self.UpdateStatus(
                'Error validate_maturity NewOrderId: %s. %s' %
                (order['NewOrderId'], e), order['NewOrderId'],
                order['TransactionTime'], 0, 'INVALID')
            return None
        else:
            return maturity

    def validate_symbol(self, order):
        try:
            symbol = order['Details']['M']['Symbol']['S']
            self.Logger.info('Validating %s' % symbol)
            response = self.__Securities.get_item(Key={'Symbol': symbol})
        except ClientError as e:
            self.Logger.error(e.response['Error']['Message'])
            self.UpdateStatus(
                'ClientError validate_symbol NewOrderId: %s. %s' %
                (order['NewOrderId'], e), order['NewOrderId'],
                order['TransactionTime'], 0, 'INVALID')
            return False, None
        except Exception as e:
            self.Logger.error(e)
            self.UpdateStatus(
                'Error validate_symbol NewOrderId: %s. %s' %
                (order['NewOrderId'], e), order['NewOrderId'],
                order['TransactionTime'], 0, 'INVALID')
            return False, None
        else:
            # self.Logger.info(json.dumps(security, indent=4, cls=DecimalEncoder))
            if response.has_key('Item') and response['Item'][
                    'Symbol'] == symbol and response['Item']['TradingEnabled']:
                return True, response['Item']
            self.UpdateStatus(
                'Symbol is unknown or not enabled for trading %s' % symbol,
                order['NewOrderId'], order['TransactionTime'], 0, 'INVALID')
            return False, None

    def validate(self):
        while not self.PendingOrders.empty():

            order = self.PendingOrders.get()
            found, security = self.validate_symbol(order)
            if not found: continue

            maturity = self.validate_maturity(order)
            if not maturity: continue

            quantity = self.validate_quantity(order, security)
            if quantity < 1: continue

            good, side = self.validate_order(order, security)
            if not good: continue

            self.SendOrder(str(side), int(quantity), str(security['Symbol']),
                           str(maturity), order['NewOrderId'],
                           order['TransactionTime'])

    # lifted from https://github.com/conor10/examples/blob/master/python/expiries/vix.py
    @staticmethod
    def get_expiry_date(date):
        """
        http://cfe.cboe.com/products/spec_vix.aspx

        TERMINATION OF TRADING:

        Trading hours for expiring VIX futures contracts end at 7:00 a.m. Chicago
        time on the final settlement date.

        FINAL SETTLEMENT DATE:

        The Wednesday that is thirty days prior to the third Friday of the
        calendar month immediately following the month in which the contract
        expires ("Final Settlement Date"). If the third Friday of the month
        subsequent to expiration of the applicable VIX futures contract is a
        CBOE holiday, the Final Settlement Date for the contract shall be thirty
        days prior to the CBOE business day immediately preceding that Friday.
        """
        # Date of third friday of the following month
        if date.month == 12:
            third_friday_next_month = datetime.date(date.year + 1, 1, 15)
        else:
            third_friday_next_month = datetime.date(date.year, date.month + 1,
                                                    15)

        one_day = datetime.timedelta(days=1)
        thirty_days = datetime.timedelta(days=30)
        while third_friday_next_month.weekday() != 4:
            # Using += results in a timedelta object
            third_friday_next_month = third_friday_next_month + one_day

        # TODO: Incorporate check that it's a trading day, if so move the 3rd
        # Friday back by one day before subtracting
        return third_friday_next_month - thirty_days
    def april_following(self, desiredTag, desiredDistance, cvQueue: Queue, isFirstUse, isLastUse):

        # Fast-fail. If there is something on the cvQueue then that means we need to respond to it. There are
        # multiple calls of april_following(...) being made in succession in a for-loop to get the robot to a
        # destiantion. We want to quickly exit from each of the calls in this situation.
        if not cvQueue.empty():
            return

        # Tune the webcam to better see april tags while robot is moving
        # (compensating for motion blur). Restore settings when done.
        # These settings can be played with to create the best effect (along with other settings if you want)
        if isFirstUse:
            self.WebcamVideoStreamObject.stream.set(cv2.CAP_PROP_EXPOSURE, 0.5)
            self.WebcamVideoStreamObject.stream.set(cv2.CAP_PROP_GAIN, 1)

        # Frame is considered to be 600x600 (after resize) (actually it's like 600x400)
        # Below are variables to set what we consider center and in-range (these numbers are in pixels)
        radiusInRangeLowerBound, radiusInRangeUpperBound = desiredDistance - 10, desiredDistance + 10
        centerRightBound, centerLeftBound = 400, 200
        radiusTooCloseLowerLimit = 250

        # When turning to search for the desiredTag, we specify time to turn, and time to wait after each semi-turn.
        # Note that these two variables are NO LONGER USED! By adjusting the exposure to reduce the effects of motion
        # blur, we no longer have to do this turn-and-stop manuever to search for tags around us. Just rotating works
        # fine.
        searchingTimeToTurn = 0.3  # seconds
        searchingTimeToHalt = 1.0  # seconds

        # Creating a window for later use
        cv2.namedWindow('result')
        cv2.resizeWindow('result', 600, 600)

        # Variables to 'smarten' the following procedure. See their usage below.
        objectSeenOnce = False  # Object has never been seen before
        leftOrRightLastSent = None  # Keep track of whether we sent left or right last
        firstTimeObjectNotSeen = None  # Holds timestamp (in seconds) of the first time we haven't been able to see
        # the tag. We don't want to instantly start freaking out and turning around looking for the tag because it's
        # very possible it was lost in some bad frame, so we wait some X number of seconds before looking around (
        # this is what this timestamp is used for).

        # Initialize apriltag detector
        options = apriltag.DetectorOptions(
            families='tag36h11',
            border=1,
            nthreads=1,
            quad_decimate=1.0,
            quad_blur=0.0,
            refine_edges=True,
            refine_decode=True,
            refine_pose=False,
            debug=False,
            quad_contours=True)
        det = apriltag.Detector(options)

        # TODO delete this block when done (not necessary to so we kept it - just thought removing extra details
        #  would speed up performance)
        start = time.time()
        num_frames = 0
        inPosition = False
        numHalts = 0

        while True:

            # Grab frame - break if we don't get it (some unknown error occurred)
            frame = self.vs.read()
            if frame is None:
                break

            # TODO delete this block when done (same as above TODO)
            end = time.time()
            seconds = end - start
            num_frames += 1
            fps = 0 if (seconds == 0) else num_frames / seconds

            frame = imutils.resize(frame, width=600)
            # frame = cv2.filter2D(frame, -1, np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])) # Sharpen image
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)  # Use grayscale image for detection
            res = det.detect(gray)  # Run the image through the apriltag detector and get the results

            commandString = None  # Stores string to print on the screen (the current command to execute)

            # Check if the desiredTag is visible
            tagObject = None
            for r in res:
                if r.tag_id == desiredTag:
                    tagObject = r

            if tagObject is None:  # We don't see the tag that we're looking for

                # Don't see the tag? Possibly just bad frame, lets wait 2 seconds and then start turning

                numHalts += 1
                # TODO delete all the numHalt tracking stuff (this was to keep track and lessen the
                #  effects of motion blur ... we kept it since it didn't affect performance).

                if firstTimeObjectNotSeen is None:
                    firstTimeObjectNotSeen = time.time()
                    self.send_serial_command(Direction.STOP, b'h')
                    commandString = "STOP"
                else:
                    secondsOfNoTag = time.time() - firstTimeObjectNotSeen
                    if secondsOfNoTag > 2:  # Haven't seen our tag for more than 2 seconds
                        if leftOrRightLastSent is not None:
                            if leftOrRightLastSent == Direction.RIGHT:
                                self.send_serial_command(Direction.RIGHT, b'r')
                                commandString = "SEARCHING: GO RIGHT"
                            elif leftOrRightLastSent == Direction.LEFT:
                                self.send_serial_command(Direction.LEFT, b'l')
                                commandString = "SEARCHING: GO LEFT"
                        else:  # variable hasn't been set yet (seems unlikely), but default to left
                            self.send_serial_command(Direction.LEFT, b'r')
                            commandString = "DEFAULT SEARCHING: GO RIGHT"

                        # We've sent the command now wait half a second and then send a halt (WE DON"T NEED THIS ANYMORE)
                        # time.sleep(searchingTimeToTurn)
                        # self.send_serial_command(Direction.STOP, b'h');
                        # time.sleep(searchingTimeToHalt)

                    else:  # Keep waiting - 2 seconds haven't elapsed
                        self.send_serial_command(Direction.STOP, b'h')
                        commandString = "STOP"

            else:  # We see the desired tag!

                # Reset firstTimeObjectNotSeen to None for the next time we can't find the tag
                if firstTimeObjectNotSeen is not None:
                    firstTimeObjectNotSeen = None

                # Set objectSeenOnce to True if isn't already
                if not objectSeenOnce:
                    objectSeenOnce = True

                # Get the corners and draw a minimally enclosing circle of it
                # and get the x/y/radius information of that circle to use for our navigation
                corners = np.array(tagObject.corners, dtype=np.float32).reshape((4, 2, 1))

                cornersList = []
                for c in corners:
                    cornersList.append([int(x) for x in c])

                cornersList = np.array(cornersList, dtype=np.int32)  # Turn the list into a numpy array
                ((x, y), radius) = cv2.minEnclosingCircle(cornersList)
                M = cv2.moments(cornersList)

                # Grab the desired information...
                center = (int(M["m10"] / M["m00"]), int(M["m01"] / M["m00"]))
                filteredPtsRadius = [radius]
                filteredPtsX = [center[0]]
                filteredPtsY = [center[1]]

                # Draw circle and center point on the frame
                cv2.circle(frame, (int(x), int(y)), int(filteredPtsRadius[0]), (0, 255, 255), 2)
                cv2.circle(frame, center, 5, (0, 0, 255), -1)

                # Determine what command to send to the Arudino (motors)
                if filteredPtsRadius[0] > radiusTooCloseLowerLimit:
                    commandString = "MOVE BACKWARD - TOO CLOSE TO TURN"
                    self.send_serial_command(Direction.BACKWARD, b'b')
                elif filteredPtsX[0] > centerRightBound:
                    commandString = "GO RIGHT"
                    self.send_serial_command(Direction.RIGHT, b'r')
                    if leftOrRightLastSent != Direction.RIGHT:
                        leftOrRightLastSent = Direction.RIGHT
                elif filteredPtsX[0] < centerLeftBound:
                    commandString = "GO LEFT"
                    self.send_serial_command(Direction.LEFT, b'l')
                    if leftOrRightLastSent != Direction.LEFT:
                        leftOrRightLastSent = Direction.LEFT
                elif filteredPtsRadius[0] < radiusInRangeLowerBound:
                    commandString = "MOVE FORWARD"
                    self.send_serial_command(Direction.FORWARD, b'f')
                elif filteredPtsRadius[0] > radiusInRangeUpperBound:
                    commandString = "MOVE BACKWARD"
                    self.send_serial_command(Direction.BACKWARD, b'b')
                elif radiusInRangeLowerBound < filteredPtsRadius[0] < radiusInRangeUpperBound:
                    commandString = "STOP MOVING - IN RANGE"
                    self.send_serial_command(Direction.STOP, b'h')
                    inPosition = True

                # Put text on the camera image to display on the screen
                cv2.putText(frame, 'center coordinate: (' + str(filteredPtsX[0]) + ',' + str(filteredPtsY[0]) + ')',
                            (10, 60), self.font, 0.5, (200, 255, 155), 1, cv2.LINE_AA)
                cv2.putText(frame, 'filtered radius: (' + str(filteredPtsRadius[0]) + ')', (10, 90), self.font, 0.5,
                            (200, 255, 155), 1, cv2.LINE_AA)

            # Show FPS and number of halts (this stuff will be on the frame regardless of whether we see our desired
            # tag or not) (TODO delete this stuff later if we don't want it)
            cv2.putText(frame, commandString, (10, 30), self.font, 0.5, (200, 255, 155), 1, cv2.LINE_AA)
            cv2.putText(frame, 'FPS: (' + str(fps) + ')', (10, 120), self.font, 0.5,
                        (200, 255, 155), 1, cv2.LINE_AA)
            cv2.putText(frame, 'numHalts: (' + str(numHalts) + ')', (10, 150), self.font, 0.5,
                        (200, 255, 155), 1, cv2.LINE_AA)

            # Display frame
            cv2.imshow("result", frame)

            # Close application on 'q' key press, new stuff on queue, or if we've reached our destination
            key = cv2.waitKey(1) & 0xFF
            if (key == ord("q")) or (not cvQueue.empty()) or inPosition:
                self.send_serial_command(Direction.STOP, b'h');
                # Restore webcam settings
                if not inPosition or (inPosition and isLastUse):
                    # Reset the webcam to its original exposure and gain
                    self.WebcamVideoStreamObject.stream.set(cv2.CAP_PROP_EXPOSURE, self.originalExposure)
                    self.WebcamVideoStreamObject.stream.set(cv2.CAP_PROP_GAIN, self.originalGain)
                    # Activate auto_exposure (which is what the webcam starts out with by default but we mess it up
                    # by changing the exposure manually).
                    subprocess.check_call("v4l2-ctl -d /dev/video1 -c exposure_auto=3", shell=True)
                cv2.destroyAllWindows()
                break
    def get_coordinates(self, cvQueue: Queue):

        # To get the coordinate, we rotate on our axis some X number of times to form images that compose a complete
        # 360 degree view of our surroundings. We use each image (as long as there are april tags in it) to get a (x,
        # z) coordinate value, and then we choose which (x,z) coordinate to return based off of which we deem the
        # most correct/reliable (this decision is shown in the code below)

        # When turning to search for the desiredTag, we specify time to turn, and time to wait after each semi-turn.
        # We do this because we want a stable photo/shot at each
        searchingTimeToTurn = 0.5  # seconds
        searchingTimeToHalt = 2.0  # seconds

        # Note that refine_pose is set to True (takes more work/processing but hopefully gets better coordinates)
        options = apriltag.DetectorOptions(
            families='tag36h11',
            border=1,
            nthreads=1,
            quad_decimate=1.0,
            quad_blur=0.0,
            refine_edges=True,
            refine_decode=True,
            refine_pose=True,
            debug=False,
            quad_contours=True)
        det = apriltag.Detector(options)

        # Load camera data
        with open('cameraParams.json', 'r') as f:
            data = json.load(f)
        cameraMatrix = np.array(data['cameraMatrix'], dtype=np.float32)
        distCoeffs = np.array(data['distCoeffs'], dtype=np.float32)

        # Load world points
        world_points = {}
        with open('worldPoints.json', 'r') as f:
            data = json.load(f)
        for k, v in data.items():
            world_points[int(k)] = np.array(v, dtype=np.float32).reshape((4, 3, 1))

        # Variables for final decision
        coordinates_list = []
        iterationNumber = 1
        numIterations = 10

        while True:
            # Rotate camera by going left by some amount
            self.send_serial_command(Direction.LEFT, b'l')
            time.sleep(searchingTimeToTurn)
            self.send_serial_command(Direction.STOP, b'h')
            time.sleep(searchingTimeToHalt)

            # Now lets read the frame (while the robot is halted so that image is clean)
            frame = self.vs.read()
            if frame is None:
                print("ERROR - frame read a NONE")
                break

            # Use grayscale image for detection
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            res = det.detect(gray)

            # Check how many tags we see... if it's 0 then ignore this frame and move on to capturing the next frame
            numTagsSeen = len(res)
            print("\nNumber of tags seen", numTagsSeen)  # TODO remove

            if numTagsSeen > 0:

                poses = []  # Store poses from each tag to average them over
                tagRadiusList = []  # Store tag radius' to determine the largest

                for r in res:  # Iterate over each tag in the frame
                    corners = r.corners
                    tag_id = r.tag_id

                    corners = np.array(corners, dtype=np.float32).reshape((4, 2, 1))
                    cornersList = []
                    for c in corners:
                        cornersList.append([int(x) for x in c])

                    cornersList = np.array(cornersList, dtype=np.int32)  # Turn into numpy array (openCV wants this)

                    # Draw circle around tag using its corners & get radius of that tag
                    ((x, y), radius) = cv2.minEnclosingCircle(cornersList)
                    filteredPtsRadius = [radius]

                    # Solve pose ((x,z) coordinates)
                    r, rot, t = cv2.solvePnP(world_points[tag_id], corners, cameraMatrix,
                                             distCoeffs)  # get rotation and translation vector using solvePnP
                    rot_mat, _ = cv2.Rodrigues(rot)  # convert to rotation matrix
                    R = rot_mat.transpose()  # Use rotation matrix to get pose = -R * t (matrix mul w/ @)
                    pose = -R @ t
                    weight = self.calc_weight(pose, world_points[tag_id][0])
                    poses.append((pose, weight))
                    tagRadiusList.append(filteredPtsRadius)

                # Done iterating over the tags that're seen in the frame...
                # Now get the average pose across the tags and get the largest tag radius that we saw.
                # We will store the (x,z) coordinate that we calculate, and we'll also
                # store the largest radius for a tag that we've seen in this frame.
                avgPose = sum([x * y for x, y in poses]) / sum([x[1] for x in poses])
                largestTagRadius = max(tagRadiusList)
                coordinates = (avgPose[0][0], avgPose[2][0], largestTagRadius)
                print(str(coordinates))  # TODO remove this
                coordinates_list.append(coordinates)

            # Display frame
            cv2.imshow('frame', frame)

            # If we've completed our numIterations, then choose the coordinate
            # and return (do closing operations too)
            if iterationNumber == numIterations:
                if len(coordinates_list) > 0:
                    # TODO 2 things we can try here ...
                    #   1) The coordinate to return is the one with the smallest z-coordinate
                    #      (which essentially means it's closest to those tags that it used)
                    #      BUT this value seems to vary a lot and I don't think it's reliable
                    #   2) I have saved the largest radius for a tag seen for each of these
                    #      coordinates, so I can use that (which I bet is more reliable)
                    # I will go with approach number 2

                    # coordinateToReturn = min(coordinates_list, key=lambda x: x[1]) # Approach (1)
                    coordinateToReturn = max(coordinates_list, key=lambda x: x[2])  # Approach (2)
                    coordinateToReturn = (
                        int(coordinateToReturn[0]), int(coordinateToReturn[1]))  # This stays regardless
                else:
                    coordinateToReturn = (0, -1)  # TODO set to some default outside the door

                # Cleanup and return
                cv2.destroyAllWindows()
                print("Value to return:")  # TODO remove
                print(coordinateToReturn)  # TODO remove
                return coordinateToReturn
            else:  # Still have iterations to go, increment the value
                iterationNumber += 1

            # Q to quit
            key = cv2.waitKey(1) & 0xFF
            if (key == ord("q")) or (not cvQueue.empty()):
                self.send_serial_command(Direction.STOP, b'h')
                cv2.destroyAllWindows()
                break
def run(cvQueue: Queue):
    # Initialize an object for the class - this will connect to the webcam and serial port and begin grabbing frames
    cvObject = OpenCVController()

    while True:
        if not cvQueue.empty():  # If there's something in the queue...

            commandFromQueue = cvQueue.get()
            cvQueue.task_done()

            if commandFromQueue == "terminate":
                cvObject.cleanup_resources()
                print("Terminate OpenCV")
                return
            elif commandFromQueue == "halt":
                cvObject.send_serial_command(Direction.STOP, b'h')
                print("Sent halt command")
            elif commandFromQueue == "getCoordinates":
                print("got command getCoordinates")
                x, z = cvObject.get_coordinates(cvQueue)
                print("sending coordinates")
                cvQueue.put(x)
                cvQueue.put(z)
                cvQueue.join()
                print("nonblocking")
            elif commandFromQueue == "personFollow":
                cvObject.person_following(False, cvQueue)
            elif commandFromQueue == "photo":
                cvObject.take_photo(cvQueue)
            elif commandFromQueue == "eyeballFollow":
                cvObject.person_following(True, cvQueue)
            elif commandFromQueue == "aprilFollow":
                # We know the next 2 items in the queue in this case are the x and z coordinates - grab them
                #  Note: get() commands will block until it can get something
                print("Receive april Tag request")
                final_target_tag_number = cvQueue.get()
                final_target_tag_radius = cvQueue.get()
                cvQueue.task_done()
                cvQueue.task_done()

                print("getting my location")
                x_cord, z_cord = cvObject.get_coordinates(cvQueue)
                # See aprilTags class (in our codebase) and our final document for an explanation of what this does (
                # essentially creates a list of tags to go to and the desied distance for each tag - we'll pass each
                # of these steps into our april_following method one-by-one).
                first_tag = aprilTags.getClosestTag(x_cord, z_cord, True)

                print("first tag is " + str(first_tag))
                print("last tag is " + str(final_target_tag_number))
                cvObject.april_following(first_tag[0], first_tag[1], cvQueue, True, False)
                i = 0
                end_index = aprilTags.endOptions[final_target_tag_number]
                for target_pair in aprilTags.aprilTargets:
                    print("going to:" + str(target_pair[0]))
                    cvObject.april_following(target_pair[0], target_pair[1], cvQueue, False, False)
                    if i == end_index:
                        break
                    i += 1
                print("going to " + str(final_target_tag_number))
                cvObject.april_following(final_target_tag_number, final_target_tag_radius, cvQueue, False, True)

            elif commandFromQueue == "halt":
                pass
Example #49
0
class Takeover(Module):
    """
    OneForAll多线程子域接管风险检查模块

    Example:
        python3 takeover.py --target www.example.com  --format csv run
        python3 takeover.py --target ./subdomains.txt --thread 10 run

    Note:
        参数format可选格式有'txt', 'rst', 'csv', 'tsv', 'json', 'yaml', 'html',
                          'jira', 'xls', 'xlsx', 'dbf', 'latex', 'ods'
        参数path默认None使用OneForAll结果目录生成路径

    :param any target:  单个子域或者每行一个子域的文件路径(必需参数)
    :param int thread:  线程数(默认100)
    :param str format:  导出格式(默认csv)
    :param str path:    导出路径(默认None)
    """
    def __init__(self, target, thread=100, path=None, format='csv'):
        Module.__init__(self)
        self.subdomains = set()
        self.module = 'Check'
        self.source = 'Takeover'
        self.target = target
        self.thread = thread
        self.path = path
        self.format = format
        self.fingerprints = None
        self.subdomainq = Queue()
        self.cnames = list()
        self.results = Dataset()

    def save(self):
        logger.log('DEBUG', '正在保存检查结果')
        if self.format == 'txt':
            data = str(self.results)
        else:
            data = self.results.export(self.format)
        utils.save_data(self.path, data)

    def compare(self, subdomain, cname, responses):
        domain_resp = self.get('http://' + subdomain, check=False)
        cname_resp = self.get('http://' + cname, check=False)
        if domain_resp is None or cname_resp is None:
            return

        for resp in responses:
            if resp in domain_resp.text and resp in cname_resp.text:
                logger.log('ALERT', f'{subdomain}存在子域接管风险')
                self.results.append([subdomain, cname])
                break

    def worker(self, subdomain):
        cname = get_cname(subdomain)
        if cname is None:
            return
        maindomain = get_maindomain(cname)
        for fingerprint in self.fingerprints:
            cnames = fingerprint.get('cname')
            if maindomain not in cnames:
                continue
            responses = fingerprint.get('response')
            self.compare(subdomain, cname, responses)

    def check(self):
        while not self.subdomainq.empty():  # 保证域名队列遍历结束后能退出线程
            subdomain = self.subdomainq.get()  # 从队列中获取域名
            self.worker(subdomain)
            self.subdomainq.task_done()

    def progress(self):
        # 设置进度
        bar = tqdm()
        bar.total = len(self.subdomains)
        bar.desc = 'Check Progress'
        bar.ncols = 80
        while True:
            done = bar.total - self.subdomainq.qsize()
            bar.n = done
            bar.update()
            if done == bar.total:  # 完成队列中所有子域的检查退出
                break
        # bar.close()

    def run(self):
        start = time.time()
        logger.log('INFOR', f'开始执行{self.source}模块')
        self.subdomains = utils.get_domains(self.target)
        self.format = utils.check_format(self.format, len(self.subdomains))
        timestamp = utils.get_timestamp()
        name = f'takeover_check_result_{timestamp}'
        self.path = utils.check_path(self.path, name, self.format)
        if self.subdomains:
            logger.log('INFOR', f'正在检查子域接管风险')
            self.fingerprints = get_fingerprint()
            self.results.headers = ['subdomain', 'cname']
            # 创建待检查的子域队列
            for domain in self.subdomains:
                self.subdomainq.put(domain)
            # 检查线程
            for _ in range(self.thread):
                check_thread = Thread(target=self.check, daemon=True)
                check_thread.start()
            # 进度线程
            progress_thread = Thread(target=self.progress, daemon=True)
            progress_thread.start()

            self.subdomainq.join()
            self.save()
        else:
            logger.log('FATAL', f'获取域名失败')
        end = time.time()
        elapse = round(end - start, 1)
        logger.log(
            'INFOR', f'{self.source}模块耗时{elapse}秒'
            f'发现{len(self.results)}个子域存在接管风险')
        logger.log('INFOR', f'子域接管风险检查结果 {self.path}')
        logger.log('INFOR', f'结束执行{self.source}模块')
Example #50
0
    def execute_parallel(self, query_list, need_result=True):
        """
        This function is to run a list of query PARALLEL.
        
        It is used when there is a number of cypher queries
        that can be executed simultaneously without any problems.

        Each query in the list consists of:
        - a work unit - a transaction function containing a cypher query
        - arguments
        - keyword arguments

        The function passes all the arguments to the work unit,
        execute it in READ_ACCESS or WRITE_ACCESS mode,
        depending if the transaction can make any modifications
        to any existing data inside the database.

        The queries are executed in separate threads,
        result are collected if needed from a queue
        The function returns result of the execution if required.
        """
        if query_list is None or len(query_list) == 0:
            raise ValueError('Empty query list.')

        # create queue to collect result if needed
        result_queue = Queue() if need_result else None

        executors, start, remain = [], 0, len(query_list)
        while remain > 0:

            # acquire a session and create a thread to execute a query
            # 'start' value is passed as the execution thread id
            session = self.get_session()
            executor = Thread(
                target=self.__execute_thread,
                args=(start, session, query_list[start], result_queue),
            )
            executors.append(executor)
            start += 1
            remain -= 1

        for executor in executors:
            executor.start()

        for executor in executors:
            executor.join()

        # return if no need for results
        if not need_result:
            return

        # collect all results from the queue,
        # each has an exec_id to indicate from which thread
        r_dict = dict()
        while not result_queue.empty():
            exec_id, result = result_queue.get()
            r_dict[exec_id] = result

        # sort the results to return them in order queries were given
        r_list = [r_dict[exec_id] for exec_id in sorted(r_dict)]
        return r_list
Example #51
0
class Inspector(threading.Thread):
    """Inspector of one cavity sharing some devices on a location
    """
    def __init__(self, database, toolbox, location_key, cavity=None, tff=True):
        """Args:
        database(Container): Persistence layer container of providers
        toolbox(Container): Container of devices
        location_key: where is located the test station and all its devices
        cavity: cavity number, None is the station is not multicavity
        tff(Boolean): Till first failure, stops test when first failure is found
        """
        name = 'Inspector'
        if cavity is not None:
            name += '_{}'.format(cavity)
        super().__init__(name=name)

        self.db = database
        self.toolbox = toolbox
        self.location_key = location_key
        self.cavity = cavity
        self.tff = tff

        # Inputs and Outputs of inspector
        self.orders = Queue()
        self.events = Queue()
        self.state = 'avalaible'

        # Batch variables
        self.responsible = None
        self.part_model = None
        self.part = None
        self.control_plan = None
        self.test = None

        self._stop_event = threading.Event()

    def set_responsible(self, responsible_key):
        """Loads responsible if it has changed
        """
        if self.responsible is None or self.responsible.key != responsible_key:
            self.responsible = self.db.Persons().get(responsible_key)

    def set_part_model(self, part_number):
        if (self.part_model is None or self.part_model.key != part_number):
            self.part_model = self.db.PartModels().get(part_number)
            if self.part_model is None:
                raise NotFoundResource(
                    'Not found part model for partnumber{}'.format(
                        part_number))

        if (self.control_plan is None
                or self.part_model not in self.control_plan.outputs):
            self.control_plan = (self.db.ControlPlans().get_by(
                self.part_model, self.location))
            if self.control_plan is None:
                raise NotFoundPath(
                    'Not found control plan for {}'.format(part_number))

    def run(self):
        """Thread activation processing order by order"""
        self.location = self.db.Locations().get(self.location_key)
        self.devices = {
            device.tracking: device
            for device in self.db.Devices().get_all_from(self.location_key)
        }

        while not self._stop_event.is_set():
            try:
                self.state = 'idle'
                logger.info('Inspector {} is idle'.format(self.name))
                order = self.orders.get()
                if order is None:
                    self.orders.task_done()
                    break
                else:
                    self.state = 'busy'
                    self.run_test(order)
                    self.orders.task_done()
                    logger.info('Inspector {} has finished'.format(self.name))
            except Exception as e:
                trc = sys.exc_info()
                self.update('loop_error', e, traceback.format_tb(trc[2]))
                logger.exception(e)
                raise e
        logger.info('Inspector {} has stopped'.format(self.name))
        self.state = 'stopped'

    def get_part(self, serial_number, pars):
        """Get part from data layer if exist or create a new one
        """

        part = self.db.Parts().get_by(self.part_model, serial_number)
        if part and part.location != self.location:
            raise WrongLocationError('Part {} with sn {} found on {}'.format(
                part.model.key, part.tracking, part.location.key))

        if part is None:
            part = prd.Part(self.part_model, serial_number, pars=pars)

        if part.model.is_device():
            part.dut = self.toolbox.dut(part.model, self.cavity)

        return part

    def run_test(self, order):
        """Process a full test  from an order
        """
        try:
            logger.info('Init testing on cavity {}'.format(self.cavity))
            part_info, responsible_key = order
            self.set_responsible(responsible_key)

            part_number = part_info.pop('part_number')
            self.set_part_model(part_number)
            serial_number = part_info.pop('serial_number')
            self.part = part = self.get_part(serial_number, part_info)
            self.test = test = self.control_plan.implement(
                self.responsible, self.update)

            self.db.Tests().add(test)
            self.db.Parts().add(part)
            logger.info('Test has been added on cavity {}'.format(self.cavity))
            test.start(part=part,
                       toolbox=self.toolbox,
                       devices=self.devices,
                       cavity=self.cavity,
                       tff=self.tff)
            try:
                if part.dut and hasattr(part.dut, 'supply_voltage'):
                    self.toolbox.dyncir().switch_on_dut(
                        voltage=part.dut.supply_voltage,
                        wait_after=1,
                        cavity=self.cavity)
                test.walk()
                test.execute()
                test.close()
            except DefectFound:
                test.close()
            except Exception as e:
                logger.exception(e)
                trc = sys.exc_info()
                self.update('test_error', e, traceback.format_tb(trc[2]))
                test.cancel()
            finally:
                self.db.Session().commit()
                if part.dut and hasattr(part.dut, 'supply_voltage'):
                    self.toolbox.dyncir().switch_off_dut(
                        voltage=part.dut.supply_voltage,
                        wait_after=0,
                        cavity=self.cavity)
                self.part = None
        except Exception as e:
            self.part = None
            trc = sys.exc_info()
            self.update('crash', e, traceback.format_tb(trc[2]))
            raise e

    def update(self, state, obj, *args):
        """Receive from test notications of states
        """
        self.events.put([state, obj] + list(args))

    def stop(self):
        """Stop thread and return unprocessed orders"""
        pending_orders = []
        self._stop_event.set()

        if self.orders.qsize() != 0:
            while not self.orders.empty():
                pending_orders.append(self.orders.get())
        else:
            self.orders.put(None)

        return pending_orders

    def answer(self, **kwargs):
        if hasattr(self.test, 'question'):
            self.test.question.answer(**kwargs)

    def get_last_events(self):
        """Retrieve a list of last events of current test
        """
        events = []
        for _ in range(self.events.qsize()):
            event = self.events.get()
            events.append(event)
            if event[0] in ('success', 'failed', 'cancelled') \
               and event[1].__class__.__name__ == 'Test':
                break

        return events

    def cancel(self):
        if self.state == 'busy':
            self.test.cancel()
Example #52
0
class Neo4jAdapter(object):
    """
    Neo4jAdapter is a data access layer over Neo4j Python Bolt driver:
    - to access Neo4j database directly,
    - for multiple callers,
    - thread-safe.

    It supports following features:
    - implicitly keeps neo4j driver's sessions in a FIFO queue,
    - each user's transaction is carried out in a separate session,
    - three transaction functions: execute_one(), execute_sequential(), execute_parallel(),
    - each transaction function requires a work unit
    - a transaction with a single or multiple cypher queries with arguments and named arguments
    - allows to return none, one or all execution results.
    - convert the results (if required) into [nested] Python instances.
    
    Note: for more information see about a transaction functionsee
    https://neo4j.com/docs/developer-manual/current/drivers/sessions-transactions/
    """
    def __init__(self, database_credentials):
        """
        Create a Neo4j Bolt Driver based on given database credentials
        Create a thread-safe Queue instance to keep the driver sessions
        """
        self.driver = GraphDatabase.driver(
            database_credentials['neo4j_bolt_server'],
            auth=(database_credentials['neo4j_user'],
                  database_credentials['neo4j_password']),
        )
        self.session_queue = Queue()

    def close(self):
        """
        Explicitly close the Neo4j driver instance
        """
        self.driver.close()

    def get_session(self):
        """
        Sessions are  neo4j driver's sessions and kept a a FIFO queue.
        A new session is created while the queue is empty, 
        meaning they are in used by pending executions.

        The number of sessions in used at anytime is the sum of:
        - the number of pending execute_one() invocations
        - the number of pending execute_sequential() invocations
        - the total number of all cypher queries in pending execute_parallel() invocations
       
        All invocations by all threads using this Neo4jAdapter instance.
        """
        if self.session_queue.empty():
            return self.driver.session()
        else:
            return self.session_queue.get()

    def execute_one(self,
                    work,
                    mode="READ_ACCESS",
                    need_result=True,
                    **kw_args):
        """
        This function is to use a work unit containing a cypher query. 
        
        It  passes all keyword arguments to the work unit, 
        execute it in READ_ACCESS or WRITE_ACCESS mode, 
        depending if the transaction can make any modifications 
        to any existing data inside the database.
        
        The function returns result of the execution if required.
        """
        if work is None:
            raise ValueError('No transaction function  is specified.')

        # acquire session from the queue
        session = self.get_session()

        r = self.__execute_transaction(session, work, mode, **kw_args)

        if need_result and r is not None:
            # result is always wrapped into a list if it is a single row
            result = [self.__convert_result(e) for e in r.records()]

        else:
            result = None

        # return the session back to the queue
        self.session_queue.put(session)

        return result

    def execute_sequential(self,
                           query_list,
                           need_result=True,
                           last_only=False):
        """
        This function is to run a list of query SEQUENTIALLY.
        
        It is used when there is a number of cypher queries 
        that must be executed in a given order 
        and no any two queries can be executed at the same time.

        Each query in the list consists of:
        - a work unit - a transaction function containing a cypher query
        - access mode
        - keyword arguments

        The function passes all the arguments to the work unit, 
        execute it in READ_ACCESS or WRITE_ACCESS mode, 
        depending if the transaction can make any modifications 
        to any existing data inside the database.
        
        The function returns result of the execution if required.
        """
        if query_list is None or len(query_list) == 0:
            raise ValueError('Empty query list.')

        # acquire session
        session = self.get_session()

        # Run transactions sequentially, one-by-one,
        # collect results into a list if required
        r_list = []
        for query in query_list:
            work_unit, mode, kwargs = query

            # If result is needed, collect them into the r_list
            if need_result:
                r = self.__execute_transaction(session, work_unit, mode,
                                               **kwargs)
                r_list.append(
                    [self.__convert_result(e)
                     for e in r.records()] if r is not None else None)

            else:
                self.__execute_transaction(session, work_unit, mode, **kwargs)

        # No need for result
        if not need_result:
            self.session_queue.put(session)  # release the session
            return

        # Only the last request is required.
        if last_only:
            result = r_list[-1] if len(r_list) > 0 else None
        else:
            result = r_list

        # release session
        self.session_queue.put(session)

        return result

    def execute_parallel(self, query_list, need_result=True):
        """
        This function is to run a list of query PARALLEL.
        
        It is used when there is a number of cypher queries
        that can be executed simultaneously without any problems.

        Each query in the list consists of:
        - a work unit - a transaction function containing a cypher query
        - arguments
        - keyword arguments

        The function passes all the arguments to the work unit,
        execute it in READ_ACCESS or WRITE_ACCESS mode,
        depending if the transaction can make any modifications
        to any existing data inside the database.

        The queries are executed in separate threads,
        result are collected if needed from a queue
        The function returns result of the execution if required.
        """
        if query_list is None or len(query_list) == 0:
            raise ValueError('Empty query list.')

        # create queue to collect result if needed
        result_queue = Queue() if need_result else None

        executors, start, remain = [], 0, len(query_list)
        while remain > 0:

            # acquire a session and create a thread to execute a query
            # 'start' value is passed as the execution thread id
            session = self.get_session()
            executor = Thread(
                target=self.__execute_thread,
                args=(start, session, query_list[start], result_queue),
            )
            executors.append(executor)
            start += 1
            remain -= 1

        for executor in executors:
            executor.start()

        for executor in executors:
            executor.join()

        # return if no need for results
        if not need_result:
            return

        # collect all results from the queue,
        # each has an exec_id to indicate from which thread
        r_dict = dict()
        while not result_queue.empty():
            exec_id, result = result_queue.get()
            r_dict[exec_id] = result

        # sort the results to return them in order queries were given
        r_list = [r_dict[exec_id] for exec_id in sorted(r_dict)]
        return r_list

    def __execute_thread(self, exec_id, session, query, queue=None):
        """
        Perform a __execute_transaction by dissect the query,
        assemble the result with execution id into the queue for results
        """
        if query is None:
            raise ValueError('Cannot execute empty query.')

        w_unit, mode, kwargs = query
        r = self.__execute_transaction(session, w_unit, mode, **kwargs)
        if queue is not None:
            queue.put([
                exec_id, [self.__convert_result(e)
                          for e in r.records()] if r is not None else None
            ])

    @staticmethod
    def __execute_transaction(session, work_unit, mode, **kw_args):
        """
        Perform a read_transaction or write_transaction
        inside the session, passed keyword arguments
        """
        if mode == "READ_ACCESS":
            return session.read_transaction(work_unit, **kw_args)
        else:
            return session.write_transaction(work_unit, **kw_args)

    def __convert_result(self, result):
        """
        Convert [nested] statement results by identifying:
        - the nested structure
        - the data type
        https://neo4j.com/docs/developer-manual/current/drivers/cypher-values/
        """
        if isinstance(result, Record):
            return self.__convert_dict(result)
        if isinstance(result, Node):
            return self.__convert_node(result)
        if isinstance(result, Relationship):
            return self.__convert_relationship(result)
        if isinstance(result, Path):
            return self.__convert_path(result)
        if isinstance(result, list):
            return self.__convert_list(result)
        if isinstance(result, dict):
            return self.__convert_dict(result)
        return result

    @staticmethod
    def __convert_node(result):
        """ Convert a Neo4j node into a python dictionary """
        return {
            'id': result.id,
            'labels': set(result.labels),
            'properties': {i[0]: i[1]
                           for i in result.items()}
        }

    @staticmethod
    def __convert_relationship(result):
        """ Convert a Neo4j relationship into a python dictionary """
        return {
            'id': result.id,
            'type': result.type,
            'properties': {i[0]: i[1]
                           for i in result.items()}
        }

    def __convert_path(self, result):
        """
        Convert a Neo4j path, a sequence of node and relationships
        in Bolt statement result format into a python dictionary
        """
        return {
            'start_node':
            self.__convert_node(result.start_node),
            'end_node':
            self.__convert_node(result.end_node),
            'nodes': [self.__convert_node(n) for n in result.nodes],
            'relationships':
            [self.__convert_relationship(r) for r in result.relationships],
        }

    def __convert_list(self, result):
        """ Convert Bolt list """
        return [self.__convert_result(e) for e in result]

    def __convert_dict(self, result):
        """ Convert Bolt dict/map """
        return {k: self.__convert_result(v) for k, v in result.items()}
Example #53
0
class CmdThread(Thread):
  def __init__(self):
    Thread.__init__(self)
    self._queue = Queue()
    self._status = 'NORMAL'
    self._statusLock = Lock()
    self._possibleStatus = ('NORMAL', 'SLEEPING', 'DONE')
    self._condition = Condition(self._statusLock)
    self.daemon = True
    self.start()
 
  def run(self):
    while True:
      self._condition.acquire()
      if self._status == 'DONE':
        self._condition.release()
        break
        
      if self._status == 'SLEEPING':
        self._condition.wait() # this function will release the lock when going to sleep
        self._condition.release()
        continue
      self._condition.release()
      
      try:
        task = self._queue.get_nowait()
      except Empty:
        self._condition.acquire()
        self._status = 'SLEEPING'
        self._condition.wait()
        # Here, we don't change the status to 'NORMAL' here
        # The status is supposed to be changed by the waker
        # Before this command thread wakes up.
        self._condition.release()
        continue

      if callable(task):
        try:
          print('perform task.')
          task()
        except Exception:
          self._statusLock.acquire()
          self._status = 'DONE'
          self._statusLock.release()
          raise
      else:
        # you can define the task interface here.
        pass
      

  def addCmd(self, callableObj, *args, **argd):
    """ non-blocking call.
    """
    assert(callable(callableObj))
    
    self._condition.acquire()
    if self._status == 'DONE':
      self._condition.release()
      return
    
    self._queue.put(partial(callableObj, *args, **argd))
    self._status = 'NORMAL'
    
    self._condition.notifyAll()
    self._condition.release()

  def close(self, timeout=None):
    """ blocking call, return after close/cancel is done.
    """
    undoneJobs = []
    self._condition.acquire()
    if self._status == 'DONE':
      self._condition.release()
      return
    
    self._status = 'DONE'
    self._condition.notifyAll()
    while not self._queue.empty():
      undoneJobs.append(self._queue.get())
    self._condition.release()
    self.join(timeout)
    return undoneJobs

  def pause(self, timeout=None):
    """ non-blocking call.
    """
    self._condition.acquire()
    if self._status == 'DONE':
      self._condition.release()
      return
    self._status = 'SLEEPING'
    self._condition.notifyAll()
    self._condition.release()

  def resume(self):
    """ non-blocking call.
    """
    self._condition.acquire()
    if self._status == 'DONE':
      self._condition.release()
      return
    self._status = 'NORMAL'
    self._condition.notifyAll()
    self._condition.release()
    
  def hasCmd(self):
    """ non-blocking call.
    """
    return not self._queue.empty()
Example #54
0
class ReplaceRobot(SingleSiteBot):
    """A bot that can do text replacements.

    @param generator: generator that yields Page objects
    @type generator: generator
    @param replacements: a list of Replacement instances or sequences of
        length 2 with the original text (as a compiled regular expression)
        and replacement text (as a string).
    @type replacements: list
    @param exceptions: a dictionary which defines when not to change an
        occurrence. This dictionary can have these keys:

        title
            A list of regular expressions. All pages with titles that
            are matched by one of these regular expressions are skipped.
        text-contains
            A list of regular expressions. All pages with text that
            contains a part which is matched by one of these regular
            expressions are skipped.
        inside
            A list of regular expressions. All occurrences are skipped which
            lie within a text region which is matched by one of these
            regular expressions.
        inside-tags
            A list of strings. These strings must be keys from the
            dictionary in textlib._create_default_regexes() or must be
            accepted by textlib._get_regexes().

    @type exceptions: dict
    @param allowoverlap: when matches overlap, all of them are replaced.
    @type allowoverlap: bool
    @param recursive: Recurse replacement as long as possible.
    @type recursive: bool
    @warning: Be careful, this might lead to an infinite loop.
    @param addcat: category to be added to every page touched
    @type addcat: pywikibot.Category or str or None
    @param sleep: slow down between processing multiple regexes
    @type sleep: int
    @param summary: Set the summary message text bypassing the default
    @type summary: str
    @keyword always: the user won't be prompted before changes are made
    @type keyword: bool
    @keyword site: Site the bot is working on.
    @warning: site parameter should be passed to constructor.
        Otherwise the bot takes the current site and warns the operator
        about the missing site
    """
    @deprecated_args(acceptall='always', addedCat='addcat')
    def __init__(self, generator, replacements, exceptions={}, **kwargs):
        """Initializer."""
        self.availableOptions.update({
            'addcat': None,
            'allowoverlap': False,
            'recursive': False,
            'sleep': 0.0,
            'summary': None,
        })
        super(ReplaceRobot, self).__init__(generator=generator, **kwargs)

        for i, replacement in enumerate(replacements):
            if isinstance(replacement, Sequence):
                if len(replacement) != 2:
                    raise ValueError('Replacement number {0} does not have '
                                     'exactly two elements: {1}'.format(
                                         i, replacement))
                # Replacement assumes it gets strings but it's already compiled
                replacements[i] = Replacement.from_compiled(
                    replacement[0], replacement[1])
        self.replacements = replacements
        self.exceptions = exceptions

        self.sleep = self.getOption('sleep')
        self.summary = self.getOption('summary')

        self.addcat = self.getOption('addcat')
        if self.addcat and isinstance(self.addcat, UnicodeType):
            self.addcat = pywikibot.Category(self.site, self.addcat)

        self._pending_processed_titles = Queue()

    def isTitleExcepted(self, title, exceptions=None):
        """
        Return True iff one of the exceptions applies for the given title.

        @rtype: bool
        """
        if exceptions is None:
            exceptions = self.exceptions
        if 'title' in exceptions:
            for exc in exceptions['title']:
                if exc.search(title):
                    return True
        if 'require-title' in exceptions:
            for req in exceptions['require-title']:
                if not req.search(title):
                    return True
        return False

    def isTextExcepted(self, original_text):
        """
        Return True iff one of the exceptions applies for the given text.

        @rtype: bool
        """
        if 'text-contains' in self.exceptions:
            for exc in self.exceptions['text-contains']:
                if exc.search(original_text):
                    return True
        return False

    def apply_replacements(self, original_text, applied, page=None):
        """
        Apply all replacements to the given text.

        @rtype: str, set
        """
        if page is None:
            pywikibot.warn(
                'You must pass the target page as the "page" parameter to '
                'apply_replacements().',
                DeprecationWarning,
                stacklevel=2)
        new_text = original_text
        exceptions = _get_text_exceptions(self.exceptions)
        skipped_containers = set()
        for replacement in self.replacements:
            if self.sleep:
                pywikibot.sleep(self.sleep)
            if (replacement.container
                    and replacement.container.name in skipped_containers):
                continue
            elif page is not None and self.isTitleExcepted(
                    page.title(), replacement.exceptions):
                if replacement.container:
                    pywikibot.output(
                        'Skipping fix "{0}" on {1} because the title is on '
                        'the exceptions list.'.format(
                            replacement.container.name,
                            page.title(as_link=True)))
                    skipped_containers.add(replacement.container.name)
                else:
                    pywikibot.output(
                        'Skipping unnamed replacement ({0}) on {1} because '
                        'the title is on the exceptions list.'.format(
                            replacement.description, page.title(as_link=True)))
                continue
            old_text = new_text
            new_text = textlib.replaceExcept(
                new_text,
                replacement.old_regex,
                replacement.new,
                exceptions + replacement.get_inside_exceptions(),
                allowoverlap=self.getOption('allowoverlap'),
                site=self.site)
            if old_text != new_text:
                applied.add(replacement)

        return new_text

    @deprecated('apply_replacements', since='20160816')
    def doReplacements(self, original_text, page=None):
        """Apply replacements to the given text and page."""
        if page is None:
            pywikibot.warn(
                'You must pass the target page as the "page" parameter to '
                'doReplacements().',
                DeprecationWarning,
                stacklevel=2)
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', DeprecationWarning)
            new_text = self.apply_replacements(original_text, set(), page=page)
        return new_text

    def _log_changes(self, page, err):
        """Log changed titles for display."""
        # This is an async put callback
        if not isinstance(err, Exception):
            self._pending_processed_titles.put(
                (page.title(as_link=True), True))
        else:  # unsuccessful pages
            self._pending_processed_titles.put(
                (page.title(as_link=True), False))

    def _replace_async_callback(self, page, err):
        """Callback for asynchronous page edit."""
        self._log_changes(page, err)

    def _replace_sync_callback(self, page, err):
        """Callback for synchronous page edit."""
        self._log_changes(page, err)
        if isinstance(err, Exception):
            raise err

    def generate_summary(self, applied_replacements):
        """Generate a summary message for the replacements."""
        # all replacements which are merged into the default message
        default_summaries = set()
        # all message parts
        summary_messages = set()
        for replacement in applied_replacements:
            if replacement.edit_summary:
                summary_messages.add(replacement.edit_summary)
            elif replacement.default_summary:
                default_summaries.add((replacement.old, replacement.new))
        summary_messages = sorted(summary_messages)
        if default_summaries:
            if self.summary:
                summary_messages.insert(0, self.summary)
            else:
                comma = self.site.mediawiki_message('comma-separator')
                default_summary = comma.join(
                    '-{0} +{1}'.format(*default_summary)
                    for default_summary in default_summaries)
                summary_messages.insert(
                    0,
                    i18n.twtranslate(
                        self.site, 'replace-replacing',
                        {'description': ' ({0})'.format(default_summary)}))
        semicolon = self.site.mediawiki_message('semicolon-separator')
        return semicolon.join(summary_messages)

    def treat(self, page):
        """Work on each page retrieved from generator."""
        if self.isTitleExcepted(page.title()):
            pywikibot.output(
                'Skipping {0} because the title is on the exceptions list.'.
                format(page.title(as_link=True)))
            return

        try:
            # Load the page's text from the wiki
            original_text = page.get(get_redirect=True)
            if not page.has_permission():
                pywikibot.output("You can't edit page " +
                                 page.title(as_link=True))
                return
        except pywikibot.NoPage:
            pywikibot.output('Page {0} not found'.format(
                page.title(as_link=True)))
            return

        applied = set()
        new_text = original_text
        last_text = None
        context = 0
        while True:
            if self.isTextExcepted(new_text):
                pywikibot.output('Skipping {0} because it contains text '
                                 'that is on the exceptions list.'.format(
                                     page.title(as_link=True)))
                break
            while new_text != last_text:
                last_text = new_text
                new_text = self.apply_replacements(last_text, applied, page)
                if not self.getOption('recursive'):
                    break
            if new_text == original_text:
                pywikibot.output('No changes were necessary in ' +
                                 page.title(as_link=True))
                break
            if self.addcat:
                # Fetch only categories in wikitext, otherwise the others
                # will be explicitly added.
                cats = textlib.getCategoryLinks(new_text, site=page.site)
                if self.addcat not in cats:
                    cats.append(self.addcat)
                    new_text = textlib.replaceCategoryLinks(new_text,
                                                            cats,
                                                            site=page.site)
            # Show the title of the page we're working on.
            # Highlight the title in purple.
            self.current_page = page
            pywikibot.showDiff(original_text, new_text, context=context)
            if self.getOption('always'):
                break
            choice = pywikibot.input_choice(
                'Do you want to accept these changes?',
                [('Yes', 'y'), ('No', 'n'), ('Edit original', 'e'),
                 ('edit Latest', 'l'), ('open in Browser', 'b'),
                 ('More context', 'm'), ('All', 'a')],
                default='N')
            if choice == 'm':
                context = context * 3 if context else 3
                continue
            if choice in ('e', 'l'):
                text_editor = editor.TextEditor()
                edit_text = original_text if choice == 'e' else new_text
                as_edited = text_editor.edit(edit_text)
                # if user didn't press Cancel
                if as_edited and as_edited != new_text:
                    new_text = as_edited
                    if choice == 'l':
                        # prevent changes from being applied again
                        last_text = new_text
                continue
            if choice == 'b':
                pywikibot.bot.open_webbrowser(page)
                try:
                    original_text = page.get(get_redirect=True, force=True)
                except pywikibot.NoPage:
                    pywikibot.output('Page {0} has been deleted.'.format(
                        page.title()))
                    break
                new_text = original_text
                last_text = None
                continue
            if choice == 'a':
                self.options['always'] = True
            if choice == 'y':
                self.save(page,
                          original_text,
                          new_text,
                          applied,
                          show_diff=False,
                          quiet=True,
                          callback=self._replace_async_callback,
                          asynchronous=True)
            while not self._pending_processed_titles.empty():
                proc_title, res = self._pending_processed_titles.get()
                pywikibot.output('Page {0}{1} saved'.format(
                    proc_title, '' if res else ' not'))
            # choice must be 'N'
            break

        if self.getOption('always') and new_text != original_text:
            self.save(page,
                      original_text,
                      new_text,
                      applied,
                      show_diff=False,
                      quiet=True,
                      callback=self._replace_sync_callback,
                      asynchronous=False)
            if self._pending_processed_titles.qsize() > 50:
                while not self._pending_processed_titles.empty():
                    proc_title, res = self._pending_processed_titles.get()
                    pywikibot.output('Page {0}{1} saved'.format(
                        proc_title, '' if res else ' not'))

    def save(self, page, oldtext, newtext, applied, **kwargs):
        """Save the given page."""
        self.userPut(page,
                     oldtext,
                     newtext,
                     summary=self.generate_summary(applied),
                     ignore_save_related_errors=True,
                     **kwargs)

    def user_confirm(self, question):
        """Always return True due to our own input choice."""
        return True
Example #55
0
class InstallationManager:
    def __init__(self, backend_service, printer):
        self.backend = backend_service
        self.printer = printer

        self.fetch_q = Queue()
        self.download_q = Queue()
        self.install_q = Queue()

    def queue_fetch(self, request: FetchRequest):
        self.fetch_q.put(item=request)

    def queue_download(self, request: DownloadRequest):
        self.download_q.put(item=request)

    def queue_install(self, request: InstallRequest):
        self.install_q.put(item=request)

    def run(self, carrot: CarrotModel, args):
        while not self.fetch_q.empty():
            req = self.fetch_q.get()
            self.do_fetch(req, carrot, args)

        self.printer.handle('info all_mod_check_complete')

        self._download_hist = set()

        while not self.download_q.empty():
            req = self.download_q.get()
            self.do_download(req, carrot, args)

        self.printer.handle('info all_mod_fetch_complete')

        self._install_hist = set()

        installed_list = []
        while not self.install_q.empty():
            req = self.install_q.get()
            installed_list.append(req.mod_info.key)
            self.do_install(req, carrot, args)

        self.printer.handle('info all_mod_install_complete',
                            {'installed_list': installed_list})

    def do_fetch(self, req: FetchRequest, carrot: CarrotModel, args):
        mod_info = self.backend.get_mod_info(req.mod_key)

        if not mod_info.key:
            self.printer.handle('error mod_key_not_found',
                                Namespace(mod_key=req.mod_key))
            return

        self.printer.handle('info mod_resolved',
                            Namespace(mod_key=str(req.mod_key), mod=mod_info))

        mod_info.file = self.backend.get_newest_file_info(
            req.mod_key, req.mc_version, req.channel)

        current_mod = find_mod_by_key(carrot.mods, mod_info.key)

        if not mod_info.file:
            self.printer.handle('warn no_files_in_channel',
                                Namespace(mod=mod_info))

            proceed = False

        elif not current_mod:
            proceed = True

        elif current_mod.file.id < mod_info.file.id:
            if args.upgrade:
                self.printer.handle('info will_upgrade_mod',
                                    Namespace(mod=mod_info))

                proceed = True

            else:
                self.printer.handle(
                    'warn upgrade_not_allowed',
                    Namespace(mod=mod_info, dependency=req.dependency))

                proceed = False

        elif current_mod.file.id == mod_info.file.id:
            self.printer.handle('warn already_newest_version',
                                Namespace(mod=mod_info))

            proceed = False

        else:
            if args.downgrade:
                self.printer.handle('info will_downgrade_mod',
                                    Namespace(mod=mod_info))

                proceed = True
            else:
                self.printer.handle('warn downgrade_not_allowed',
                                    Namespace(mod=mod_info))

                proceed = False

        if proceed:
            self.printer.handle('info will_download_mod',
                                Namespace(mod=mod_info))

            self.download_q.put(
                DownloadRequest(mod_info=mod_info, dependency=req.dependency))

            self.install_q.put(
                InstallRequest(mod_info=mod_info, dependency=req.dependency))

            if mod_info.file.mod_dependencies:
                self.printer.handle(
                    'info dependencies_detected',
                    Namespace(deps=mod_info.file.mod_dependencies))

                for dep in mod_info.file.mod_dependencies:
                    self.queue_fetch(
                        FetchRequest(mod_key=dep,
                                     mc_version=req.mc_version,
                                     channel=req.channel,
                                     dependency=True))

    def do_download(self, req: DownloadRequest, carrot: CarrotModel, args):
        if req.mod_info.file.file_name in self._download_hist:
            return

        self.printer.handle('info downloading_file',
                            Namespace(file=req.mod_info.file))

        #        file_contents = self.backend.download_file(req.mod_info.file.download_url)
        # API_ENDPOINT = 'https://cursemeta.dries007.net'
        # QString metaurl = QString("%1/%2/%3.json").arg(metabase, projectIdStr, fileIdStr);
        awful_hack = self.backend.download_file(
            'https://cursemeta.dries007.net/' + str(req.mod_info.id) + '/' +
            str(req.mod_info.file.id) + '.json')
        awful_hack_json = awful_hack.decode('utf8')
        awful_hack_data = json.loads(awful_hack_json)
        #        print(awful_hack_data)
        file_contents = self.backend.download_file(
            awful_hack_data['DownloadURL'])
        self.put_file_in_cache(file_contents, req.mod_info.file.file_name)
        self._download_hist.add(req.mod_info.file.file_name)

    def do_install(self, req: InstallRequest, carrot: CarrotModel, args):
        if req.mod_info.file.file_name in self._install_hist:
            return

        current_mod = find_mod_by_key(carrot.mods, req.mod_info.key)
        new_mod = InstalledModModel.from_dict(req.mod_info.to_dict())
        new_mod.dependency = req.dependency

        if not current_mod:
            self.printer.handle('info installing_mod',
                                Namespace(mod=req.mod_info, new_mod=True))

            carrot.mods.append(new_mod)

            self.move_file_from_cache_to_content(new_mod.file.file_name)

        else:
            self.printer.handle('info installing_mod',
                                Namespace(mod=req.mod_info, new_mod=False))

            # Prevent a user-installed mod from becoming a dependency
            if not current_mod.dependency and new_mod.dependency:
                new_mod.dependency = False

            replace_mod_by_key(carrot.mods, req.mod_info.key, new_mod)

            enabled = self.delete_file(current_mod.file.file_name)

            self.printer.handle(
                'info updating_file',
                Namespace(file=current_mod.file, enabled=enabled))

            self.move_file_from_cache_to_content(req.mod_info.file.file_name,
                                                 enabled)

        self._install_hist.add(req.mod_info.file.file_name)

        self.printer.handle('info mod_install_complete',
                            Namespace(mod=new_mod))

    def put_file_in_cache(self, content: bytes, file_name: str):
        if not os.path.exists('.carrot_cache'):
            os.mkdir('.carrot_cache')

        with open('.carrot_cache/' + file_name, 'wb+') as f:
            f.write(content)

    def delete_file(self, file_name: str):
        if os.path.exists(file_name):
            os.remove(file_name)
            return True

        elif os.path.exists(file_name + '.disabled'):
            os.remove(file_name + '.disabled')
            return False

        # If file is missing, assume it's meant to be installed and enabled
        return True

    def move_file_from_cache_to_content(self,
                                        file_name: str,
                                        enabled: bool = True):
        target_file_name = file_name
        if not enabled:
            target_file_name += '.disabled'

        os.rename('.carrot_cache/' + file_name, target_file_name)
def check_if_url_legal():
    if not request.method == "OPTIONS":   
        req_data = request.get_json()
        url = str(req_data["url"])
        print("url in analysis: " + url)
        domain = str(req_data["domain"])
        # one crawler for each page
        crawler = None      
        crawler = crawlerMG.get_crawler(url)
        data = {
            "URL": url,
            "Domain": domain,
            "Blocked": "False",
            "CaughtWord": ""
        }

        audioMG = AUDIO.AudioHandler(8)        

        q_audio_words = Queue()
        q_stop_flag = Queue()
        q_caught_word = Queue()
        
        if crawler != None:
            try:
                # downlaod video to local disk
                start = timer()
                crawler.collect_video_tags()
                crawler.download_video()
                end = timer()
                print("Time elapse for downloading: " + str(timedelta(seconds=end-start)))
                
                # converting mp3 to wav file
                filename = crawler.get_filename()         

                # new threads to handle SR transcribing and checking all captured words at the same time.
                start = timer()
                transcribe_thread = threading.Thread(target=audioMG.trans_audio_file_batch, args=(filename, q_audio_words, q_stop_flag))
                check_word_thread = threading.Thread(target=catcher.catch_word, args=(q_audio_words, q_stop_flag, q_caught_word))
                # start thread
                transcribe_thread.start()
                check_word_thread.start()
                # close thread
                transcribe_thread.join()
                check_word_thread.join()
                end = timer()
                print("Time elapse for transcribing: " + str(timedelta(seconds=end-start)))

                # set url information if there are words got captured                
                if not q_caught_word.empty():                    
                    data["Blocked"] = "True"
                    data["CaughtWord"] = q_caught_word.get()
                    print(colored(data["CaughtWord"], "red"))
                    
                # insert current information into database
                print("Insert into database")        
                db.insert_url_details(data)

                #remove crawler
                crawlerMG.del_crawler(url)
            except Exception as e:
                print(e)
            finally:   
                return jsonify(data)
        else:    
            return jsonify({"Blocked": "Wait"})
Example #57
0
    def randomPickHelper(self,node,N):
        if self._levelFlag == False:
            self.generateLevel()
            self._levelFlag = True

        pickedList = {}
        q = Queue()
        q.put((node,N))
        while not q.empty():
            size = q.qsize()

            for i in range(size):
                item,target = q.get()

                # print(item.getName())
                childNum = item.childrenSize()
                if childNum == 0:
                    continue

                count = childNum
                left = target
                #assign target to each child
                if childNum <= target:
                    remainder = target % childNum
                    unit = target // childNum

                    for childName, childNode in item.getChildren().items():
                        count -= 1

                        if remainder > 0:
                            tmp = unit + 1
                            remainder -= 1
                        else:
                            tmp = unit
                        #stop condition
                        if childNode.getLevel() == self._levels -2:
                            if childNode.leafSize() == tmp:
                                for k, v in childNode.getLeafDir().items():
                                    if k in pickedList.keys():
                                        continue
                                    pickedList[k] = v
                                    left -= 1
                            elif childNode.leafSize() < tmp:
                                left -= childNode.leafSize()

                                # re-assgin
                                if count > 0:
                                    remainder = left % count
                                    unit = left // count

                                for k, v in childNode.getLeafDir().items():
                                    if k in pickedList.keys():
                                        left += 1
                                        continue
                                    pickedList[k] = v
                            else:

                                left -= tmp
                                for i in range(tmp):
                                    self.pick(pickedList,childNode.getLeafDir())
                            continue

                        if childNode.leafSize() == tmp:

                            for k,v in childNode.getLeafDir().items():
                                if k in pickedList.keys():
                                    continue
                                pickedList[k] = v
                                left -= 1

                        elif childNode.leafSize() < tmp:
                            left -= childNode.leafSize()

                            #re-assgin

                            if count > 0:
                                remainder = left % count
                                unit = left // count

                            for k,v in childNode.getLeafDir().items():
                                if k in pickedList.keys():
                                    left +=1
                                    continue
                                pickedList[k] = v
                        else:
                            q.put((childNode,tmp))

                #randomly pick target
                else:

                    if item.getLevel() == self._levels - 2:
                        for i in range(min(target,childNum)):
                            self.pick(pickedList,item.getLeafDir())

                    elif item.getLevel() == self._levels -1 :
                        continue
                    else:
                        for i in range(target):
                            tmpNode = item.getChildren().get(random.choice(list(item.getChildren())))
                            self.pick(pickedList,tmpNode.getLeafDir())
        return pickedList
Example #58
0
class TTS(object):
    def __init__(self):
        self.clients = []
        self.voice_choices = []
        self.queue = Queue()
        if 'win32com' not in globals():
            return
        Thread(target=self._background).start()

    def _background(self):
        pythoncom.CoInitialize()
        self.tts = win32com.client.Dispatch("SAPI.SpVoice")
        self.voices = self.tts.GetVoices()
        self.voices = [self.voices.Item(i) for i in range(self.voices.Count)]
        self.voice_choices = [
            dict(desc=v.GetDescription(), id=i)
            for i, v in enumerate(self.voices)
        ]
        self.tts.Rate = -5
        self.event_sink = win32com.client.WithEvents(self.tts, TTSEventSink)
        self.event_sink.setTTS(self)
        while True:
            self._speak(self.queue.get(True))

    def _speak(self, text):
        self._speaking = True
        self.tts.Skip("Sentence", INT32_MAX)
        self.tts.Speak(text, SVSFlagsAsync)
        self._pump()

    def speak(self, text):
        while True:
            try:
                self.queue.get(False)
            except Empty:
                break
        self.queue.put(text)

    def get_voice_choices(self):
        return self.voice_choices

    def set_voice(self, voice_id):
        self.tts.Voice = self.voices[voice_id]

    def handle_event(self, event, *args):
        msg = dict(type=event)
        if event == 'end':
            self._speaking = False
        elif event == 'word':
            msg.update(dict(char_pos=args[0], length=args[1]))
        msg = json.dumps(msg)
        for c in self.clients:
            c.write_message(msg)

    def _pump(self):
        skipped = False
        while self._speaking:
            if not skipped and not self.queue.empty():
                self.tts.Skip("Sentence", INT32_MAX)
                skipped = True
            pythoncom.PumpWaitingMessages()
            time.sleep(0.05)
Example #59
0
                sleep(0)
            else:
                handler.join(1)
                command_handlers.remove(handler)

        logging.info('Stopping and Releasing OT-2 Commander')
        ot2.disconnect.set()
        sleep(0)
        if ot2.is_alive():
            command_queue.join()
            ot2.stop.set()
            sleep(0)
        ot2.join(2)

        logging.info('Resolving Outgoing Missives')
        while not response_queue.empty():
            response = response_queue.get()
            address = response['address']
            for handler in command_handlers:
                if address == handler.address:
                    handler.response = response['missive']
                    handler.send.set()
            response_queue.task_done()
        response_queue.join()

        logging.info('Finalizing Command Handlers')
        for handler in command_handlers:
            handler.stop.set()
            handler.join(2)
            command_handlers.remove(handler)
Example #60
0
class Fish():
    """This class models each fish robot node in the collective from the fish'
    perspective.

    Each fish has an ID, communicates over the channel, and perceives its
    neighbors and takes actions accordingly. In taking actions, the fish can
    weight information from neighbors based on their distance. Different collective behaviors run different methods of this class. It can perceive and move according to its perceptual and dynamics model, and updates its behavior on every clock tick.

    Attributes:
        behavior (str): Behavior that fish follows
        body_length (int): Length of a BlueBot (130mm)
        caudal (int): Caudal fin control
        channel (Class): Communication channel
        clock (int): Local clock time
        clock_freq (float): Clock speed (Hz)
        clock_speed (float): Clock speed (s)
        d_center (int): Relative distance to center of perceived neighbors
        dorsal (int): Dorsal fin control
        dynamics (Class): Fish dynamics model
        fish_max_speed (int): Maximum forward speed of fish (old simulator)
        hop_count (int): Hop count variable
        hop_count_initiator (bool): Hop count started or not
        hop_distance (int): Hop distance to other fish
        id (int): ID number of fish
        info (str): Some information
        info_clock (int): Time stamp of the information, i.e., the clock
        info_hops (int): Number of hops until the information arrived
        initial_hop_count_clock (int): Hop count start time
        interaction (Class): Class for interactions between fish
        is_started (bool): True/false
        last_hop_count_clock (TYPE): Time since last hop count
        last_leader_election_clock (int): Time since last leader election
        leader_election_max_id (int): Highest ID
        lim_neighbors (int): Max. and min. desired number of neighbors
        messages (list): Messages between fish
        name (string): Name for logger file
        neighbor_weight (float): Gain that influeces decision making
        neighbors (set): Set of observed neighboring fish
        pect_l (int): Pectoral left fin control
        pect_r (int): Pectoral right fin control
        queue (TYPE): Message queue for messages from neighbors
        saw_hop_count (bool): True/false
        status (str): Behavioral status
        target_depth (int): Target depth for diving
        target_dist (int): Target distance for orbiting
        target_pos (int): Target position instructed by observer
        verbose (bool): Print statements on/off
    """

    def __init__(
        self,
        id,
        channel,
        interaction,
        dynamics,
        w_blindspot=50,
        r_blocking=65,
        target_dist=390,
        lim_neighbors=[0, math.inf],
        fish_max_speed=1,
        clock_freq=1,
        neighbor_weight=1.0,
        name='Unnamed',
        verbose=False
    ):
        """Create a new fish

        Arguments:
            id (TYPE): UUID of fish
            channel (Class): Communication channel
            interaction (Class): Class for interactions between fish
            dynamics (Class): Fish dynamics model
            target_dist (int, optional): target_distance to neighbors
            lim_neighbors (int, int): Lower and upper limit of neighbors each
                fish aims to be connected to.
                (default: {0, math.inf})
            fish_max_speed (float): Max speed of each fish. Defines by how
                much it can change its position in one simulation step.
                (default: {1})
            clock_freq (number): Behavior update rate in Hertz (default: {1})
            neighbor_weight (number): A weight based on distance that defines
                how much each of a fish's neighbor affects its next move.
                (default: {1.0})
            name (str): Unique name of the fish. (default: {'Unnamed'})
            verbose (bool): If `true` log out some stuff (default: {False})
        """

        self.id = id
        self.channel = channel
        self.interaction = interaction
        self.dynamics = dynamics
        self.w_blindspot = w_blindspot
        self.r_blocking = r_blocking
        self.target_dist = target_dist
        self.neighbor_weight = neighbor_weight
        self.lim_neighbors = lim_neighbors
        self.fish_max_speed = fish_max_speed
        self.clock_freq = clock_freq
        self.name = name
        self.verbose = verbose

        self.caudal = 0
        self.dorsal = 0
        self.pect_r = 0
        self.pect_l = 0
        self.target_depth = 0

        self.d_center = 0
        self.body_length = 130
        self.clock_speed = 1 / self.clock_freq
        self.clock = 0
        self.queue = Queue()
        self.target_pos = np.zeros((3,))
        self.is_started = False
        self.neighbors = set()

        self.status = None
        self.behavior = 'home'

        self.info = None  # Some information
        self.info_clock = 0  # Time stamp of the information, i.e., the clock
        self.info_hops = 0  # Number of hops until the information arrived
        self.last_hop_count_clock = -math.inf
        self.hop_count = 0
        self.hop_distance = 0
        self.hop_count_initiator = False
        self.initial_hop_count_clock = 0

        self.leader_election_max_id = -1
        self.last_leader_election_clock = -1

        now = datetime.datetime.now()

        # Stores messages to be sent out at the end of the clock cycle
        self.messages = []

        # Logger instance
        # with open('{}_{}.log'.format(self.name, self.id), 'w') as f:
        #     f.truncate()
        #     f.write('TIME  ::  #NEIGHBORS  ::  INFO  ::  ({})\n'.format(
        #         datetime.datetime.now())
        #     )

    def start(self):
        """Start the process

        This sets `is_started` to true and invokes `run()`.
        """
        self.is_started = True
        self.run()

    def stop(self):
        """Stop the process

        This sets `is_started` to false.
        """
        self.is_started = False

    def log(self, neighbors=set()):
        """Log current state
        """

        with open('{}_{}.log'.format(self.name, self.id), 'a+') as f:
            f.write(
                '{:05}    {:04}    {}    {}\n'.format(
                    self.clock,
                    len(neighbors),
                    self.info,
                    self.info_hops
                )
            )

    def run(self):
        """Run the process recursively

        This method simulates the fish and calls `eval` on every clock tick as
        long as the fish `is_started`.
        """

        while self.is_started:

            start_time = time.time()
            self.eval()
            time_elapsed = time.time() - start_time

            sleep_time = (self.clock_speed / 2) - time_elapsed

            # print(time_elapsed, sleep_time, self.clock_speed / 2)
            time.sleep(max(0, sleep_time))
            if sleep_time < 0 and self.verbose:
                print('Warning frequency too high or computer too slow')

            start_time = time.time()
            self.communicate()
            time_elapsed = time.time() - start_time

            sleep_time = (self.clock_speed / 2) - time_elapsed
            time.sleep(max(0, sleep_time))
            if sleep_time < 0 and self.verbose:
                print('Warning frequency too high or computer too slow')


    def move_handler(self, event):
        """Handle move events, i.e., update the target position.

        Arguments:
            event (Move): Event holding an x, y, and z target position
        """
        self.target_pos[0] = event.x
        self.target_pos[1] = event.y
        self.target_pos[2] = event.z

    def ping_handler(self, neighbors, rel_pos, event):
        """Handle ping events

        Arguments:
            neighbors {set} -- Set of active neighbors, i.e., nodes from which
                this fish received a ping event.
            rel_pos {dict} -- Dictionary of relative positions from this fish
                to the source of the ping event.
            event {Ping} -- The ping event instance
        """
        neighbors.add(event.source_id)

        # When the other fish is not perceived its relative position is [0,0]
        rel_pos[event.source_id] = self.interaction.perceive_pos(
            self.id, event.source_id
        )

        if self.verbose:
            print('Fish #{}: saw friend #{} at {}'.format(
                self.id, event.source_id, rel_pos[event.source_id]
            ))

    def homing_handler(self, event, pos):
        """Homing handler, i.e., make fish aggregated extremely

        Arguments:
            event {Homing} -- Homing event
            pos {np.array} -- Position of the homing event initialtor
        """
        self.info = 'signal_aircraft'  # Very bad practice. Needs to be fixed!
        self.info_clock = self.clock

        self.messages.append(
            (self, InfoInternal(self.id, self.clock, self.info))
        )

        # update behavior based on external event
        self.status = 'wait'
        self.target_pos = self.interaction.perceive_object(self.id, pos)

        if self.verbose:
            print('Fish #{} got external info {}'.format(
                self.id, event.message
            ))

    def info_ext_handler(self, event):
        """External information handler

        Always accept the external information and spread the news.

        Arguments:
            event {InfoExternal} -- InfoExternal event
        """
        self.info = event.message
        self.info_clock = self.clock

        self.messages.append(
            (self, InfoInternal(self.id, self.clock, self.info))
        )

        if self.verbose:
            print('Fish #{} got external info {}'.format(
                self.id, event.message
            ))

    def info_int_handler(self, event):
        """Internal information event handler.

        Only accept the information of the clock is higher than from the last
        information

        Arguments:
            event {InfoInternal} -- Internal information event instance
        """
        if self.info_clock >= event.clock:
            return

        self.info = event.message
        self.info_clock = event.clock
        self.info_hops = event.hops + 1

        self.messages.append((
            self,
            InfoInternal(self.id, self.info_clock, self.info, self.info_hops)
        ))

        if self.verbose:
            print('Fish #{} got info: {} from #{}'.format(
                self.id, event.message, event.source_id
            ))

    def hop_count_handler(self, event):
        """Hop count handler

        Initialize only of the last hop count event is 4 clocks old. Otherwise
        update the hop count and resend the new value only if its larger than
        the previous hop count value.

        Arguments:
            event {HopCount} -- Hop count event instance
        """
        # initialize
        if (self.clock - self.last_hop_count_clock) > 4:
            self.hop_count_initiator = False
            self.hop_distance = event.hops + 1
            self.hop_count = event.hops + 1
            self.messages.append((
                self,
                HopCount(self.id, self.info_clock, self.hop_count)
            ))

        else:
            # propagate value
            if self.hop_count < event.hops:
                self.hop_count = event.hops

                if not self.hop_count_initiator:
                    self.messages.append((
                        self,
                        HopCount(self.id, self.info_clock, self.hop_count)
                    ))

        self.last_hop_count_clock = self.clock

        if self.verbose:
            print('Fish #{} counts hops {} from #{}'.format(
                self.id, event.hop_count, event.source_id
            ))

    def start_hop_count_handler(self, event):
        """Hop count start handler

        Always accept a new start event for a hop count

        Arguments:
            event {StartHopCount} -- Hop count start event
        """
        self.last_hop_count_clock = self.clock
        self.hop_distance = 0
        self.hop_count = 0
        self.hop_count_initiator = True
        self.initial_hop_count_clock = self.clock

        self.messages.append((
            self,
            HopCount(self.id, self.info_clock, self.hop_count)
        ))

        if self.verbose:
            print('Fish #{} counts hops {} from #{}'.format(
                self.id, event.hop_count, event.source_id
            ))

    def leader_election_handler(self, event):
        """Leader election handler

        Arguments:
            event {LeaderElection} -- Leader election event instance
        """
        # This need to be adjusted in the future
        if (self.clock - self.last_leader_election_clock) < math.inf:
            new_max_id = max(event.max_id, self.id)
            # propagate value
            if self.leader_election_max_id < new_max_id:
                self.leader_election_max_id = new_max_id

                self.messages.append((
                    self,
                    LeaderElection(self.id, new_max_id)
                ))

        self.last_leader_election_clock = self.clock

    def weight_neighbor(self, rel_pos_to_neighbor): #xx obsolete with lj-pot?
        """Weight neighbors by the relative position to them

        Currently only returns a static value but this could be tweaked in the
        future to calculate a weighted center point.

        Arguments:
            rel_pos_to_neighbor {np.array} -- Relative position to a neighbor

        Returns:
            float -- Weight for this neighbor
        """
        return self.neighbor_weight

    def start_leader_election_handler(self, event):
        """Leader election start handler

        Always accept a new start event for a leader election

        Arguments:
            event {StartLeaderElection} -- Leader election start event
        """
        self.last_leader_election_clock = self.clock
        self.leader_election_max_id = self.id

        self.messages.append((
            self,
            LeaderElection(self.id, self.id)
        ))

    def comp_center(self, rel_pos):
        """Compute the (potentially weighted) centroid of the fish neighbors

        Arguments:
            rel_pos {dict} -- Dictionary of relative positions to the
                neighboring fish.

        Returns:
            np.array -- 3D centroid
        """
        center = np.zeros((3,))
        n = max(1, len(rel_pos))

        for key, value in rel_pos.items():
            weight = self.weight_neighbor(value)
            center += value * weight

        center /= n

        if self.verbose:
            print('Fish #{}: swarm centroid {}'.format(self.id, center))

        return center

    def lj_force(self, neighbors, rel_pos):
        """lj_force derives the Lennard-Jones potential and force based on the relative positions of all neighbors and the desired self.target_dist to neighbors. The force is a gain factor, attracting or repelling a fish from a neighbor. The center is a point in space toward which the fish will move, based on the sum of all weighted neighbor positions.

        Args:
            neighbors (set): Visible neighbors
            rel_pos (dict): Relative positions of visible neighbors

        Returns:
            np.array: Weighted 3D direction based on visible neighbors
        """
        if not neighbors:
            return np.zeros((3,))

        a = 12 # 12
        b = 6 # 6
        epsilon = 100 # depth of potential well, V_LJ(r_target) = epsilon
        gamma = 100 # force gain
        r_target = self.target_dist
        r_const = r_target + 1 * self.body_length

        center = np.zeros((3,))
        n = len(neighbors)

        for neighbor in neighbors:
            r = np.clip(np.linalg.norm(rel_pos[neighbor]), 0.001, r_const)
            f_lj = -gamma * epsilon /r * (a * (r_target / r)**a - 2 * b * (r_target / r)**b)
            center += f_lj * rel_pos[neighbor]

        center /= n

        return center

    def depth_ctrl(self, r_move_g):
        """Controls diving depth based on direction of desired move.

        Args:
            r_move_g (np.array): Relative position of desired goal location in robot frame.
        """
        pitch = np.arctan2(r_move_g[2], math.sqrt(r_move_g[0]**2 + r_move_g[1]**2)) * 180 / math.pi

        if pitch > 1:
            self.dorsal = 1
        elif pitch < -1:
            self.dorsal = 0

    def depth_waltz(self, r_move_g):
        """Controls diving depth in a pressure sensor fashion. Own depth is "measured", i.e. reveiled by the interaction. Depth control is then done based on a target depth coming from a desired goal location in the robot frame.

        Args:
            r_move_g (np.array): Relative position of desired goal location in robot frame.
        """
        depth = self.interaction.perceive_depth(self.id)

        if self.target_depth == 0:
            self.target_depth = depth + r_move_g[2] / 2

        if depth > self.target_depth:
            self.dorsal = 0
        else:
            self.dorsal = 1

    def home(self, r_move_g):
        """Homing behavior. Sets fin controls to move toward a desired goal location.

        Args:
            r_move_g (np.array): Relative position of desired goal location in robot frame.
        """
        caudal_range = 20 # abs(heading) below which caudal fin is switched on

        heading = np.arctan2(r_move_g[1], r_move_g[0]) * 180 / math.pi

        # target to the right
        if heading > 0:
            self.pect_l = min(1, 0.6 + abs(heading) / 180)
            self.pect_r = 0

            if heading < caudal_range:
                self.caudal = min(1, 0.1 + np.linalg.norm(r_move_g[0:2])/(8*self.body_length))
            else:
                self.caudal = 0

        # target to the left
        else:
            self.pect_r = min(1, 0.6 + abs(heading) / 180)
            self.pect_l = 0

            if heading > -caudal_range:
                self.caudal = min(1, 0.1 + np.linalg.norm(r_move_g[0:2])/(8*self.body_length))
            else:
                self.caudal = 0

    def collisions(self, r_move_g):
        """Local collision avoidance where r_move_g comes from a local Lennard-Jones potential.

        Args:
            r_move_g (np.array): Relative position of desired goal location in robot frame.
        """
        caudal_range = 20 # abs(heading) below which caudal fin is switched on

        heading = np.arctan2(r_move_g[1], r_move_g[0]) * 180 / math.pi

        # target to the right
        if heading > 0:
            self.pect_l = min(1, 0.6 + abs(heading) / 180)

            if heading < caudal_range:
                self.caudal = min(self.caudal+0.5, self.caudal+0.2 + np.linalg.norm(r_move_g[0:2])/(8*self.body_length))

        # target to the left
        else:
            self.pect_r += min(1, 0.6 + abs(heading) / 180)

            if heading > -caudal_range:
                self.caudal = min(self.caudal+0.5, self.caudal+0.2 + np.linalg.norm(r_move_g[0:2])/(8*self.body_length))

    def transition(self, r_move_g):
        """Transitions between homing and orbiting. Uses pectoral right fin to align tangentially with the orbit.

        Args:
            r_move_g (np.array): Relative position of desired goal location in robot frame.
        """
        self.caudal = 0
        self.pect_l = 0
        self.pect_r = 1

        heading = np.arctan2(r_move_g[1], r_move_g[0]) * 180 / math.pi

        if heading > 35:
            self.pect_r = 0
            self.behavior = 'orbit'

    def orbit(self, r_move_g, target_dist):
        """Orbits an object, e.g. two vertically stacked LEDs, at a predefined radius

        Uses four zones to control the orbit with pectoral and caudal fins. The problem is reduced to 2D and depth control is handled separately.
        Could make fin frequencies dependent on distance and heading, i.e., use proportianl control.

        Args:
            r_move_g (np.array): Relative position of desired goal location in robot frame.
            target_dist (int): Target orbiting radius, [mm]
        """
        dist = np.linalg.norm(r_move_g[0:2]) # 2D, ignoring z
        heading = np.arctan2(r_move_g[1], r_move_g[0]) * 180 / math.pi

        if dist > target_dist:
            if heading < 90:
                self.caudal = 0.45
                self.pect_l = 0
                self.pect_r = 0
            else:
                self.caudal = 0.3
                self.pect_l = 1
                self.pect_r = 0
        else:
            if heading < 90:
                self.caudal = 0.45
                self.pect_l = 0
                self.pect_r = 1
            else:
                self.caudal = 0.45
                self.pect_l = 0
                self.pect_r = 0

    def move(self, neighbors, rel_pos):
        """Make a cohesion and target-driven move

        The move is determined by the relative position of the centroid and a
        target position and is limited by the maximum fish speed.

        Arguments:
            neighbors (TYPE): Description
            rel_pos (TYPE): Description
            neighbors {set} -- Set of active neighbors, i.e., other fish that
                responded to the most recent ping event.
            rel_pos {dict} -- Relative positions to all neighbors

        Returns:
            np.array -- Move direction as a 3D vector
        """

        # Get the centroid of the swarm
        centroid_pos = np.zeros((3,))

        # Get the relative direction to the centroid of the swarm
        centroid_pos = self.lj_force(neighbors, rel_pos)
        #centroid_pos = -self.comp_center(rel_pos)

        move = self.target_pos + centroid_pos

        # Global to Robot Transformation
        r_T_g = self.interaction.rot_global_to_robot(self.id)
        r_move_g = r_T_g @ move

        # Simulate dynamics and restrict movement
        self.depth_ctrl(r_move_g)
        self.home(r_move_g)

        self.dynamics.update_ctrl(self.dorsal, self.caudal, self.pect_r, self.pect_l)
        final_move = self.dynamics.simulate_move(self.id)

        return final_move

    def update_behavior(self):
        """Update the fish behavior.

        This actively changes the cohesion strategy to either 'wait', i.e, do
        not care about any neighbors or 'signal_aircraft', i.e., aggregate with
        as many fish friends as possible.

        In robotics 'signal_aircraft' is a secret key word for robo-fish-nerds
        to gather in a secret lab until some robo fish finds a robo aircraft.
        """
        if self.status == 'wait':
            self.lim_neighbors = [0, math.inf]
        elif self.info == 'signal_aircraft':
            self.lim_neighbors = [math.inf, math.inf]

    def eval(self):
        """The fish evaluates its state

        Currently the fish checks all responses to previous pings and evaluates
        its relative position to all neighbors. Neighbors are other fish that
        received the ping element.
        """

        # Set of neighbors at this point. Will be reconstructed every time
        neighbors = set()
        rel_pos = {}

        self.saw_hop_count = False

        while not self.queue.empty():
            (event, pos) = self.queue.get()

            if event.opcode == PING:
                self.ping_handler(neighbors, rel_pos, event)

            if event.opcode == HOMING:
                self.homing_handler(event, pos)

            if event.opcode == START_HOP_COUNT:
                self.start_hop_count_handler(event)

            if event.opcode == HOP_COUNT:
                self.hop_count_handler(event)

            if event.opcode == INFO_EXTERNAL:
                self.info_ext_handler(event)

            if event.opcode == INFO_INTERNAL:
                self.info_int_handler(event)

            if event.opcode == START_LEADER_ELECTION:
                self.start_leader_election_handler(event)

            if event.opcode == LEADER_ELECTION:
                self.leader_election_handler(event)

            if event.opcode == MOVE:
                self.move_handler(event)

        if self.clock > 1:
            # Move around (or just stay where you are)
            self.d_center = np.linalg.norm(self.comp_center(rel_pos)) # mean neighbor distance

            no_neighbors_before = len(neighbors)
            self.interaction.blind_spot(self.id, neighbors, rel_pos)
            no_neighbors_blind = len(neighbors)
            self.interaction.occlude(self.id, neighbors, rel_pos, self.r_blocking)
            no_neighbors_blocking = len(neighbors)
            if self.id == 5:
                #print('fish #5 sees {} neighbors before blindspot and {} after in current iteration'.format(no_neighbors_before, no_neighbors_blind))
                print('fish #5 sees {} neighbors before blocking sphere and {} after in current iteration'.format(no_neighbors_blind, no_neighbors_blocking))
            self.interaction.move(self.id, self.move(neighbors, rel_pos))

        # Update behavior based on status and information - update behavior
        self.update_behavior()

        self.neighbors = neighbors

        # self.log(neighbors)
        self.clock += 1

    def communicate(self):
        """Broadcast all collected event messages.

        This method is called as part of the second clock cycle.
        """
        for message in self.messages:
            self.channel.transmit(*message)

        self.messages = []

        # Always send out a ping to other fish
        self.channel.transmit(self, Ping(self.id))