Example #1
0
 def __init__(self, num_chairs):
     self.shopLock = Lock()
     self.numChairs = num_chairs
     self.nfreeBarbers = 0
     self.nwaitingCustomers = 0
     self.waitingCustomersCond = Condition(self.shopLock)
     self.waitingBarbersCond= Condition(self.shopLock)
Example #2
0
    class _EventRef(object):
        """ Helper class which acts as a threading.Event, but which can be used
            to pass a reference when signaling the event.
        """
        def __init__(self):
            self._cond = Condition(Lock())
            self._flag = False
            self._ref = None

        def isSet(self):
            return self._flag

        def set(self, ref):
            with self._cond:
                assert self._ref is None
                self._ref = ref
                self._flag = True
                self._cond.notifyAll()

        def get(self, timeout=None):
            with self._cond:
                if not self._flag:
                    self._cond.wait(timeout)

                if not self._flag:
                    raise TimeoutExceeded('Could not get the reference.')

                return self._ref

        def clear(self, ref):
            with self._cond:
                self._ref = None
                self._flag = False
Example #3
0
class DirtyBit:
    def __init__(self):
        self._mon = RLock()
        self._rc = Condition(self._mon)
        self._dirty = 0
        
    def clear(self):
        self._mon.acquire()
        self._dirty = 0
        #self._rc.notify() only interested in knowing when we are dirty
        self._mon.release()

    def set(self):
        self._mon.acquire()
        self._dirty = 1
        self._rc.notify()
        self._mon.release()

    def value(self):
        return self._dirty

    def wait(self):
        self._mon.acquire()
        self._rc.wait()
        self._mon.release()
Example #4
0
class TestBackend(Backend):
    def __init__(self, queue):
        self.lock = Condition()
        queue._set_backend(self)

        # Statistics
        self.notifies = 0

    def start(self):
        pass

    def stop(self):
        pass

    def start_feedback(self):
        pass

    def queue_lock(self):
        return self.lock

    def queue_notify(self):
        self.notifies += 1

        self.lock.notify_all()

    def sleep(self):
        pass
class SimpleBlockingQueue(object):
    def __init__(self):
        self.items = deque()
        self.condition = Condition()
        self.interrupted = False

    def put(self, item):
        with self.condition:
            self.items.pushleft(item)

    def take(self, block=False):
        with self.condition:
            if block:
                while not self.items:
                    self.condition.wait()
                    if self.interrupted:
                        self.interrupted = False
                        return None
            elif not self.items:
                return None

            return self.items.popright()

    def interrupt(self):
        with self.condition:
            self.interrupted = True
            self.condition.notifyAll()
Example #6
0
class BinaryServer(Thread):

    """
    Socket server forwarding request to internal server
    """

    def __init__(self, internal_server, hostname, port):
        Thread.__init__(self)
        self.socket_server = None
        self.hostname = hostname
        self.port = port
        self.iserver = internal_server
        self._cond = Condition()

    def start(self):
        with self._cond:
            Thread.start(self)
            self._cond.wait()

    def run(self):
        logger.warning("Listening on %s:%s", self.hostname, self.port)
        socketserver.TCPServer.allow_reuse_address = True  # get rid of address already in used warning
        self.socket_server = ThreadingTCPServer((self.hostname, self.port), UAHandler)
        # self.socket_server.daemon_threads = True # this will force a shutdown of all threads, maybe too hard
        self.socket_server.internal_server = self.iserver  # allow handler to acces server properties
        with self._cond:
            self._cond.notify_all()
        self.socket_server.serve_forever()

    def stop(self):
        logger.warning("server shutdown request")
        self.socket_server.shutdown()
 def __init__(self):
     '''Init.'''
     gobject.GObject.__init__(self)
     
     # set condition lock.
     self.__condition = Condition()
     
     # set datebase operation lock.
     self.__db_operation_lock = Condition()
     
     # songs
     self.__hiddens = set()        
     self.__song_types_capability = {}
     self.__songs = {}
     self.__song_types = []
     self.__songs_by_type = {}
     
     # playlist
     self.__playlists = {}
     self.__playlist_types = []
     
     # init constant
     self.__is_loaded = False
     self.__force_check = False
     self.__save_song_type = ["local", "cue", "unknown"]
     self.__dirty = False
     
     # Init queued signal.
     self.__reset_queued_signal()
class Consumer(Thread):
    """Класс потока – обработчика данных"""

    def __init__(self):
        """Конструктор класса"""

        super().__init__()
        self.condition = Condition()
        self.received_data = None

    def run(self):
        """Код, исполняемый в потоке"""

        running_sum = 0
        count = 0

        while True:
            with self.condition:
                # Ожидание доступности данных
                self.condition.wait()
                # Получение данных
                data = self.received_data

            # Обработка данных
            running_sum += data
            count += 1
            print('Got data: {}\nTotal count: {}\nAverage: {}\n'.format(
                data, count, running_sum / count))

    def send(self, data):
        """Метод отправки данных в обработчик"""
        self.received_data = data
Example #9
0
class SendBuffer(object):

    def __init__(self, max_size):
        self.size = 0
        self.frontbuffer = ''
        self.full = Condition()
        self.backbuffer = StringIO()
        self.max_size = max_size

    def __len__(self):
        return len(self.frontbuffer) + self.size

    def write(self, data):
        dlen = len(data)
        with self.full:
            while self.size + dlen > self.max_size and dlen < self.max_size:
                # if a single write is bigger than the buffer, swallow hard.
                # No amount of waitng will make it fit.
                self.full.wait()
            self.backbuffer.write(data)
            self.size += dlen

    def to_sock(self, sock):
        if not self.frontbuffer:
            with self.full:
                buf, self.backbuffer = self.backbuffer, StringIO()
                self.size = 0
                self.full.notify_all()
            self.frontbuffer = buf.getvalue()

        written = sock.send(self.frontbuffer)
        self.frontbuffer = self.frontbuffer[written:]
Example #10
0
class Heartbeater(object):
    interval = 5

    def __init__(self, api, jobstep_id, interval=None):
        self.api = api
        self.jobstep_id = jobstep_id
        self.cv = Condition()
        self.finished = Event()

        if interval is not None:
            self.interval = interval

    def wait(self):
        with self.cv:
            self.finished.clear()
            while not self.finished.is_set():
                data = self.api.get_jobstep(self.jobstep_id)
                if data['status']['id'] == 'finished':
                    self.finished.set()
                    break

                self.cv.wait(self.interval)

    def close(self):
        with self.cv:
            self.finished.set()
            self.cv.notifyAll()
Example #11
0
class Future(object):
    def __init__(self, func, *args, **kwargs):
        self.__done = False
        self.__result = None
        self.__cond = Condition()
        self.__func = func
        self.__args = args
        self.__kwargs = kwargs
        self.__except = None

    def __call__(self):
        with self.__cond:
            try:
                self.__result = self.__func(*self.__args, **self.__kwargs)
            except:
                self.__result = None
                self.__except = sys.exc_info()
            self.__done = True
            self.__cond.notify()

    def result(self):
        with self.__cond:
            while not self.__done:
                self.__cond.wait()
        if self.__except:
            exc = self.__except
            reraise(exc[0], exc[1], exc[2])
        result = self.__result
        return copy.deepcopy(result)
class ResourcePool(object):
    def __init__(self, initial_value):
        super(ResourcePool, self).__init__()
        self.condition = Condition()
        self.value = initial_value

    def acquire(self, amount):
        with self.condition:
            while amount > self.value:
                self.condition.wait()
            self.value -= amount
            self.__validate()

    def release(self, amount):
        with self.condition:
            self.value += amount
            self.__validate()
            self.condition.notifyAll()

    def __validate(self):
        assert 0 <= self.value

    def __str__(self):
        return str(self.value)

    def __repr__(self):
        return "ResourcePool(%i)" % self.value

    @contextmanager
    def acquisitionOf(self, amount):
        self.acquire(amount)
        try:
            yield
        finally:
            self.release(amount)
Example #13
0
class IOManager(Thread):

    def __init__(self,so):
        Thread.__init__(self)
        self.mutexIO = Condition()
        self.listIO = []
        self.handlerIO = so.handlerIO
    
    def process(self):
        for irq in self.listIO:
            print("Ejecuto IOManager")
            self.handlerIO.handle(irq)
        self.removeIOFromListIO()
        
    def run(self):
        while(True):
            with self.mutexIO:
                self.mutexIO.wait()
                self.process()
    
    def removeIOFromListIO(self):
        for reqIO in self.listIO:
            self.listIO.remove(reqIO)
              
    def add(self,iOReq):
        self.listIO.append(iOReq)
Example #14
0
    def __init__(self, node_id, data):
        """
        Constructor.

        @type node_id: Integer
        @param node_id: the unique id of this node; between 0 and N-1

        @type data: List of Integer
        @param data: a list containing this node's data
        """
        self.node_id = node_id
        self.data = data
        # temporary buffer for needed for scatter 
	self.copy = data[:]
        self.lock_copy = Lock()
        self.nodes = None
        self.lock_data = Lock()
        # list of threads (in this case 16 fo each node)
	self.thread_list = []
	# list with tasks that need to be computed
        self.thread_pool = []
        self.mutex = Lock()
        # condition used for put and get
	self.condition = Condition(self.mutex)
	# condition needed for checking if there are 
	# still tasks that need o be solved
        self.all_tasks_done = Condition(self.mutex)
	# number of unfinished tasks
        self.unfinished_tasks = 0
        # start the 16 threads
	for i in range(16):
            th = Worker(self, i)
            self.thread_list.append(th)
            th.start()
Example #15
0
    def __init__(self, client):
        self.prefix = "auau:"
        self.mc = client

        self.random = Random()
        self.random.seed(os.urandom(128))
        self.random.jumpahead(os.getpid())
        # thread exit flag
        self.exit_flag = [False]

        from threading import Thread
        from threading import Condition

        # asynchronous deflation thread
        self.def_cv = Condition()
        self.def_que = []
        self.def_thread = Thread(
            target=lambda: self.async_work(self.def_cv, self.def_que, self.exit_flag, lambda x: self.deflate(x))
        )
        self.def_thread.setDaemon(True)
        self.def_thread.start()

        # asynchronous deletion thread
        self.del_cv = Condition()
        self.del_que = []
        self.del_thread = Thread(
            target=lambda: self.async_work(self.del_cv, self.del_que, self.exit_flag, lambda x: self.mc.delete(x))
        )
        self.del_thread.setDaemon(True)
        self.del_thread.start()
Example #16
0
    def request(self, request_type, request, sequential=False):
        """
        """
        cond = Condition()
        # dict is used as reference type to allow result to be returned from
        # seperate thread
        response_cont = {}

        def on_success(response_type, response):
            response_cont.update({
                "success": True,
                "type": response_type,
                "response": response,
            })
            with cond:
                cond.notify()

        def on_error(exception):
            response_cont.update({
                "success": False,
                "exception": exception,
            })
            with cond:
                cond.notify()

        with cond:
            self.request_async(request_type, request,
                               on_success, on_error,
                               sequential)
            cond.wait()

        if not response_cont["success"]:
            raise response_cont["exception"]

        return response_cont["type"], response_cont["response"]
class Presponse(PktHandler):
    def __init__(self):
        self.ping_response_wait = Condition()
        '''
        To wait for a ping response, do the following:{
        start = time.time()
        with Presponse.ping_response_wait:
            Presponse.ping_response_wait.wait(5) # Standard timeout in PING latency is 5.
        end = time.time()
        }
        (end - start) is the ping latency.
        '''

    def handle(self, packet_map):
        with self.ping_response_wait:
            self.ping_response_wait.notify_all()
        return None

    def getproperties(self):
        return {
            'DisplayName':'Ping-Response Notifier',
            'CodeName':'PRESPONSE_NOTIFY',
            'PacketType':'PRESPONSE',
            'Version':0.01
        }
Example #18
0
 def __init__(self, threads_num, task_queue_max=None):
     self._threads_num   = threads_num
     self._tasks         = deque()
     self._threads       = []
     self._task_num      = 0
     self._task_lock     = Condition(Lock())
     self._thead_lock    = Condition(Lock())
Example #19
0
 def __init__(self, numchairs):
     self.barbers_ready = False
     self.numchairs = numchairs
     self.open_seats = numchairs
     self.shop_lock = Lock()
     self.barber_condition = Condition(self.shop_lock)
     self.customer_condition = Condition(self.shop_lock)
Example #20
0
 def __init__(self):
     self.nwasherWaiting = 0 # number of sims waiting for washing hands
     self.nwasherUsing = 0 # number of sims washing hands right now
     self.toiletBusy = 0 # is toilet busy right now ?
     self.bathroomLock = Lock();
     self.toiletLine = Condition(self.bathroomLock)
     self.washerLine = Condition(self.bathroomLock)
Example #21
0
class Broadcaster(object):
    def __init__(self):
        self._condition = Condition()
        self._last_id = 0
        self._queue = []
    
    @acquire_then_notify
    def send(self, item):
        print "Sending message", item
        self._last_id += 1
        self._queue.append((self._last_id, item))
    
    def _find_items(self, since_id=None):
        if not since_id:
            found = self._queue
        else:
            found = [(id, item) for (id,item) in self._queue if id > since_id]
        return found
    
    @acquire
    def recv(self, since_id=None, timeout=10):
        end_time = time() + timeout
        while time() < end_time:
            found = self._find_items(since_id)
            if found:
                break
            print "Waiting"
            self._condition.wait(timeout)
        print "found %d" % len(found)
        return found
class BackUpThread (threading.Thread):
    
    def __init__(self):
        threading.Thread.__init__ ( self )
        self.srcfile = "D:\study\operating system prac\CS4410\mp3\mp3\src\mail.txt"
        self.destfile = "D:\study\operating system prac\CS4410\mp3\mp3\src\\backup.txt"
        self.mutex = Lock()
        self.cv = Condition(self.mutex)
        self.msgCount = 0
        
    def run(self):
        with self.mutex:
            while True:
                # TODO: BUG here.
                while self.msgCount != 32:
                    self.cv.wait()

                print "Backing up the mail file."
#                TODO: copy only the new part.
#                desthandle = open(self.destfile, "r")
#                desthandle.seek(0, 2)
#                offset = desthandle.tell()
                shutil.copyfile(self.srcfile, self.destfile)
                self.msgCount = 0

    def newMsg(self):
        with self.mutex:
            self.msgCount += 1
            if self.msgCount == 32:
                self.cv.notifyAll()
Example #23
0
class Worker(Thread):
    def __init__(self):
        Thread.__init__(self)
        self.lock = Lock()
        self.cond = Condition(self.lock)
        self.stopped = False
        self.queue = []

    def run(self):
        while True:
            job = None
            with self.cond:
                if self.stopped:
                    return
                if not self.queue:
                    self.cond.wait()
                else:
                    job, params = self.queue.pop(0)
            if not job:
                continue
            self.execute(job, params)

    def execute(self, job, params):
        try:
            func, args = job(*params)
            # The async function may decide to NOT update
            # the UI:
            if func:
                gobject.idle_add(func, *args)
        except Exception, e:
            print "Warning:", e
Example #24
0
 def __init__(self, algorithm, interpreter, change_listener, debug=False,
              progress_listener=None):
     if not isinstance(algorithm, Algorithm):
         raise TypeError('%r is not a matching algorithm object' %
             algorithm)
     
     self.current_location = None
     self.previous_location = None
     self._debug_enabled = debug
     self._running = False
     self._within_piece = False
     self._startup_condition = Condition()
     self._piece_control = Condition()
     
     self._playing_piece = False
     self._history = HistoryQueue()
     self.intervals = None
     self._incoming_notes = None
     
     self.interpreter = interpreter
     self.change_listener = change_listener
     self.progress_listener = progress_listener
     
     algorithm.assign_matcher(self)
     self._algorithm = algorithm
     
     self._thread = None
     self._stopping = False
Example #25
0
 def __init__(self, number_threads):
   self.pool_lock = Lock()
   self.connection_available = Condition(self.pool_lock)
   self.request_available = Condition(self.pool_lock)
   self.connection_pool = []
   self.number_connections = 0
   self.max_connections = number_threads
Example #26
0
 def __init__(self):
     # TODO
     self.carsCrossingNorth = 0
     self.carsCrossingSouth = 0
     self.cvLock = Lock()
     self.southQ = Condition(self.cvLock)
     self.northQ = Condition(self.cvLock)
 def __init__(self):
     self.lock = Lock()
     self.toWrite = Condition(self.lock)
     self.toBackup = Condition(self.lock)
     self.isWriting = False
     self.isBackup = False
     self.emailNum = 1
Example #28
0
 def __init__(self, num_photos=10, sources=[], pool_dir=cache_dir):
     Thread.__init__(self)
     
     self.num_photos = num_photos
     self.sources = sources
     self.dir = pool_dir
     
     # Make sure cache dir exists
     if not os.path.exists(self.dir):
         os.mkdir(self.dir)
     
     # Clean cache directory
     self.clean_cache()
     
     # Load cached photos
     self.photos = os.listdir(self.dir)
     
     # Delete queue
     self.trash = []
     
     # Condition when a new photo is added
     self.added = Condition()
     
     # Condition when a photo is removed
     self.removed = Condition()
     
     # Event for stopping the pool
     self._stop = Event()
Example #29
0
class Future:
    def __init__(self, correlation_id, request_type=None):
        self.correlation_id = correlation_id
        self._result = None
        self._condition = Condition()
        self._request_type = request_type

    def done(self):
        return self._result is not None

    def result(self, timeout=None):
        with self._condition:
            if self._result is None:
                if not self._condition.wait(timeout):
                    message_type = validator_pb2.Message.MessageType.Name(
                        self._request_type) if self._request_type else None
                    raise FutureTimeoutError(
                        'Future timed out waiting for response to {}'.format(
                            message_type))
        return self._result

    def set_result(self, result):
        with self._condition:
            self._result = result
            self._condition.notify()
Example #30
0
    def __init__(self, cache_dir, map_desc, db_schema, is_concurrency=True):
        self.__map_desc = map_desc
        self.__db_path = os.path.join(cache_dir, map_desc.map_id + ".mbtiles")
        self.__conn = None

        #configs
        self.__db_schema = db_schema
        self.__has_timestamp = True

        self.__is_concurrency = is_concurrency

        if is_concurrency:
            self.__surrogate = None  #the thread do All DB operations, due to sqlite3 requiring only the same thread.

            self.__is_closed = False

            #concurrency get/put
            self.__sql_queue = []
            self.__sql_queue_lock = Lock()
            self.__sql_queue_cv = Condition(self.__sql_queue_lock)

            self.__get_lock = Lock()    #block the 'get' action

            self.__get_respose = None   #the pair (data, exception)
            self.__get_respose_lock = Lock()
            self.__get_respose_cv = Condition(self.__get_respose_lock)
Example #31
0
class PoolBlockPut(PoolBase):
    def __init__(self, needed_count):
        super(PoolBlockPut, self).__init__()
        self.condition = Condition()  # 用于实现线程通信
        self.needed_count = needed_count  # 下一次读取所需数据数量

    # 若池子中有足够数据,则阻塞当前线程,否则将数据放入池子
    def put(self, data):
        if not isinstance(data, np.ndarray):
            raise TypeError("Input data must be ndarray!")
        if data.ndim != 1:
            raise ValueError("Input data must be 1D!")

        self.condition.acquire()
        if len(self.datas) >= self.needed_count:
            self.condition.wait()
        self.datas = np.concatenate((self.datas, data))
        self.condition.release()

    # 非循环读取池子中数据。若取后剩余数据小于所需数量,则释放被阻塞线程
    def get(self, count):
        if count < 0:
            raise ValueError("Input count must >= 0!")
        if not isinstance(count, int):
            raise ValueError("Input count must be integer!")

        if count > len(self.datas):
            count = len(self.datas)
            # logging.warning("IO is going to be close!")

        self.condition.acquire()
        re = self.datas[:count].copy()
        self.datas = np.delete(self.datas, range(0, count))
        self.needed_count = count
        if len(self.datas
               ) < 4*self.needed_count:  # 取后剩余数据不够下一次读取,释放被阻塞线程,让其继续生产数据
            self.condition.notify_all()
        self.condition.release()

        return re

    def release(self):
        self.condition.acquire()
        self.condition.notify_all()
        self.condition.release()
Example #32
0
class EventRPCClient(object):
    def __init__(self,
                 event_engine,
                 service,
                 event_client=None,
                 event_server=None):
        self.EVENT_CLIENT = event_client if event_client else "%s_CLIENT" % service.upper(
        )
        self.EVENT_SERVER = event_server if event_server else "%s_SERVER" % service.upper(
        )
        self.rid = 0
        self._handlers = {}
        self._handlers_lock = Lock()
        self._event_engine = event_engine
        self._event_engine.register(self.EVENT_SERVER, self._process_apiback)
        self._pause_condition = Condition()
        self._sync_ret = None
        self._timeout = 0
        self._timer_sleep = 1
        self._sync_call_time_lock = Lock()
        self._sync_call_time = datetime.now()
        timer = Thread(target=self._run_timer)
        timer.daemon = True
        timer.start()

    def _run_timer(self):
        while True:
            if not self._timeout == 0:
                with self._sync_call_time_lock:
                    mtime = self._sync_call_time
                delta = (datetime.now() - mtime).seconds
                if delta >= self._timeout:
                    #print "timeout", self._timeout, delta
                    # 不可重入,保证self.rid就是超时的那个
                    with self._handlers_lock:
                        del self._handlers[self.rid]
                    log.debug("[RPCClient._runtimer] 处理超时, delete rid; %s" %
                              self.rid)
                    self._timeout = 0
                    self._notify_server_data()
            time.sleep(self._timer_sleep)

    def _process_apiback(self, event):
        assert (event.route == self.EVENT_SERVER)
        self._timeout = 0
        rid = event.args['rid']
        try:
            with self._handlers_lock:
                handler = self._handlers[rid]
        except KeyError:
            log.info('[RPCClient._process_apiback] 放弃超时任务的返回结果')
        else:
            try:
                if handler:
                    # 异步
                    handler(event.args['ret'])
                else:
                    # 同步
                    self._sync_ret = event.args['ret']
                    self._notify_server_data()
            except Exception as e:
                print e
            log.debug("[RPCClient._process_apiback] 删除已经完成的任务 rid; %s" % rid)
            with self._handlers_lock:
                del self._handlers[rid]

    def call(self, apiname, args, handler):
        """ 给定参数args,异步调用RPCServer的apiname服务,
        返回结果做为回调函数handler的参数。
        
        Args:
            apiname (str): 服务API名称。
            args (dict): 给服务API的参数。
            handler (function): 回调函数。
        """
        if not isinstance(args, dict):
            raise InvalidRPCClientArguments(argtype=type(args))
        assert (not handler == None)
        self.rid += 1
        args['apiname'] = apiname
        args['rid'] = self.rid
        self._event_engine.emit(Event(self.EVENT_CLIENT, args))
        with self._handlers_lock:
            self._handlers[self.rid] = handler

    def sync_call(self, apiname, args, timeout=10):
        """ 给定参数args,同步调用RPCServer的apiname服务,
        返回该服务的处理结果。如果超时,返回None。
        
        Args:
            apiname (str): 服务API名称。
            args (dict): 给服务API的参数。
            handler (function): 回调函数。
        """
        log.debug('sync_call: %s' % apiname)
        if not isinstance(args, dict):
            self._timeout = 0
            self._sync_ret = None
            raise InvalidRPCClientArguments(argtype=type(args))
        self.rid += 1
        args['apiname'] = apiname
        args['rid'] = self.rid
        with self._sync_call_time_lock:
            self._sync_call_time = datetime.now()
        self._timeout = timeout
        with self._handlers_lock:
            self._handlers[self.rid] = None
        self._event_engine.emit(Event(self.EVENT_CLIENT, args))
        self._waiting_server_data()
        ret = self._sync_ret
        #self._sync_ret = None
        return ret

    def _waiting_server_data(self):
        with self._pause_condition:
            self._pause_condition.wait()

    def _notify_server_data(self):
        with self._pause_condition:
            self._pause_condition.notify()
        cmd = config['source'][len('pipe:'):]
        (child_out,source) = os.popen2( cmd, 'b' )
    else:
        # File source
        source = open(config['source'],"rb")
        dscfg.set_video_ratelimit(tdef.get_bitrate())

    restartstatefilename = config['name']+'.restart'
    dscfg.set_video_source(source, restartstatefilename=restartstatefilename)

    dscfg.set_max_uploads(config['nuploads'])

    d = s.start_download(tdef,dscfg)
    d.set_state_callback(state_callback,getpeerlist=False)

    # condition variable would be prettier, but that don't listen to
    # KeyboardInterrupt
    #time.sleep(sys.maxint/2048)
    #try:
    #    while True:
    #        x = sys.stdin.read()
    #except:
    #    print_exc()
    cond = Condition()
    cond.acquire()
    cond.wait()

    s.shutdown()
    time.sleep(3)
    shutil.rmtree(statedir)
Example #34
0
class ForkingBase(object):
    '''Base class for classes which fork children and wait for them to exit.

    Sub-classes must provide the following data attributes and methods:

        log - an object of type getmailcore.logging.Logger()

    '''
    def _child_handler(self, sig, stackframe):
        def notify():
            self.__child_exited.acquire()
            self.__child_exited.notify_all()
            self.__child_exited.release()

        self.log.trace('handler called for signal %s' % sig)
        try:
            pid, r = os.waitpid(self.child.childpid, 0)
        except OSError as o:
            # No children on SIGCHLD.  Can't happen?
            self.log.warning('handler called, but no children (%s)' % o)
            notify()
            return
        signal.signal(signal.SIGCHLD, self.__orig_handler)
        self.__child_pid = pid
        self.__child_status = r
        self.log.trace('handler reaped child %s with status %s' % (pid, r))
        notify()

    def _prepare_child(self):
        self.log.trace('')
        self.__child_exited = Condition()
        self.__child_pid = 0
        self.__child_status = None
        self.__orig_handler = signal.signal(signal.SIGCHLD,
                                            self._child_handler)

    def _wait_for_child(self, childpid):
        self.__child_exited.acquire()
        if self.__child_exited.wait(60) == False:  # Py2, <Py3.2: always None
            raise getmailOperationError('waiting child pid %d timed out' %
                                        childpid)
        self.__child_exited.release()
        if self.__child_pid != childpid:
            #self.log.error('got child pid %d, not %d' % (pid, childpid))
            raise getmailOperationError('got child pid %d, not %d' %
                                        (self.__child_pid, childpid))
        if os.WIFSTOPPED(self.__child_status):
            raise getmailOperationError(
                'child pid %d stopped by signal %d' %
                (self.__child_pid, os.WSTOPSIG(self.__child_status)))
        if os.WIFSIGNALED(self.__child_status):
            raise getmailOperationError(
                'child pid %d killed by signal %d' %
                (self.__child_pid, os.WTERMSIG(self.__child_status)))
        if not os.WIFEXITED(self.__child_status):
            raise getmailOperationError('child pid %d failed to exit' %
                                        self.__child_pid)
        exitcode = os.WEXITSTATUS(self.__child_status)
        return exitcode

    def _pipemail(self, msg, delivered_to, received, unixfrom, stdout, stderr):
        # Write out message
        msgfile = TemporaryFile23()
        msgfile.write(
            msg.flatten(delivered_to, received, include_from=unixfrom))
        msgfile.flush()
        os.fsync(msgfile.fileno())
        # Rewind
        msgfile.seek(0)
        # Set stdin to read from this file
        os.dup2(msgfile.fileno(), 0)
        # Set stdout and stderr to write to files
        os.dup2(stdout.fileno(), 1)
        os.dup2(stderr.fileno(), 2)

    def child_replace_me(self,
                         msg,
                         delivered_to,
                         received,
                         unixfrom,
                         stdout,
                         stderr,
                         args,
                         nolog=False):
        self._pipemail(msg, delivered_to, received, unixfrom, stdout, stderr)
        nolog or self.log.debug('about to execl() with args %s\n' % str(args))
        os.execl(*args)

    def forkchild(self, childfun, with_out=True):
        self.child = child = Namespace()
        child.stdout = TemporaryFile23()
        child.stderr = TemporaryFile23()
        child.childpid = os.fork()
        if child.childpid != 0:  # here (in the parent)
            self._prepare_child()
            self.log.debug('spawned child %d\n' % child.childpid)
            child.exitcode = self._wait_for_child(child.childpid)
            child.stderr.seek(0)
            child.err = child.stderr.read().strip().decode()
            child.stdout.seek(0)
            if with_out:
                child.out = child.stdout.read().strip()
            return child
        else:  #== 0 in the child
            # calls child_replace_me to execl external command
            childfun(child.stdout, child.stderr)

    def get_msginfo(self, msg):
        msginfo = {}
        msginfo['sender'] = msg.sender.strip()
        if msg.recipient != None:
            rcpnt = msg.recipient.strip()
            msginfo['recipient'] = rcpnt
            msginfo['domain'] = rcpnt.lower().split('@')[-1]
            msginfo['local'] = '@'.join(rcpnt.split('@')[:-1])
        self.log.debug('msginfo "%s"\n' % msginfo)
        return msginfo
Example #35
0
class TaggerWorker(QObject):
    """This class is required to prevent locking up the main Qt thread."""
    on_error = pyqtSignal(str)
    on_review_ready = pyqtSignal(tagger.UpdatesResult)
    on_updates_sent = pyqtSignal(int)
    on_stopped = pyqtSignal()
    on_mint_mfa = pyqtSignal()
    on_progress = pyqtSignal(str, int, int)
    stopping = False
    mfa_condition = Condition()

    @pyqtSlot()
    def stop(self):
        self.stopping = True

    @pyqtSlot(str)
    def mfa_code(self, code):
        logger.info('Got code')
        logger.info(code)
        self.mfa_code = code
        logger.info('Waking thread')
        self.mfa_condition.notify()

    @pyqtSlot(object)
    def create_updates(self, args, parent):
        try:
            self.do_create_updates(args, parent)
        except Exception as e:
            msg = 'Internal error while creating updates: {}'.format(e)
            self.on_error.emit(msg)
            logger.exception(msg)

    @pyqtSlot(list, object)
    def send_updates(self, updates, args):
        try:
            self.do_send_updates(updates, args)
        except Exception as e:
            msg = 'Internal error while sending updates: {}'.format(e)
            self.on_error.emit(msg)
            logger.exception(msg)

    def do_create_updates(self, args, parent):
        def on_mint_mfa(prompt):
            logger.info('Asking for Mint MFA')
            self.on_mint_mfa.emit()
            logger.info('Blocking')
            self.mfa_condition.wait()
            logger.info('got code!')
            logger.info(self.mfa_code)
            return self.mfa_code

        # Factory that handles indeterminite, determinite, and counter style.
        def progress_factory(msg, max=0):
            return QtProgress(msg, max, self.on_progress.emit)

        self.mint_client = MintClient(
            email=args.mint_email,
            password=args.mint_password,
            session_path=args.session_path,
            headless=args.headless,
            mfa_method=args.mint_mfa_method,
            wait_for_sync=args.mint_wait_for_sync,
            mfa_input_callback=on_mint_mfa,
            progress_factory=progress_factory)

        results = tagger.create_updates(
            args, self.mint_client,
            on_critical=self.on_error.emit,
            indeterminate_progress_factory=progress_factory,
            determinate_progress_factory=progress_factory,
            counter_progress_factory=progress_factory)

        if results.success and not self.stopping:
            self.on_review_ready.emit(results)

    def do_send_updates(self, updates, args):
        num_updates = self.mint_client.send_updates(
            updates,
            progress=QtProgress(
                'Sending updates to Mint',
                len(updates),
                self.on_progress.emit),
            ignore_category=args.no_tag_categories)

        self.on_updates_sent.emit(num_updates)
        self.mint_client.close()
            print("{}:在".format(self.name))
            self.cond.notify()
            print("{}:我们来对古诗吧".format(self.name))


class TianMao(threading.Thread):
    def __init__(self, cond):
        super().__init__(name="tianmao")
        self.cond = cond

    def run(self) -> None:
        with self.cond:
            print("{}:小爱同学".format(self.name))
            self.cond.notify()
            self.cond.wait()
            print("{}:好啊".format(self.name))


if __name__ == '__main__':
    cond = Condition()
    xiaoAi = XiaoAi(cond=cond)
    tianMao = TianMao(cond=cond)

    # 1. 启动顺序很重要
    # 2. 在调用with cond之后才能调用wait或者notify方法,with自动实现了acquire和release,否则需要自己输入这两者
    # 3. condition有两层锁,一把底层锁会在线程调用了wait方法的时候释放,上面的锁会在每次调用wait的时候分配一把并发
    # 到cond的等待队列中,等待notify方法的唤醒。

    xiaoAi.start()  # 必须先启动xiaoai,不然notify唤醒不了xiaoai。
    tianMao.start()
Example #37
0
 def __init__(self, needed_count):
     super(PoolBlockPut, self).__init__()
     self.condition = Condition()  # 用于实现线程通信
     self.needed_count = needed_count  # 下一次读取所需数据数量
class BlockPublisher(BlockPublisherInterface):
    """DevMode consensus uses genesis utility to configure Min/MaxWaitTime
     to determine when to claim a block.
     Default MinWaitTime to zero and MaxWaitTime is 0 or unset,
     ValidBlockPublishers default to None or an empty list.
     DevMode Consensus (BlockPublisher) will read these settings
     from the StateView when Constructed.
    """
    def __init__(self, block_cache, state_view_factory, batch_publisher,
                 data_dir, config_dir, validator_id):
        super().__init__(block_cache, state_view_factory, batch_publisher,
                         data_dir, config_dir, validator_id)

        self._block_cache = block_cache
        self._state_view_factory = state_view_factory

        self._start_time = 0
        self._wait_time = 0

        # Set these to default values right now, when we asked to initialize
        # a block, we will go ahead and check real configuration
        self._min_wait_time = 0.01
        self._max_wait_time = 0.06
        self._valid_block_publishers = None  # list of validator which can participate into consensus
        self._consensus = None
        self._condition = Condition()
        self._is_finalize_complete = None

    def set_consensus_name(self, name):
        self._consensus = bytes(name, 'utf-8')
        LOGGER.debug("PROXY:set_consensus_name=%s->%s", name, self._consensus)

    def set_publisher(self, publisher):
        self._publisher = publisher
        LOGGER.debug("PROXY:set_publisher=%s", publisher)

    def initialize_block(self, block_header):
        """Do initialization necessary for the consensus to claim a block,
        this may include initiating voting activates, starting proof of work
        hash generation, or create a PoET wait timer.

        Args:
            block_header (BlockHeader): the BlockHeader to initialize.
        Returns:
            True
        """
        if not self._consensus:
            LOGGER.debug(
                "initialize_block: external consensus not regitered\n")
            return False
        # Using the current chain head, we need to create a state view so we
        # can get our config values.
        state_view = BlockWrapper.state_view_for_block(
            self._block_cache.block_store.chain_head, self._state_view_factory)

        settings_view = SettingsView(state_view)
        self._min_wait_time = settings_view.get_setting(
            "bgx.consensus.min_wait_time", self._min_wait_time, float)
        self._max_wait_time = settings_view.get_setting(
            "bgx.consensus.max_wait_time", self._max_wait_time, float)
        self._valid_block_publishers = settings_view.get_setting(
            "bgx.consensus.valid_block_publishers",
            self._valid_block_publishers, list)

        block_header.consensus = self._consensus  # b"Devmode"
        self._start_time = time.time()
        self._wait_time = random.uniform(self._min_wait_time,
                                         self._max_wait_time)
        LOGGER.debug(
            "PROXY:initialize_block min_wait_time=%s max_wait_time=%s",
            self._min_wait_time, self._max_wait_time)
        return True

    def check_publish_block(self, block_header):
        """
        Check if a candidate block is ready to be claimed.
        For many peers we should control block's content .
        If this peer is not owner of batch we must wait until all batches which were putted into block for peer owner of batch 
        will be putted into block for this peer too.  

        block_header (BlockHeader): the block_header to be checked if it
            should be claimed
        Returns:
            Boolean: True if the candidate block_header should be claimed.
        """
        if self._valid_block_publishers and block_header.signer_public_key not in self._valid_block_publishers:
            return False
        elif self._min_wait_time == 0:
            return True
        elif self._min_wait_time > 0 and self._max_wait_time <= 0:
            if self._start_time + self._min_wait_time <= time.time():
                return True
        elif self._min_wait_time > 0 and self._max_wait_time > self._min_wait_time:
            if self._start_time + self._wait_time <= time.time():
                return True
        else:
            return False

    def finalize_block_complete(self, consensus):
        with self._condition:
            self._is_finalize_complete = consensus
            self._condition.notify()

    def _finalize_complete(self):
        return self._is_finalize_complete is not None

    def finalize_block(self, block_header):
        """Finalize a block to be claimed. Provide any signatures and
        data updates that need to be applied to the block before it is
        signed and broadcast to the network.

        Args:
            block_header (BlockHeader): The candidate block that needs to be
            finalized.
        Returns:
            True
        """
        LOGGER.debug(
            "PROXY:finalize_block inform external engine header=%s is_complete=%s",
            block_header.block_num, self._is_finalize_complete)
        self._publisher.on_finalize_block(block_header)
        self._is_finalize_complete = None
        """
        after that consensus engine should be informed that block could be finalized and engine can say finalize for this candidate
        FIXME - for DAG we can say for all ready candidate that his block's could be finalized and only after that wait engine's reply
        """
        LOGGER.debug(
            "PROXY:finalize_block wait proxy reply via finalize_block_complete...\n"
        )
        with self._condition:
            return self._condition.wait_for(self._finalize_complete)
        return True
class ForkResolver(ForkResolverInterface):
    """Provides the fork resolution interface for the BlockValidator to use
    when deciding between 2 forks.
    """

    # pylint: disable=useless-super-delegation

    def __init__(self, block_cache, state_view_factory, data_dir, config_dir,
                 validator_id):
        super().__init__(block_cache, state_view_factory, data_dir, config_dir,
                         validator_id)
        self._consensus = bytes(_CONSENSUS_NAME_, 'utf-8')
        self._condition = Condition()

    @staticmethod
    def hash_signer_public_key(signer_public_key, header_signature):
        m = hashlib.sha256()
        m.update(signer_public_key.encode())
        m.update(header_signature.encode())
        digest = m.hexdigest()
        number = int(digest, 16)
        return number

    def _compare_forks_complete(self):
        return self._is_compare_forks is not None

    def compare_forks_complete(self, result):
        LOGGER.debug("PROXY:compare_forks_complete result=%s", result)
        with self._condition:
            self._is_compare_forks = result
            self._condition.notify()

    def compare_forks(self, cur_fork_head, new_fork_head):
        """The longest chain is selected. If they are equal, then the hash
        value of the previous block id and publisher signature is computed.
        The lowest result value is the winning block.
        Args:
            cur_fork_head: The current head of the block chain.
            new_fork_head: The head of the fork that is being evaluated.
        Returns:
            bool: True if choosing the new chain head, False if choosing
            the current chain head.
        """
        LOGGER.debug(
            "PROXY:compare_forks cur~new=%s~%s new fork consensus=%s~%s",
            cur_fork_head.identifier[:8], new_fork_head.identifier[:8],
            new_fork_head.consensus, self._consensus)
        # If the new fork head is not DevMode consensus, bail out.  This should
        # never happen, but we need to protect against it.
        if new_fork_head.consensus != self._consensus and new_fork_head.consensus != b"Genesis":
            raise TypeError('New fork head {} is not a {} block'.format(
                new_fork_head.identifier[:8], _CONSENSUS_NAME_))

        self._is_compare_forks = None
        _consensus_notifier.notify_block_valid(new_fork_head.identifier)
        LOGGER.debug(
            "PROXY:compare_forks waiting consensus reply for new head=%s\n",
            new_fork_head.identifier[:8])
        with self._condition:
            if self._condition.wait_for(self._compare_forks_complete):
                if self._is_compare_forks:
                    # send message to external consensus
                    _consensus_notifier.notify_block_commit(
                        new_fork_head.identifier)
                return self._is_compare_forks

        # If the current fork head is not DevMode consensus, check the new fork
        # head to see if its immediate predecessor is the current fork head. If
        # so that means that consensus mode is changing.  If not, we are again
        # in a situation that should never happen, but we need to guard
        # against.
        """
Example #40
0
class PoolBlockGet(PoolBase):
    def __init__(self, needed_count):
        super(PoolBlockGet, self).__init__()
        self.condition = Condition()  # 用于实现线程通信
        self.needed_count = needed_count  # 读取所需数据数量

    # 将数据放入池子。若池子中有足够数据,则释放被阻塞线程
    def put(self, data):
        if not isinstance(data, np.ndarray):
            raise TypeError("Input data must be ndarray!")
        if data.ndim != 1:
            raise ValueError("Input data must be 1D!")

        self.condition.acquire()
        self.datas = np.concatenate((self.datas, data))
        if len(self.datas) >= self.needed_count:
            self.condition.notify_all()
        self.condition.release()

    # 若要求数量大于库存,则阻塞当前线程。否则非循环读取池子中数据
    def get(self, count):
        if count < 0:
            raise ValueError("Input count must >= 0!")
        if not isinstance(count, int):
            raise ValueError("Input count must be integer!")

        self.condition.acquire()
        if count > len(self.datas):  # 所需帧数大于总量,则让调用该方法线程进入等待队列,直到存在足够多数据
            self.needed_count = count
            self.condition.wait()
        re = self.datas[:count].copy()
        self.datas = np.delete(self.datas, range(0, count))
        self.condition.release()

        return re

    def release(self):
        self.condition.acquire()
        self.condition.notify_all()
        self.condition.release()
Example #41
0
class Rusher(object):
    """
    A class for orchestrating the rush of some yielding callable.

    This class creates self.thread_count worker threads which will be used to
    rush a resource when they are ready.

    During Rusher.rush() the workers are created and run until they yield.
    Once all the workers have yielded they are then woken up at once so they
    can rush.
    """
    def __init__(self, worker, thread_count=2):
        self.worker = worker
        self.thread_count = thread_count
        # the orchestrator waits for notification of this from the workers
        self.ready_progress = Condition()
        self._total_ready = 0
        # this triggers the rush in the worker threads
        self.trigger = Event()
        self._return_list = []
        self._threads = []

    def _create_threads(self):
        """
        Create and start the worker threads so they can get ready to rush.
        """
        self._wait_for_threads()
        self._threads = []
        self._return_list = []
        for i in range(self.thread_count):
            thread = Thread(target=self._work, args=(i, ))
            thread.start()
            self._threads.append(thread)

    def _wait_for_threads(self, end_time=None):
        """
        Wait for all worker threads to finish.

        Unfinished threads are not killed.
        """
        for thread in self._threads:
            if end_time is not None:
                max_wait = end_time - time()
                if max_wait < 0:
                    return
            else:
                max_wait = None
            thread.join(max_wait)
            # this is very likely to happen if the timeout tripped
            if thread.is_alive():
                return

        # all workers returned before end_time
        return

    def rush(self, max_seconds=None):
        """
        Create worker threads and trigger their rush once they are ready.

        max_seconds is either None for no limiting, or a float.

        Returns (duration, results).
        """
        with self.ready_progress:
            self._total_ready = 0
            self._create_threads()
            self.ready_progress.wait()

        start = time()
        wait_until = time() + max_seconds if max_seconds else None

        self.trigger.set()
        self._wait_for_threads(wait_until)
        self.trigger.clear()

        results = tuple(self._return_list)
        end = time()
        duration = end - start
        return duration, results

    def _work(self, index):
        """
        Interface with the orchestration in the rush method, to run the worker.

        self.worker is iterated once so it can preform any needed prepairation,
        then again when this worker thread is awakened in self.rush() when
        the final worker notifies it that it has completed.

        Each worker is passed a unique index 0 ≤ index < self.thread_count.

        The result of the second yield to self._return_list.
        """
        worker_run = iter(self.worker(index, self.thread_count))
        next(worker_run)
        with self.ready_progress:
            self._total_ready += 1
            if self._total_ready == self.thread_count:
                self.ready_progress.notify_all()
        # Wait for the trigger to be fired
        self.trigger.wait()
        try:
            result = next(worker_run)
        except StopIteration:
            result = None
        self._return_list.append(result)

    def rush_and_report(self, max_seconds=None, output=sys.stdout):
        """
        Perform a rush and wite a summary of results to an output.

        This requires the results of self.work be hashable.

        Returns (duration, results) from the rush method.
        """
        duration, results = self.rush(max_seconds)
        # This avoids the use of collections.Counter to be 2.6+ compatible
        counts = defaultdict(int)
        for result in results:
            counts[result] += 1
        output.write("{} threads completed in {}, results:\n".format(
            len(results),
            str(timedelta(seconds=duration)).lstrip('0:'),
        ))
        for result, count in counts.items():
            output.write("\t{}: {}\n".format(result, count))

        return (duration, results)
 def __init__(self, block_cache, state_view_factory, data_dir, config_dir,
              validator_id):
     super().__init__(block_cache, state_view_factory, data_dir, config_dir,
                      validator_id)
     self._consensus = bytes(_CONSENSUS_NAME_, 'utf-8')
     self._condition = Condition()
Example #43
0
 def __init__(self):
     # Number of tasks that are being executed
     self._num_work_executing = 0
     self._work_condition = Condition()
Example #44
0
    def __init__(self, queue, lock, condition):
        super().__init__()
        self._queue = queue
        self._lock = lock
        self._condition = condition

    def run(self):
        while True:
            with self._condition:
                if not self._queue.empty():
                    self._lock.acquire()
                    item = self._queue.get()
                    self._lock.release()
                    self._condition.notify()
                    self._condition.wait()
                    print('Consume ', item)
                else:
                    print("Queue is empty, please wait!")
            time.sleep(1.0e-4)


if __name__ == '__main__':
    q = Queue(maxsize=20)
    lock = Lock()
    # 线程条件变量,实现线程调度
    condition = Condition()
    producer = Producer(q, lock, condition)
    consumer = Consumer(q, lock, condition)
    producer.start()
    consumer.start()
Example #45
0
class Job(object):
    __slots__ = '__func', '__args', '__kwargs', '__completed', '__result', '__exception', '__lock', '__cond', '__creation_stack', '__id'

    def __init__(self, func, *args, **kwargs):
        if not callable(func):
            raise ValueError(func)
        self.__id = str(uuid.uuid4())
        self.__lock = Lock()
        self.__cond = Condition(self.__lock)
        self.__func, self.__args, self.__kwargs = func, args, kwargs
        self.__creation_stack = tuple(extract_stack())
        self.__completed, self.__result, self.__exception = False, None, None

    def abbrev_info_unlocked(self):
        return 'Job %s: %s(*%s, **%s)' % (
            self.__id, self.__func, repr(self.__args), repr(self.__kwargs))

    def full_info_unlocked(self):
        out = [self.abbrev_info_unlocked(), '\n']
        out.extend(format_list(self.__creation_stack))
        return ''.join(out)

    def __str__(self):
        with self.__lock:
            return self.full_info_unlocked()

    def __repr__(self):
        with self.__lock:
            return self.abbrev_info_unlocked()

    @property
    def stack(self):
        with self.__lock:
            return self.__creation_stack

    def complete(self, result):
        with self.__lock:
            self.__completed = True
            self.__result = result
            self.__cond.notify_all()
            LOGGER.debug('Job %s is complete' % self.abbrev_info_unlocked())

    def complete_exception(self, exception):
        with self.__lock:
            self.__completed = True
            self.__exception = exception
            self.__cond.notify_all()
            LOGGER.debug('Job %s is complete' % self.abbrev_info_unlocked())

    def __call__(self):
        func, args, kwargs = None, None, None
        with self.__lock:
            LOGGER.debug('Job %s is being executed' %
                         self.abbrev_info_unlocked())
            func, args, kwargs = self.__func, self.__args, self.__kwargs
        return func(*args, **kwargs)

    @property
    def result(self):
        with self.__lock:
            if not self.__completed:
                raise RuntimeError('Job %s is not complete' % self)
            elif self.__exception is not None:
                raise self.__exception
            else:
                return self.__result

    def wait(self, timeout=None):
        with self.__lock:
            if not self.__completed:
                self.__cond.wait(timeout)
            if not self.__completed:
                raise RuntimeError('Job %s is not complete' % self)
            elif self.__exception is not None:
                raise self.__exception
            else:
                return self.__result
class SerialScheduler(Scheduler):
    """Serial scheduler which returns transactions in the natural order.

    This scheduler will schedule one transaction at a time (only one may be
    unapplied), in the exact order provided as batches were added to the
    scheduler.

    This scheduler is intended to be used for comparison to more complex
    schedulers - for tests related to performance, correctness, etc.
    """
    def __init__(self, squash_handler, first_state_hash, always_persist):
        self._txn_queue = queue.Queue()
        self._scheduled_transactions = []
        self._batch_statuses = {}
        self._txn_to_batch = {}
        self._in_progress_transaction = None
        self._final = False
        self._complete = False
        self._cancelled = False
        self._previous_context_id = None
        self._previous_valid_batch_c_id = None
        self._squash = squash_handler
        self._condition = Condition()
        # contains all txn.signatures where txn is
        # last in it's associated batch
        self._last_in_batch = []
        self._previous_state_hash = first_state_hash
        # The state hashes here are the ones added in add_batch, and
        # are the state hashes that correspond with block boundaries.
        self._required_state_hashes = {}
        self._already_calculated = False
        self._always_persist = always_persist

    def __del__(self):
        self.cancel()

    def __iter__(self):
        return SchedulerIterator(self, self._condition)

    def set_transaction_execution_result(self, txn_signature, is_valid,
                                         context_id):
        with self._condition:
            if (self._in_progress_transaction is None
                    or self._in_progress_transaction != txn_signature):
                raise ValueError(
                    "transaction not in progress: {}".format(txn_signature))
            self._in_progress_transaction = None

            if txn_signature not in self._txn_to_batch:
                raise ValueError(
                    "transaction not in any batches: {}".format(txn_signature))

            batch_signature = self._txn_to_batch[txn_signature]
            if is_valid:
                self._previous_context_id = context_id

            else:
                # txn is invalid, preemptively fail the batch
                self._batch_statuses[batch_signature] = \
                    BatchExecutionResult(is_valid=False, state_hash=None)
            if txn_signature in self._last_in_batch:
                if batch_signature not in self._batch_statuses:
                    # because of the else clause above, txn is valid here
                    self._previous_valid_batch_c_id = self._previous_context_id
                    state_hash = self._calculate_state_root_if_required(
                        batch_id=batch_signature)
                    self._batch_statuses[batch_signature] = \
                        BatchExecutionResult(is_valid=True,
                                             state_hash=state_hash)
                else:
                    self._previous_context_id = self._previous_valid_batch_c_id

                is_last_batch = \
                    len(self._batch_statuses) == len(self._last_in_batch)

                if self._final and is_last_batch:
                    self._complete = True
            self._condition.notify_all()

    def add_batch(self, batch, state_hash=None):
        with self._condition:
            if self._final:
                raise SchedulerError("Scheduler is finalized. Cannot take"
                                     " new batches")
            batch_signature = batch.header_signature
            if state_hash is not None:
                self._required_state_hashes[batch_signature] = state_hash
            batch_length = len(batch.transactions)
            for idx, txn in enumerate(batch.transactions):
                if idx == batch_length - 1:
                    self._last_in_batch.append(txn.header_signature)
                self._txn_to_batch[txn.header_signature] = batch_signature
                self._txn_queue.put(txn)
            self._condition.notify_all()

    def get_batch_execution_result(self, batch_signature):
        with self._condition:
            return self._batch_statuses.get(batch_signature)

    def count(self):
        with self._condition:
            return len(self._scheduled_transactions)

    def get_transaction(self, index):
        with self._condition:
            return self._scheduled_transactions[index]

    def next_transaction(self):
        with self._condition:
            if self._in_progress_transaction is not None:
                return None
            try:
                txn = self._txn_queue.get(block=False)
            except queue.Empty:
                return None

            self._in_progress_transaction = txn.header_signature
            base_contexts = [] if self._previous_context_id is None \
                else [self._previous_context_id]
            txn_info = TxnInformation(txn=txn,
                                      state_hash=self._previous_state_hash,
                                      base_context_ids=base_contexts)
            self._scheduled_transactions.append(txn_info)
            return txn_info

    def finalize(self):
        with self._condition:
            self._final = True
            if len(self._batch_statuses) == len(self._last_in_batch):
                self._complete = True
            self._condition.notify_all()

    def _compute_merkle_root(self, required_state_root):
        """Computes the merkle root of the state changes in the context
        corresponding with _last_valid_batch_c_id as applied to
        _previous_state_hash.

        Args:
            required_state_root (str): The merkle root that these txns
                should equal.

        Returns:
            state_hash (str): The merkle root calculated from the previous
                state hash and the state changes from the context_id
        """

        state_hash = None
        if self._previous_valid_batch_c_id is not None:
            publishing_or_genesis = self._always_persist or \
                                    required_state_root is None
            state_hash = self._squash(
                state_root=self._previous_state_hash,
                context_ids=[self._previous_valid_batch_c_id],
                persist=self._always_persist,
                clean_up=publishing_or_genesis)
            if self._always_persist is True:
                return state_hash
            if state_hash == required_state_root:
                self._squash(state_root=self._previous_state_hash,
                             context_ids=[self._previous_valid_batch_c_id],
                             persist=True,
                             clean_up=True)
        return state_hash

    def _calculate_state_root_if_not_already_done(self):
        if not self._already_calculated:
            if not self._last_in_batch:
                return
            last_txn_signature = self._last_in_batch[-1]
            batch_id = self._txn_to_batch[last_txn_signature]
            required_state_hash = self._required_state_hashes.get(batch_id)

            state_hash = self._compute_merkle_root(required_state_hash)
            self._already_calculated = True
            for t_id in self._last_in_batch[::-1]:
                b_id = self._txn_to_batch[t_id]
                if self._batch_statuses[b_id].is_valid:
                    self._batch_statuses[b_id].state_hash = state_hash
                    # found the last valid batch, so break out
                    break

    def _calculate_state_root_if_required(self, batch_id):
        required_state_hash = self._required_state_hashes.get(batch_id)
        state_hash = None
        if required_state_hash is not None:
            state_hash = self._compute_merkle_root(required_state_hash)
            self._already_calculated = True
        return state_hash

    def complete(self, block):
        with self._condition:
            if not self._final:
                return False
            if self._complete:
                self._calculate_state_root_if_not_already_done()
                return True
            if block:
                self._condition.wait_for(lambda: self._complete)
                self._calculate_state_root_if_not_already_done()
                return True
            return False

    def cancel(self):
        with self._condition:
            self._cancelled = True
            self._condition.notify_all()

    def is_cancelled(self):
        with self._condition:
            return self._cancelled
Example #47
0
class DistributedPipelineRecord:
    """A class for storing a single mini-batch (consisting of multiple micro-batches) as input to
    a single partition.
    Args:
        device: the local device that runs the partition.
        rank: the rank of the partition in the pipeline.
        chunks: number of micro-batches in a mini-batch
        num_inputs: number of inputs to the partition.
        consumers: list of consumers of outputs of the partition. Each consumer in the list is a tuple
            (remote_partition_rref, input_idx, output_idx) where remote_partition_rref points to a
            remote DistributedPipelineRecord for consumer partiton for this mini-batch. The output number
            output_idx of this partition will be used as the input number input_idx of that partition.
    """

    # Need to use Union due to https://github.com/python/mypy/issues/7866
    DataConsumer = Union[DataConsumer[rpc.RRef]]

    def __init__(
        self,
        device: torch.device,
        rank: int,
        chunks: int,
        num_inputs: int,
        num_outputs: Optional[int],
        consumers: List[DataConsumer],
    ) -> None:
        self.ready_cv = Condition()
        # Each chunk consists of num_inputs tensors. self.tensors stores these individual tensors.
        self.tensors: List[List[Optional[Tensor]]] = [[None] * num_inputs
                                                      for _ in range(chunks)]
        # For each tensor in self.tensors, we record a cuda event in corrsponding tensorpipe stream in self.recv_events,
        # and later the stream that processes that tensor will wait on that event.
        self.recv_events = [[None] * num_inputs for _ in range(chunks)]
        # Once all num_inputs tensors of a given chunk are recieved, they are assembled as a batch and stored in
        # self.batches
        self.batches: List[Optional[Batch]] = [None] * chunks
        # For each tensor of each chunk, we fork a phony tensor, which will be used for injecting dependency between
        # different chunks in backward path.
        if num_outputs is None:
            num_outputs = 1
        self.forwarded_phony: List[List[List[rpc.RRef]]] = [
            [[] for j in range(num_outputs)] for i in range(chunks)
        ]
        self.consumers = consumers
        self.rank = rank
        self.device = device

    def __getstate__(self) -> Dict:
        # avoid pickling failure.
        return {}

    def feed(self, chunk: int, input_idx: int, input: Tensor) -> Tensor:
        """This function is called remotely to provide individual tensors of a given chunk."""
        if input.device.type == "cpu":
            input = input.to(self.device)
        cuda_stream = torch.cuda.current_stream(
            input.device) if input.device.type == "cuda" else None

        with self.ready_cv:
            assert self.tensors[chunk][input_idx] is None
            input, phony = fork(input)
            self.recv_events[chunk][input_idx] = (
                cuda_stream.record_event()
                if cuda_stream is not None else None  # type: ignore
            )
            self.tensors[chunk][input_idx] = input
            self.ready_cv.notify_all()
        return phony

    def wait_for(self, chunk: int) -> None:
        """Waits until all elements of given chunk is populated in self.tensors.
        Then it constructs self.batches[chunk] if it is not constructed yet.
        """
        with self.ready_cv:
            while self.batches[chunk] is None and any(
                    b is None for b in self.tensors[chunk]):
                self.ready_cv.wait()
            if self.batches[chunk] is None:
                tensors = cast(List[Tensor], self.tensors[chunk])
                self.batches[chunk] = Batch(tuple(tensors), chunk)

    def fence(self, chunk: int) -> None:
        """Prepares micro-batches for computation."""
        # Ensure that batches[chunk-1] is executed after batches[chunk] in
        # backpropagation by an explicit dependency.
        # TODO: This dependency injection causes deadlock if this partition
        # gets its input from model input. 1) Figure out why 2) If we need to live
        # with this constraint, replace the condition 'self.rank > 0' below with
        # a more accurate one.
        if chunk != 0 and self.consumers and self.rank > 0:
            batch = self.batches[chunk]
            assert batch is not None
            dependant_tensors = list(batch.tensors)
            for remote_ph_list in self.forwarded_phony[chunk - 1]:
                for remote_ph in remote_ph_list:
                    phony = remote_ph.to_here()
                    dependant_tensors[0] = join(dependant_tensors[0], phony)
            self.batches[chunk] = Batch(tuple(dependant_tensors), chunk)

    def sync_stream(self, chunk: int, stream: torch.cuda.Stream) -> None:
        """syncs the stream with cuda events associated with transmission of the chunck to the cuda device."""
        for e in self.recv_events[chunk]:
            if e is not None:
                stream.wait_event(e)

    def forward_results(self, chunk: int) -> None:
        """Forward outputs of processing the chunk in this parition for processing by next partition."""
        for consumer in self.consumers:
            v = self.get_batch(chunk).value[consumer.output_idx]
            self.forwarded_phony[chunk][consumer.output_idx].append(
                consumer.consumer.remote().feed(chunk,
                                                consumer.consumer_input_idx,
                                                v))

    def get_batch(self, chunk: int) -> Batch:
        batch = self.batches[chunk]
        assert batch is not None
        return batch
Example #48
0
            else:
                sleep_time = None
        if sleep_time: time.sleep(sleep_time)

        # a new t_id produced
        cv.acquire()
        produced.append(t_id)
        cv.notify()
        cv.release()

        # exit condition
        if not sleep_time: break
        with lock:
            print(f'Thread {t_id}: sleep finished')

    with lock:
        print(f'Thread {t_id}: finished')


print('----- Condition Variable -----')
produced, threads = [0, 1, 2, 4], []
now = time.time()
cv = Condition()
for i in range(4):
    t = Thread(target=calc_producer_consumer, args=(lock, cv))
    t.start()
    threads.append(t)
for t in threads:
    t.join()
print(f'Total time: {round(time.time() - now, 2)} second(s)\n')
Example #49
0
    def __init__(self, name):
        self.robot_name = name

        # Mutual exclusion odometry
        self.odometry_me = Condition()

        # Create trajectory server
        self.trajectory_server = SimpleActionServer(
            'approach_server', ExecuteDroneApproachAction, self.goCallback,
            False)
        self.server_feedback = ExecuteDroneApproachFeedback()
        self.server_result = ExecuteDroneApproachResult()

        # Get client from hector_quadrotor_actions
        self.move_client = SimpleActionClient("/{}/action/pose".format(name),
                                              PoseAction)
        self.move_client.wait_for_server()

        # Subscribe to ground_truth to monitor the current pose of the robot
        rospy.Subscriber("/{}/ground_truth/state".format(name), Odometry,
                         self.poseCallback)

        # Subscribe to topic to receive the planned trajectory
        rospy.Subscriber("/{}/move_group/display_planned_path".format(name),
                         DisplayTrajectory, self.planCallback)

        #Auxiliary variables
        self.trajectory = []  # Array with the trajectory to be executed
        self.trajectory_received = False  # Flag to signal trajectory received
        self.odom_received = False  # Flag to signal odom received

        self.robot = RobotCommander(
            robot_description="{}/robot_description".format(name),
            ns="/{}".format(name))
        self.display_trajectory_publisher = rospy.Publisher(
            '/{}/move_group/display_planned_path'.format(name),
            DisplayTrajectory,
            queue_size=20)

        # Variables for collision callback
        self.validity_srv = rospy.ServiceProxy(
            '/{}/check_state_validity'.format(name), GetStateValidity)
        self.validity_srv.wait_for_service()
        self.collision = False

        # Set planning algorithm
        # self.move = MoveGroupCommander(PLANNING_GROUP, robot_description="{}/robot_description".format(name), ns="/UAV_1")        # Set group from srdf
        # self.move_group.set_planner_id("RRTConnectkConfigDefault")                      # Set planner type  (RRTConnectkConfigDefault)
        # self.move_group.set_num_planning_attempts(10)                                   # Set planning attempts
        # self.move_group.set_workspace([XMIN,YMIN,ZMIN,XMAX,YMAX,ZMAX])                  # Set the workspace size

        #Start move_group
        self.move_group = MoveGroup('earth', name)
        self.move_group.set_planner()

        #Start planningScenePublisher
        self.scene_pub = PlanningScenePublisher(name)

        # Get current robot position to define as start planning point
        self.current_pose = self.robot.get_current_state()

        # Start trajectory server
        self.trajectory_server.start()
Example #50
0
class SerialScheduler(Scheduler):
    """Serial scheduler which returns transactions in the natural order.

    This scheduler will schedule one transaction at a time (only one may be
    unapplied), in the exact order provided as batches were added to the
    scheduler.

    This scheduler is intended to be used for comparison to more complex
    schedulers - for tests related to performance, correctness, etc.
    """
    def __init__(self, squash_handler, first_state_hash):
        self._txn_queue = queue.Queue()
        self._scheduled_transactions = []
        self._batch_statuses = {}
        self._txn_to_batch = {}
        self._in_progress_transaction = None
        self._final = False
        self._complete = False
        self._squash = squash_handler
        self._condition = Condition()
        # contains all txn.signatures where txn is
        # last in it's associated batch
        self._last_in_batch = []
        self._last_state_hash = first_state_hash

    def __iter__(self):
        return SchedulerIterator(self, self._condition)

    def set_transaction_execution_result(
            self, txn_signature, is_valid, context_id):
        """the control flow is that on every valid txn a new state root is
        generated. If the txn is invalid the batch status is set,
        if the txn is the last txn in the batch, is valid, and no
         prior txn failed the batch, the
        batch is valid
        """
        with self._condition:
            if (self._in_progress_transaction is None or
                    self._in_progress_transaction != txn_signature):
                raise ValueError("transaction not in progress: {}",
                                 txn_signature)
            self._in_progress_transaction = None

            if txn_signature not in self._txn_to_batch:
                raise ValueError("transaction not in any batches: {}".format(
                    txn_signature))
            if is_valid:
                # txn is valid, get a new state hash
                state_hash = self._squash(self._last_state_hash, [context_id])
                self._last_state_hash = state_hash
            else:
                # txn is invalid, preemptively fail the batch
                batch_signature = self._txn_to_batch[txn_signature]
                self._batch_statuses[batch_signature] = \
                    BatchExecutionResult(is_valid=is_valid, state_hash=None)
            if txn_signature in self._last_in_batch:
                batch_signature = self._txn_to_batch[txn_signature]
                if batch_signature not in self._batch_statuses:
                    # because of the else clause above, txn is valid here
                    self._batch_statuses[batch_signature] = \
                        BatchExecutionResult(
                            is_valid=is_valid,
                            state_hash=self._last_state_hash)

            if self._final and self._txn_queue.empty():
                self._complete = True
            self._condition.notify_all()

    def add_batch(self, batch, state_hash=None):
        with self._condition:
            if self._final:
                raise SchedulerError("Scheduler is finalized. Cannnot take"
                                     " new batches")
            batch_signature = batch.header_signature
            batch_length = len(batch.transactions)
            for idx, txn in enumerate(batch.transactions):
                if idx == batch_length - 1:
                    self._last_in_batch.append(txn.header_signature)
                self._txn_to_batch[txn.header_signature] = batch_signature
                self._txn_queue.put(txn)
            self._condition.notify_all()

    def get_batch_execution_result(self, batch_signature):
        with self._condition:
            return self._batch_statuses.get(batch_signature)

    def count(self):
        with self._condition:
            return len(self._scheduled_transactions)

    def get_transaction(self, index):
        with self._condition:
            return self._scheduled_transactions[index]

    def next_transaction(self):
        with self._condition:
            if self._in_progress_transaction is not None:
                return None
            try:
                txn = self._txn_queue.get(block=False)
            except queue.Empty:
                return None

            self._in_progress_transaction = txn.header_signature
            txn_info = TxnInformation(txn=txn,
                                      state_hash=self._last_state_hash,
                                      base_context_ids=[])
            self._scheduled_transactions.append(txn_info)
            return txn_info

    def finalize(self):
        with self._condition:
            self._final = True
            if self._txn_queue.empty():
                self._complete = True
            self._condition.notify_all()

    def complete(self, block):
        with self._condition:
            if not self._final:
                return False
            if self._complete:
                return True
            if block:
                self._condition.wait_for(lambda: self._complete)
                return True
            return False
Example #51
0
class _ContextFuture:
    """Controls access to bytes set in the _result variable. The booleans
     that are flipped in set_result, based on whether the value is being set
     from the merkle tree or a direct set on the context manager are needed
     to later determine whether the value was set in that context or was
     looked up as a new address location from the merkle tree and then only
     read from, not set.

    In any context the lifecycle of a _ContextFuture can be several paths:

    Input:
    Address not in base:
      F -----> get from merkle database ----> get from the context
    Address in base:
            |---> set (F)
      F --->|
            |---> get
    Output:
      Doesn't exist ----> set address in context (F)

    Input + Output:
    Address not in base:

                             |-> set
      F |-> get from merkle -|
        |                    |-> get
        |                    |
        |                    |-> noop
        |--> set Can happen before the pre-fetch operation


                     |-> set (F) ---> get
                     |
                     |-> set (F) ----> set
                     |
    Address in base: |-> set (F)
      Doesn't exist -|
                     |-> get Future doesn't exit in context
                     |
                     |-> get ----> set (F)

    """
    def __init__(self, address, result=None, wait_for_tree=False):
        self.address = address
        self._result = result
        self._result_set_in_context = False
        self._condition = Condition()
        self._wait_for_tree = wait_for_tree
        self._tree_has_set = False
        self._read_only = False
        self._deleted = False

    def make_read_only(self):
        with self._condition:
            if self._wait_for_tree and not self._result_set_in_context:
                self._condition.wait_for(
                    lambda: self._tree_has_set or self._result_set_in_context)

            self._read_only = True

    def set_in_context(self):
        with self._condition:
            return self._result_set_in_context

    def deleted_in_context(self):
        with self._condition:
            return self._deleted

    def result(self):
        """Return the value at an address, optionally waiting until it is
        set from the context_manager, or set based on the pre-fetch mechanism.

        Returns:
            (bytes): The opaque value for an address.
        """

        if self._read_only:
            return self._result
        with self._condition:
            if self._wait_for_tree and not self._result_set_in_context:
                self._condition.wait_for(
                    lambda: self._tree_has_set or self._result_set_in_context)
            return self._result

    def set_deleted(self):
        self._result_set_in_context = False
        self._deleted = True

    def set_result(self, result, from_tree=False):
        """Set the addresses's value unless the future has been declared
        read only.

        Args:
            result (bytes): The value at an address.
            from_tree (bool): Whether the value is being set by a read from
                the merkle tree.

        Returns:
            None
        """

        if self._read_only:
            if not from_tree:
                LOGGER.warning(
                    "Tried to set address %s on a"
                    " read-only context.", self.address)
            return

        with self._condition:
            if self._read_only:
                if not from_tree:
                    LOGGER.warning(
                        "Tried to set address %s on a"
                        " read-only context.", self.address)
                return
            if from_tree:
                # If the result has not been set in the context, overwrite the
                # value with the value from the merkle tree. Otherwise, do
                # nothing.
                if not self._result_set_in_context:
                    self._result = result
                    self._tree_has_set = True
            else:
                self._result = result
                self._result_set_in_context = True
                self._deleted = False

            self._condition.notify_all()
Example #52
0
 def __init__(self):
     Thread.__init__(self)
     self.loop = None
     self._cond = Condition()
Example #53
0
class SprintDatasetBase(Dataset):
  """
  In Sprint, we use this object for multiple purposes:
  - Multiple epoch handling via SprintInterface.getSegmentList().
    For this, we get the segment list from Sprint and use the Dataset
    shuffling method.
  - Fill in data which we get via SprintInterface.feedInput*().
    Note that each such input doesn't necessarily correspond to a single
    segment. This depends which type of FeatureExtractor is used in Sprint.
    If we use the BufferedFeatureExtractor in utterance mode, we will get
    one call for every segment and we get also segmentName as parameter.
    Otherwise, we will get batches of fixed size - in that case,
    it doesn't correspond to the segments.
    In any case, we use this data as-is as a new seq.
    Because of that, we cannot really know the number of seqs in advance,
    nor the total number of time frames, etc.

  If you want to use this directly in RETURNN, see ExternSprintDataset.
  """

  SprintCachedSeqsMax = 200
  SprintCachedSeqsMin = 100

  def __init__(self, target_maps=None, str_add_final_zero=False, input_stddev=1.,
               orth_post_process=None, bpe=None, orth_vocab=None,
               suppress_load_seqs_print=False,
               **kwargs):
    """
    :param dict[str,str|dict] target_maps: e.g. {"speaker": "speaker_map.txt"}
    :param bool str_add_final_zero: adds e.g. "orth0" with '\0'-ending
    :param float input_stddev: if != 1, will divide the input "data" by that
    :param str|list[str]|((str)->str)|None orth_post_process: :func:`get_post_processor_function`, applied on orth
    :param None|dict[str] bpe: if given, will be opts for :class:`BytePairEncoding`
    :param None|dict[str] orth_vocab: if given, orth_vocab is applied to orth and orth_classes is an available target`
    :param bool suppress_load_seqs_print: less verbose
    """
    super(SprintDatasetBase, self).__init__(**kwargs)
    self.suppress_load_seqs_print = suppress_load_seqs_print
    if target_maps:
      assert isinstance(target_maps, dict)
      target_maps = target_maps.copy()
      for key, tmap in list(target_maps.items()):
        if isinstance(tmap, (str, unicode)):
          tmap = {l: i for (i, l) in enumerate(open(tmap).read().splitlines())}
        assert isinstance(tmap, dict)  # dict[str,int]
        target_maps[key] = tmap
    self.target_maps = target_maps
    self.str_add_final_zero = str_add_final_zero
    self.input_stddev = input_stddev
    # Note: "orth" is actually the raw bytes of the utf8 string,
    # so it does not make quite sense to associate a single str to each byte.
    # However, some other code might expect that the labels are all strings, not bytes,
    # and the API requires the labels to be strings.
    # The code in Dataset.serialize_data tries to decode this case as utf8 (if possible).
    self.labels["orth"] = [chr(i) for i in range(255)]
    self.orth_post_process = None  # type: typing.Optional[typing.Callable[[str],str]]
    if orth_post_process:
      if callable(orth_post_process):
        self.orth_post_process = orth_post_process
      else:
        from LmDataset import get_post_processor_function
        self.orth_post_process = get_post_processor_function(orth_post_process)
    self.bpe = None
    if bpe:
      from GeneratingDataset import BytePairEncoding
      self.bpe = BytePairEncoding(**bpe)
      self.labels["bpe"] = self.bpe.labels
    self.orth_vocab = None
    if orth_vocab:
      assert not bpe, "bpe has its own vocab"
      from GeneratingDataset import Vocabulary
      self.orth_vocab = Vocabulary(**orth_vocab)
      self.labels["orth_classes"] = self.orth_vocab.labels
    self.cond = Condition(lock=self.lock)
    self.add_data_thread_id = thread.get_ident()  # This will be created in the Sprint thread.
    self.ready_for_data = False
    self.reached_final_seq = False
    self.reached_final_seq_seen_all = False
    self.multiple_epochs = False
    self._complete_frac = None
    self.sprintEpoch = None  # in SprintInterface.getSegmentList()
    self.crnnEpoch = None  # in CRNN train thread, Engine.train(). set via init_seq_order
    self.predefined_seq_list_order = None  # via init_seq_order
    self.sprintFinalized = False
    self._target_black_list = []  # if we get non numpy arrays and cannot convert them
    self._reset_cache()
    assert self.shuffle_frames_of_nseqs == 0  # Currently broken. But just use Sprint itself to do this.

  def use_multiple_epochs(self):
    """
    Called via SprintInterface.getSegmentList().
    """
    self.multiple_epochs = True

  def set_dimensions(self, input_dim, output_dim):
    """
    :type input_dim: int
    :type output_dim: int

    Called via python_train.
    """
    assert input_dim > 0
    self.num_inputs = input_dim
    self.num_outputs = {"data": (input_dim * self.window, 2)}
    if output_dim > 0:
      self.num_outputs["classes"] = (output_dim, 1)
    if self.bpe:
      self.num_outputs["bpe"] = (self.bpe.num_labels, 1)
    if self.orth_vocab:
      self.num_outputs["orth_classes"] = (self.orth_vocab.num_labels, 1)
    self.num_outputs["orth"] = (256, 1)
    self._base_init()
    # At this point, we are ready for data. In case we don't use the Sprint PythonSegmentOrdering
    # (SprintInterface.getSegmentList()), we must call this at least once.
    if not self.multiple_epochs:
      self.init_sprint_epoch(None)

  def _reset_cache(self):
    self.expected_load_seq_start = 0
    self.requested_load_seq_end = 0
    self.next_seq_to_be_added = 0
    self.reached_final_seq = False
    self.reached_final_seq_seen_all = False
    self._num_timesteps = 0
    self.added_data = []  # type: typing.List[DatasetSeq]
    self.ready_for_data = True

  def init_sprint_epoch(self, epoch):
    """
    :type epoch: int | None
    Called by SprintInterface.getSegmentList() when we start a new epoch.
    We must not call this via self.init_seq_order() because we will already have filled the cache by Sprint
    before the CRNN train thread starts the epoch.
    """
    with self.lock:
      self.sprintEpoch = epoch
      self.sprintFinalized = False
      self._reset_cache()
      self.cond.notify_all()

  def finalize_sprint(self):
    """
    Called when SprintInterface.getSegmentList() ends.
    """
    with self.lock:
      self.sprintFinalized = True
      self.cond.notify_all()

  def init_seq_order(self, epoch=None, seq_list=None):
    """
    Called by CRNN train thread when we enter a new epoch.
    """
    super(SprintDatasetBase, self).init_seq_order(epoch=epoch, seq_list=seq_list)
    with self.lock:
      self.crnnEpoch = epoch
      self.predefined_seq_list_order = seq_list
      self.cond.notify_all()
      # No need to wait/check for Sprint thread here.
      # SprintInterface.getSegmentList() will wait for us.
    return True

  def _cleanup_old_seq_cache(self, seq_end):
    i = 0
    while i < len(self.added_data):
      if self.added_data[i].seq_idx >= seq_end:
        break
      i += 1
    del self.added_data[:i]

  def wait_for_returnn_epoch(self, epoch):
    """
    Called by SprintInterface.
    """
    with self.lock:
      while epoch != self.crnnEpoch:
        assert epoch > self.crnnEpoch
        self.cond.wait()

  def _wait_for_seq_can_pass_check(self, seq_start, seq_end):
    """
    :param int seq_start:
    :param int seq_end:
    :return: True if _waitForSeq can pass/return. False means that we need to wait more (until next signal)
    :rtype: bool
    """
    if self.reached_final_seq:
      return True
    if self._have_seqs_added(seq_start, seq_end):
      return True
    return False

  def _wait_for_seq(self, seq_start, seq_end=None):
    """
    Called by RETURNN train thread.
    Wait until we have seqs [seqStart,..,seqEnd-1] loaded,
    or we now that they will not be loaded anymore because we reached the end.

    :param int seq_start:
    :param int|None seq_end:
    """
    if seq_end is None:
      seq_end = seq_start + 1
    if seq_end > self.requested_load_seq_end:
      self.requested_load_seq_end = seq_end
      self.cond.notify_all()
    if self._wait_for_seq_can_pass_check(seq_start=seq_start, seq_end=seq_end):
      return
    # We need to wait.
    assert thread.get_ident() != self.add_data_thread_id
    print("%s %s: wait for seqs (%i,%i) (last added: %s) (current time: %s)" % (
      self, currentThread().name, seq_start, seq_end, self._latest_added_seq(), time.strftime("%H:%M:%S")), file=log.v5)
    while not self._wait_for_seq_can_pass_check(seq_start=seq_start, seq_end=seq_end):
      self.cond.wait()

  def _latest_added_seq(self):
    if not self.added_data:
      return None
    return self.added_data[-1].seq_idx

  def _have_seqs_added(self, start, end=None):
    if end is None:
      end = start + 1
    if start >= end:
      return True
    for data in self.added_data:
      assert start >= data.seq_idx, "%s: We expect that we only ask about the cache of the upcoming seqs." % self
      if data.seq_idx == start:
        start += 1
      if start >= end:
        return True
    return False

  def _get_seq(self, seq_idx):
    """
    :param int seq_idx:
    :rtype: DatasetSeq
    """
    for data in self.added_data:
      if data.seq_idx == seq_idx:
        return data
    return None

  def is_cached(self, start, end):
    """
    :param int start:
    :param int end:
    :rtype: bool
    """
    # Always False, to force that we call self._load_seqs().
    # This is important for our buffer management.
    return False

  def load_seqs(self, start, end):
    """
    Called by RETURNN train thread.

    :param int start:
    :param int end:
    """
    if start == end:
      return
    if not self.suppress_load_seqs_print:
      print("%s load_seqs in %s:" % (self, currentThread().name), start, end, end=' ', file=log.v5)
    with self.lock:
      super(SprintDatasetBase, self).load_seqs(start, end)
      if not self.suppress_load_seqs_print:
        print("first features shape:", self._get_seq(start).features["data"].shape, file=log.v5)

  def _load_seqs(self, start, end):
    """
    Called by RETURNN train thread.
    We expect that start increase monotonic on each call
    for not-yet-loaded data.
    This will already be called with _load_seqs_superset indices.

    :param int start:
    :param int end:
    """
    assert start >= self.expected_load_seq_start
    if start > self.expected_load_seq_start:
      # Cleanup old data.
      self._cleanup_old_seq_cache(start)
      self.expected_load_seq_start = start
      self.cond.notify_all()
    self._wait_for_seq(start, end)

  def add_new_data(self, features, targets=None, segment_name=None):
    """
    Adds a new seq.
    This is called via the Sprint main thread.

    :param numpy.ndarray features: format (input-feature,time) (via Sprint)
    :param dict[str,numpy.ndarray|str] targets: format (time) (idx of output-feature)
    :param str|None segment_name:
    :returns the sorted seq index
    :rtype: int
    """

    # is in format (feature,time)
    assert self.num_inputs == features.shape[0]
    num_frames = features.shape[1]
    # must be in format: (time,feature)
    features = features.transpose()
    assert features.shape == (num_frames, self.num_inputs)
    if self.input_stddev != 1:
      features /= self.input_stddev
    if self.window > 1:
      features = self.sliding_window(features)
      assert features.shape == (num_frames, self.num_inputs * self.window)

    if targets is None:
      targets = {}
    if not isinstance(targets, dict):
      targets = {"classes": targets}
    if "classes" in targets:
      # 'classes' is always the alignment
      assert targets["classes"].shape == (num_frames,), (  # is in format (time,)
        "Number of targets %s does not equal to number of features %s" % (targets["classes"].shape, (num_frames,)))
    if "orth" in targets:
      targets["orth"] = targets["orth"].decode("utf8").strip()
    if "orth" in targets and self.orth_post_process:
      targets["orth"] = self.orth_post_process(targets["orth"])
    if self.bpe:
      assert "orth" in targets
      orth = targets["orth"]
      assert isinstance(orth, (str, unicode))
      assert "bpe" not in targets
      targets["bpe"] = numpy.array(self.bpe.get_seq(orth), dtype="int32")
    if self.orth_vocab:
      assert not self.orth_post_process
      assert "orth" in targets
      orth = targets["orth"]
      assert isinstance(orth, (str, unicode))
      assert "orth_classes" not in targets
      targets["orth_classes"] = numpy.array(self.orth_vocab.get_seq(orth), dtype="int32")

    # Maybe convert some targets.
    if self.target_maps:
      for key, target_map in self.target_maps.items():
        assert key in targets
        v = target_map[targets[key]]
        v = numpy.asarray(v)
        if v.ndim == 0:
          v = numpy.zeros((num_frames,), dtype=v.dtype) + v  # add time dimension
        targets[key] = v

    # Maybe remove some targets.
    for key in self._target_black_list:
      if key in targets:
        del targets[key]

    # Check if all targets are valid.
    for key, v in sorted(list(targets.items())):
      if isinstance(v, numpy.ndarray):
        continue  # ok
      if isinstance(v, unicode):
        v = v.encode("utf8")
      if isinstance(v, (str, bytes)):
        if PY3:
          assert isinstance(v, bytes)
          v = list(v)
        else:
          v = list(map(ord, v))
        v = numpy.array(v, dtype="uint8")
        targets[key] = v
        if self.str_add_final_zero:
          v = numpy.append(v, numpy.array([0], dtype=v.dtype))
          assert key + "0" not in targets
          targets[key + "0"] = v
        continue
      print("%s, we will ignore the target %r because it is not a numpy array: %r" % (self, key, v), file=log.v3)
      self._target_black_list += [key]
      del targets[key]

    with self.lock:
      # This gets called in the Sprint main thread.
      # If this is used together with SprintInterface.getSegmentList(), we are always in a state where
      # we just yielded a segment name, thus we are always in a Sprint epoch and thus ready for data.
      assert self.ready_for_data
      assert not self.reached_final_seq
      assert not self.sprintFinalized

      seq_idx = self.next_seq_to_be_added

      if self.predefined_seq_list_order:
        # Note: Only in ExternSprintDataset, we can reliably set the seq order for now.
        assert seq_idx < len(self.predefined_seq_list_order), "seq_idx %i, expected predef num seqs %i" % (
          seq_idx, len(self.predefined_seq_list_order))
        expected_seq_name = self.predefined_seq_list_order[seq_idx]
        if expected_seq_name != segment_name:
          if segment_name in self.predefined_seq_list_order:
            raise Exception("seq_idx %i expected to be tag %r but got tag %r; tag %r is at idx %i" % (
              seq_idx, expected_seq_name, segment_name, segment_name,
              self.predefined_seq_list_order.index(segment_name)))
          raise Exception("seq_idx %i expected to be tag %r but got tag %r; tag %r not found" % (
            seq_idx, expected_seq_name, segment_name, segment_name))

      self.next_seq_to_be_added += 1
      self._num_timesteps += num_frames
      self.cond.notify_all()

      if seq_idx > self.requested_load_seq_end - 1 + self.SprintCachedSeqsMax:
        print("%s add_new_data: seq=%i, len=%i. Cache filled, waiting to get loaded..." % (
          self, seq_idx, num_frames), file=log.v5)
        while seq_idx > self.requested_load_seq_end - 1 + self.SprintCachedSeqsMin:
          assert not self.reached_final_seq
          assert seq_idx + 1 == self.next_seq_to_be_added
          self.cond.wait()

      self.added_data += [DatasetSeq(seq_idx, features, targets, seq_tag=segment_name)]
      self.cond.notify_all()
      return seq_idx

  def finish_sprint_epoch(self, seen_all=True):
    """
    Called by SprintInterface.getSegmentList().
    This is in a state where Sprint asks for the next segment after we just finished an epoch.
    Thus, any upcoming self.add_new_data() call will contain data from a segment in the new epoch.
    Thus, we finish the current epoch in Sprint.
    """
    with self.lock:
      self.reached_final_seq = True
      self.reached_final_seq_seen_all = seen_all
      self.ready_for_data = False
      self.cond.notify_all()

  def _shuffle_frames_in_seqs(self, start, end):
    assert False, "Shuffling in SprintDataset only via Sprint at the moment."

  def get_num_timesteps(self):
    """
    :rtype: int
    """
    with self.lock:
      assert self.reached_final_seq
      return self._num_timesteps

  @property
  def num_seqs(self):
    """
    :rtype: int
    """
    with self.lock:
      if self.predefined_seq_list_order:
        return len(self.predefined_seq_list_order)
      if not self.reached_final_seq:
        raise NotImplementedError
      return self.next_seq_to_be_added

  def have_seqs(self):
    """
    :rtype: bool
    """
    with self.lock:
      if self.next_seq_to_be_added > 0:
        return True
      self._wait_for_seq(0)
      return self.next_seq_to_be_added > 0

  def is_less_than_num_seqs(self, n):
    """
    :param int n:
    :rtype: bool
    """
    with self.lock:
      self._wait_for_seq(n)
      return n < self.next_seq_to_be_added

  def get_data_keys(self):
    """
    :rtype: list[str]
    """
    with self.lock:
      if not self.added_data:
        self._wait_for_seq(0)
      assert self.added_data
      return sorted(self.added_data[0].features.keys())

  def get_target_list(self):
    """
    :rtype: list[str]
    """
    keys = list(self.get_data_keys())
    if "data" in keys:
      keys.remove("data")
    return keys

  def set_complete_frac(self, frac):
    """
    :param float frac:
    """
    self._complete_frac = frac

  def get_complete_frac(self, seq_idx):
    """
    :param int seq_idx:
    :rtype: float
    """
    with self.lock:
      if self.predefined_seq_list_order:
        return float(seq_idx + 1) / len(self.predefined_seq_list_order)
      elif self._complete_frac is not None:
        if not self.next_seq_to_be_added:
          return self._complete_frac
        else:
          # We can do somewhat better. self._complete_frac is for self.next_seq_to_be_added.
          return self._complete_frac * float(seq_idx + 1) / self.next_seq_to_be_added
      else:
        return super(SprintDatasetBase, self).get_complete_frac(seq_idx)

  def get_seq_length(self, sorted_seq_idx):
    """
    :param int sorted_seq_idx:
    :rtype: Util.NumbersDict
    """
    with self.lock:
      self._wait_for_seq(sorted_seq_idx)
      return self._get_seq(sorted_seq_idx).num_frames

  def get_data(self, seq_idx, key):
    """
    :param int seq_idx:
    :param str key:
    :rtype: numpy.ndarray
    """
    with self.lock:
      self._wait_for_seq(seq_idx)
      return self._get_seq(seq_idx).features[key]

  def get_input_data(self, sorted_seq_idx):
    """
    :param int sorted_seq_idx:
    :rtype: numpy.ndarray
    """
    with self.lock:
      self._wait_for_seq(sorted_seq_idx)
      return self._get_seq(sorted_seq_idx).features["data"]

  def get_targets(self, target, sorted_seq_idx):
    """
    :param str target:
    :param int sorted_seq_idx:
    :rtype: numpy.ndarray
    """
    with self.lock:
      self._wait_for_seq(sorted_seq_idx)
      return self._get_seq(sorted_seq_idx).features.get(target, None)

  def get_ctc_targets(self, sorted_seq_idx):
    """
    :param int sorted_seq_idx:
    """
    assert False, "No CTC targets."

  def get_tag(self, sorted_seq_idx):
    """
    :param int sorted_seq_idx:
    :rtype: str
    """
    with self.lock:
      self._wait_for_seq(sorted_seq_idx)
      return self._get_seq(sorted_seq_idx).seq_tag
Example #54
0
class Approach(object):
    def __init__(self, name):
        self.robot_name = name

        # Mutual exclusion odometry
        self.odometry_me = Condition()

        # Create trajectory server
        self.trajectory_server = SimpleActionServer(
            'approach_server', ExecuteDroneApproachAction, self.goCallback,
            False)
        self.server_feedback = ExecuteDroneApproachFeedback()
        self.server_result = ExecuteDroneApproachResult()

        # Get client from hector_quadrotor_actions
        self.move_client = SimpleActionClient("/{}/action/pose".format(name),
                                              PoseAction)
        self.move_client.wait_for_server()

        # Subscribe to ground_truth to monitor the current pose of the robot
        rospy.Subscriber("/{}/ground_truth/state".format(name), Odometry,
                         self.poseCallback)

        # Subscribe to topic to receive the planned trajectory
        rospy.Subscriber("/{}/move_group/display_planned_path".format(name),
                         DisplayTrajectory, self.planCallback)

        #Auxiliary variables
        self.trajectory = []  # Array with the trajectory to be executed
        self.trajectory_received = False  # Flag to signal trajectory received
        self.odom_received = False  # Flag to signal odom received

        self.robot = RobotCommander(
            robot_description="{}/robot_description".format(name),
            ns="/{}".format(name))
        self.display_trajectory_publisher = rospy.Publisher(
            '/{}/move_group/display_planned_path'.format(name),
            DisplayTrajectory,
            queue_size=20)

        # Variables for collision callback
        self.validity_srv = rospy.ServiceProxy(
            '/{}/check_state_validity'.format(name), GetStateValidity)
        self.validity_srv.wait_for_service()
        self.collision = False

        # Set planning algorithm
        # self.move = MoveGroupCommander(PLANNING_GROUP, robot_description="{}/robot_description".format(name), ns="/UAV_1")        # Set group from srdf
        # self.move_group.set_planner_id("RRTConnectkConfigDefault")                      # Set planner type  (RRTConnectkConfigDefault)
        # self.move_group.set_num_planning_attempts(10)                                   # Set planning attempts
        # self.move_group.set_workspace([XMIN,YMIN,ZMIN,XMAX,YMAX,ZMAX])                  # Set the workspace size

        #Start move_group
        self.move_group = MoveGroup('earth', name)
        self.move_group.set_planner()

        #Start planningScenePublisher
        self.scene_pub = PlanningScenePublisher(name)

        # Get current robot position to define as start planning point
        self.current_pose = self.robot.get_current_state()

        # Start trajectory server
        self.trajectory_server.start()

    def goCallback(self, pose):
        '''
            Require a plan to go to the desired target and try to execute it 5 time or return erro
        '''
        self.target = pose.goal

        rospy.loginfo("Try to start from [{},{},{}]".format(
            self.odometry.position.x, self.odometry.position.y,
            self.odometry.position.z))
        rospy.loginfo("Try to go to [{},{},{}]".format(self.target.position.x,
                                                       self.target.position.y,
                                                       self.target.position.z))

        trials = 0
        while trials < 5:
            rospy.logwarn("Attempt {}".format(trials + 1))
            result = self.go(self.target)
            if (result == 'replan') or (result == 'no_plan'):
                trials += 1
            else:
                trials = 10
            self.collision = False

        if result == 'ok':
            self.trajectory_server.set_succeeded()
        elif (result == 'preempted'):
            self.trajectory_server.set_preempted()
        else:
            self.trajectory_server.set_aborted()

        self.trials = 0

    def go(self, target_):
        '''
            Function to plan and execute the trajectory one time
        '''
        # Insert goal position on an array
        target = []
        target.append(target_.position.x)
        target.append(target_.position.y)
        target.append(target_.position.z)
        target.append(target_.orientation.x)
        target.append(target_.orientation.y)
        target.append(target_.orientation.z)
        target.append(target_.orientation.w)

        #Define target for move_group
        # self.move_group.set_joint_value_target(target)
        self.move_group.set_target(target)

        self.odometry_me.acquire()
        self.current_pose.multi_dof_joint_state.transforms[
            0].translation.x = self.odometry.position.x
        self.current_pose.multi_dof_joint_state.transforms[
            0].translation.y = self.odometry.position.y
        self.current_pose.multi_dof_joint_state.transforms[
            0].translation.z = self.odometry.position.z
        self.current_pose.multi_dof_joint_state.transforms[
            0].rotation.x = self.odometry.orientation.x
        self.current_pose.multi_dof_joint_state.transforms[
            0].rotation.x = self.odometry.orientation.y
        self.current_pose.multi_dof_joint_state.transforms[
            0].rotation.x = self.odometry.orientation.z
        self.current_pose.multi_dof_joint_state.transforms[
            0].rotation.x = self.odometry.orientation.w
        self.odometry_me.release()

        #Set start state
        self.move_group.set_start_state(self.current_pose)

        #Update PlanningSene
        self.scene_pub.publishScene(self.current_pose)

        #Insert start state on move_group

        # self.move_group.set_start_state_to_current_state()

        # Plan a trajectory till the desired target
        plan = self.move_group.plan()

        if plan.planned_trajectory.multi_dof_joint_trajectory.points:  # Execute only if has points on the trajectory
            # if plan.multi_dof_joint_trajectory.points:                                           # Execute only if has points on the trajectory
            while (not self.trajectory_received):
                rospy.loginfo("Waiting for trajectory!")
                rospy.sleep(0.2)

            # rospy.loginfo("TRAJECTORY: {}".format(self.trajectory))

            #Execute trajectory with action_pose
            last_pose = self.trajectory[0]

            for pose in self.trajectory:

                # Verify preempt call
                if self.trajectory_server.is_preempt_requested():
                    self.move_client.send_goal(last_pose)
                    self.trajectory_received = False
                    self.odom_received = False
                    return 'preempted'

                #Send next pose to move
                self.next_pose = pose.target_pose.pose

                self.move_client.send_goal(pose,
                                           feedback_cb=self.collisionCallback)
                self.move_client.wait_for_result()
                result = self.move_client.get_state()

                # Abort if the drone can not reach the position
                if result == GoalStatus.ABORTED:
                    self.move_client.send_goal(
                        last_pose)  #Go back to the last pose
                    self.trajectory_received = False
                    self.odom_received = False
                    return 'aborted'
                elif result == GoalStatus.PREEMPTED:
                    return 'replan'
                last_pose = pose
                self.server_feedback.current_pose = self.odometry
                self.trajectory_server.publish_feedback(self.server_feedback)

            # Reset control variables
            self.trajectory_received = False
            self.odom_received = False
            rospy.loginfo("Trajectory is traversed!")
            return 'ok'
        else:
            rospy.logerr("Trajectory is empty. Planning was unsuccessful.")
            return 'no_plan'

    def planCallback(self, msg):
        '''
            Receive planned trajectories and insert it into an array of waypoints
        '''
        if (not self.odom_received):
            return

        # Variable to calculate the distance difference between 2 consecutive points
        last_pose = PoseGoal()
        last_pose.target_pose.pose.position.x = self.odometry.position.x
        last_pose.target_pose.pose.position.y = self.odometry.position.y
        last_pose.target_pose.pose.position.z = self.odometry.position.z
        last_pose.target_pose.pose.orientation.x = self.odometry.orientation.x
        last_pose.target_pose.pose.orientation.y = self.odometry.orientation.y
        last_pose.target_pose.pose.orientation.z = self.odometry.orientation.z
        last_pose.target_pose.pose.orientation.w = self.odometry.orientation.w

        self.trajectory = []
        for t in msg.trajectory:
            for point in t.multi_dof_joint_trajectory.points:
                waypoint = PoseGoal()
                waypoint.target_pose.header.frame_id = "{}/world".format(
                    self.robot_name)
                waypoint.target_pose.pose.position.x = point.transforms[
                    0].translation.x
                waypoint.target_pose.pose.position.y = point.transforms[
                    0].translation.y
                waypoint.target_pose.pose.position.z = point.transforms[
                    0].translation.z

                # Orientate the robot always to the motion direction
                delta_x = point.transforms[
                    0].translation.x - last_pose.target_pose.pose.position.x
                delta_y = point.transforms[
                    0].translation.y - last_pose.target_pose.pose.position.y
                motion_theta = atan2(delta_y, delta_x)

                # Make the robot orientation fit with the motion orientation if the movemente on xy is bigger than 0.2
                if (abs(delta_x) > RESOLUTION) or (abs(delta_y) > RESOLUTION):
                    q = quaternion_from_euler(0, 0, motion_theta)
                    waypoint.target_pose.pose.orientation.x = q[0]
                    waypoint.target_pose.pose.orientation.y = q[1]
                    waypoint.target_pose.pose.orientation.z = q[2]
                    waypoint.target_pose.pose.orientation.w = q[3]
                else:
                    waypoint.target_pose.pose.orientation.x = point.transforms[
                        0].rotation.x
                    waypoint.target_pose.pose.orientation.y = point.transforms[
                        0].rotation.y
                    waypoint.target_pose.pose.orientation.z = point.transforms[
                        0].rotation.z
                    waypoint.target_pose.pose.orientation.w = point.transforms[
                        0].rotation.w

                last_pose = copy.copy(
                    waypoint)  # Save pose to calc the naxt delta

                self.trajectory.append(waypoint)

            #Insert a last point to ensure that the robot end at the right position
            waypoint = PoseGoal()
            waypoint.target_pose.header.frame_id = "{}/world".format(
                self.robot_name)
            waypoint.target_pose.pose.position.x = point.transforms[
                0].translation.x
            waypoint.target_pose.pose.position.y = point.transforms[
                0].translation.y
            waypoint.target_pose.pose.position.z = point.transforms[
                0].translation.z

            waypoint.target_pose.pose.orientation.x = point.transforms[
                0].rotation.x
            waypoint.target_pose.pose.orientation.y = point.transforms[
                0].rotation.y
            waypoint.target_pose.pose.orientation.z = point.transforms[
                0].rotation.z
            waypoint.target_pose.pose.orientation.w = point.transforms[
                0].rotation.w
            self.trajectory.append(waypoint)

        self.trajectory_received = True

    def poseCallback(self, odometry):
        '''
            Monitor the current position of the robot
        '''
        self.odometry_me.acquire()
        self.odometry = odometry.pose.pose
        # print(self.odometry)
        self.odometry_me.release()
        self.odom_received = True

    def collisionCallback(self, feedback):
        '''
            This callback runs on every feedback message received
        '''
        validity_msg = GetStateValidityRequest(
        )  # Build message to verify collision
        validity_msg.group_name = PLANNING_GROUP

        if self.next_pose and (not self.collision):
            self.odometry_me.acquire()

            x = self.odometry.position.x
            y = self.odometry.position.y
            z = self.odometry.position.z

            # Distance between the robot and the next position
            dist = sqrt((self.next_pose.position.x - x)**2 +
                        (self.next_pose.position.y - y)**2 +
                        (self.next_pose.position.z - z)**2)

            # Pose to verify collision
            pose = Transform()
            pose.rotation.x = self.odometry.orientation.x
            pose.rotation.y = self.odometry.orientation.y
            pose.rotation.z = self.odometry.orientation.z
            pose.rotation.w = self.odometry.orientation.w
            self.odometry_me.release()

            #Verify possible collisions on diferent points between the robot and the goal point
            for d in arange(RESOLUTION, dist, RESOLUTION):
                pose.translation.x = (self.next_pose.position.x -
                                      x) * (d / dist) + x
                pose.translation.y = (self.next_pose.position.y -
                                      y) * (d / dist) + y
                pose.translation.z = (self.next_pose.position.z -
                                      z) * (d / dist) + z

                self.current_pose.multi_dof_joint_state.transforms[
                    0] = pose  # Insert the correct odometry value
                validity_msg.robot_state = self.current_pose

                #Update PlanningSene
                self.scene_pub.publishScene(self.current_pose)

                # Call service to verify collision
                collision_res = self.validity_srv.call(validity_msg)
                # print("\nCollision response:")
                # print(collision_res)

                # Check if robot is in collision
                if not collision_res.valid:
                    # print(validity_msg)
                    rospy.logwarn('Collision in front [x:{} y:{} z:{}]'.format(
                        pose.translation.x, pose.translation.y,
                        pose.translation.z))
                    # print(collision_res)
                    self.move_client.cancel_goal()
                    self.collision = True
                    return
 def __init__(self):
     self.frame = None
     self.buffer = io.BytesIO()
     self.condition = Condition()
Example #56
0
 def __init__(self, target_maps=None, str_add_final_zero=False, input_stddev=1.,
              orth_post_process=None, bpe=None, orth_vocab=None,
              suppress_load_seqs_print=False,
              **kwargs):
   """
   :param dict[str,str|dict] target_maps: e.g. {"speaker": "speaker_map.txt"}
   :param bool str_add_final_zero: adds e.g. "orth0" with '\0'-ending
   :param float input_stddev: if != 1, will divide the input "data" by that
   :param str|list[str]|((str)->str)|None orth_post_process: :func:`get_post_processor_function`, applied on orth
   :param None|dict[str] bpe: if given, will be opts for :class:`BytePairEncoding`
   :param None|dict[str] orth_vocab: if given, orth_vocab is applied to orth and orth_classes is an available target`
   :param bool suppress_load_seqs_print: less verbose
   """
   super(SprintDatasetBase, self).__init__(**kwargs)
   self.suppress_load_seqs_print = suppress_load_seqs_print
   if target_maps:
     assert isinstance(target_maps, dict)
     target_maps = target_maps.copy()
     for key, tmap in list(target_maps.items()):
       if isinstance(tmap, (str, unicode)):
         tmap = {l: i for (i, l) in enumerate(open(tmap).read().splitlines())}
       assert isinstance(tmap, dict)  # dict[str,int]
       target_maps[key] = tmap
   self.target_maps = target_maps
   self.str_add_final_zero = str_add_final_zero
   self.input_stddev = input_stddev
   # Note: "orth" is actually the raw bytes of the utf8 string,
   # so it does not make quite sense to associate a single str to each byte.
   # However, some other code might expect that the labels are all strings, not bytes,
   # and the API requires the labels to be strings.
   # The code in Dataset.serialize_data tries to decode this case as utf8 (if possible).
   self.labels["orth"] = [chr(i) for i in range(255)]
   self.orth_post_process = None  # type: typing.Optional[typing.Callable[[str],str]]
   if orth_post_process:
     if callable(orth_post_process):
       self.orth_post_process = orth_post_process
     else:
       from LmDataset import get_post_processor_function
       self.orth_post_process = get_post_processor_function(orth_post_process)
   self.bpe = None
   if bpe:
     from GeneratingDataset import BytePairEncoding
     self.bpe = BytePairEncoding(**bpe)
     self.labels["bpe"] = self.bpe.labels
   self.orth_vocab = None
   if orth_vocab:
     assert not bpe, "bpe has its own vocab"
     from GeneratingDataset import Vocabulary
     self.orth_vocab = Vocabulary(**orth_vocab)
     self.labels["orth_classes"] = self.orth_vocab.labels
   self.cond = Condition(lock=self.lock)
   self.add_data_thread_id = thread.get_ident()  # This will be created in the Sprint thread.
   self.ready_for_data = False
   self.reached_final_seq = False
   self.reached_final_seq_seen_all = False
   self.multiple_epochs = False
   self._complete_frac = None
   self.sprintEpoch = None  # in SprintInterface.getSegmentList()
   self.crnnEpoch = None  # in CRNN train thread, Engine.train(). set via init_seq_order
   self.predefined_seq_list_order = None  # via init_seq_order
   self.sprintFinalized = False
   self._target_black_list = []  # if we get non numpy arrays and cannot convert them
   self._reset_cache()
   assert self.shuffle_frames_of_nseqs == 0  # Currently broken. But just use Sprint itself to do this.
Example #57
0
    EngineSettingsTable,
    ProcessStepTable,
    ProcessSubscriptionTable,
    ProcessTable,
    SubscriptionTable,
    WorkflowTable,
    db,
)
from orchestrator.services.processes import shutdown_thread_pool
from orchestrator.settings import app_settings
from orchestrator.targets import Target
from orchestrator.workflow import ProcessStatus, done, init, step, workflow
from test.unit_tests.conftest import CUSTOMER_ID
from test.unit_tests.workflows import WorkflowInstanceForTests

test_condition = Condition()


@pytest.fixture
def long_running_workflow():
    @step("Long Running Step")
    def long_running_step():
        with test_condition:
            test_condition.wait()
        return {"done": True}

    @workflow("Long Running Workflow")
    def long_running_workflow_py():
        return init >> long_running_step >> long_running_step >> done

    with WorkflowInstanceForTests(long_running_workflow_py,
class Sampler(object):
    """ Sampler used to play, stop and mix multiple sounds.

        .. warning:: A single sampler instance should be used at a time.

    """

    def __init__(self, sr=22050, backend='sounddevice', timeout=1):
        """
        :param int sr: samplerate used - all sounds added to the sampler will automatically be resampled if needed (- his can be a CPU consumming task, try to use sound with all identical sampling rate if possible.
        :param str backend: backend used for playing sound. Can be either 'sounddevice' or 'dummy'.

        """
        self.sr = sr
        self.sounds = []

        self.chunks = Queue(1)
        self.chunk_available = Condition()
        self.is_done = Event()  # new event to prevent play to be called again before the sound is actually played
        self.timeout = timeout  # timeout value for graceful exit of the BackendStream

        if backend == 'dummy':
            from .dummy_stream import DummyStream
            self.BackendStream = DummyStream
        elif backend == 'sounddevice':
            from sounddevice import OutputStream
            self.BackendStream = OutputStream
        else:
            raise ValueError("Backend can either be 'sounddevice' or 'dummy'")

        self.play_thread = None

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.play_thread.join()

    def play(self, sound, wait=False):
        """ Adds and plays a new Sound to the Sampler.

            :param sound: sound to play

            .. note:: If the sound is already playing, it will restart from the beginning.

        """
        self.is_done.clear()  # hold is_done until the sound is played
        if self.sr != sound.sr:
            raise ValueError('You can only play sound with a samplerate of {} (here {}). Use the Sound.resample method for instance.', self.sr, sound.sr)

        if sound in self.sounds:
            self.remove(sound)

        with self.chunk_available:
            self.sounds.append(sound)
            sound.playing = True

            self.chunk_available.notify()

        if self.play_thread is None or not self.play_thread.is_alive():
            self.play_thread = Thread(target=self.run)
            self.play_thread.daemon = True
            self.play_thread.start()

        if wait:
            self.is_done.wait()  # wait for the sound to be entirely played

    def remove(self, sound):
        """ Remove a currently played sound. """
        with self.chunk_available:
            sound.playing = False
            self.sounds.remove(sound)

    # Play loop

    def next_chunks(self):
        """ Gets a new chunk from all played sound and mix them together. """
        with self.chunk_available:
            while True:
                playing_sounds = [s for s in self.sounds if s.playing]

                chunks = []
                for s in playing_sounds:
                    try:
                        chunks.append(next(s.chunks))
                    except StopIteration:
                        s.playing = False
                        self.sounds.remove(s)
                        self.is_done.set()  # sound was played, release is_done to end the wait in play

                if chunks:
                    break

                self.chunk_available.wait()

            return numpy.mean(chunks, axis=0)

    def run(self):
        """ Play loop, i.e. send all sound chunk by chunk to the soundcard. """
        self.running = True

        def chunks_producer():
            while self.running:
                self.chunks.put(self.next_chunks())

        t = Thread(target=chunks_producer)
        t.start()

        with self.BackendStream(samplerate=self.sr, channels=1) as stream:
            while self.running:
                try:
                    stream.write(self.chunks.get(timeout=self.timeout))  # timeout so stream.write() thread can exit
                except Empty:
                    self.running = False  # let play_thread exit
Example #59
0
def main():
    cv = Condition()
    Thread(target=consumer, args=(cv, )).start()
    Thread(target=producer, args=(cv, )).start()
        while (wypelniona == 0
               ):  #chroni przed przedwczesnym budzeniem, nie stosować if!!!
            print("czekanie...")
            zm_war.wait()  #otwiera atomowo zamek
        zm_war.release()
        for i in range(0, 10):
            print(i, ": tab = ", tab[i])


class Piszacz(Thread):
    def run(self):
        global tab, wypelniona
        global zm_war
        print("Pisarz wlasnie pisze")
        for i in range(1, 11):
            tab.append(1.0 / i / i)
            print(i, "Piszetab = ", 1.0 / i / i)
        zm_war.acquire()
        wypelniona = 1
        zm_war.notify()
        zm_war.release()


if __name__ == "__main__":
    zm_war = Condition()
    czyt = Czytacz()
    pisz = Piszacz()
    czyt.start()
    pisz.start()