示例#1
0
class Fork:

    def __init__(self):
        self.condition = Condition()
        self.being_used = False

    def use(self, delay):
        with self.condition:
            self.condition.wait_for(self.can_use)
            self.being_used = True
            time.sleep(delay)
            self.being_used = False
            self.condition.notify()

    def can_use(self):
        return not self.being_used
class ProcessorIteratorCollection(object):

    def __init__(self, processor_iterator_class):
        self._processors = {}
        self._proc_iter_class = processor_iterator_class
        self._condition = Condition()

    def __getitem__(self, item):
        """Get a particular ProcessorIterator

        :param item (ProcessorType):
        :return: (Processor)
        """
        with self._condition:
            return self._processors[item].next_processor()

    def __contains__(self, item):
        with self._condition:
            return item in self._processors

    def __setitem__(self, key, value):
        """Set a ProcessorIterator to a ProcessorType,
        if the key is already set, add the processor
        to the iterator.
        :param key (ProcessorType):
        :param value (Processor):
        """
        with self._condition:
            if key not in self._processors:
                proc_iterator = self._proc_iter_class()
                proc_iterator.add_processor(value)
                self._processors[key] = proc_iterator
            else:
                self._processors[key].add_processor(value)
            self._condition.notify_all()

    def __repr__(self):
        return ",".join([repr(k) for k in self._processors.keys()])

    def wait_to_process(self, item):
        with self._condition:
            self._condition.wait_for(lambda: item in self)
示例#3
0
class Account:
    def __init__(self):
        self.condition = Condition()
        self.balance = 0

    def make_payment(self, amount):
        with self.condition:
            while not self.can_pay(amount):
                self.condition.wait()
            self.balance -= amount
            self.condition.notify()

    def receive_payment(self, amount):
        with self.condition:
            self.condition.wait_for(self.can_receive)
            self.balance += amount
            self.condition.notify()

    def can_pay(self, amount):
        return self.balance > amount

    def can_receive(self):
        return True
class SC(object):
    def __init__(self):
        self.events = []
        self.event_cond = Condition()

    def __call__(self, event):
        with self.event_cond:
            self.events.append(event)
            self.event_cond.notify()

    def wait(self, count):
        with self.event_cond:
            result = self.event_cond.wait_for(lambda: len(self.events) == count, 2)
            assert result, "expected %s events, got %s" % (count, len(self.events))
            return self.events[:]
示例#5
0
class ProcessorManager:
    """Contains all of the registered (added via __setitem__)
    transaction processors in a _processors (dict) where the keys
    are ProcessorTypes and the values are ProcessorIterators.
    """

    def __init__(self, processor_iterator_class):
        # bytes: list of ProcessorType
        self._identities = {}
        # ProcessorType: ProcessorIterator
        self._processors = {}
        self._proc_iter_class = processor_iterator_class
        self._condition = Condition()
        self._cancelled_event = Event()

    def __getitem__(self, item):
        """Get a particular ProcessorIterator

        Args:
            item (ProcessorType): The processor type key.
        """
        with self._condition:
            return self._processors[item]

    def __contains__(self, item):
        with self._condition:
            return item in self._processors

    def get_next_of_type(self, processor_type):
        """Get the next available processor of a particular type and increment
        its occupancy counter.

        Args:
            processor_type (ProcessorType): The processor type associated with
                a zmq identity.

        Returns:
            (Processor): Information about the transaction processor
        """
        with self._condition:
            if processor_type not in self:
                self.wait_for_registration(processor_type)
            try:
                processor = self[processor_type].next_processor()
            except NoProcessorVacancyError:
                processor = self.wait_for_vacancy(processor_type)
            processor.inc_occupancy()
            return processor

    def get_all_processors(self):
        processors = []
        for processor in self._processors.values():
            processors += processor.processor_identities()
        return processors

    def __setitem__(self, key, value):
        """Either create a new ProcessorIterator, if none exists for a
        ProcessorType, or add the Processor to the ProcessorIterator.

        Args:
            key (ProcessorType): The type of transactions this transaction
                processor can handle.
            value (Processor): Information about the transaction processor.
        """
        with self._condition:
            if key not in self._processors:
                proc_iterator = self._proc_iter_class()
                proc_iterator.add_processor(value)
                self._processors[key] = proc_iterator
            else:
                self._processors[key].add_processor(value)
            if value.connection_id not in self._identities:
                self._identities[value.connection_id] = [key]
            else:
                self._identities[value.connection_id].append(key)
            self._condition.notify_all()

    def remove(self, processor_identity):
        """Removes all of the Processors for
        a particular transaction processor zeromq identity.

        Args:
            processor_identity (str): The zeromq identity of the transaction
                processor.
        """
        with self._condition:
            processor_types = self._identities.get(processor_identity)
            if processor_types is None:
                LOGGER.warning("transaction processor with identity %s tried "
                               "to unregister but was not registered",
                               processor_identity)
                return
            for processor_type in processor_types:
                if processor_type not in self._processors:
                    LOGGER.warning("processor type %s not a known processor "
                                   "type but is associated with identity %s",
                                   processor_type,
                                   processor_identity)
                    continue
                self._processors[processor_type].remove_processor(
                    processor_identity=processor_identity)
                if not self._processors[processor_type]:
                    del self._processors[processor_type]

    def __repr__(self):
        return ",".join([repr(k) for k in self._processors])

    def wait_for_registration(self, processor_type):
        """Waits for a particular processor type to register or until
        is_cancelled is True. is_cancelled cannot be part of this class
        since we aren't cancelling all waiting for a processor_type,
        but just this particular wait.

        Args:
            processor_type (ProcessorType): The family, and version of
                the transaction processor.

        Returns:
            None
        """
        with self._condition:
            self._condition.wait_for(lambda: (
                processor_type in self
                or self._cancelled_event.is_set()))
            if self._cancelled_event.is_set():
                raise WaitCancelledException()

    def wait_for_vacancy(self, processor_type):
        """Waits for a particular processor type to have the capacity to
        handle additional transactions or until is_cancelled is True.

        Args:
            processor_type (ProcessorType): The family, and version of
                the transaction processor.

        Returns:
            Processor
        """

        with self._condition:
            self._condition.wait_for(lambda: (
                self._processor_available(processor_type)
                or self._cancelled_event.is_set()))
            if self._cancelled_event.is_set():
                raise WaitCancelledException()
            processor = self[processor_type].next_processor()
            return processor

    def _processor_available(self, processor_type):
        try:
            self[processor_type].next_processor()
        except NoProcessorVacancyError:
            return False
        return True

    def cancel(self):
        with self._condition:
            self._cancelled_event.set()
            self._condition.notify_all()

    def notify(self):
        with self._condition:
            self._condition.notify_all()
示例#6
0
class ProcessKillingExecutor:
    """
    The ProcessKillingExecutor works like an `Executor <https://docs.python.org/dev/library/concurrent.futures.html#executor-objects>`_
    in that it uses a bunch of processes to execute calls to a function with different arguments asynchronously.

    But other than the `ProcessPoolExecutor <https://docs.python.org/dev/library/concurrent.futures.html#concurrent.futures.ProcessPoolExecutor>`_,
    the ProcessKillingExecutor forks a new Process for each function call that terminates after the function returns or
    when a timeout occurs.

    This means that contrary to the Executors and similar classes provided by the Python Standard Library, you can
    rely on the fact that a process will get killed if a timeout occurs and that absolutely no side effects can occur
    between function calls.

    Note that descendant processes of each process will not be terminated – they will simply become orphaned.
    """
    def __init__(self, max_workers: int = None):
        """
        Initializes a new ProcessKillingExecutor instance.
        :param max_workers: The maximum number of processes that can be used to execute the given calls.
        """
        super().__init__()
        self.max_workers = (os.cpu_count()
                            or 1) if max_workers is None else max_workers
        if self.max_workers <= 0:
            raise ValueError("max_workers must be greater than 0")
        self.manager = Manager()
        self.lock = Lock()
        self.cv = Condition()
        self.worker_count = 0

    def map(self,
            func: Callable,
            iterable: Iterable,
            timeout: float = None,
            callback_timeout: Callable = None,
            daemon: bool = True):
        """
        Returns an iterator (i.e. a generator) equivalent to map(fn, iter).
        :param func: the function to execute
        :param iterable: an iterable of function arguments
        :param timeout: after this time, the process executing the function will be killed if it did not finish
        :param callback_timeout: this function will be called, if the task times out. It gets the same arguments as
                                 the original function
        :param daemon: run the child process as daemon
        :return: An iterator equivalent to: map(func, *iterables) but the calls may be evaluated out-of-order.
        """

        # approach:
        # create a fixed amount of threads, all threads share an input queue
        # the queue holds the function params and can be fairly short
        # each thread takes a function param from the queue, starts the process and joins
        # after the process has joined, the thread takes the next element from the queue
        # or terminates, if the shutdown flag was set (this is done by the main thread,
        # when the generator has no more elements)
        # the results are stored in an output queue, which the main thread flushes
        # and yields whenever a new place becomes free in the input queue
        # after the generator is empty, the main thread waits for all threads to terminate
        # and yields the remaining results.
        # issue: input order is not preserved with that approach. May not be that bad for me,
        # but this can be solved...

        # slight modification:
        # - we store the unfinished results in the output queue immediately. the threads do not alter the queues,
        #   they just mutate the result objects. therefore we don't even have to synchronize any queues
        # - whenever the main thread is notified, fill up the input queue, yield the next elements of the output queue,
        #   until it is empty or the next element is unfinished
        # that way we preserve order and don't interrupt processing.
        # Only issue: a long-running task might stall many short tasks, but that cannot be prevented when order has to be preserved...

        params = ({
            'func': func,
            'args': args,
            'timeout': timeout,
            'callback_timeout': callback_timeout,
            'daemon': daemon,
            'result': Result()
        } for args in iterable)

        # supports indexing. not threadsafe. use append() and popleft()
        output_q = collections.deque()

        for thread_kwargs in params:
            # store result wrapper in output queue
            output_q.append(thread_kwargs['result'])
            # start the thread
            workers = self.__worker_count_inc()
            t = Thread(target=self.submit, kwargs=thread_kwargs)
            t.start()
            # yield all results from the output queue that are ready
            while len(output_q) > 0 and output_q[0].ready:
                yield output_q.popleft().value
            # blocks if max size is reached
            # there is the oh so slightest chance of a race condition here:
            # if the last thread calls notify just before we go into wait,
            # we have to wait for the next thread, which takes forever in a
            # single-thread scenario. Never happened so far, but still...
            if self.__worker_count_get() >= self.max_workers:
                with self.cv:
                    self.cv.wait()

        # almost done, wait for threads to finish, then yield the remaining results
        with self.cv:
            self.cv.wait_for(lambda: self.worker_count == 0)
        for result in output_q:
            yield result.value

    def submit(self,
               func: Callable = None,
               args: Any = (),
               kwargs: Dict = {},
               result: 'Result' = None,
               timeout: float = None,
               callback_timeout: Callable[[Any], Any] = None,
               daemon: bool = True):
        """
        Submits a callable to be executed with the given arguments.
        Schedules the callable to be executed as func(*args, **kwargs) in a new process.
        Returns the result, if the process finished successfully, or None, if it fails or a timeout occurs.
        :param func: the function to execute
        :param args: the arguments to pass to the function. Can be one argument or a tuple of multiple args.
        :param kwargs: the kwargs to pass to the function
        :param timeout: after this time, the process executing the function will be killed if it did not finish
        :param callback_timeout: this function will be called with the same arguments, if the task times out.
        :param daemon: run the child process as daemon
        :return: the result of the function, or None if the process failed or timed out
        """
        try:
            args = args if isinstance(args, tuple) else (args, )
            shared_dict = self.manager.dict()
            process_kwargs = {
                'func': func,
                'args': args,
                'kwargs': kwargs,
                'share': shared_dict
            }
            p = Process(target=self._process_run,
                        kwargs=process_kwargs,
                        daemon=daemon)
            p.start()
            p.join(timeout=timeout)
            if 'return' in shared_dict:
                if result:
                    result.success(shared_dict['return'])
                return shared_dict['return']
            else:
                if result:
                    result.failure()
                if callback_timeout:
                    callback_timeout(*args, **kwargs)
                if p.is_alive():
                    p.terminate()
                return None
        except Exception as e:
            logger.error("Process failed due to exception: ", exc_info=1)
        finally:
            if result:
                result.ready = True
            self.__worker_count_dec()
            with self.cv:
                self.cv.notify()

    @staticmethod
    def _process_run(func: Callable[[Any], Any] = None,
                     args: Any = (),
                     kwargs: Dict = {},
                     share: Dict = None):
        """
        Executes the specified function as func(*args, **kwargs).
        The result will be stored in the shared dictionary
        :param func: the function to execute
        :param args: the arguments to pass to the function
        :param kwargs: the kwargs to pass to the function
        :param share: a dictionary created using Manager.dict()
        """
        result = func(*args, **kwargs)
        share['return'] = result

    def __worker_count_inc(self):
        with self.lock:
            self.worker_count += 1
            return self.worker_count

    def __worker_count_dec(self):
        with self.lock:
            self.worker_count -= 1
            return self.worker_count

    def __worker_count_get(self):
        with self.lock:
            return self.worker_count
示例#7
0
class MultiprocProcessor(BaseProcessor):
    """
    Processor to run builders using python multiprocessing
    """
    def __init__(self, builders, num_workers=None):
        # multiprocessing only if mpi is not used, no mixing
        self.num_workers = num_workers
        super(MultiprocProcessor, self).__init__(builders)
        self.logger.info(
            "Building with multiprocessing, {} workers in the pool".format(
                self.num_workers))

    def process(self, builder_id):
        """
        Run the builder using the builtin multiprocessing.

        Args:
            builder_id (int): the index of the builder in the builders list
        """
        self.builder = self.builders[builder_id]
        self.builder.connect()

        cursor = self.builder.get_items()

        self.setup_pbars(cursor)

        self.setup_multithreading()
        self.put_tasks()
        self.clean_up_data()
        self.builder.finalize(cursor)
        self.cleanup_pbars()

    def setup_pbars(self, cursor):
        """
        Sets up progress bars
        """
        total = None

        if isinstance(cursor, types.GeneratorType):
            try:
                cursor = primed(cursor)
                if hasattr(self.builder, "total"):
                    total = self.builder.total
            except StopIteration:
                self.logger.debug("Get items returned empty iterator")

        elif hasattr(cursor, "__len__"):
            total = len(cursor)
        elif hasattr(cursor, "count"):
            total = cursor.count()

        self.get_pbar = tqdm(cursor, desc="Get Items", total=total)
        self.process_pbar = tqdm(desc="Processing Item", total=total)
        self.update_pbar = tqdm(desc="Updating Targets", total=total)

    def cleanup_pbars(self):
        """
        Cleans up the TQDM bars
        """
        self.get_pbar.close()
        self.process_pbar.close()
        self.update_pbar.close()

    def setup_multithreading(self):
        """
        Sets up objects necessary to store and synchronize data in multiprocessing
        """
        self.data = deque()
        self.task_count = BoundedSemaphore(self.builder.chunk_size)
        self.update_data_condition = Condition()

        self.run_update_targets = True
        self.update_targets_thread = Thread(target=self.update_targets)
        self.update_targets_thread.start()

    def put_tasks(self):
        """
        Processes all items from builder using a pool of processes
        """
        # 1.) setup a process pool
        with ProcessPoolExecutor(self.num_workers) as executor:
            # 2.) Loop over every item wrapped in a tqdm bar
            for item in self.get_pbar:
                # 3.) Limit total number of queues tasks using a semaphore
                self.task_count.acquire()
                # 4.) Submit a task to processing pool
                f = executor.submit(self.builder.process_item, item)
                # 5.) Add call back to update our data list
                f.add_done_callback(self.update_data_callback)

    def clean_up_data(self):
        """
        Updates targets with remaining data and then cleans up the data collection
        """
        try:
            with self.update_data_condition:
                self.run_update_targets = False
                self.update_data_condition.notify_all()
        except Exception as e:
            self.logger.debug(
                "Problem in updating targets at end of builder run: {}".format(
                    e))

        self.update_targets_thread.join()

    def update_data_callback(self, future):
        """
        Call back to add data into a list in thread safe manner and signal other threads to add more tasks or update_targets
        """
        with self.update_data_condition:
            self.process_pbar.update(1)
            self.data.append(future.result())
            self.update_data_condition.notify_all()

        self.task_count.release()

    def update_targets(self):
        """
        Thread to update targets periodically
        """

        while self.run_update_targets:
            with self.update_data_condition:
                self.update_data_condition.wait_for(
                    lambda: not self.run_update_targets or len(
                        self.data) > self.builder.chunk_size)
                try:
                    if self.data is not None:
                        self.update_pbar.unpause()
                        self.builder.update_targets(self.data)
                        self.update_pbar.update(len(self.data))
                        self.data.clear()
                except Exception as e:
                    self.logger.debug(
                        "Problem in updating targets in builder run: {}".
                        format(e))
示例#8
0
class DBDiskCache(DiskCache):
    @property
    def map_id(self):
        return self.__map_desc.map_id

    def __init__(self, cache_dir, map_desc, db_schema, is_concurrency=True):
        self.__map_desc = map_desc
        self.__db_path = os.path.join(cache_dir, map_desc.map_id + ".mbtiles")
        self.__conn = None

        #configs
        self.__db_schema = db_schema
        self.__has_timestamp = True

        self.__is_concurrency = is_concurrency

        if is_concurrency:
            self.__surrogate = None  #the thread do All DB operations, due to sqlite3 requiring only the same thread.

            self.__is_closed = False

            #concurrency get/put
            self.__sql_queue = []
            self.__sql_queue_lock = Lock()
            self.__sql_queue_cv = Condition(self.__sql_queue_lock)

            self.__get_lock = Lock()  #block the 'get' action

            self.__get_respose = None  #the pair (data, exception)
            self.__get_respose_lock = Lock()
            self.__get_respose_cv = Condition(self.__get_respose_lock)

    def __initDB(self):
        def getBoundsText(map_desc):
            left, bottom = map_desc.lower_corner
            right, top = map_desc.upper_corner
            bounds = "%f,%f,%f,%f" % (left, bottom, right, top
                                      )  #OpenLayers Bounds format
            return bounds

        desc = self.__map_desc
        conn = self.__conn

        #meatadata
        meta_create_sql = "CREATE TABLE metadata(name TEXT PRIMARY KEY, value TEXT)"
        meta_data_sqls = (
            "INSERT INTO metadata(name, value) VALUES('%s', '%s')" %
            ('name', desc.map_id),
            "INSERT INTO metadata(name, value) VALUES('%s', '%s')" %
            ('type', 'overlayer'),
            "INSERT INTO metadata(name, value) VALUES('%s', '%s')" %
            ('version', '1.0'),
            "INSERT INTO metadata(name, value) VALUES('%s', '%s')" %
            ('description', desc.map_title),
            "INSERT INTO metadata(name, value) VALUES('%s', '%s')" %
            ('format', desc.tile_format),
            "INSERT INTO metadata(name, value) VALUES('%s', '%s')" %
            ('bounds', getBoundsText(desc)),
            "INSERT INTO metadata(name, value) VALUES('%s', '%s')" %
            ('schema', self.__db_schema),
        )
        #tiles
        tiles_create_sql = "CREATE TABLE tiles("
        tiles_create_sql += "zoom_level  INTEGER, "
        tiles_create_sql += "tile_column INTEGER, "
        tiles_create_sql += "tile_row    INTEGER, "
        tiles_create_sql += "tile_data   BLOB     NOT NULL, "
        tiles_create_sql += "timestamp   INTEGER  NOT NULL, "
        tiles_create_sql += "PRIMARY KEY (zoom_level, tile_column, tile_row))"

        #tiles_idx
        tiles_idx_create_sql = "CREATE INDEX tiles_idx on tiles(zoom_level, tile_column, tile_row)"

        #exec
        conn.execute(meta_create_sql)
        conn.execute(tiles_create_sql)
        conn.execute(tiles_idx_create_sql)
        for sql in meta_data_sqls:
            conn.execute(sql)
        conn.commit()

    def __getMetadata(self, name):
        try:
            sql = 'SELECT value FROM metadata WHERE name="%s"' % (name, )
            cursor = self.__conn.execute(sql)
            row = cursor.fetchone()
            data = None if row is None else row[0]
            return data
        except Exception as ex:
            logging.warning('[%s] Get mbtiles metadata error: %s' %
                            (self.map_id, str(ex)))
        return None

    def __tableHasColumn(self, tbl_name, col_name):
        try:
            sql = "PRAGMA table_info(%s)" % tbl_name
            cursor = self.__conn.execute(sql)
            rows = cursor.fetchall()

            for row in rows:
                if row[1] == col_name:
                    return True

        except Exception as ex:
            logging.warning(
                "[%s] detect table '%s' has column '%s' error: %s" %
                (self.map_id, tbl_name, col_name, str(ex)))
        return False

    def __readConfig(self):
        #db schema from metadta
        schema = self.__getMetadata('schema')
        if schema and self.__db_schema != schema:
            logging.info("[%s] Reset db schema from %s to %s" %
                         (self.map_id, self.__db_schema, schema))
            self.__db_schema = schema

        self.__has_timestamp = self.__tableHasColumn("tiles", "timestamp")

    #the true actions which are called by Surrogate
    def __start(self):
        if not os.path.exists(self.__db_path):
            logging.info("[%s] Initializing local cache DB..." %
                         (self.map_id, ))
            mkdirSafely(os.path.dirname(self.__db_path))
            self.__conn = sqlite3.connect(self.__db_path)
            self.__initDB()
        else:
            self.__conn = sqlite3.connect(self.__db_path)
            self.__readConfig()

        logging.info("[%s][Config] db schema: %s" %
                     (self.map_id, self.__db_schema))
        logging.info("[%s][Config] suuport tile timestamp: %s" %
                     (self.map_id, self.__has_timestamp))

    def __close(self):
        logging.info("[%s] Closing local cache DB..." % (self.map_id, ))
        self.__conn.close()

    @classmethod
    def flipY(cls, y, level):
        return (1 << level) - 1 - y

    def __put(self, level, x, y, data):
        #sql
        if self.__db_schema == 'tms':
            y = self.flipY(y, level)

        sql = None
        if self.__has_timestamp:
            sql = "INSERT OR REPLACE INTO tiles(zoom_level, tile_column, tile_row, tile_data, timestamp)"
            sql += " VALUES(%d, %d, %d, ?, %d)" % (level, x, y, int(
                time.time()))
        else:
            sql = "INSERT OR REPLACE INTO tiles(zoom_level, tile_column, tile_row, tile_data)"
            sql += " VALUES(%d, %d, %d, ?)" % (level, x, y)

        #query
        try:
            self.__conn.execute(sql, (data, ))
            self.__conn.commit()
            logging.info("[%s] %s [OK]" % (self.map_id, sql))
        except Exception as ex:
            logging.info("[%s] %s [Fail]" % (self.map_id, sql))
            raise ex

    def __get(self, level, x, y):
        #sql
        if self.__db_schema == 'tms':
            y = self.flipY(y, level)

        cols = "tile_data, timestamp" if self.__has_timestamp else "tile_data"
        sql = "SELECT %s FROM tiles WHERE zoom_level=%d AND tile_column=%d AND tile_row=%d" % \
                (cols, level, x, y,)

        row = None
        try:
            #query
            cursor = self.__conn.execute(sql)
            row = cursor.fetchone()
        except Exception as ex:
            logging.info("[%s] %s [Fail]" % (self.map_id, sql))
            raise ex

        #result (tile, timestamp)
        if row is None:
            logging.info("[%s] %s [NA]" % (self.map_id, sql))
            return (None, None)
        elif self.__has_timestamp:
            logging.info("[%s] %s [OK][TS]" % (self.map_id, sql))
            return row
        else:
            logging.info("[%s] %s [OK]" % (self.map_id, sql))
            return (row[0], None)

    #the interface which are called by the user
    def start(self):
        if not self.__is_concurrency:
            self.__start()
        else:
            self.__surrogate = Thread(target=self.__runSurrogate)
            self.__surrogate.start()

    def close(self):
        if not self.__is_concurrency:
            self.__close()
        else:
            with self.__sql_queue_cv:
                self.__is_closed = True
                self.__sql_queue_cv.notify()
            self.__surrogate.join()

    def put(self, level, x, y, data):
        if not self.__is_concurrency:
            self.__put(level, x, y, data)
        else:
            with self.__sql_queue_cv:
                item = (level, x, y, data)
                self.__sql_queue.append(item)
                self.__sql_queue_cv.notify()

    def get(self, level, x, y):
        if not self.__is_concurrency:
            return self.__get(level, x, y)
        else:

            def has_respose():
                return self.__get_respose is not None

            with self.__get_lock:  #for blocking the continuous get
                #req tile
                with self.__sql_queue_cv:
                    item = (level, x, y, None)
                    self.__sql_queue.insert(0, item)  #service first
                    self.__sql_queue_cv.notify()

                #wait resposne
                res = None
                with self.__get_respose_cv:
                    self.__get_respose_cv.wait_for(has_respose)
                    res, self.__get_respose = self.__get_respose, res  #swap: pop response of get()

                #return data
                data, ex = res
                if ex:
                    raise ex
                return data

    #the Surrogate thread
    def __runSurrogate(self):
        def has_sql_events():
            return self.__is_closed or len(self.__sql_queue)

        self.__start()
        try:
            while True:
                #wait events
                item = None
                with self.__sql_queue_cv:
                    self.__sql_queue_cv.wait_for(has_sql_events)
                    if self.__is_closed:
                        return
                    item = self.__sql_queue.pop(0)

                level, x, y, data = item
                #put data
                if data:
                    try:
                        self.__put(level, x, y, data)
                    except Exception as ex:
                        logging.error("[%s] DB put data error: %s" %
                                      (self.map_id, str(ex)))
                #get data
                else:
                    res_data, res_ex = None, None
                    try:
                        res_data = self.__get(level, x, y)
                        res_ex = None
                    except Exception as ex:
                        logging.error("[%s] DB get data error: %s" %
                                      (self.map_id, str(ex)))
                        res_data = None
                        res_ex = ex

                    #notify
                    with self.__get_respose_cv:
                        self.__get_respose = (res_data, res_ex)
                        self.__get_respose_cv.notify()

        finally:
            self.__close()
示例#9
0
class BlockingQueue:
    '''
    fetching next will block, until someone append/close
    in the first case return appended value, in the last case 
    raise StopIteration

    >>> from time import sleep
    >>> from threading import Thread
    >>> q = BlockingQueue()
    >>> results = []
    >>> def foo():
    ...     for i in q:
    ...         results.append(i)
    ...     results.append('exited')
    ...
    >>> thread = Thread(target=foo)
    >>> thread.start()
    >>> sleep(.1) # making sure thread blocked
    >>> results
    []
    >>> q.append(99)
    >>> while len(results) != 1: pass
    >>> results
    [99]
    >>> sleep(.1)
    >>> results
    [99]
    >>> q.append(199)
    >>> while len(results) != 2: pass
    >>> results
    [99, 199]
    >>> sleep(.1)
    >>> results
    [99, 199]
    >>> q.close()
    >>> while len(results) != 3: pass
    >>> results
    [99, 199, 'exited']
    '''
    def __init__(self):
        self._storage = []
        self._open = True
        self._lock = Lock()
        self._condition = Condition(self._lock)

    def __iter__(self):
        return self

    def __next__(self):
        with self._condition:
            self._condition.wait_for(lambda: self._storage or not self._open)
            if self._storage:
                return self._storage.pop(0)
            elif not self._open:
                raise StopIteration

    def append(self, obj):
        with self._condition:
            if self._open:
                self._storage.append(obj)
                self._condition.notify()
            else:
                raise RuntimeError("cannot append to closed queue")

    def close(self):
        with self._condition:
            if not self._open:
                raise RuntimeError("already closed")

            self._open = False
            self._condition.notify()
class SerialScheduler(Scheduler):
    """Serial scheduler which returns transactions in the natural order.

    This scheduler will schedule one transaction at a time (only one may be
    unapplied), in the exact order provided as batches were added to the
    scheduler.

    This scheduler is intended to be used for comparison to more complex
    schedulers - for tests related to performance, correctness, etc.
    """
    def __init__(self, squash_handler, first_state_hash, always_persist):
        self._txn_queue = queue.Queue()
        self._scheduled_transactions = []
        self._batch_statuses = {}
        self._txn_to_batch = {}
        self._in_progress_transaction = None
        self._final = False
        self._complete = False
        self._cancelled = False
        self._previous_context_id = None
        self._previous_valid_batch_c_id = None
        self._squash = squash_handler
        self._condition = Condition()
        # contains all txn.signatures where txn is
        # last in it's associated batch
        self._last_in_batch = []
        self._previous_state_hash = first_state_hash
        # The state hashes here are the ones added in add_batch, and
        # are the state hashes that correspond with block boundaries.
        self._required_state_hashes = {}
        self._already_calculated = False
        self._always_persist = always_persist

    def __del__(self):
        self.cancel()

    def __iter__(self):
        return SchedulerIterator(self, self._condition)

    def set_transaction_execution_result(
            self, txn_signature, is_valid, context_id):
        with self._condition:
            if (self._in_progress_transaction is None or
                    self._in_progress_transaction != txn_signature):
                raise ValueError("transaction not in progress: {}".format(
                                 txn_signature))
            self._in_progress_transaction = None

            if txn_signature not in self._txn_to_batch:
                raise ValueError("transaction not in any batches: {}".format(
                    txn_signature))

            batch_signature = self._txn_to_batch[txn_signature]
            if is_valid:
                self._previous_context_id = context_id

            else:
                # txn is invalid, preemptively fail the batch
                self._batch_statuses[batch_signature] = \
                    BatchExecutionResult(is_valid=False, state_hash=None)
            if txn_signature in self._last_in_batch:
                if batch_signature not in self._batch_statuses:
                    # because of the else clause above, txn is valid here
                    self._previous_valid_batch_c_id = self._previous_context_id
                    state_hash = self._calculate_state_root_if_required(
                        batch_id=batch_signature)
                    self._batch_statuses[batch_signature] = \
                        BatchExecutionResult(is_valid=True,
                                             state_hash=state_hash)
                else:
                    self._previous_context_id = self._previous_valid_batch_c_id

                is_last_batch = \
                    len(self._batch_statuses) == len(self._last_in_batch)

                if self._final and is_last_batch:
                    self._complete = True
            self._condition.notify_all()

    def add_batch(self, batch, state_hash=None):
        with self._condition:
            if self._final:
                raise SchedulerError("Scheduler is finalized. Cannnot take"
                                     " new batches")
            batch_signature = batch.header_signature
            if state_hash is not None:
                self._required_state_hashes[batch_signature] = state_hash
            batch_length = len(batch.transactions)
            for idx, txn in enumerate(batch.transactions):
                if idx == batch_length - 1:
                    self._last_in_batch.append(txn.header_signature)
                self._txn_to_batch[txn.header_signature] = batch_signature
                self._txn_queue.put(txn)
            self._condition.notify_all()

    def get_batch_execution_result(self, batch_signature):
        with self._condition:
            return self._batch_statuses.get(batch_signature)

    def count(self):
        with self._condition:
            return len(self._scheduled_transactions)

    def get_transaction(self, index):
        with self._condition:
            return self._scheduled_transactions[index]

    def next_transaction(self):
        with self._condition:
            if self._in_progress_transaction is not None:
                return None
            try:
                txn = self._txn_queue.get(block=False)
            except queue.Empty:
                return None

            self._in_progress_transaction = txn.header_signature
            base_contexts = [] if self._previous_context_id is None \
                else [self._previous_context_id]
            txn_info = TxnInformation(txn=txn,
                                      state_hash=self._previous_state_hash,
                                      base_context_ids=base_contexts)
            self._scheduled_transactions.append(txn_info)
            return txn_info

    def finalize(self):
        with self._condition:
            self._final = True
            if len(self._batch_statuses) == len(self._last_in_batch):
                self._complete = True
            self._condition.notify_all()

    def _compute_merkle_root(self, required_state_root):
        """Computes the merkle root of the state changes in the context
        corresponding with _last_valid_batch_c_id as applied to
        _previous_state_hash.

        Args:
            required_state_root (str): The merkle root that these txns
                should equal.

        Returns:
            state_hash (str): The merkle root calculated from the previous
                state hash and the state changes from the context_id
        """

        state_hash = None
        if self._previous_valid_batch_c_id is not None:
            publishing_or_genesis = self._always_persist or \
                                    required_state_root is None
            state_hash = self._squash(
                state_root=self._previous_state_hash,
                context_ids=[self._previous_valid_batch_c_id],
                persist=self._always_persist, clean_up=publishing_or_genesis)
            if self._always_persist is True:
                return state_hash
            if state_hash == required_state_root:
                self._squash(state_root=self._previous_state_hash,
                             context_ids=[self._previous_valid_batch_c_id],
                             persist=True, clean_up=True)
        return state_hash

    def _calculate_state_root_if_not_already_done(self):
        if not self._already_calculated:
            if not self._last_in_batch:
                return
            last_txn_signature = self._last_in_batch[-1]
            batch_id = self._txn_to_batch[last_txn_signature]
            required_state_hash = self._required_state_hashes.get(
                batch_id)

            state_hash = self._compute_merkle_root(required_state_hash)
            self._already_calculated = True
            for t_id in self._last_in_batch[::-1]:
                b_id = self._txn_to_batch[t_id]
                if self._batch_statuses[b_id].is_valid:
                    self._batch_statuses[b_id].state_hash = state_hash
                    # found the last valid batch, so break out
                    break

    def _calculate_state_root_if_required(self, batch_id):
        required_state_hash = self._required_state_hashes.get(
            batch_id)
        state_hash = None
        if required_state_hash is not None:
            state_hash = self._compute_merkle_root(required_state_hash)
            self._already_calculated = True
        return state_hash

    def complete(self, block):
        with self._condition:
            if not self._final:
                return False
            if self._complete:
                self._calculate_state_root_if_not_already_done()
                return True
            if block:
                self._condition.wait_for(lambda: self._complete)
                self._calculate_state_root_if_not_already_done()
                return True
            return False

    def cancel(self):
        with self._condition:
            self._cancelled = True
            self._condition.notify_all()

    def is_cancelled(self):
        with self._condition:
            return self._cancelled
示例#11
0
class _ContextFuture(object):
    """Controls access to bytes set in the _result variable. The booleans
     that are flipped in set_result, based on whether the value is being set
     from the merkle tree or a direct set on the context manager are needed
     to later determine whether the value was set in that context or was
     looked up as a new address location from the merkle tree and then only
     read from, not set.

    In any context the lifecycle of a _ContextFuture can be several paths:

    Input:
    Address not in base:
      F -----> get from merkle database ----> get from the context
    Address in base:
            |---> set (F)
      F --->|
            |---> get
    Output:
      Doesn't exist ----> set address in context (F)

    Input + Output:
    Address not in base:

                             |-> set
      F |-> get from merkle -|
        |                    |-> get
        |                    |
        |                    |-> noop
        |--> set Can happen before the pre-fetch operation


                     |-> set (F) ---> get
                     |
                     |-> set (F) ----> set
                     |
    Address in base: |-> set (F)
      Doesn't exist -|
                     |-> get Future doesn't exit in context
                     |
                     |-> get ----> set (F)

    """
    def __init__(self, address, result=None, wait_for_tree=False):
        self.address = address
        self._result = result
        self._result_set_in_context = False
        self._condition = Condition()
        self._wait_for_tree = wait_for_tree
        self._tree_has_set = False
        self._read_only = False
        self._deleted = False

    def make_read_only(self):
        with self._condition:
            if self._wait_for_tree and not self._result_set_in_context:
                self._condition.wait_for(
                    lambda: self._tree_has_set or self._result_set_in_context)

            self._read_only = True

    def set_in_context(self):
        with self._condition:
            return self._result_set_in_context

    def deleted_in_context(self):
        with self._condition:
            return self._deleted

    def result(self):
        """Return the value at an address, optionally waiting until it is
        set from the context_manager, or set based on the pre-fetch mechanism.

        Returns:
            (bytes): The opaque value for an address.
        """

        if self._read_only:
            return self._result
        with self._condition:
            if self._wait_for_tree and not self._result_set_in_context:
                self._condition.wait_for(
                    lambda: self._tree_has_set or self._result_set_in_context)
            return self._result

    def set_deleted(self):
        self._result_set_in_context = False
        self._deleted = True

    def set_result(self, result, from_tree=False):
        """Set the addresses's value unless the future has been declared
        read only.

        Args:
            result (bytes): The value at an address.
            from_tree (bool): Whether the value is being set by a read from
                the merkle tree.

        Returns:
            None
        """

        if self._read_only:
            if not from_tree:
                LOGGER.warning(
                    "Tried to set address %s on a"
                    " read-only context.", self.address)
            return

        with self._condition:
            if self._read_only:
                if not from_tree:
                    LOGGER.warning(
                        "Tried to set address %s on a"
                        " read-only context.", self.address)
                return
            if from_tree:
                # If the result has not been set in the context, overwrite the
                # value with the value from the merkle tree. Otherwise, do
                # nothing.
                if not self._result_set_in_context:
                    self._result = result
                    self._tree_has_set = True
            else:
                self._result = result
                self._result_set_in_context = True
                self._deleted = False

            self._condition.notify_all()
示例#12
0
class SerialScheduler(Scheduler):
    """Serial scheduler which returns transactions in the natural order.

    This scheduler will schedule one transaction at a time (only one may be
    unapplied), in the exact order provided as batches were added to the
    scheduler.

    This scheduler is intended to be used for comparison to more complex
    schedulers - for tests related to performance, correctness, etc.
    """
    def __init__(self, squash_handler, first_state_hash, always_persist):
        self._txn_queue = queue.Queue()
        self._scheduled_transactions = []
        self._batch_statuses = {}
        self._txn_to_batch = {}
        self._in_progress_transaction = None
        self._final = False
        self._complete = False
        self._cancelled = False
        self._previous_context_id = None
        self._previous_valid_batch_c_id = None
        self._squash = squash_handler
        self._condition = Condition()
        # contains all txn.signatures where txn is
        # last in it's associated batch
        self._last_in_batch = []
        self._previous_state_hash = first_state_hash
        # The state hashes here are the ones added in add_batch, and
        # are the state hashes that correspond with block boundaries.
        self._required_state_hashes = {}
        self._always_persist = always_persist

    def __del__(self):
        self.cancel()

    def __iter__(self):
        return SchedulerIterator(self, self._condition)

    def set_transaction_execution_result(self, txn_signature, is_valid,
                                         context_id):
        with self._condition:
            if (self._in_progress_transaction is None
                    or self._in_progress_transaction != txn_signature):
                raise ValueError(
                    "transaction not in progress: {}".format(txn_signature))
            self._in_progress_transaction = None

            if txn_signature not in self._txn_to_batch:
                raise ValueError(
                    "transaction not in any batches: {}".format(txn_signature))

            batch_signature = self._txn_to_batch[txn_signature]
            if is_valid:
                self._previous_context_id = context_id

            else:
                # txn is invalid, preemptively fail the batch
                self._batch_statuses[batch_signature] = \
                    BatchExecutionResult(is_valid=is_valid, state_hash=None)
            if txn_signature in self._last_in_batch:
                if batch_signature not in self._batch_statuses:
                    # because of the else clause above, txn is valid here
                    self._previous_valid_batch_c_id = self._previous_context_id
                    state_hash = None
                    required_state_hash = self._required_state_hashes.get(
                        batch_signature)
                    if required_state_hash is not None \
                            or self._last_in_batch[-1] == txn_signature:
                        state_hash = self._compute_merkle_root(
                            required_state_hash)
                    self._batch_statuses[batch_signature] = \
                        BatchExecutionResult(
                            is_valid=is_valid,
                            state_hash=state_hash)
                else:
                    self._previous_context_id = self._previous_valid_batch_c_id

                is_last_batch = \
                    len(self._batch_statuses) == len(self._last_in_batch)

                if self._final and is_last_batch:
                    self._complete = True
            self._condition.notify_all()

    def add_batch(self, batch, state_hash=None):
        with self._condition:
            if self._final:
                raise SchedulerError("Scheduler is finalized. Cannnot take"
                                     " new batches")
            batch_signature = batch.header_signature
            if state_hash is not None:
                self._required_state_hashes[batch_signature] = state_hash
            batch_length = len(batch.transactions)
            for idx, txn in enumerate(batch.transactions):
                if idx == batch_length - 1:
                    self._last_in_batch.append(txn.header_signature)
                self._txn_to_batch[txn.header_signature] = batch_signature
                self._txn_queue.put(txn)
            self._condition.notify_all()

    def get_batch_execution_result(self, batch_signature):
        with self._condition:
            return self._batch_statuses.get(batch_signature)

    def count(self):
        with self._condition:
            return len(self._scheduled_transactions)

    def get_transaction(self, index):
        with self._condition:
            return self._scheduled_transactions[index]

    def next_transaction(self):
        with self._condition:
            if self._in_progress_transaction is not None:
                return None
            try:
                txn = self._txn_queue.get(block=False)
            except queue.Empty:
                return None

            self._in_progress_transaction = txn.header_signature
            base_contexts = [] if self._previous_context_id is None \
                else [self._previous_context_id]
            txn_info = TxnInformation(txn=txn,
                                      state_hash=self._previous_state_hash,
                                      base_context_ids=base_contexts)
            self._scheduled_transactions.append(txn_info)
            return txn_info

    def finalize(self):
        with self._condition:
            self._final = True
            if len(self._batch_statuses) == len(self._last_in_batch):
                self._complete = True
            self._condition.notify_all()

    def _compute_merkle_root(self, required_state_root):
        """Computes the merkle root of the state changes in the context
        corresponding with _last_valid_batch_c_id as applied to
        _previous_state_hash.

        Args:
            required_state_root (str): The merkle root that these txns
                should equal.

        Returns:
            state_hash (str): The merkle root calculated from the previous
                state hash and the state changes from the context_id
        """
        state_hash = self._squash(
            state_root=self._previous_state_hash,
            context_ids=[self._previous_valid_batch_c_id],
            persist=self._always_persist)
        if self._always_persist is True:
            return state_hash
        if state_hash == required_state_root:
            self._squash(state_root=self._previous_state_hash,
                         context_ids=[self._previous_valid_batch_c_id],
                         persist=True)
        return state_hash

    def complete(self, block):
        with self._condition:
            if not self._final:
                return False
            if self._complete:
                return True
            if block:
                self._condition.wait_for(lambda: self._complete)
                return True
            return False

    def cancel(self):
        with self._condition:
            self._cancelled = True
            self._condition.notify_all()

    def is_cancelled(self):
        with self._condition:
            return self._cancelled
示例#13
0
class MqttClient(object):
    def __init__(self, params):
        self._params = _translate_tls_params(params)
        self._params.update(params)
        _replace_ssl_params(self._params)
        logger.debug(self._params)

        try:
            self._mqttc = _create_mqtt_client(self._params)
        except ValueError as ex:
            raise InvalidArgumentError(ex)

        self._mqttc.on_connect = self._on_connect
        self.qos = _to_qos(self._params)
        self.connection_timeout = self._params.get('connection_timeout', 10)
        self.keepalive = self._params.get('keepalive', 60)
        self.host, self.port = _get_broker(self._params)
        logger.debug(f'broker={self.host} port={self.port}')
        self.protocol = _to_protocol(self._params.get('protocol')),
        self._connection_result = None
        self._conn_cond = Condition()

    def open(self, timeout=None):
        logger.debug("open")
        try:
            self._mqttc.connect(self.host, self.port, self.keepalive)
        except (socket.error, OSError, WebsocketConnectionError):
            logger.error(f"cannot connect broker: {self.host}:{self.port}")
            self.close()
            raise ConnectionError(
                f"cannot connect broker: {self.host}:{self.port}")

        self._mqttc.loop_start()

        if timeout is None:
            timeout = self.connection_timeout
        with self._conn_cond:
            ret = self._conn_cond.wait_for(self._is_connected, timeout)
            if not ret:
                self.close()
                raise ConnectionError('connection timed out')
            if self._connection_result != 0:
                if self.protocol == MQTTv5:
                    reason = str(self._connection_result)
                else:
                    reason = connack_string(self._connection_result)
                self.close()
                raise ConnectionError(f'connection error: reason={reason}')
        return self

    def close(self):
        logger.debug("close")
        try:
            self._mqttc.disconnect()
            self._mqttc.loop_stop()
        except Exception:
            logger.error("mqtt close() error")

    def _is_connected(self):
        return self._connection_result is not None

    def _on_connect(self, client, userdata, flags, rc, properties=None):
        logger.debug(f"MQTT:on_connect: rc={rc}")
        if rc != 0:
            logger.error(f"MQTT: {connack_string(rc)}: {rc}")
        with self._conn_cond:
            self._connection_result = rc
            self._conn_cond.notify_all()
示例#14
0
class TaskExecutor(object):
    def __init__(self, balancer, index):
        self.balancer = balancer
        self.index = index
        self.task = None
        self.proc = None
        self.pid = None
        self.conn = None
        self.state = WorkerState.STARTING
        self.key = str(uuid.uuid4())
        self.result = AsyncResult()
        self.exiting = False
        self.thread = gevent.spawn(self.executor)
        self.cv = Condition()
        self.status_lock = RLock()

    def checkin(self, conn):
        with self.cv:
            self.balancer.logger.debug('Check-in of worker #{0} (key {1})'.format(self.index, self.key))
            self.conn = conn
            self.state = WorkerState.IDLE
            self.cv.notify_all()

    def get_status(self):
        with self.cv:
            self.cv.wait_for(lambda: self.state == WorkerState.EXECUTING)
            try:
                st = TaskStatus(0)
                st.__setstate__(self.conn.call_sync('taskproxy.get_status'))
                return st
            except RpcException as err:
                self.balancer.logger.error(
                    "Cannot obtain status from task #{0}: {1}".format(self.task.id, str(err))
                )
                self.terminate()

    def put_status(self, status):
        with self.cv:
            # Try to collect rusage at this point, when process is still alive
            try:
                kinfo = bsd.kinfo_getproc(self.pid)
                self.task.rusage = kinfo.rusage
            except LookupError:
                pass

            if status['status'] == 'ROLLBACK':
                self.task.set_state(TaskState.ROLLBACK)

            if status['status'] == 'FINISHED':
                self.result.set(status['result'])

            if status['status'] == 'FAILED':
                error = status['error']
                cls = TaskException

                if error['type'] == 'TaskAbortException':
                    cls = TaskAbortException

                if error['type'] == 'ValidationException':
                    cls = ValidationException

                self.result.set_exception(cls(
                    code=error['code'],
                    message=error['message'],
                    stacktrace=error['stacktrace'],
                    extra=error.get('extra')
                ))

    def put_warning(self, warning):
        self.task.add_warning(warning)

    def run(self, task):
        with self.cv:
            self.cv.wait_for(lambda: self.state == WorkerState.IDLE)
            self.result = AsyncResult()
            self.task = task
            self.task.set_state(TaskState.EXECUTING)
            self.state = WorkerState.EXECUTING
            self.cv.notify_all()

        self.balancer.logger.debug('Actually starting task {0}'.format(task.id))

        filename = None
        module_name = inspect.getmodule(task.clazz).__name__
        for dir in self.balancer.dispatcher.plugin_dirs:
            found = False
            try:
                for root, _, files in os.walk(dir):
                    file = first_or_default(lambda f: module_name in f, files)
                    if file:
                        filename = os.path.join(root, file)
                        found = True
                        break

                if found:
                    break
            except FileNotFoundError:
                continue

        try:
            self.conn.call_sync('taskproxy.run', {
                'id': task.id,
                'user': task.user,
                'class': task.clazz.__name__,
                'filename': filename,
                'args': task.args,
                'debugger': task.debugger,
                'environment': task.environment
            })
        except RpcException as e:
            self.balancer.logger.warning('Cannot start task {0} on executor #{1}: {2}'.format(
                task.id,
                self.index,
                str(e)
            ))

            self.balancer.logger.warning('Killing unresponsive task executor #{0} (pid {1})'.format(
                self.index,
                self.proc.pid
            ))

            self.terminate()

        try:
            self.result.get()
        except BaseException as e:
            if not isinstance(e, TaskException):
                self.balancer.dispatcher.report_error(
                    'Task {0} raised exception other than TaskException'.format(self.task.name),
                    e
                )

            if isinstance(e, TaskAbortException):
                self.task.set_state(TaskState.ABORTED, TaskStatus(0, 'aborted'))
            else:
                self.task.error = serialize_error(e)
                self.task.set_state(TaskState.FAILED, TaskStatus(0, str(e), extra={
                    "stacktrace": traceback.format_exc()
                }))

            with self.cv:
                self.task.ended.set()
                self.balancer.task_exited(self.task)

                if self.state == WorkerState.EXECUTING:
                    self.state = WorkerState.IDLE
                    self.cv.notify_all()

                return

        with self.cv:
            self.task.result = self.result.value
            self.task.set_state(TaskState.FINISHED, TaskStatus(100, ''))
            self.task.ended.set()
            self.balancer.task_exited(self.task)
            if self.state == WorkerState.EXECUTING:
                self.state = WorkerState.IDLE
                self.cv.notify_all()

    def abort(self):
        self.balancer.logger.info("Trying to abort task #{0}".format(self.task.id))
        # Try to abort via RPC. If this fails, kill process
        try:
            self.conn.call_sync('taskproxy.abort')
        except RpcException as err:
            self.balancer.logger.warning("Failed to abort task #{0} gracefully: {1}".format(self.task.id, str(err)))
            self.balancer.logger.warning("Killing process {0}".format(self.pid))
            self.terminate()

    def terminate(self):
        try:
            self.proc.terminate()
        except ProcessLookupError:
            self.balancer.logger.warning('Executor process with PID {0} already dead'.format(self.proc.pid))

    def executor(self):
        while not self.exiting:
            try:
                self.proc = Popen(
                    [TASKWORKER_PATH, self.key],
                    close_fds=True,
                    preexec_fn=os.setpgrp,
                    stdout=subprocess.PIPE,
                    stderr=subprocess.STDOUT)

                self.pid = self.proc.pid
                self.balancer.logger.debug('Started executor #{0} as PID {1}'.format(self.index, self.pid))
            except OSError:
                self.result.set_exception(TaskException(errno.EFAULT, 'Cannot spawn task executor'))
                self.balancer.logger.error('Cannot spawn task executor #{0}'.format(self.index))
                return

            for line in self.proc.stdout:
                line = line.decode('utf8')
                self.balancer.logger.debug('Executor #{0}: {1}'.format(self.index, line.strip()))
                if self.task:
                    self.task.output += line

            self.proc.wait()

            with self.cv:
                self.state = WorkerState.STARTING
                self.cv.notify_all()

            if self.proc.returncode == -signal.SIGTERM:
                self.balancer.logger.info(
                    'Executor process with PID {0} was terminated gracefully'.format(
                        self.proc.pid
                    )
                )
            else:
                self.balancer.logger.error('Executor process with PID {0} died abruptly with exit code {1}'.format(
                    self.proc.pid,
                    self.proc.returncode)
                )

            self.result.set_exception(TaskException(errno.EFAULT, 'Task executor died'))
            gevent.sleep(1)

    def die(self):
        self.exiting = True
        if self.proc:
            self.terminate()
class SerialScheduler(Scheduler):
    """Serial scheduler which returns transactions in the natural order.

    This scheduler will schedule one transaction at a time (only one may be
    unapplied), in the exact order provided as batches were added to the
    scheduler.

    This scheduler is intended to be used for comparison to more complex
    schedulers - for tests related to performance, correctness, etc.
    """
    def __init__(self, squash_handler, first_state_hash):
        self._txn_queue = queue.Queue()
        self._scheduled_transactions = []
        self._batch_statuses = {}
        self._txn_to_batch = {}
        self._in_progress_transaction = None
        self._final = False
        self._complete = False
        self._squash = squash_handler
        self._condition = Condition()
        # contains all txn.signatures where txn is
        # last in it's associated batch
        self._last_in_batch = []
        self._last_state_hash = first_state_hash

    def __iter__(self):
        return SchedulerIterator(self, self._condition)

    def set_transaction_execution_result(
            self, txn_signature, is_valid, context_id):
        """the control flow is that on every valid txn a new state root is
        generated. If the txn is invalid the batch status is set,
        if the txn is the last txn in the batch, is valid, and no
         prior txn failed the batch, the
        batch is valid
        """
        with self._condition:
            if (self._in_progress_transaction is None or
                    self._in_progress_transaction != txn_signature):
                raise ValueError("transaction not in progress: {}",
                                 txn_signature)
            self._in_progress_transaction = None

            if txn_signature not in self._txn_to_batch:
                raise ValueError("transaction not in any batches: {}".format(
                    txn_signature))
            if is_valid:
                # txn is valid, get a new state hash
                state_hash = self._squash(self._last_state_hash, [context_id])
                self._last_state_hash = state_hash
            else:
                # txn is invalid, preemptively fail the batch
                batch_signature = self._txn_to_batch[txn_signature]
                self._batch_statuses[batch_signature] = \
                    BatchExecutionResult(is_valid=is_valid, state_hash=None)
            if txn_signature in self._last_in_batch:
                batch_signature = self._txn_to_batch[txn_signature]
                if batch_signature not in self._batch_statuses:
                    # because of the else clause above, txn is valid here
                    self._batch_statuses[batch_signature] = \
                        BatchExecutionResult(
                            is_valid=is_valid,
                            state_hash=self._last_state_hash)

            if self._final and self._txn_queue.empty():
                self._complete = True
            self._condition.notify_all()

    def add_batch(self, batch, state_hash=None):
        with self._condition:
            if self._final:
                raise SchedulerError("Scheduler is finalized. Cannnot take"
                                     " new batches")
            batch_signature = batch.header_signature
            batch_length = len(batch.transactions)
            for idx, txn in enumerate(batch.transactions):
                if idx == batch_length - 1:
                    self._last_in_batch.append(txn.header_signature)
                self._txn_to_batch[txn.header_signature] = batch_signature
                self._txn_queue.put(txn)

    def get_batch_execution_result(self, batch_signature):
        with self._condition:
            return self._batch_statuses.get(batch_signature)

    def count(self):
        with self._condition:
            return len(self._scheduled_transactions)

    def get_transaction(self, index):
        with self._condition:
            return self._scheduled_transactions[index]

    def next_transaction(self):
        with self._condition:
            if self._in_progress_transaction is not None:
                return None
            try:
                txn = self._txn_queue.get(block=False)
            except queue.Empty:
                return None

            self._in_progress_transaction = txn.header_signature
            txn_info = TxnInformation(txn, self._last_state_hash)
            self._scheduled_transactions.append(txn_info)
            return txn_info

    def finalize(self):
        with self._condition:
            self._final = True
            self._condition.notify_all()

    def complete(self, block):
        with self._condition:
            if not self._final:
                return False
            if self._complete:
                return True
            if block:
                self._condition.wait_for(lambda: self._complete)
                return True
            return False
示例#16
0
class _SendReceiveThread(Thread):
    """
    Internal thread to Stream class that runs the asyncio event loop.
    """

    def __init__(self, url, futures, ready_event, error_queue):
        """constructor for background thread

        :param url (str): the address to connect to the validator on
        :param futures (FutureCollection): The Futures associated with
                messages sent through Stream.send
        :param ready_event (threading.Event): used to notify waiting/asking
               classes that the background thread of Stream is ready after
               a disconnect event.
        """
        super(_SendReceiveThread, self).__init__()
        self._futures = futures
        self._url = url
        self._shutdown = False
        self._event_loop = None
        self._sock = None
        self._monitor_sock = None
        self._monitor_fd = None
        self._recv_queue = None
        self._send_queue = None
        self._context = None
        self._ready_event = ready_event
        self._error_queue = error_queue
        self._condition = Condition()
        self.identity = _generate_id()[0:16]

    @asyncio.coroutine
    def _receive_message(self):
        """
        internal coroutine that receives messages and puts
        them on the recv_queue
        """
        while True:
            if not self._ready_event.is_set():
                break
            msg_bytes = yield from self._sock.recv()
            message = validator_pb2.Message()
            message.ParseFromString(msg_bytes)
            try:
                self._futures.set_result(
                    message.correlation_id,
                    FutureResult(message_type=message.message_type,
                                 content=message.content))
                self._futures.remove(message.correlation_id)
            except FutureCollectionKeyError:
                # if we are getting an initial message, not a response
                if not self._ready_event.is_set():
                    break
                self._recv_queue.put_nowait(message)

    @asyncio.coroutine
    def _send_message(self):
        """
        internal coroutine that sends messages from the send_queue
        """
        while True:
            if not self._ready_event.is_set():
                break
            msg = yield from self._send_queue.get()
            yield from self._sock.send_multipart([msg.SerializeToString()])

    @asyncio.coroutine
    def _put_message(self, message):
        """
        Puts a message on the send_queue. Not to be accessed directly.
        :param message: protobuf generated validator_pb2.Message
        """
        self._send_queue.put_nowait(message)

    @asyncio.coroutine
    def _get_message(self):
        """
        Gets a message from the recv_queue. Not to be accessed directly.
        """
        with self._condition:
            self._condition.wait_for(lambda: self._recv_queue is not None)
        msg = yield from self._recv_queue.get()

        return msg

    @asyncio.coroutine
    def _monitor_disconnects(self):
        """Monitors the client socket for disconnects
        """
        yield from self._monitor_sock.recv_multipart()
        self._sock.disable_monitor()
        self._monitor_sock.disconnect(self._monitor_fd)
        self._monitor_sock.close(linger=0)
        self._monitor_sock = None
        self._sock.disconnect(self._url)
        self._ready_event.clear()
        LOGGER.debug("monitor socket received disconnect event")
        for future in self._futures.future_values():
            future.set_result(FutureError())
        for task in asyncio.Task.all_tasks(self._event_loop):
            task.cancel()
        self._event_loop.stop()
        self._send_queue = None
        self._recv_queue = None

    def put_message(self, message):
        """
        :param message: protobuf generated validator_pb2.Message
        """
        if not self._ready_event.is_set():
            return
        with self._condition:
            self._condition.wait_for(lambda: self._event_loop is not None
                                     and self._send_queue is not None)
        asyncio.run_coroutine_threadsafe(self._put_message(message),
                                         self._event_loop)

    def get_message(self):
        """
        :return message: concurrent.futures.Future
        """
        with self._condition:
            self._condition.wait_for(lambda: self._event_loop is not None)
        return asyncio.run_coroutine_threadsafe(self._get_message(),
                                                self._event_loop)

    def _cancel_tasks_yet_to_be_done(self):
        """Cancels all the tasks (pending coroutines and futures)
        """
        for task in asyncio.Task.all_tasks(self._event_loop):
            self._event_loop.call_soon_threadsafe(task.cancel)
        self._event_loop.call_soon_threadsafe(self._done_callback)

    def shutdown(self):
        """Shutdown the _SendReceiveThread. Is an irreversible operation.
        """

        self._shutdown = True
        self._cancel_tasks_yet_to_be_done()

    def _done_callback(self):
        """Stops the event loop, closes the socket, and destroys the context

        :param future: concurrent.futures.Future not used
        """
        self._event_loop.call_soon_threadsafe(self._event_loop.stop)
        self._sock.close(linger=0)
        self._monitor_sock.close(linger=0)
        self._context.destroy(linger=0)

    def run(self):
        first_time = True
        while True:
            try:
                if self._event_loop is None:
                    self._event_loop = zmq.asyncio.ZMQEventLoop()
                    asyncio.set_event_loop(self._event_loop)
                if self._context is None:
                    self._context = zmq.asyncio.Context()
                if self._sock is None:
                    self._sock = self._context.socket(zmq.DEALER)
                self._sock.identity = self.identity

                self._sock.connect(self._url)

                self._monitor_fd = "inproc://monitor.s-{}".format(
                    _generate_id()[0:5])
                self._monitor_sock = self._sock.get_monitor_socket(
                    zmq.EVENT_DISCONNECTED,
                    addr=self._monitor_fd)
                self._send_queue = asyncio.Queue(loop=self._event_loop)
                self._recv_queue = asyncio.Queue(loop=self._event_loop)
                if first_time is False:
                    self._recv_queue.put_nowait(RECONNECT_EVENT)
                with self._condition:
                    self._condition.notify_all()
                asyncio.ensure_future(self._send_message(),
                                      loop=self._event_loop)
                asyncio.ensure_future(self._receive_message(),
                                      loop=self._event_loop)
                asyncio.ensure_future(self._monitor_disconnects(),
                                      loop=self._event_loop)
                # pylint: disable=broad-except
            except Exception as e:
                LOGGER.error("Exception connecting to validator "
                             "address %s, so shutting down", self._url)
                self._error_queue.put_nowait(e)
                break

            self._error_queue.put_nowait(_NO_ERROR)
            self._ready_event.set()
            self._event_loop.run_forever()
            if self._shutdown:
                self._sock.close(linger=0)
                self._monitor_sock.close(linger=0)
                self._context.destroy(linger=0)
                break
            if first_time is True:
                first_time = False
示例#17
0
class RingBuffer(object):
    def __init__(self, size):
        self.data = bytearray(size)
        self.view = memoryview(self.data)
        self.size = size
        self.head = 0
        self.tail = 0
        self.closed = False
        self.cv = Condition()

    @property
    def empty(self):
        return self.head == self.tail

    @property
    def full(self):
        return self.head == (self.tail + 1) % self.size

    @property
    def used_space(self):
        if self.empty:
            return 0

        if self.tail > self.head:
            return self.tail - self.head

        if self.head > self.tail:
            return (self.size - self.head) + self.tail

    @property
    def avail_space(self):
        return self.size - self.used_space - 1

    def write(self, data):
        with self.cv:
            if self.full:
                self.cv.wait_for(lambda: not self.full or self.closed)
                if self.closed:
                    return 0

            towrite = min(len(data), self.avail_space)

            if self.tail >= self.head:
                first = min(towrite, self.size - self.tail)
                rest = towrite - first

                if first:
                    self.view[self.tail:self.tail+first] = data[:first]
                    self.tail = (self.tail + first) % self.size

                if rest:
                    self.view[:rest] = data[first:first+rest]
                    self.tail = (self.tail + rest) % self.size

                self.cv.notify_all()
                return towrite

            if self.head > self.tail:
                self.view[self.tail:self.head] = data[:towrite]
                self.tail = (self.tail + towrite) % self.size
                self.cv.notify_all()
                return towrite

    def writeall(self, data):
        done = 0
        while done < len(data):
            ret = self.write(data[done:])
            if ret == 0:
                break

            done += ret

    def read(self, count):
        with self.cv:
            if self.empty:
                if self.closed:
                    return b''

                self.cv.wait_for(lambda: not self.empty or self.closed)

            toread = min(count, self.used_space)

            if self.tail >= self.head:
                toread = min(toread, self.tail - self.head)
                result = bytes(self.view[self.head:self.head+toread])
                self.head = (self.head + toread) % self.size
                self.cv.notify_all()
                return result

            if self.head > self.tail:
                first = min(toread, self.size - self.head)
                rest = toread - first

                if first:
                    result = bytes(self.view[self.head:self.head+first])
                    self.head = (self.head + first) % self.size

                if rest:
                    result += bytes(self.view[:rest])
                    self.head = (self.head + rest) % self.size

                self.cv.notify_all()
                return result

    def readall(self, count):
        result = b''
        done = 0

        while done < count:
            ret = self.read(count - done)
            if ret == b'':
                break

            result += ret

        return result

    def close(self):
        with self.cv:
            self.closed = True
            self.cv.notify_all()
示例#18
0
class Context(object):
    def __init__(self):
        self.logger = logging.getLogger(self.__class__.__name__)
        self.msock = msock.client.Client()
        self.msock.on_closed = self.on_msock_close
        self.rpc_fd = -1
        self.connection_id = None
        self.jobs = []
        self.state = ConnectionState.OFFLINE
        self.config = None
        self.keepalive = None
        self.connected_at = None
        self.cv = Condition()
        self.rpc = RpcContext()
        self.client = Client()
        self.server = Server()
        self.middleware_endpoint = None

    def start(self, configpath, sockpath):
        signal.signal(signal.SIGUSR2, lambda signo, frame: self.connect())
        self.read_config(configpath)
        self.server.rpc = RpcContext()
        self.server.rpc.register_service_instance("control", ControlService(self))
        self.server.start(sockpath)
        threading.Thread(target=self.server.serve_forever, name="server thread", daemon=True).start()

    def init_dispatcher(self):
        def on_error(reason, **kwargs):
            if reason in (ClientError.CONNECTION_CLOSED, ClientError.LOGOUT):
                self.logger.warning("Connection to dispatcher lost")
                self.connect_dispatcher()

        self.middleware_endpoint = Client()
        self.middleware_endpoint.on_error(on_error)
        self.connect_dispatcher()

    def connect_dispatcher(self):
        while True:
            try:
                self.middleware_endpoint.connect("unix:")
                self.middleware_endpoint.login_service("debugd")
                self.middleware_endpoint.enable_server()
                self.middleware_endpoint.register_service("debugd.management", ControlService(self))
                self.middleware_endpoint.resume_service("debugd.management")
                return
            except (OSError, RpcException) as err:
                self.logger.warning("Cannot connect to dispatcher: {0}, retrying in 1 second".format(str(err)))
                time.sleep(1)

    def read_config(self, path):
        try:
            with open(path) as f:
                self.config = json.load(f)
        except (IOError, OSError, ValueError) as err:
            self.logger.fatal("Cannot open config file: {0}".format(str(err)))
            self.logger.fatal("Exiting.")
            sys.exit(1)

    def connect(self, discard=False):
        if discard:
            self.connection_id = None

        self.keepalive = threading.Thread(target=self.connect_keepalive, daemon=True)
        self.keepalive.start()

    def connect_keepalive(self):
        while True:
            try:
                if not self.connection_id:
                    self.connection_id = uuid.uuid4()

                self.msock.connect(SUPPORT_PROXY_ADDRESS)
                self.logger.info("Connecting to {0}".format(SUPPORT_PROXY_ADDRESS))
                self.rpc_fd = self.msock.create_channel(0)
                time.sleep(1)  # FIXME
                self.client = Client()
                self.client.connect("fd://", fobj=self.rpc_fd)
                self.client.channel_serializer = MSockChannelSerializer(self.msock)
                self.client.standalone_server = True
                self.client.enable_server()
                self.client.register_service("debug", DebugService(self))
                self.client.call_sync(
                    "server.login", str(self.connection_id), socket.gethostname(), get_version(), "none"
                )
                self.set_state(ConnectionState.CONNECTED)
            except BaseException as err:
                self.logger.warning("Failed to initiate support connection: {0}".format(err), exc_info=True)
                self.msock.disconnect()
            else:
                self.connected_at = datetime.now()
                with self.cv:
                    self.cv.wait_for(lambda: self.state in (ConnectionState.LOST, ConnectionState.OFFLINE))
                    if self.state == ConnectionState.OFFLINE:
                        return

            self.logger.warning("Support connection lost, retrying in 10 seconds")
            time.sleep(10)

    def disconnect(self):
        self.connected_at = None
        self.set_state(ConnectionState.OFFLINE)
        self.client.disconnect()
        self.msock.destroy_channel(0)
        self.msock.disconnect()
        self.jobs.clear()

    def on_msock_close(self):
        self.connected_at = None
        self.set_state(ConnectionState.LOST)

    def run_job(self, job):
        self.jobs.append(job)
        job.context = self
        job.start()

    def set_state(self, state):
        with self.cv:
            self.state = state
            self.cv.notify_all()
示例#19
0
class Match(Thread):
    def __init__(
        self, ai_list, game_manager, player_manager=None, sound_manager=None, two_wind_game=False, starting_score=25000, wall=None, deadwall=None
    ):
        super().__init__()
        self.ai_list = ai_list
        self.players = []
        self.player_manager = player_manager
        player_manager.current_match = self
        self.sound_manager = sound_manager
        self.game_manager = game_manager
        self.scores = [starting_score] * 4
        self.delta_scores = [0] * 4
        self.current_board = None
        self.east_prevalent = True
        self.round_number = 0
        self.two_wind_game = two_wind_game
        self.match_ready = False
        self.match_alive = True
        self.process_lock = Condition()
        self.match_lock = Lock()
        self.game_id = -1
        self.current_board = None
        self.encountered_end_game = False
        if wall is not None and deadwall is not None:
            self.current_board = Board(
                wall=wall,
                deadwall=deadwall,
                dora_revealed=0,
            )

    def new_board(self, wall=None, deadwall=None):
        self.current_board = None
        self.current_board = Board(
            wall=wall,
            deadwall=deadwall,
            dora_revealed=0,
        )
        self.current_board.current_dealer = (self.round_number - 1) % 4

    def play_clack(self):
        if self.sound_manager is not None:
            self.sound_manager.play_from_set("clack")

    def bootstrap_match(self):
        for i in range(4):
            if i == self.player_manager.player_id:
                self.players += [self.player_manager]
            else:
                self.players += [Player("Bot {}".format(i), starting_hand=[Piece(PieceType.ERROR)] * 13, player_id=i)]
        self.round_number = 1
        self.new_board()

        self.match_ready = True
    
    def run(self):
        settings = GameSettings()
        settings.seat_controllers = self.ai_list
        settings.seed = numpy.random.randint(0, 2147483647)
        #print('SEED:', settings.seed)
        if self.current_board is not None:
            wall = list(map(Piece, self.current_board.wall+self.current_board.deadwall))
            settings.override_wall = wall
        self.game_id = start_game(settings, True)
        while self.match_alive:
            self.on_update()
        #print('Match Halting...')
        #if not self.encountered_end_game:
        halt_game(self.game_id)

    def start_next_round(self):
        self.scores = list(map(lambda x: x[0] + x[1], zip(self.scores, self.delta_scores)))
        self.delta_scores = [0] * 4
        if self.player_manager.seat_wind != self.player_manager.next_round_seat:
            self.round_number += 1
            self.east_prevalent = not (self.east_prevalent and self.player_manager.prevalent_wind != Wind.East)
        self.player_manager.reset()
        self.player_manager.next_round()
        self.new_board()
        for i in range(4):
            if i != self.player_manager.player_id:
                self.players[i].reset()
                self.players[i].hand = [Piece(PieceType.ERROR)] * 13
        self.game_manager.board_manager.did_exhaustive_draw = True
        self.game_manager.board_manager.round_should_end = False

    def on_update(self):
        if not self.match_ready and self.player_manager.player_id is None:
            return
        if not self.match_ready and self.player_manager.player_id is not None:
            #print("Player ID Found, bootstrapping...")
            self.bootstrap_match()

        with self.process_lock:
            self.process_lock.wait_for(lambda: self.player_manager.GetQueueLength() > 0, timeout=500)
            with self.match_lock:
                process_event_queue(self.game_manager, self)
        
        if self.game_manager.board_manager.round_should_end:
            sleep(7)
            self.start_next_round()
示例#20
0
class Configuration(ServiceMixin):

    # states
    STATE_SYNC = 0
    STATE_GOGO = 1

    def __init__(
            self,
            control_pipe,  # control pipe for shutting down service
            local_ep,  # this is us
            local_uuid,  # this is also us
            config_svc,  # obviously, this must be None
            parent_ep,  # from where we receive config updates/snapshots
            level,
            group,
            sos_func,  # call when our parent stops sending HUGZ
            config_state=None,  # should only be given for root level
    ):
        ServiceMixin.__init__(self, control_pipe, local_ep, local_uuid, config_svc)
        assert config_svc is None

        assert isinstance(parent_ep, (EndpntSpec, type(None)))
        self.parent = parent_ep

        assert level in ['root', 'branch', 'leaf']
        self.level = level

        self.group = group

        if self.level in ['branch', 'leaf']:
            assert isinstance(sos_func, MethodType)
        else:
            assert sos_func is None
        self.sos = sos_func

        # Access to __kvdict, __kv_seq, and __tree are protected by __kvlock. There are multiple
        # writers (both the Management and Configuration services can write at the same time) and
        # multiple readers (all services, snapshots requests, etc).
        #
        # When reading, the Lock only needs to be held while reading multiple values (i.e.
        # iterating across the entire dictionary or walking the tree). Getting individual elements
        # should not require the Lock.

        # { key : ( value, seq-num ) }
        self.__kvlock = RLock()  # reentrant lock because I'm lazy
        self.__kvdict = {}
        self.__kv_seq = -1

        # 1) tree starts empty
        # 2) as config receives /TOPO keys, add nodes to tree as topo-node
        # tree.nodes contains { EndpntSpec: TopoNode }
        # TopoNode contains (endpoint, role, group, parent, children, last-seen)
        # topo keys come from tree.get_topo_key(node)

        # topo changes SUB'ed from parent are mimicked in __tree and PUB'ed to children; topo
        # changes by local services are applied to tree and then all generated key-value
        # updates are PUB'ed to children all at once.
        self.__tree = TopoTreeMixin()

        ### Common Members

        # external callers wait on this condition until state==GOGO
        self.state_cond = Condition()
        self.state = Configuration.STATE_SYNC

        ### Root/Branch Members

        # sending hearbeats to children
        self.hug_msg = None
        self.next_hug = None

        ### Branch/Leaf Members

        # detecting parent heartbeats
        self.next_sos = None

        # receiving snapshots
        # { topic : final-seq-num }
        self.kvsync_completed = {}
        self.pending_updates = []

        # TODO: fix timing issue w.r.t. root node sync--root needs to be sync'ed before any other node.

        # receiving updates
        self.topics = []
        self.topics.append('/TOPO/root')  # must be first item sync'ed for all nodes
        if 'leaf' == self.level:
            assert self.group is not None
            self.topics.append('/CONFIG/global/')
            self.topics.append('/CONFIG/%s/' % self.group)
        elif 'branch' == self.level:
            self.topics.append('/')  # collectors get everything

        ### let's get it started

        # sockets and message counts
        (self.update_sub, self.update_pub, self.kvsync_req, self.kvsync_rep) = (None, None, None, None)
        (self.subcnt, self.pubcnt, self.hugcnt, self.reqcnt, self.repcnt) = (0, 0, 0, 0, 0)

        # populate dict with config_file values if root level; values are used later
        if 'root' == self.level:
            if isinstance(config_state, str):
                self.__init_kvdict_from_file(config_state)
            else:
                assert isinstance(config_state, dict)
                self.__init_kvdict_from_dict(config_state)

            # set GOGO state; this basically means we have all the config values
            self.__set_gogo()
            self.__init_producer_sockets()
        else:
            self.__init_consumer_sockets()

    #####
    #  BEGIN dictionary access methods

    def copy_kvdict(self):
        assert 'root' == self.level
        with self.__kvlock:
            return self.__kvdict.copy()  # shallow copy; kvdict values should not be modified
                                         # after this call

    def get(self, k, default=None):
        """ returns (value, seq-id) """
        return self.__kvdict.get(k, (default, -1))

    def __getitem__(self, k):
        (val, seq) = self.__kvdict[k]
        return val

    def __setitem__(self, k, v):
        assert 'root' == self.level, "only root level allowed to make modifications"
        self.__kvlist_store_and_pub([(k, v)])  # add to our dict and publish update

    def __delitem__(self, k):
        assert 'root' == self.level, "only root level allowed to make modifications"
        self.__kvlist_store_and_pub([(k, None)])  # remove from our dict and publish update

    def __len__(self):
        return len(self.__kvdict)

    def __iter__(self):
        return iter(self.__kvdict)

    def __kvlist_store_and_pub(self, kvlist, ignore_seq=False, skip_topo=False):
        pub_kvlist = []
        with self.__kvlock:
            for item in kvlist:
                # for convenience, list contains CONFIG messages or 3-tuples
                if isinstance(item, config.CONFIG):
                    (k, v, seq) = (item.key, item.value, item.sequence)
                else:
                    assert isinstance(item, tuple)
                    if 3 == len(item):
                        (k, v, seq) = item
                    elif 2 == len(item):
                        (k, v) = item
                        seq = None
                    else:
                        raise NotImplementedError('unknown tuple length')

                if seq is None:
                    assert not ignore_seq
                    seq = self.__kv_seq + 1

                # write to dict
                if self.__kv_write(k, v, seq, ignore_seq):
                    # if successful, pub later
                    pub_kvlist.append((k, v, seq))

        for (k, v, seq) in pub_kvlist:
            # pass all topo updates to tree; if update is actually coming from
            # tree, skip_topo should be True
            if not skip_topo and k.startswith('/TOPO'):
                self.__tree.kv_update(k, v)

            # wait until finished with sync state before sending updates
            if self._is_gogo():
                assert self.update_pub is not None
                update = config.KVPUB(k, v, seq)
                update.send(self.update_pub)
                self.__hb_sent()
                self.pubcnt += 1

    def __kv_write(self, key, value, sequence, ignore_seq):
        """ N.B.: self.__kvlock MUST be held when calling __kv_write() """

        if ignore_seq:
            # during kvsync, we allow out of sequence updates; otherwise,
            assert self._is_sync()
        else:
            # if not greater than current kv-sequence, skip this one
            if sequence <= self.__kv_seq:
                # TODO: trigger a kvsync if sequence != kv_seq + 1; leaf nodes do not get every
                #       update, so that won't work...
                self.logger.warn('kv write out of sequence (cur=%d, recvd=%d); dropping' % (
                    self.__kv_seq, sequence))
                return False

            # always set seq-num if not ignore_seq
            self.__kv_seq = sequence

        # set/delete given key-value pair
        self.__kvdict[key] = (value, sequence)
        if value is None:
            del self.__kvdict[key]

        return True

    # END dictionary access methods
    #####

    #####
    # BEGIN topo tree access methods

    def topo_get_size(self):
        return len(self.__tree)

    # root access

    def topo_get_root(self):
        return self.__tree.root()

    def topo_set_root(self, ep, uuid):
        with self.__kvlock:
            kvlist = self.__tree.insert_root(ep, uuid)
            self.__kvlist_store_and_pub(kvlist, skip_topo=True)

    # branch/collector access

    def topo_group_size(self, group):
        c = self.topo_get_collector(group)
        return c is not None and len(c.children) or 0

    def topo_get_collector(self, group):
        return self.__tree.get_collector(group)

    def topo_get_all_collectors(self):
        cs = []
        for g in self.config_get_groups():
            c = self.__tree.get_collector(g)
            if c is not None:
                cs.append(c)
        return cs

    def topo_del_branch(self, collector):
        with self.__kvlock:
            kvlist = self.__tree.remove_collector(collector)
            self.__kvlist_store_and_pub(kvlist, skip_topo=True)

    # leaf/node access

    def topo_get_node(self, endpoint):
        return self.__tree.get_node(endpoint)

    def topo_set_node(self, node):
        with self.__kvlock:
            kvlist = self.__tree.update_node(node)
            self.__kvlist_store_and_pub(kvlist, skip_topo=True)

    def topo_insert_endpoint(self, ep, uuid, level, group, parent):
        node = TopoNode(ep, uuid, level, group)
        node.touch()
        with self.__kvlock:
            kvlist = self.__tree.insert_node(node, parent)
            self.__kvlist_store_and_pub(kvlist, skip_topo=True)
        return node

    def topo_touch_node(self, node):
        with self.__kvlock:
            kvlist = self.__tree.touch_node(node)
            self.__kvlist_store_and_pub(kvlist, skip_topo=True)

    # END topo tree access methods
    #####

    def __set_gogo(self):
        self.state_cond.acquire()
        self.state = Configuration.STATE_GOGO
        self.state_cond.notify_all()
        self.state_cond.release()

    def _is_gogo(self):
        return Configuration.STATE_GOGO == self.state

    def _is_sync(self):
        return Configuration.STATE_SYNC == self.state

    def wait_for_gogo(self):
        self.state_cond.acquire()
        self.state_cond.wait_for(self._is_gogo)
        self.state_cond.release()

    #####
    # BEGIN config access methods

    def config_get_hb_int(self):
        if self._is_gogo():
            return self['/CONFIG/global/heartbeat']
        return 60  # default hb interval until init is complete

    def config_get_metric_specs(self, group=None):
        assert self._is_gogo()

        # root needs all metrics
        if self.level == 'root':
            s = -1
            m = set()  # returning unique set of metrics from all groups

            for g in self.config_get_groups():
                (gm, gs) = self.get('/CONFIG/%s/metrics' % g, [])
                s = max(gs, s)
                m.update(gm)

            return m, s

        # otherwise, just return this group's metrics
        if group is None:
            group = self.group

        # return tuple = (spec-list, seq-id) or (None, -1)
        return self.get('/CONFIG/%s/metrics' % group, [])

    def config_get_endpoints(self, group=None):
        assert self._is_gogo()

        if group is None:
            group = self.group

        # return ep-list or []
        try:
            return self['/CONFIG/%s/endpoints' % group]
        except KeyError:
            return []

    def config_get_groups(self):
        regex = re.compile('/CONFIG/(\w+)/endpoints')
        gs = set()
        with self.__kvlock:
            for key in self.__kvdict:
                m = regex.match(key)
                if m is not None:
                    # every group has three keys in the dict; lazily find the unique
                    # groups names by using a set
                    gs.add(m.group(1))
        return gs

    # END config access methods
    #####

    def __init_kvdict_from_file(self, config_file):
        assert 'root' == self.level
        assert config_file is not None
        cfg = ConfigFileMixin()
        cfg.read_file(open(config_file))
        for (k, v) in cfg.kvdict.items():
            assert isinstance(k, str)
            self.logger.debug('INIT: {}: {}'.format(k, v))
        self.__kvlist_store_and_pub(cfg.kvdict.items())
        self.logger.debug('INIT: final kv-seq = {}'.format(self.__kv_seq))

    def __init_kvdict_from_dict(self, cfg):
        assert 'root' == self.level
        for (k, (v, seq)) in cfg.items():
            assert isinstance(k, str)
            self.logger.debug('INIT: [{}] {}: {}'.format(seq, k, v))
        self.__kvlist_store_and_pub([(k, v, seq) for k, (v, seq) in cfg.items()])
        self.logger.debug('INIT: final kv-seq = {}'.format(self.__kv_seq))

    def __init_consumer_sockets(self):
        assert self.level in ['branch', 'leaf']
        assert self.parent is not None
        assert len(self.topics) > 0

        # 1) subscribe to udpates from parent
        self.update_sub = self.ctx.socket(SUB)
        for t in self.topics:
            self.update_sub.setsockopt_string(SUBSCRIBE, t)
        self.update_sub.connect(self.parent.connect_uri(EndpntSpec.CONFIG_UPDATE))

        self.poller.register(self.update_sub, POLLIN)

        # 2) request snapshot(s) from parent
        self.kvsync_req = self.ctx.socket(DEALER)
        self.kvsync_req.connect(self.parent.connect_uri(EndpntSpec.CONFIG_SNAPSHOT))
        for t in self.topics:
            icanhaz = config.ICANHAZ(t)
            icanhaz.send(self.kvsync_req)

        self.poller.register(self.kvsync_req, POLLIN)

    def __init_producer_sockets(self):
        assert self.level in ['root', 'branch']
        assert self._is_gogo()

        # 3) publish updates to children (bind)
        self.update_pub = self.ctx.socket(PUB)
        self.update_pub.bind(self.endpoint.bind_uri(EndpntSpec.CONFIG_UPDATE))

        if 'branch' == self.level:
            assert self.group is not None
            t = '/CONFIG/%s' % self.group
        elif 'root' == self.level:
            t = '/TOPO'
        else:
            raise NotImplementedError('unknown level: %s' % self.level)

        self.hug_msg = config.HUGZ(t)
        self.__hb_sent()  # start the hb timer

        # 4) service snapshot requests to children (bind)
        self.kvsync_rep = self.ctx.socket(ROUTER)
        self.kvsync_rep.bind(self.endpoint.bind_uri(EndpntSpec.CONFIG_SNAPSHOT))
        self.poller.register(self.kvsync_rep, POLLIN)

        # process pending updates; this will trigger kvpub updates for each message processed
        self.__kvlist_store_and_pub(self.pending_updates)
        del self.pending_updates

    def _cleanup(self):
        # service exiting; return some status info and cleanup
        self.logger.debug(
            "%d subs; %d pubs; %d hugz; %d reqs; %d reps" %
            (self.subcnt, self.pubcnt, self.hugcnt, self.reqcnt, self.repcnt))

        # print each key-value pair; value is really (value, seq-num)
        self.logger.debug('kv-seq: %d' % self.__kv_seq)
        width = len(str(self.__kv_seq))
        with self.__kvlock:
            for (k, (v, s)) in sorted(self.__kvdict.items()):
                self.logger.debug('({0:0{width}d}) {1}: {2}'.format(s, k, v, width=width))

        self.__tree.print()

        if self.update_sub is not None:
            self.update_sub.close()
        del self.update_sub

        if self.update_pub is not None:
            self.update_pub.close()
        del self.update_pub

        if self.kvsync_req is not None:
            self.kvsync_req.close()
        del self.kvsync_req

        if self.kvsync_rep is not None:
            self.kvsync_rep.close()
        del self.kvsync_rep

        ServiceMixin._cleanup(self)

    def _pre_poll(self):
        # wait until finished with sync state before sending anything
        if not self._is_gogo():
            self.poller_timer = None
            return

        if self.level in ['root', 'branch']:
            assert self.next_hug is not None
            if self.next_hug <= now_secs():
                self.__send_hug()

        if self.level in ['branch', 'leaf']:
            assert self.next_sos is not None
            if self.next_sos <= now_secs():
                self.sos()
                self.__hb_received()  # reset hb monitor so we don't flood the system with sos

        self.poller_timer = self.__get_next_wakeup()

    def _post_poll(self, items):
        if self.update_sub in items:
            self.__recv_update()
        if self.kvsync_req in items:
            self.__recv_snapshot()
        if self.kvsync_rep in items:
            self.__send_snapshot()

    def __send_hug(self):
        assert self.level in ['root', 'branch']
        assert self._is_gogo()
        assert self.update_pub is not None
        assert self.hug_msg is not None

        self.hug_msg.send(self.update_pub)
        self.__hb_sent()
        self.hugcnt += 1

    def __get_next_wakeup(self):
        """ @returns next wakeup time (as msecs delta) """
        assert self._is_gogo()

        next_wakeup = None
        next_hug_wakeup = None
        next_sos_wakeup = None

        if self.level in ['root', 'branch']:
            assert self.next_hug is not None  # initialized by __init_producer_sockets()
            # next_hug is in secs; subtract current msecs to get next wakeup
            next_hug_wakeup = (self.next_hug * 1e3) - now_msecs()
            next_wakeup = next_hug_wakeup

        if self.level in ['branch', 'leaf']:
            assert self.next_sos is not None  # initialized by __recv_snapshot() / _recv_update()
            # next_sos is in secs; subtract current msecs to get next wakeup
            next_sos_wakeup = (self.next_sos * 1e3) - now_msecs()
            next_wakeup = next_sos_wakeup

        if 'branch' == self.level:
            assert next_hug_wakeup is not None
            assert next_sos_wakeup is not None
            # pick the sooner of the two wakeups
            next_wakeup = min(next_hug_wakeup, next_sos_wakeup)

        assert next_wakeup is not None
        # make sure it's not negative
        val = max(0, next_wakeup)

        self.logger.debug('next wakeup in %dms' % val)

        return val

    def __hb_sent(self):
        # reset hugz using last pub time
        self.next_hug = now_secs() + self.config_get_hb_int()

    def __hb_received(self):
        # reset sos using last hb time
        self.next_sos = now_secs() + (self.config_get_hb_int() * 5)

    def __recv_update(self):
        update = config.CONFIG.recv(self.update_sub)
        self.__hb_received()  # any message from parent is considered a heartbeat
        self.subcnt += 1

        if update.is_error:
            self.logger.error('received error message from parent: %s' % update)
            return

        if update.is_hugz:
            # already noted above. moving on...
            self.logger.debug('received hug.')
            return

        if self._is_sync():
            # another solution is to just not read the message; let them queue
            # up on the socket itself...but that relies on the HWM of the socket
            # being set high enough to account for all messages received while
            # in the SYNC state. this approach guarantees no updates are lost.
            self.pending_updates.append(update)
        elif self._is_gogo():
            self.__kvlist_store_and_pub([update])
        else:
            raise NotImplementedError('unknown state')

    def __recv_snapshot(self):
        assert self.level in ['branch', 'leaf']
        assert self._is_sync()

        # should either be KVSYNC or KTHXBAI
        response = config.CONFIG.recv(self.kvsync_req)
        self.__hb_received()  # any message from parent is considered a heartbeat
        self.repcnt += 1

        if response.is_error:
            self.logger.error(response)
            return

        if isinstance(response, config.KTHXBAI):
            if response.value not in self.topics:
                self.logger.error('received KTHXBAI of unexpected subtree: %s' % response.value)
                return

            # add given subtree to completed list; return if still waiting for other
            # subtree kvsync sessions
            self.kvsync_completed[response.value] = response.sequence
            if len(self.kvsync_completed) != len(self.topics):
                return

            self.__kv_seq = max(self.kvsync_completed.values())
            del self.kvsync_completed

            # set GOGO state; this basically means we have all the config values
            self.__set_gogo()

            if 'branch' == self.level:
                self.__init_producer_sockets()

        else:
            self.__kvlist_store_and_pub([response], ignore_seq=True)

    def __send_snapshot(self):
        assert self._is_gogo()

        request = config.CONFIG.recv(self.kvsync_rep)
        self.reqcnt += 1

        if request.is_error:
            self.logger.error(request)
            return

        peer_id = request.peer_id
        subtree = request.value  # subtree stored as value in ICANHAZ message

        # send all the key-value pairs in our dict
        max_seq = -1
        with self.__kvlock:
            for (k, (v, s)) in self.__kvdict.items():
                # skip keys not in the requested subtree
                if not k.startswith(subtree):
                    continue
                max_seq = max([max_seq, s])
                snap = config.KVSYNC(k, v, s, peer_id)
                snap.send(self.kvsync_rep)

        # send final message, closing the kvsync session
        snap = config.KTHXBAI(max_seq, peer_id, subtree)
        snap.send(self.kvsync_rep)
示例#21
0
class Connection:
    def __init__(self, server_host: str, server_port: int):
        self.server_host = server_host
        self.server_port = server_port
        self.response_map = dict()
        self.response_map_cv = Condition()
        self.connection_state = ConnectionState.NEW
        self.connection_state_cv = Condition()

    def connect(self, timeout=10):
        """
        Attempt to connect to the configured server.
        """

        if self.connection_state == ConnectionState.NEW:
            log.info(
                f"Attempting connection to {self.server_host}:{self.server_port}"
            )

            Thread(target=self.handle_connection).start()

            with self.connection_state_cv:
                if not self.connection_state_cv.wait_for(
                        lambda: self.connection_state != ConnectionState.
                        CONNECTED, timeout):
                    return False

                return self.connection_state == ConnectionState.CONNECTED
        else:
            raise Error("Illegal connection state")

    def handle_connection(self):
        """
        Run connection handling routine.
        """

        context = SSLContext(PROTOCOL_TLS_CLIENT)
        context.check_hostname = False
        context.verify_mode = CERT_NONE

        with create_connection(
            (self.server_host, self.server_port)) as self.socket:
            with context.wrap_socket(
                    self.socket, server_hostname=self.server_host) as self.tls:

                # Use this buffer for all reads
                read_buffer = bytearray(4096)

                # Perform SID handshake
                rq = RQ_Cvid()
                rq.uuid = os.environ["S7S_UUID"]
                rq.instance = InstanceType.CLIENT
                rq.instance_flavor = InstanceFlavor.CLIENT_BRIGHTSTONE

                msg = MSG()
                setattr(msg, "payload", rq.SerializeToString())
                setattr(msg, "id", randint(0, 65535))

                self.tls.write(_VarintBytes(msg.ByteSize()))
                self.tls.sendall(msg.SerializeToString())

                read = self.tls.recv_into(read_buffer, 4096)
                if read == 0:
                    raise EOFError("")

                msg_len, msg_start = _DecodeVarint32(read_buffer, 0)

                msg = MSG()
                msg.ParseFromString(read_buffer[msg_start:msg_start + msg_len])

                rs = RS_Cvid()
                rs.ParseFromString(msg.payload)

                self.server_cvid = rs.server_cvid
                self.sid = rs.sid

                # The connection is now connected
                with self.connection_state_cv:
                    self.connection_state = ConnectionState.CONNECTED
                    self.connection_state_cv.notify()

                # Begin accepting messages
                while True:

                    # TODO there may be another message in the read buffer

                    # Read from the socket
                    read = self.tls.recv_into(read_buffer, 4096)
                    if read == 0:
                        raise EOFError("")

                    n = 0
                    while n < read:
                        msg_len, n = _DecodeVarint32(read_buffer, n)

                        msg = MSG()
                        msg.ParseFromString(read_buffer[n:n + msg_len])

                        # Place message in response map
                        with self.response_map_cv:
                            self.response_map[msg.id] = msg
                            self.response_map_cv.notify()

        # The connection is now closed
        with self.connection_state_cv:
            self.connection_state = ConnectionState.CLOSED
            self.connection_state_cv.notify()

    def request(self, rq, to=0, timeout=10):
        """
        Send the given request and return the response it provoked.
        """

        if self.connection_state != ConnectionState.CONNECTED:
            raise Error("Illegal connection state")

        msg = MSG()
        setattr(msg, "payload", rq.SerializeToString())
        setattr(msg, "id", randint(0, 65535))
        if to == 0:
            setattr(msg, "to", self.server_cvid)
        else:
            setattr(msg, "to", to)
        setattr(msg, "from", self.sid)

        # Write the message
        self.tls.write(_VarintBytes(msg.ByteSize()))
        self.tls.sendall(msg.SerializeToString())

        # Wait for response
        with self.response_map_cv:
            if not self.response_map_cv.wait_for(
                    lambda: msg.id in self.response_map.keys(), timeout):
                return None

            # Return the response message
            rs = self.response_map[msg.id]
            del self.response_map[msg.id]
            return rs
示例#22
0
class ProcessorIteratorCollection(object):
    """Contains all of the registered (added via __setitem__)
    transaction processors in a _processors (dict) where the keys
    are ProcessorTypes and the values are ProcessorIterators.
    """

    def __init__(self, processor_iterator_class):
        # bytes: list of ProcessorType
        self._identities = {}
        # ProcessorType: ProcessorIterator
        self._processors = {}
        self._proc_iter_class = processor_iterator_class
        self._condition = Condition()

    def __getitem__(self, item):
        """Get a particular ProcessorIterator

        Args:
            item (ProcessorType): The processor type key.
        """
        with self._condition:
            return self._processors[item]

    def __contains__(self, item):
        with self._condition:
            return item in self._processors

    def get_next_of_type(self, processor_type):
        """Get the next processor of a particular type

        Args:
            processor_type (ProcessorType): The processor type associated with
                a zmq identity.

        Returns:
            (Processor): Information about the transaction processor
        """
        with self._condition:
            if processor_type in self:
                return self[processor_type].next_processor()
            return None

    def get_all_processors(self):
        processors = []
        for processor in self._processors.values():
            processors += processor.processor_identities()
        return processors

    def __setitem__(self, key, value):
        """Either create a new ProcessorIterator, if none exists for a
        ProcessorType, or add the Processor to the ProcessorIterator.

        Args:
            key (ProcessorType): The type of transactions this transaction
                processor can handle.
            value (Processor): Information about the transaction processor.
        """
        with self._condition:
            if key not in self._processors:
                proc_iterator = self._proc_iter_class()
                proc_iterator.add_processor(value)
                self._processors[key] = proc_iterator
            else:
                self._processors[key].add_processor(value)
            if value.connection_id not in self._identities:
                self._identities[value.connection_id] = [key]
            else:
                self._identities[value.connection_id].append(key)
            self._condition.notify_all()

    def remove(self, processor_identity):
        """Removes all of the Processors for
        a particular transaction processor zeromq identity.

        Args:
            processor_identity (str): The zeromq identity of the transaction
                processor.
        """
        with self._condition:
            processor_types = self._identities.get(processor_identity)
            if processor_types is None:
                LOGGER.warning("transaction processor with identity %s tried "
                               "to unregister but was not registered",
                               processor_identity)
                return
            for processor_type in processor_types:
                if processor_type not in self._processors:
                    LOGGER.warning("processor type %s not a known processor "
                                   "type but is associated with identity %s",
                                   processor_type,
                                   processor_identity)
                    continue
                self._processors[processor_type].remove_processor(
                    processor_identity=processor_identity)
                if not self._processors[processor_type]:
                    del self._processors[processor_type]

    def __repr__(self):
        return ",".join([repr(k) for k in self._processors])

    def cancellable_wait(self, processor_type, cancelled_event):
        """Waits for a particular processor type to register or until
        is_cancelled is True. is_cancelled cannot be part of this class
        since we aren't cancelling all waiting for a processor_type,
        but just this particular wait.

        Args:
            processor_type (ProcessorType): The family, and version of
                the transaction processor.
            cancelled_event (threading.Event): is_set() will return True when
                the wait is cancelled.

        Returns:
            None
        """
        with self._condition:
            self._condition.wait_for(
                lambda: processor_type in self or cancelled_event.is_set()
            )

    def notify(self):
        """Must be called after setting the cancelled_event, when
        cancelling a wait.
        """
        with self._condition:
            self._condition.notify_all()
class ParallelScheduler(Scheduler):
    def __init__(self, squash_handler, first_state_hash, always_persist):
        self._squash = squash_handler
        self._first_state_hash = first_state_hash
        self._last_state_hash = first_state_hash
        self._condition = Condition()
        self._predecessor_tree = PredecessorTree()
        self._txn_predecessors = {}

        self._always_persist = always_persist

        self._predecessor_chain = PredecessorChain()

        # Transaction identifiers which have been scheduled.  Stored as a list,
        # since order is important; SchedulerIterator instances, for example,
        # must all return scheduled transactions in the same order.
        self._scheduled = []

        # Transactions that must be replayed but the prior result hasn't
        # been returned yet.
        self._outstanding = set()

        # Batch id for the batch with the property that the batch doesn't have
        # all txn results, and all batches prior to it have all their txn
        # results.
        self._least_batch_id_wo_results = None

        # A dict of transaction id to TxnInformation objects, containing all
        # transactions present in self._scheduled.
        self._scheduled_txn_info = {}

        # All batches in their natural order (the order they were added to
        # the scheduler.
        self._batches = []
        # The batches that have state hashes added in add_batch, used in
        # Block validation.
        self._batches_with_state_hash = {}

        # Indexes to find a batch quickly
        self._batches_by_id = {}
        self._batches_by_txn_id = {}

        # Transaction results
        self._txn_results = {}

        self._txns_available = OrderedDict()
        self._transactions = {}

        self._cancelled = False
        self._final = False

    def _find_input_dependencies(self, inputs):
        """Use the predecessor tree to find dependencies based on inputs.

        Returns: A list of transaction ids.
        """
        dependencies = []
        for address in inputs:
            dependencies.extend(
                self._predecessor_tree.find_read_predecessors(address))
        return dependencies

    def _find_output_dependencies(self, outputs):
        """Use the predecessor tree to find dependencies based on outputs.

        Returns: A list of transaction ids.
        """
        dependencies = []
        for address in outputs:
            dependencies.extend(
                self._predecessor_tree.find_write_predecessors(address))
        return dependencies

    def add_batch(self, batch, state_hash=None, required=False):
        with self._condition:
            if self._final:
                raise SchedulerError('Invalid attempt to add batch to '
                                     'finalized scheduler; batch: {}'.format(
                                         batch.header_signature))
            if not self._batches:
                self._least_batch_id_wo_results = batch.header_signature

            preserve = required
            if not required:
                # If this is the first non-required batch, it is preserved for
                # the schedule to be completed (i.e. no empty schedules in the
                # event of unschedule_incomplete_batches being called before
                # the first batch is completed).
                preserve = _first(
                    filterfalse(lambda sb: sb.required,
                                self._batches_by_id.values())) is None

            self._batches.append(batch)
            self._batches_by_id[batch.header_signature] = \
                _AnnotatedBatch(batch, required=required, preserve=preserve)
            for txn in batch.transactions:
                self._batches_by_txn_id[txn.header_signature] = batch
                self._txns_available[txn.header_signature] = txn
                self._transactions[txn.header_signature] = txn

            if state_hash is not None:
                b_id = batch.header_signature
                self._batches_with_state_hash[b_id] = state_hash

            # For dependency handling: First, we determine our dependencies
            # based on the current state of the predecessor tree.  Second,
            # we update the predecessor tree with reader and writer
            # information based on input and outputs.
            for txn in batch.transactions:
                header = TransactionHeader()
                header.ParseFromString(txn.header)

                # Calculate predecessors (transaction ids which must come
                # prior to the current transaction).
                predecessors = self._find_input_dependencies(header.inputs)
                predecessors.extend(
                    self._find_output_dependencies(header.outputs))

                txn_id = txn.header_signature
                # Update our internal state with the computed predecessors.
                self._txn_predecessors[txn_id] = set(predecessors)
                self._predecessor_chain.add_relationship(
                    txn_id=txn_id, predecessors=predecessors)

                # Update the predecessor tree.
                #
                # Order of reader/writer operations is relevant.  A writer
                # may overshadow a reader.  For example, if the transaction
                # has the same input/output address, the end result will be
                # this writer (txn.header_signature) stored at the address of
                # the predecessor tree.  The reader information will have been
                # discarded.  Write operations to partial addresses will also
                # overshadow entire parts of the predecessor tree.
                #
                # Thus, the order here (inputs then outputs) will cause the
                # minimal amount of relevant information to be stored in the
                # predecessor tree, with duplicate information being
                # automatically discarded by the set_writer() call.
                for address in header.inputs:
                    self._predecessor_tree.add_reader(address, txn_id)
                for address in header.outputs:
                    self._predecessor_tree.set_writer(address, txn_id)

            self._condition.notify_all()

    def _is_explicit_request_for_state_root(self, batch_signature):
        return batch_signature in self._batches_with_state_hash

    def _is_implicit_request_for_state_root(self, batch_signature):
        return self._final and self._is_last_valid_batch(batch_signature)

    def _is_valid_batch(self, batch):
        for txn in batch.transactions:
            if txn.header_signature not in self._txn_results:
                raise _UnscheduledTransactionError()

            result = self._txn_results[txn.header_signature]
            if not result.is_valid:
                return False
        return True

    def _is_last_valid_batch(self, batch_signature):
        batch = self._batches_by_id[batch_signature].batch
        if not self._is_valid_batch(batch):
            return False
        index_of_next = self._batches.index(batch) + 1
        for later_batch in self._batches[index_of_next:]:
            if self._is_valid_batch(later_batch):
                return False
        return True

    def _get_contexts_for_squash(self, batch_signature):
        """Starting with the batch referenced by batch_signature, iterate back
        through the batches and for each valid batch collect the context_id.
        At the end remove contexts for txns that are other txn's predecessors.

        Args:
            batch_signature (str): The batch to start from, moving back through
                the batches in the scheduler

        Returns:
            (list): Context ids that haven't been previous base contexts.
        """

        batch = self._batches_by_id[batch_signature].batch
        index = self._batches.index(batch)
        contexts = []
        txns_added_predecessors = []
        for b in self._batches[index::-1]:
            batch_is_valid = True
            contexts_from_batch = []
            for txn in b.transactions[::-1]:
                result = self._txn_results[txn.header_signature]
                if not result.is_valid:
                    batch_is_valid = False
                    break
                else:
                    txn_id = txn.header_signature
                    if txn_id not in txns_added_predecessors:
                        txns_added_predecessors.append(
                            self._txn_predecessors[txn_id])
                        contexts_from_batch.append(result.context_id)
            if batch_is_valid:
                contexts.extend(contexts_from_batch)

        return contexts

    def _is_state_hash_correct(self, state_hash, batch_id):
        return state_hash == self._batches_with_state_hash[batch_id]

    def get_batch_execution_result(self, batch_signature):
        with self._condition:
            # This method calculates the BatchExecutionResult on the fly,
            # where only the TxnExecutionResults are cached, instead
            # of BatchExecutionResults, as in the SerialScheduler
            if batch_signature not in self._batches_by_id:
                return None

            batch = self._batches_by_id[batch_signature].batch

            if not self._is_valid_batch(batch):
                return BatchExecutionResult(is_valid=False, state_hash=None)

            state_hash = None
            try:
                if self._is_explicit_request_for_state_root(batch_signature):
                    contexts = self._get_contexts_for_squash(batch_signature)
                    state_hash = self._squash(self._first_state_hash,
                                              contexts,
                                              persist=False,
                                              clean_up=False)
                    if self._is_state_hash_correct(state_hash,
                                                   batch_signature):
                        self._squash(self._first_state_hash,
                                     contexts,
                                     persist=True,
                                     clean_up=True)
                    else:
                        self._squash(self._first_state_hash,
                                     contexts,
                                     persist=False,
                                     clean_up=True)
                elif self._is_implicit_request_for_state_root(batch_signature):
                    contexts = self._get_contexts_for_squash(batch_signature)
                    state_hash = self._squash(self._first_state_hash,
                                              contexts,
                                              persist=self._always_persist,
                                              clean_up=True)
            except _UnscheduledTransactionError:
                return None

            return BatchExecutionResult(is_valid=True, state_hash=state_hash)

    def get_transaction_execution_results(self, batch_signature):

        with self._condition:
            annotated_batch = self._batches_by_id.get(batch_signature)
            if annotated_batch is None:
                return None

            results = []
            for txn in annotated_batch.batch.transactions:
                result = self._txn_results.get(txn.header_signature)
                if result is not None:
                    results.append(result)
            return results

    def _is_predecessor_of_possible_successor(self, txn_id,
                                              possible_successor):
        return self._predecessor_chain.is_predecessor_of_other(
            txn_id, [possible_successor])

    def _txn_has_result(self, txn_id):
        return txn_id in self._txn_results

    def _is_in_same_batch(self, txn_id_1, txn_id_2):
        return self._batches_by_txn_id[txn_id_1] == \
            self._batches_by_txn_id[txn_id_2]

    def _is_txn_to_replay(self, txn_id, possible_successor, already_seen):
        """Decide if possible_successor should be replayed.

        Args:
            txn_id (str): Id of txn in failed batch.
            possible_successor (str): Id of txn to possibly replay.
            already_seen (list): A list of possible_successors that have
                been replayed.

        Returns:
            (bool): If the possible_successor should be replayed.
        """

        is_successor = self._is_predecessor_of_possible_successor(
            txn_id, possible_successor)
        in_different_batch = not self._is_in_same_batch(
            txn_id, possible_successor)
        has_not_been_seen = possible_successor not in already_seen

        return is_successor and in_different_batch and has_not_been_seen

    def _remove_subsequent_result_because_of_batch_failure(self, sig):
        """Remove transactions from scheduled and txn_results for
        successors of txns in a failed batch. These transactions will now,
        or in the future be rescheduled in next_transaction; giving a
        replay ability.

        Args:
            sig (str): Transaction header signature

        """

        batch = self._batches_by_txn_id[sig]
        seen = []
        for txn in batch.transactions:
            txn_id = txn.header_signature
            for poss_successor in self._scheduled.copy():
                if not self.is_transaction_in_schedule(poss_successor):
                    continue

                if self._is_txn_to_replay(txn_id, poss_successor, seen):
                    if self._txn_has_result(poss_successor):
                        del self._txn_results[poss_successor]
                        self._scheduled.remove(poss_successor)
                        self._txns_available[poss_successor] = \
                            self._transactions[poss_successor]
                    else:
                        self._outstanding.add(poss_successor)
                    seen.append(poss_successor)

    def _reschedule_if_outstanding(self, txn_signature):
        if txn_signature in self._outstanding:
            self._txns_available[txn_signature] = \
                self._transactions[txn_signature]
            self._scheduled.remove(txn_signature)
            self._outstanding.discard(txn_signature)
            return True
        return False

    def _index_of_batch(self, batch):
        batch_index = None
        try:
            batch_index = self._batches.index(batch)
        except ValueError:
            pass
        return batch_index

    def _set_least_batch_id(self, txn_signature):
        """Set the first batch id that doesn't have all results.

        Args:
            txn_signature (str): The txn identifier of the transaction with
                results being set.

        """

        batch = self._batches_by_txn_id[txn_signature]

        least_index = self._index_of_batch(
            self._batches_by_id[self._least_batch_id_wo_results].batch)

        current_index = self._index_of_batch(batch)
        all_prior = False

        if current_index <= least_index:
            return
            # Test to see if all batches from the least_batch to
            # the prior batch to the current batch have results.
        if all(
                all(t.header_signature in self._txn_results
                    for t in b.transactions)
                for b in self._batches[least_index:current_index]):
            all_prior = True
        if not all_prior:
            return
        possible_least = self._batches[current_index].header_signature
        # Find the first batch from the current batch on, that doesn't have
        # all results.
        for b in self._batches[current_index:]:
            if not all(t.header_signature in self._txn_results
                       for t in b.transactions):
                possible_least = b.header_signature
                break
        self._least_batch_id_wo_results = possible_least

    def set_transaction_execution_result(self,
                                         txn_signature,
                                         is_valid,
                                         context_id,
                                         state_changes=None,
                                         events=None,
                                         data=None,
                                         error_message="",
                                         error_data=b""):
        with self._condition:
            if txn_signature not in self._scheduled:
                raise SchedulerError(
                    "transaction not scheduled: {}".format(txn_signature))

            if txn_signature not in self._batches_by_txn_id:
                return

            self._set_least_batch_id(txn_signature=txn_signature)
            if not is_valid:
                self._remove_subsequent_result_because_of_batch_failure(
                    txn_signature)
            is_rescheduled = self._reschedule_if_outstanding(txn_signature)

            if not is_rescheduled:
                self._txn_results[txn_signature] = TxnExecutionResult(
                    signature=txn_signature,
                    is_valid=is_valid,
                    context_id=context_id if is_valid else None,
                    state_hash=self._first_state_hash if is_valid else None,
                    state_changes=state_changes,
                    events=events,
                    data=data,
                    error_message=error_message,
                    error_data=error_data)

            self._condition.notify_all()

    def _has_predecessors(self, txn_id):
        for predecessor_id in self._txn_predecessors[txn_id]:
            if predecessor_id not in self._txn_results:
                return True
            # Since get_initial_state_for_transaction gets context ids not
            # just from predecessors but also in the case of an enclosing
            # writer failing, predecessors of that predecessor, this extra
            # check is needed.
            for pre_pred_id in self._txn_predecessors[predecessor_id]:
                if pre_pred_id not in self._txn_results:
                    return True

        return False

    def _is_outstanding(self, txn_id):
        return txn_id in self._outstanding

    def _txn_is_in_valid_batch(self, txn_id):
        """Returns whether the transaction is in a valid batch.

        Args:
            txn_id (str): The transaction header signature.

        Returns:
            (bool): True if the txn's batch is valid, False otherwise.
        """

        batch = self._batches_by_txn_id[txn_id]

        # Return whether every transaction in the batch with a
        # transaction result is valid
        return all(self._txn_results[sig].is_valid
                   for sig in set(self._txn_results).intersection((
                       txn.header_signature for txn in batch.transactions)))

    def _get_initial_state_for_transaction(self, txn):
        # Collect contexts that this transaction depends upon
        # We assume that all prior txns in the batch are valid
        # or else this transaction wouldn't run. We assume that
        # the mechanism in next_transaction makes sure that each
        # predecessor txn has a result. Also any explicit
        # dependencies that could have failed this txn did so.
        contexts = []
        txn_dependencies = deque()
        txn_dependencies.extend(self._txn_predecessors[txn.header_signature])
        while txn_dependencies:
            prior_txn_id = txn_dependencies.popleft()
            if self._txn_is_in_valid_batch(prior_txn_id):
                result = self._txn_results[prior_txn_id]
                if (prior_txn_id, result.context_id) not in contexts:
                    contexts.append((prior_txn_id, result.context_id))
            else:
                txn_dependencies.extend(self._txn_predecessors[prior_txn_id])

        contexts.sort(key=lambda x: self._index_of_txn_in_schedule(x[0]),
                      reverse=True)
        return [c_id for _, c_id in contexts]

    def _index_of_txn_in_schedule(self, txn_id):
        batch = self._batches_by_txn_id[txn_id]
        index_of_batch_in_schedule = self._batches.index(batch)
        number_of_txns_in_prior_batches = 0
        for prior in self._batches[:index_of_batch_in_schedule]:
            number_of_txns_in_prior_batches += len(prior.transactions)

        txn_index, _ = next((i, t) for i, t in enumerate(batch.transactions)
                            if t.header_signature == txn_id)

        return number_of_txns_in_prior_batches + txn_index - 1

    def _can_fail_fast(self, txn_id):
        batch_id = self._batches_by_txn_id[txn_id].header_signature
        return batch_id == self._least_batch_id_wo_results

    def next_transaction(self):
        with self._condition:
            # We return the next transaction which hasn't been scheduled and
            # is not blocked by a dependency.

            next_txn = None

            no_longer_available = []

            for txn_id, txn in self._txns_available.items():
                if (self._has_predecessors(txn_id)
                        or self._is_outstanding(txn_id)):
                    continue

                header = TransactionHeader()
                header.ParseFromString(txn.header)
                deps = tuple(header.dependencies)

                if self._dependency_not_processed(deps):
                    continue

                if self._txn_failed_by_dep(deps):
                    no_longer_available.append(txn_id)
                    self._txn_results[txn_id] = \
                        TxnExecutionResult(
                            signature=txn_id,
                            is_valid=False,
                            context_id=None,
                            state_hash=None)
                    continue

                if not self._txn_is_in_valid_batch(txn_id) and \
                        self._can_fail_fast(txn_id):
                    self._txn_results[txn_id] = \
                        TxnExecutionResult(False, None, None)
                    no_longer_available.append(txn_id)
                    continue

                next_txn = txn
                break

            for txn_id in no_longer_available:
                del self._txns_available[txn_id]

            if next_txn is not None:
                bases = self._get_initial_state_for_transaction(next_txn)

                info = TxnInformation(txn=next_txn,
                                      state_hash=self._first_state_hash,
                                      base_context_ids=bases)
                self._scheduled.append(next_txn.header_signature)
                del self._txns_available[next_txn.header_signature]
                self._scheduled_txn_info[next_txn.header_signature] = info
                return info
            return None

    def _dependency_not_processed(self, deps):
        if any(not self._all_in_batch_have_results(d) for d in deps
               if d in self._batches_by_txn_id):
            return True
        return False

    def _txn_failed_by_dep(self, deps):
        if any(
                self._any_in_batch_are_invalid(d) for d in deps
                if d in self._batches_by_txn_id):
            return True
        return False

    def _all_in_batch_have_results(self, txn_id):
        batch = self._batches_by_txn_id[txn_id]
        return all(t.header_signature in self._txn_results
                   for t in list(batch.transactions))

    def _any_in_batch_are_invalid(self, txn_id):
        batch = self._batches_by_txn_id[txn_id]
        return any(not self._txn_results[t.header_signature].is_valid
                   for t in list(batch.transactions))

    def available(self):
        with self._condition:
            # We return the next transaction which hasn't been scheduled and
            # is not blocked by a dependency.

            count = 0
            for txn_id in self._txns_available:
                if not self._has_predecessors(txn_id):
                    count += 1

            return count

    def unschedule_incomplete_batches(self):
        incomplete_batches = set()
        with self._condition:
            # These transactions have never been scheduled.
            for txn_id, txn in self._txns_available.items():
                batch = self._batches_by_txn_id[txn_id]
                batch_id = batch.header_signature

                annotated_batch = self._batches_by_id[batch_id]
                if not annotated_batch.preserve:
                    incomplete_batches.add(batch_id)

            # These transactions were in flight.
            in_flight = set(self._transactions.keys()).difference(
                self._txn_results.keys())

            for txn_id in in_flight:
                batch = self._batches_by_txn_id[txn_id]
                batch_id = batch.header_signature

                annotated_batch = self._batches_by_id[batch_id]
                if not annotated_batch.preserve:
                    incomplete_batches.add(batch_id)

            # clean up the batches, including partial complete information
            for batch_id in incomplete_batches:
                annotated_batch = self._batches_by_id[batch_id]
                self._batches.remove(annotated_batch.batch)
                del self._batches_by_id[batch_id]
                for txn in annotated_batch.batch.transactions:
                    txn_id = txn.header_signature
                    del self._batches_by_txn_id[txn_id]

                    if txn_id in self._txn_results:
                        del self._txn_results[txn_id]

                    if txn_id in self._txns_available:
                        del self._txns_available[txn_id]

                    if txn_id in self._outstanding:
                        self._outstanding.remove(txn_id)

            self._condition.notify_all()

        if incomplete_batches:
            LOGGER.debug('Removed %s incomplete batches from the schedule',
                         len(incomplete_batches))

    def is_transaction_in_schedule(self, txn_signature):
        with self._condition:
            return txn_signature in self._batches_by_txn_id

    def finalize(self):
        with self._condition:
            self._final = True
            self._condition.notify_all()

    def _complete(self):
        return self._final and \
            len(self._txn_results) == len(self._batches_by_txn_id)

    def complete(self, block=True):
        with self._condition:
            if self._complete():
                return True

            if block:
                return self._condition.wait_for(self._complete)

            return False

    def __del__(self):
        self.cancel()

    def __iter__(self):
        return SchedulerIterator(self, self._condition)

    def count(self):
        with self._condition:
            return len(self._scheduled)

    def get_transaction(self, index):
        with self._condition:
            return self._scheduled_txn_info[self._scheduled[index]]

    def cancel(self):
        with self._condition:
            if not self._cancelled and not self._final:
                contexts = [
                    tr.context_id for tr in self._txn_results.values()
                    if tr.context_id
                ]
                self._squash(self._first_state_hash,
                             contexts,
                             persist=False,
                             clean_up=True)
                self._cancelled = True
                self._condition.notify_all()

    def is_cancelled(self):
        with self._condition:
            return self._cancelled
示例#24
0
class DBDiskCache(DiskCache):
    @property
    def map_id(self):
        return self.__map_desc.map_id

    def __init__(self, cache_dir, map_desc, db_schema, is_concurrency=True):
        self.__map_desc = map_desc
        self.__db_path = os.path.join(cache_dir, map_desc.map_id + ".mbtiles")
        self.__conn = None

        #configs
        self.__db_schema = db_schema
        self.__has_timestamp = True

        self.__is_concurrency = is_concurrency

        if is_concurrency:
            self.__surrogate = None  #the thread do All DB operations, due to sqlite3 requiring only the same thread.

            self.__is_closed = False

            #concurrency get/put
            self.__sql_queue = []
            self.__sql_queue_lock = Lock()
            self.__sql_queue_cv = Condition(self.__sql_queue_lock)

            self.__get_lock = Lock()    #block the 'get' action

            self.__get_respose = None   #the pair (data, exception)
            self.__get_respose_lock = Lock()
            self.__get_respose_cv = Condition(self.__get_respose_lock)

    def __initDB(self):
        def getBoundsText(map_desc):
            left, bottom = map_desc.lower_corner
            right, top   = map_desc.upper_corner
            bounds = "%f,%f,%f,%f" % (left, bottom, right, top) #OpenLayers Bounds format
            return bounds

        desc = self.__map_desc
        conn = self.__conn

        #meatadata
        meta_create_sql = "CREATE TABLE metadata(name TEXT PRIMARY KEY, value TEXT)"
        meta_data_sqls = ("INSERT INTO metadata(name, value) VALUES('%s', '%s')" % ('name', desc.map_id),
                          "INSERT INTO metadata(name, value) VALUES('%s', '%s')" % ('type', 'overlayer'),
                          "INSERT INTO metadata(name, value) VALUES('%s', '%s')" % ('version', '1.0'),
                          "INSERT INTO metadata(name, value) VALUES('%s', '%s')" % ('description', desc.map_title),
                          "INSERT INTO metadata(name, value) VALUES('%s', '%s')" % ('format', desc.tile_format),
                          "INSERT INTO metadata(name, value) VALUES('%s', '%s')" % ('bounds', getBoundsText(desc)),
                          "INSERT INTO metadata(name, value) VALUES('%s', '%s')" % ('schema', self.__db_schema),
                         )
        #tiles
        tiles_create_sql = "CREATE TABLE tiles("
        tiles_create_sql += "zoom_level  INTEGER, "
        tiles_create_sql += "tile_column INTEGER, "
        tiles_create_sql += "tile_row    INTEGER, "
        tiles_create_sql += "tile_data   BLOB     NOT NULL, "
        tiles_create_sql += "timestamp   INTEGER  NOT NULL, "
        tiles_create_sql += "PRIMARY KEY (zoom_level, tile_column, tile_row))"

        #tiles_idx
        tiles_idx_create_sql = "CREATE INDEX tiles_idx on tiles(zoom_level, tile_column, tile_row)"

        #exec
        conn.execute(meta_create_sql)
        conn.execute(tiles_create_sql)
        conn.execute(tiles_idx_create_sql)
        for sql in meta_data_sqls:
            conn.execute(sql)
        conn.commit()


    def __getMetadata(self, name):
        try:
            sql = 'SELECT value FROM metadata WHERE name="%s"' % (name,)
            cursor = self.__conn.execute(sql)
            row = cursor.fetchone()
            data = None if row is None else row[0]
            return data
        except Exception as ex:
            logging.warning('[%s] Get mbtiles metadata error: %s' % (self.map_id, str(ex)))
        return None

    def __tableHasColumn(self, tbl_name, col_name):
        try:
            sql = "PRAGMA table_info(%s)" % tbl_name
            cursor = self.__conn.execute(sql)
            rows = cursor.fetchall()

            for row in rows:
                if row[1] == col_name:
                    return True

        except Exception as ex:
            logging.warning("[%s] detect table '%s' has column '%s' error: %s" % (self.map_id, tbl_name, col_name, str(ex)))
        return False

    def __readConfig(self):
        #db schema from metadta
        schema = self.__getMetadata('schema')
        if schema and self.__db_schema != schema:
            logging.info("[%s] Reset db schema from %s to %s" % (self.map_id, self.__db_schema, schema))
            self.__db_schema = schema

        self.__has_timestamp = self.__tableHasColumn("tiles", "timestamp")

    #the true actions which are called by Surrogate
    def __start(self):
        if not os.path.exists(self.__db_path):
            logging.info("[%s] Initializing local cache DB..." % (self.map_id,))
            mkdirSafely(os.path.dirname(self.__db_path))
            self.__conn = sqlite3.connect(self.__db_path)
            self.__initDB()
        else:
            self.__conn = sqlite3.connect(self.__db_path)
            self.__readConfig()

        logging.info("[%s][Config] db schema: %s" % (self.map_id, self.__db_schema))
        logging.info("[%s][Config] suuport tile timestamp: %s" % (self.map_id, self.__has_timestamp))

    def __close(self):
        logging.info("[%s] Closing local cache DB..." % (self.map_id,))
        self.__conn.close()

    @classmethod
    def flipY(cls, y, level):
        return (1 << level) - 1 - y

    def __put(self, level, x, y, data):
        #sql
        if self.__db_schema == 'tms':
            y = self.flipY(y, level)

        sql = None
        if self.__has_timestamp:
            sql  = "INSERT OR REPLACE INTO tiles(zoom_level, tile_column, tile_row, tile_data, timestamp)"
            sql += " VALUES(%d, %d, %d, ?, %d)" % (level, x, y, int(time.time()))
        else:
            sql = "INSERT OR REPLACE INTO tiles(zoom_level, tile_column, tile_row, tile_data)"
            sql += " VALUES(%d, %d, %d, ?)" % (level, x, y)

        #query
        try:
            self.__conn.execute(sql, (data,))
            self.__conn.commit()
            logging.info("[%s] %s [OK]" % (self.map_id, sql))
        except Exception as ex:
            logging.info("[%s] %s [Fail]" % (self.map_id, sql))
            raise ex

    def __get(self, level, x, y):
        #sql
        if self.__db_schema == 'tms':
            y = self.flipY(y, level)

        cols = "tile_data, timestamp" if self.__has_timestamp else "tile_data"
        sql = "SELECT %s FROM tiles WHERE zoom_level=%d AND tile_column=%d AND tile_row=%d" % \
                (cols, level, x, y,)

        row = None
        try:
            #query
            cursor = self.__conn.execute(sql)
            row = cursor.fetchone()
        except Exception as ex:
            logging.info("[%s] %s [Fail]" % (self.map_id, sql))
            raise ex

        #result (tile, timestamp)
        if row is None:
            logging.info("[%s] %s [NA]" % (self.map_id, sql))
            return (None, None)
        elif self.__has_timestamp:
            logging.info("[%s] %s [OK][TS]" % (self.map_id, sql))
            return row
        else:
            logging.info("[%s] %s [OK]" % (self.map_id, sql))
            return (row[0], None)

    #the interface which are called by the user
    def start(self):
        if not self.__is_concurrency:
            self.__start()
        else:
            self.__surrogate = Thread(target=self.__runSurrogate)
            self.__surrogate.start()

    def close(self):
        if not self.__is_concurrency:
            self.__close()
        else:
            with self.__sql_queue_cv:
                self.__is_closed = True
                self.__sql_queue_cv.notify()
            self.__surrogate.join()

    def put(self, level, x, y, data):
        if not self.__is_concurrency:
            self.__put(level, x, y, data)
        else:
            with self.__sql_queue_cv:
                item = (level, x, y, data)
                self.__sql_queue.append(item)
                self.__sql_queue_cv.notify()

    def get(self, level, x, y):
        if not self.__is_concurrency:
            return self.__get(level, x, y)
        else:
            def has_respose():
                return self.__get_respose is not None

            with self.__get_lock:  #for blocking the continuous get
                #req tile
                with self.__sql_queue_cv:
                    item = (level, x, y, None)
                    self.__sql_queue.insert(0, item)  #service first
                    self.__sql_queue_cv.notify()

                #wait resposne
                res = None
                with self.__get_respose_cv:
                    self.__get_respose_cv.wait_for(has_respose)
                    res, self.__get_respose = self.__get_respose, res   #swap: pop response of get()

                #return data
                data, ex = res
                if ex:
                    raise ex
                return data

    #the Surrogate thread
    def __runSurrogate(self):
        def has_sql_events():
            return self.__is_closed or len(self.__sql_queue)
        
        self.__start()
        try:
            while True:
                #wait events
                item = None
                with self.__sql_queue_cv:
                    self.__sql_queue_cv.wait_for(has_sql_events)
                    if self.__is_closed:
                        return
                    item = self.__sql_queue.pop(0)

                level, x, y, data = item
                #put data
                if data:
                    try:
                        self.__put(level, x, y, data)
                    except Exception as ex:
                        logging.error("[%s] DB put data error: %s" % (self.map_id, str(ex)))
                #get data
                else:
                    res_data, res_ex = None, None
                    try:
                        res_data = self.__get(level, x, y)
                        res_ex = None
                    except Exception as ex:
                        logging.error("[%s] DB get data error: %s" % (self.map_id, str(ex)))
                        res_data = None
                        res_ex = ex

                    #notify
                    with self.__get_respose_cv:
                        self.__get_respose = (res_data, res_ex)
                        self.__get_respose_cv.notify()

        finally:
            self.__close()
class SerialScheduler(Scheduler):
    """Serial scheduler which returns transactions in the natural order.

    This scheduler will schedule one transaction at a time (only one may be
    unapplied), in the exact order provided as batches were added to the
    scheduler.

    This scheduler is intended to be used for comparison to more complex
    schedulers - for tests related to performance, correctness, etc.
    """

    def __init__(self, squash_handler, first_state_hash, always_persist):
        self._txn_queue = deque()
        self._scheduled_transactions = []
        self._batch_statuses = {}
        self._txn_to_batch = {}
        self._batch_by_id = {}
        self._txn_results = {}
        self._in_progress_transaction = None
        self._final = False
        self._cancelled = False
        self._previous_context_id = None
        self._previous_valid_batch_c_id = None
        self._squash = squash_handler
        self._condition = Condition()
        # contains all txn.signatures where txn is
        # last in it's associated batch
        self._last_in_batch = []
        self._previous_state_hash = first_state_hash
        # The state hashes here are the ones added in add_batch, and
        # are the state hashes that correspond with block boundaries.
        self._required_state_hashes = {}
        self._already_calculated = False
        self._always_persist = always_persist

    def __del__(self):
        self.cancel()

    def __iter__(self):
        return SchedulerIterator(self, self._condition)

    def set_transaction_execution_result(
            self, txn_signature, is_valid, context_id, state_changes=None,
            events=None, data=None, error_message="", error_data=b""):
        with self._condition:
            if (self._in_progress_transaction is None
                    or self._in_progress_transaction != txn_signature):
                return

            self._in_progress_transaction = None

            if txn_signature not in self._txn_to_batch:
                raise ValueError(
                    "transaction not in any batches: {}".format(txn_signature))

            if txn_signature not in self._txn_results:
                self._txn_results[txn_signature] = TxnExecutionResult(
                    signature=txn_signature,
                    is_valid=is_valid,
                    context_id=context_id if is_valid else None,
                    state_hash=self._previous_state_hash if is_valid else None,
                    state_changes=state_changes,
                    events=events,
                    data=data,
                    error_message=error_message,
                    error_data=error_data)

            batch_signature = self._txn_to_batch[txn_signature]
            if is_valid:
                self._previous_context_id = context_id

            else:
                # txn is invalid, preemptively fail the batch
                self._batch_statuses[batch_signature] = \
                    BatchExecutionResult(is_valid=False, state_hash=None)
            if txn_signature in self._last_in_batch:
                if batch_signature not in self._batch_statuses:
                    # because of the else clause above, txn is valid here
                    self._previous_valid_batch_c_id = self._previous_context_id
                    state_hash = self._calculate_state_root_if_required(
                        batch_id=batch_signature)
                    self._batch_statuses[batch_signature] = \
                        BatchExecutionResult(is_valid=True,
                                             state_hash=state_hash)
                else:
                    self._previous_context_id = self._previous_valid_batch_c_id

            self._condition.notify_all()

    def add_batch(self, batch, state_hash=None, required=False):
        with self._condition:
            if self._final:
                raise SchedulerError("Scheduler is finalized. Cannot take"
                                     " new batches")
            preserve = required
            if not required:
                # If this is the first non-required batch, it is preserved for
                # the schedule to be completed (i.e. no empty schedules in the
                # event of unschedule_incomplete_batches being called before
                # the first batch is completed).
                preserve = _first(
                    filterfalse(lambda sb: sb.required,
                                self._batch_by_id.values())) is None

            batch_signature = batch.header_signature
            self._batch_by_id[batch_signature] = \
                _AnnotatedBatch(batch, required=required, preserve=preserve)

            if state_hash is not None:
                self._required_state_hashes[batch_signature] = state_hash
            batch_length = len(batch.transactions)
            for idx, txn in enumerate(batch.transactions):
                if idx == batch_length - 1:
                    self._last_in_batch.append(txn.header_signature)
                self._txn_to_batch[txn.header_signature] = batch_signature
                self._txn_queue.append(txn)
            self._condition.notify_all()

    def get_batch_execution_result(self, batch_signature):
        with self._condition:
            return self._batch_statuses.get(batch_signature)

    def get_transaction_execution_results(self, batch_signature):
        with self._condition:
            batch_status = self._batch_statuses.get(batch_signature)
            if batch_status is None:
                return None

            annotated_batch = self._batch_by_id.get(batch_signature)
            if annotated_batch is None:
                return None

            results = []
            for txn in annotated_batch.batch.transactions:
                result = self._txn_results.get(txn.header_signature)
                if result is not None:
                    results.append(result)
            return results

    def count(self):
        with self._condition:
            return len(self._scheduled_transactions)

    def get_transaction(self, index):
        with self._condition:
            return self._scheduled_transactions[index]

    def _get_dependencies(self, transaction):
        header = TransactionHeader()
        header.ParseFromString(transaction.header)
        return list(header.dependencies)

    def _set_batch_result(self, txn_id, valid, state_hash):
        if txn_id not in self._txn_to_batch:
            # An incomplete transaction in progress will have been removed
            return

        batch_id = self._txn_to_batch[txn_id]
        self._batch_statuses[batch_id] = BatchExecutionResult(
            is_valid=valid,
            state_hash=state_hash)
        batch = self._batch_by_id[batch_id].batch
        for txn in batch.transactions:
            if txn.header_signature not in self._txn_results:
                self._txn_results[txn.header_signature] = TxnExecutionResult(
                    txn.header_signature, is_valid=False)

    def _get_batch_result(self, txn_id):
        batch_id = self._txn_to_batch[txn_id]
        return self._batch_statuses[batch_id]

    def _dep_is_known(self, txn_id):
        return txn_id in self._txn_to_batch

    def _in_invalid_batch(self, txn_id):
        if self._txn_to_batch[txn_id] in self._batch_statuses:
            dependency_result = self._get_batch_result(txn_id)
            return not dependency_result.is_valid
        return False

    def _handle_fail_fast(self, txn):
        self._set_batch_result(
            txn.header_signature,
            False,
            None)
        self._check_change_last_good_context_id(txn)

    def _check_change_last_good_context_id(self, txn):
        if txn.header_signature in self._last_in_batch:
            self._previous_context_id = self._previous_valid_batch_c_id

    def next_transaction(self):
        with self._condition:
            if self._in_progress_transaction is not None:
                return None

            txn = None
            while txn is None:
                try:
                    txn = self._txn_queue.popleft()
                except IndexError:
                    if self._final:
                        self._condition.notify_all()
                        raise StopIteration()
                    return None
                # Handle this transaction being invalid based on a
                # dependency.
                if any(self._dep_is_known(d) and self._in_invalid_batch(d)
                        for d in self._get_dependencies(txn)):
                    self._set_batch_result(
                        txn.header_signature,
                        False,
                        None)
                    self._check_change_last_good_context_id(txn=txn)
                    txn = None
                    continue
                # Handle fail fast.
                if self._in_invalid_batch(txn.header_signature):
                    self._handle_fail_fast(txn)
                    txn = None

            self._in_progress_transaction = txn.header_signature
            base_contexts = [] if self._previous_context_id is None \
                else [self._previous_context_id]
            txn_info = TxnInformation(
                txn=txn,
                state_hash=self._previous_state_hash,
                base_context_ids=base_contexts)
            self._scheduled_transactions.append(txn_info)
            return txn_info

    def unschedule_incomplete_batches(self):
        inprogress_batch_id = None
        with self._condition:
            # remove the in-progress transaction's batch
            if self._in_progress_transaction is not None:
                batch_id = self._txn_to_batch[self._in_progress_transaction]
                annotated_batch = self._batch_by_id[batch_id]

                # if the batch is preserve or there are no completed batches,
                # keep it in the schedule
                if not annotated_batch.preserve:
                    self._in_progress_transaction = None
                else:
                    inprogress_batch_id = batch_id

            def in_schedule(entry):
                (batch_id, annotated_batch) = entry
                return batch_id in self._batch_statuses or \
                    annotated_batch.preserve or batch_id == inprogress_batch_id

            incomplete_batches = list(
                filterfalse(in_schedule, self._batch_by_id.items()))

            # clean up the batches, including partial complete information
            for batch_id, annotated_batch in incomplete_batches:
                for txn in annotated_batch.batch.transactions:
                    txn_id = txn.header_signature
                    if txn_id in self._txn_results:
                        del self._txn_results[txn_id]

                    if txn in self._txn_queue:
                        self._txn_queue.remove(txn)

                    del self._txn_to_batch[txn_id]

                self._last_in_batch.remove(
                    annotated_batch.batch.transactions[-1].header_signature)

                del self._batch_by_id[batch_id]

            self._condition.notify_all()

        if incomplete_batches:
            LOGGER.debug('Removed %s incomplete batches from the schedule',
                         len(incomplete_batches))

    def is_transaction_in_schedule(self, txn_signature):
        with self._condition:
            return txn_signature in self._txn_to_batch

    def finalize(self):
        with self._condition:
            self._final = True
            self._condition.notify_all()

    def _compute_merkle_root(self, required_state_root):
        """Computes the merkle root of the state changes in the context
        corresponding with _last_valid_batch_c_id as applied to
        _previous_state_hash.

        Args:
            required_state_root (str): The merkle root that these txns
                should equal.

        Returns:
            state_hash (str): The merkle root calculated from the previous
                state hash and the state changes from the context_id
        """

        state_hash = None
        if self._previous_valid_batch_c_id is not None:
            publishing_or_genesis = self._always_persist or \
                required_state_root is None
            state_hash = self._squash(
                state_root=self._previous_state_hash,
                context_ids=[self._previous_valid_batch_c_id],
                persist=self._always_persist, clean_up=publishing_or_genesis)
            if self._always_persist is True:
                return state_hash
            if state_hash == required_state_root:
                self._squash(state_root=self._previous_state_hash,
                             context_ids=[self._previous_valid_batch_c_id],
                             persist=True, clean_up=True)
        return state_hash

    def _calculate_state_root_if_not_already_done(self):
        if not self._already_calculated:
            if not self._last_in_batch:
                return
            last_txn_signature = self._last_in_batch[-1]
            batch_id = self._txn_to_batch[last_txn_signature]
            required_state_hash = self._required_state_hashes.get(
                batch_id)

            state_hash = self._compute_merkle_root(required_state_hash)
            self._already_calculated = True
            for t_id in self._last_in_batch[::-1]:
                b_id = self._txn_to_batch[t_id]
                if self._batch_statuses[b_id].is_valid:
                    self._batch_statuses[b_id].state_hash = state_hash
                    # found the last valid batch, so break out
                    break

    def _calculate_state_root_if_required(self, batch_id):
        required_state_hash = self._required_state_hashes.get(
            batch_id)
        state_hash = None
        if required_state_hash is not None:
            state_hash = self._compute_merkle_root(required_state_hash)
            self._already_calculated = True
        return state_hash

    def _complete(self):
        return self._final and \
            len(self._txn_results) == len(self._txn_to_batch)

    def complete(self, block):
        with self._condition:
            if not self._final:
                return False
            if self._complete():
                self._calculate_state_root_if_not_already_done()
                return True
            if block:
                self._condition.wait_for(self._complete)
                self._calculate_state_root_if_not_already_done()
                return True
            return False

    def cancel(self):
        with self._condition:
            if not self._cancelled and not self._final \
                    and self._previous_context_id:
                self._squash(
                    state_root=self._previous_state_hash,
                    context_ids=[self._previous_context_id],
                    persist=False,
                    clean_up=True)
                self._cancelled = True
                self._condition.notify_all()

    def is_cancelled(self):
        with self._condition:
            return self._cancelled
示例#26
0
class SerialScheduler(Scheduler):
    """Serial scheduler which returns transactions in the natural order.

    This scheduler will schedule one transaction at a time (only one may be
    unapplied), in the exact order provided as batches were added to the
    scheduler.

    This scheduler is intended to be used for comparison to more complex
    schedulers - for tests related to performance, correctness, etc.
    """
    def __init__(self, squash_handler, first_state_hash, always_persist):
        self._txn_queue = queue.Queue()
        self._scheduled_transactions = []
        self._batch_statuses = {}
        self._txn_to_batch = {}
        self._batch_by_id = {}
        self._txns_with_results = []
        self._in_progress_transaction = None
        self._final = False
        self._cancelled = False
        self._previous_context_id = None
        self._previous_valid_batch_c_id = None
        self._squash = squash_handler
        self._condition = Condition()
        # contains all txn.signatures where txn is
        # last in it's associated batch
        self._last_in_batch = []
        self._previous_state_hash = first_state_hash
        # The state hashes here are the ones added in add_batch, and
        # are the state hashes that correspond with block boundaries.
        self._required_state_hashes = {}
        self._already_calculated = False
        self._always_persist = always_persist

    def __del__(self):
        self.cancel()

    def __iter__(self):
        return SchedulerIterator(self, self._condition)

    def set_transaction_execution_result(self, txn_signature, is_valid,
                                         context_id):
        with self._condition:
            if (self._in_progress_transaction is None
                    or self._in_progress_transaction != txn_signature):
                raise ValueError(
                    "transaction not in progress: {}".format(txn_signature))
            self._in_progress_transaction = None

            if txn_signature not in self._txn_to_batch:
                raise ValueError(
                    "transaction not in any batches: {}".format(txn_signature))

            if txn_signature not in self._txns_with_results:
                self._txns_with_results.append(txn_signature)

            batch_signature = self._txn_to_batch[txn_signature]
            if is_valid:
                self._previous_context_id = context_id

            else:
                # txn is invalid, preemptively fail the batch
                self._batch_statuses[batch_signature] = \
                    BatchExecutionResult(is_valid=False, state_hash=None)
            if txn_signature in self._last_in_batch:
                if batch_signature not in self._batch_statuses:
                    # because of the else clause above, txn is valid here
                    self._previous_valid_batch_c_id = self._previous_context_id
                    state_hash = self._calculate_state_root_if_required(
                        batch_id=batch_signature)
                    self._batch_statuses[batch_signature] = \
                        BatchExecutionResult(is_valid=True,
                                             state_hash=state_hash)
                else:
                    self._previous_context_id = self._previous_valid_batch_c_id

            self._condition.notify_all()

    def add_batch(self, batch, state_hash=None):
        with self._condition:
            if self._final:
                raise SchedulerError("Scheduler is finalized. Cannot take"
                                     " new batches")
            batch_signature = batch.header_signature
            self._batch_by_id[batch_signature] = batch
            if state_hash is not None:
                self._required_state_hashes[batch_signature] = state_hash
            batch_length = len(batch.transactions)
            for idx, txn in enumerate(batch.transactions):
                if idx == batch_length - 1:
                    self._last_in_batch.append(txn.header_signature)
                self._txn_to_batch[txn.header_signature] = batch_signature
                self._txn_queue.put(txn)
            self._condition.notify_all()

    def get_batch_execution_result(self, batch_signature):
        with self._condition:
            return self._batch_statuses.get(batch_signature)

    def count(self):
        with self._condition:
            return len(self._scheduled_transactions)

    def get_transaction(self, index):
        with self._condition:
            return self._scheduled_transactions[index]

    def _get_dependencies(self, transaction):
        header = TransactionHeader()
        header.ParseFromString(transaction.header)
        return list(header.dependencies)

    def _set_batch_result(self, txn_id, valid, state_hash):
        batch_id = self._txn_to_batch[txn_id]
        self._batch_statuses[batch_id] = BatchExecutionResult(
            is_valid=valid, state_hash=state_hash)
        batch = self._batch_by_id[batch_id]
        for txn in batch.transactions:
            if txn.header_signature not in self._txns_with_results:
                self._txns_with_results.append(txn.header_signature)

    def _get_batch_result(self, txn_id):
        batch_id = self._txn_to_batch[txn_id]
        return self._batch_statuses[batch_id]

    def _dep_is_known(self, txn_id):
        return txn_id in self._txn_to_batch

    def _dep_is_not_valid(self, txn_id):
        dependency_result = self._get_batch_result(txn_id)
        return not dependency_result.is_valid

    def next_transaction(self):
        with self._condition:
            if self._in_progress_transaction is not None:
                return None

            txn = None
            try:
                while txn is None:
                    txn = self._txn_queue.get(block=False)
                    # Handle this transaction being invalid based on a
                    # dependency.
                    if any(
                            self._dep_is_known(d) and self._dep_is_not_valid(d)
                            for d in self._get_dependencies(txn)):
                        self._set_batch_result(txn.header_signature, False,
                                               None)
                        txn = None
            except queue.Empty:
                return None

            self._in_progress_transaction = txn.header_signature
            base_contexts = [] if self._previous_context_id is None \
                else [self._previous_context_id]
            txn_info = TxnInformation(txn=txn,
                                      state_hash=self._previous_state_hash,
                                      base_context_ids=base_contexts)
            self._scheduled_transactions.append(txn_info)
            return txn_info

    def finalize(self):
        with self._condition:
            self._final = True
            self._condition.notify_all()

    def _compute_merkle_root(self, required_state_root):
        """Computes the merkle root of the state changes in the context
        corresponding with _last_valid_batch_c_id as applied to
        _previous_state_hash.

        Args:
            required_state_root (str): The merkle root that these txns
                should equal.

        Returns:
            state_hash (str): The merkle root calculated from the previous
                state hash and the state changes from the context_id
        """

        state_hash = None
        if self._previous_valid_batch_c_id is not None:
            publishing_or_genesis = self._always_persist or \
                                    required_state_root is None
            state_hash = self._squash(
                state_root=self._previous_state_hash,
                context_ids=[self._previous_valid_batch_c_id],
                persist=self._always_persist,
                clean_up=publishing_or_genesis)
            if self._always_persist is True:
                return state_hash
            if state_hash == required_state_root:
                self._squash(state_root=self._previous_state_hash,
                             context_ids=[self._previous_valid_batch_c_id],
                             persist=True,
                             clean_up=True)
        return state_hash

    def _calculate_state_root_if_not_already_done(self):
        if not self._already_calculated:
            if not self._last_in_batch:
                return
            last_txn_signature = self._last_in_batch[-1]
            batch_id = self._txn_to_batch[last_txn_signature]
            required_state_hash = self._required_state_hashes.get(batch_id)

            state_hash = self._compute_merkle_root(required_state_hash)
            self._already_calculated = True
            for t_id in self._last_in_batch[::-1]:
                b_id = self._txn_to_batch[t_id]
                if self._batch_statuses[b_id].is_valid:
                    self._batch_statuses[b_id].state_hash = state_hash
                    # found the last valid batch, so break out
                    break

    def _calculate_state_root_if_required(self, batch_id):
        required_state_hash = self._required_state_hashes.get(batch_id)
        state_hash = None
        if required_state_hash is not None:
            state_hash = self._compute_merkle_root(required_state_hash)
            self._already_calculated = True
        return state_hash

    def _complete(self):
        return self._final and \
               len(self._txns_with_results) == len(self._txn_to_batch)

    def complete(self, block):
        with self._condition:
            if not self._final:
                return False
            if self._complete():
                self._calculate_state_root_if_not_already_done()
                return True
            if block:
                self._condition.wait_for(self._complete)
                self._calculate_state_root_if_not_already_done()
                return True
            return False

    def cancel(self):
        with self._condition:
            self._cancelled = True
            self._condition.notify_all()

    def is_cancelled(self):
        with self._condition:
            return self._cancelled
示例#27
0
def submit_block(block):
    with context.socket(zmq.PUSH) as sock:
        sock.connect(HOST + str(PORT + 1))
        sock.send_string(block)


thread = Thread(target=listening_thread)
thread.start()

while True:
    if template:
        print("Working on template %s" % template["hash"])
        task = parse_mining_task(template)

        mine(task.block, task.target)

        print("Block successfully mined!")
        serialized = task.block.serialize()
        response = compose_mining_result(serialized)

        submit_block(response)
        template = None

    else:
        print("Waiting for work from pool leader.")
        with template_received:
            template_received.wait_for(lambda: pending_template is not None)
            template = pending_template
            pending_template = None
class ParallelScheduler(Scheduler):
    def __init__(self, squash_handler, first_state_hash, always_persist):
        self._squash = squash_handler
        self._first_state_hash = first_state_hash
        self._last_state_hash = first_state_hash
        self._condition = Condition()
        self._predecessor_tree = PredecessorTree()
        self._txn_predecessors = {}

        self._always_persist = always_persist

        self._predecessor_chain = PredecessorChain()

        # Transaction identifiers which have been scheduled.  Stored as a list,
        # since order is important; SchedulerIterator instances, for example,
        # must all return scheduled transactions in the same order.
        self._scheduled = []

        # Transactions that must be replayed but the prior result hasn't
        # been returned yet.
        self._outstanding = set()

        # Batch id for the batch with the property that the batch doesn't have
        # all txn results, and all batches prior to it have all their txn
        # results.
        self._least_batch_id_wo_results = None

        # A dict of transaction id to TxnInformation objects, containing all
        # transactions present in self._scheduled.
        self._scheduled_txn_info = {}

        # All batches in their natural order (the order they were added to
        # the scheduler.
        self._batches = []
        # The batches that have state hashes added in add_batch, used in
        # Block validation.
        self._batches_with_state_hash = {}

        # Indexes to find a batch quickly
        self._batches_by_id = {}
        self._batches_by_txn_id = {}

        # Transaction results
        self._txn_results = {}

        self._txns_available = OrderedDict()
        self._transactions = {}

        self._cancelled = False
        self._final = False

    def _find_input_dependencies(self, inputs):
        """Use the predecessor tree to find dependencies based on inputs.

        Returns: A list of transaction ids.
        """
        dependencies = []
        for address in inputs:
            dependencies.extend(
                self._predecessor_tree.find_read_predecessors(address))
        return dependencies

    def _find_output_dependencies(self, outputs):
        """Use the predecessor tree to find dependencies based on outputs.

        Returns: A list of transaction ids.
        """
        dependencies = []
        for address in outputs:
            dependencies.extend(
                self._predecessor_tree.find_write_predecessors(address))
        return dependencies

    def add_batch(self, batch, state_hash=None, required=False):
        with self._condition:
            if self._final:
                raise SchedulerError('Invalid attempt to add batch to '
                                     'finalized scheduler; batch: {}'
                                     .format(batch.header_signature))
            if not self._batches:
                self._least_batch_id_wo_results = batch.header_signature

            preserve = required
            if not required:
                # If this is the first non-required batch, it is preserved for
                # the schedule to be completed (i.e. no empty schedules in the
                # event of unschedule_incomplete_batches being called before
                # the first batch is completed).
                preserve = _first(
                    filterfalse(lambda sb: sb.required,
                                self._batches_by_id.values())) is None

            self._batches.append(batch)
            self._batches_by_id[batch.header_signature] = \
                _AnnotatedBatch(batch, required=required, preserve=preserve)
            for txn in batch.transactions:
                self._batches_by_txn_id[txn.header_signature] = batch
                self._txns_available[txn.header_signature] = txn
                self._transactions[txn.header_signature] = txn

            if state_hash is not None:
                b_id = batch.header_signature
                self._batches_with_state_hash[b_id] = state_hash

            # For dependency handling: First, we determine our dependencies
            # based on the current state of the predecessor tree.  Second,
            # we update the predecessor tree with reader and writer
            # information based on input and outputs.
            for txn in batch.transactions:
                header = TransactionHeader()
                header.ParseFromString(txn.header)

                # Calculate predecessors (transaction ids which must come
                # prior to the current transaction).
                predecessors = self._find_input_dependencies(header.inputs)
                predecessors.extend(
                    self._find_output_dependencies(header.outputs))

                txn_id = txn.header_signature
                # Update our internal state with the computed predecessors.
                self._txn_predecessors[txn_id] = set(predecessors)
                self._predecessor_chain.add_relationship(
                    txn_id=txn_id,
                    predecessors=predecessors)

                # Update the predecessor tree.
                #
                # Order of reader/writer operations is relevant.  A writer
                # may overshadow a reader.  For example, if the transaction
                # has the same input/output address, the end result will be
                # this writer (txn.header_signature) stored at the address of
                # the predecessor tree.  The reader information will have been
                # discarded.  Write operations to partial addresses will also
                # overshadow entire parts of the predecessor tree.
                #
                # Thus, the order here (inputs then outputs) will cause the
                # minimal amount of relevant information to be stored in the
                # predecessor tree, with duplicate information being
                # automatically discarded by the set_writer() call.
                for address in header.inputs:
                    self._predecessor_tree.add_reader(
                        address, txn_id)
                for address in header.outputs:
                    self._predecessor_tree.set_writer(
                        address, txn_id)

            self._condition.notify_all()

    def _is_explicit_request_for_state_root(self, batch_signature):
        return batch_signature in self._batches_with_state_hash

    def _is_implicit_request_for_state_root(self, batch_signature):
        return self._final and self._is_last_valid_batch(batch_signature)

    def _is_valid_batch(self, batch):
        for txn in batch.transactions:
            if txn.header_signature not in self._txn_results:
                raise _UnscheduledTransactionError()

            result = self._txn_results[txn.header_signature]
            if not result.is_valid:
                return False
        return True

    def _is_last_valid_batch(self, batch_signature):
        batch = self._batches_by_id[batch_signature].batch
        if not self._is_valid_batch(batch):
            return False
        index_of_next = self._batches.index(batch) + 1
        for later_batch in self._batches[index_of_next:]:
            if self._is_valid_batch(later_batch):
                return False
        return True

    def _get_contexts_for_squash(self, batch_signature):
        """Starting with the batch referenced by batch_signature, iterate back
        through the batches and for each valid batch collect the context_id.
        At the end remove contexts for txns that are other txn's predecessors.

        Args:
            batch_signature (str): The batch to start from, moving back through
                the batches in the scheduler

        Returns:
            (list): Context ids that haven't been previous base contexts.
        """

        batch = self._batches_by_id[batch_signature].batch
        index = self._batches.index(batch)
        contexts = []
        txns_added_predecessors = []
        for b in self._batches[index::-1]:
            batch_is_valid = True
            contexts_from_batch = []
            for txn in b.transactions[::-1]:
                result = self._txn_results[txn.header_signature]
                if not result.is_valid:
                    batch_is_valid = False
                    break
                else:
                    txn_id = txn.header_signature
                    if txn_id not in txns_added_predecessors:
                        txns_added_predecessors.append(
                            self._txn_predecessors[txn_id])
                        contexts_from_batch.append(result.context_id)
            if batch_is_valid:
                contexts.extend(contexts_from_batch)

        return contexts

    def _is_state_hash_correct(self, state_hash, batch_id):
        return state_hash == self._batches_with_state_hash[batch_id]

    def get_batch_execution_result(self, batch_signature):
        with self._condition:
            # This method calculates the BatchExecutionResult on the fly,
            # where only the TxnExecutionResults are cached, instead
            # of BatchExecutionResults, as in the SerialScheduler
            if batch_signature not in self._batches_by_id:
                return None

            batch = self._batches_by_id[batch_signature].batch

            if not self._is_valid_batch(batch):
                return BatchExecutionResult(is_valid=False, state_hash=None)

            state_hash = None
            try:
                if self._is_explicit_request_for_state_root(batch_signature):
                    contexts = self._get_contexts_for_squash(batch_signature)
                    state_hash = self._squash(
                        self._first_state_hash,
                        contexts,
                        persist=False,
                        clean_up=False)
                    if self._is_state_hash_correct(state_hash,
                                                   batch_signature):
                        self._squash(
                            self._first_state_hash,
                            contexts,
                            persist=True,
                            clean_up=True)
                    else:
                        self._squash(
                            self._first_state_hash,
                            contexts,
                            persist=False,
                            clean_up=True)
                elif self._is_implicit_request_for_state_root(batch_signature):
                    contexts = self._get_contexts_for_squash(batch_signature)
                    state_hash = self._squash(
                        self._first_state_hash,
                        contexts,
                        persist=self._always_persist,
                        clean_up=True)
            except _UnscheduledTransactionError:
                return None

            return BatchExecutionResult(is_valid=True, state_hash=state_hash)

    def get_transaction_execution_results(self, batch_signature):
        with self._condition:
            annotated_batch = self._batches_by_id.get(batch_signature)
            if annotated_batch is None:
                return None

            results = []
            for txn in annotated_batch.batch.transactions:
                result = self._txn_results.get(txn.header_signature)
                if result is not None:
                    results.append(result)
            return results

    def _is_predecessor_of_possible_successor(self,
                                              txn_id,
                                              possible_successor):
        return self._predecessor_chain.is_predecessor_of_other(
            txn_id,
            [possible_successor])

    def _txn_has_result(self, txn_id):
        return txn_id in self._txn_results

    def _is_in_same_batch(self, txn_id_1, txn_id_2):
        return self._batches_by_txn_id[txn_id_1] == \
            self._batches_by_txn_id[txn_id_2]

    def _is_txn_to_replay(self, txn_id, possible_successor, already_seen):
        """Decide if possible_successor should be replayed.

        Args:
            txn_id (str): Id of txn in failed batch.
            possible_successor (str): Id of txn to possibly replay.
            already_seen (list): A list of possible_successors that have
                been replayed.

        Returns:
            (bool): If the possible_successor should be replayed.
        """

        is_successor = self._is_predecessor_of_possible_successor(
            txn_id,
            possible_successor)
        in_different_batch = not self._is_in_same_batch(txn_id,
                                                        possible_successor)
        has_not_been_seen = possible_successor not in already_seen

        return is_successor and in_different_batch and has_not_been_seen

    def _remove_subsequent_result_because_of_batch_failure(self, sig):
        """Remove transactions from scheduled and txn_results for
        successors of txns in a failed batch. These transactions will now,
        or in the future be rescheduled in next_transaction; giving a
        replay ability.

        Args:
            sig (str): Transaction header signature

        """

        batch = self._batches_by_txn_id[sig]
        seen = []
        for txn in batch.transactions:
            txn_id = txn.header_signature
            for poss_successor in self._scheduled.copy():
                if not self.is_transaction_in_schedule(poss_successor):
                    continue

                if self._is_txn_to_replay(txn_id, poss_successor, seen):
                    if self._txn_has_result(poss_successor):
                        del self._txn_results[poss_successor]
                        self._scheduled.remove(poss_successor)
                        self._txns_available[poss_successor] = \
                            self._transactions[poss_successor]
                    else:
                        self._outstanding.add(poss_successor)
                    seen.append(poss_successor)

    def _reschedule_if_outstanding(self, txn_signature):
        if txn_signature in self._outstanding:
            self._txns_available[txn_signature] = \
                self._transactions[txn_signature]
            self._scheduled.remove(txn_signature)
            self._outstanding.discard(txn_signature)
            return True
        return False

    def _index_of_batch(self, batch):
        batch_index = None
        try:
            batch_index = self._batches.index(batch)
        except ValueError:
            pass
        return batch_index

    def _set_least_batch_id(self, txn_signature):
        """Set the first batch id that doesn't have all results.

        Args:
            txn_signature (str): The txn identifier of the transaction with
                results being set.

        """

        batch = self._batches_by_txn_id[txn_signature]

        least_index = self._index_of_batch(
            self._batches_by_id[self._least_batch_id_wo_results].batch)

        current_index = self._index_of_batch(batch)
        all_prior = False

        if current_index <= least_index:
            return
            # Test to see if all batches from the least_batch to
            # the prior batch to the current batch have results.
        if all(
                all(t.header_signature in self._txn_results
                    for t in b.transactions)
                for b in self._batches[least_index:current_index]):
            all_prior = True
        if not all_prior:
            return
        possible_least = self._batches[current_index].header_signature
        # Find the first batch from the current batch on, that doesn't have
        # all results.
        for b in self._batches[current_index:]:
            if not all(t.header_signature in self._txn_results
                       for t in b.transactions):
                possible_least = b.header_signature
                break
        self._least_batch_id_wo_results = possible_least

    def set_transaction_execution_result(
            self, txn_signature, is_valid, context_id, state_changes=None,
            events=None, data=None, error_message="", error_data=b""):
        with self._condition:
            if txn_signature not in self._scheduled:
                raise SchedulerError(
                    "transaction not scheduled: {}".format(txn_signature))

            if txn_signature not in self._batches_by_txn_id:
                return

            self._set_least_batch_id(txn_signature=txn_signature)
            if not is_valid:
                self._remove_subsequent_result_because_of_batch_failure(
                    txn_signature)
            is_rescheduled = self._reschedule_if_outstanding(txn_signature)

            if not is_rescheduled:
                self._txn_results[txn_signature] = TxnExecutionResult(
                    signature=txn_signature,
                    is_valid=is_valid,
                    context_id=context_id if is_valid else None,
                    state_hash=self._first_state_hash if is_valid else None,
                    state_changes=state_changes,
                    events=events,
                    data=data,
                    error_message=error_message,
                    error_data=error_data)

            self._condition.notify_all()

    def _has_predecessors(self, txn_id):
        for predecessor_id in self._txn_predecessors[txn_id]:
            if predecessor_id not in self._txn_results:
                return True
            # Since get_initial_state_for_transaction gets context ids not
            # just from predecessors but also in the case of an enclosing
            # writer failing, predecessors of that predecessor, this extra
            # check is needed.
            for pre_pred_id in self._txn_predecessors[predecessor_id]:
                if pre_pred_id not in self._txn_results:
                    return True

        return False

    def _is_outstanding(self, txn_id):
        return txn_id in self._outstanding

    def _txn_is_in_valid_batch(self, txn_id):
        """Returns whether the transaction is in a valid batch.

        Args:
            txn_id (str): The transaction header signature.

        Returns:
            (bool): True if the txn's batch is valid, False otherwise.
        """

        batch = self._batches_by_txn_id[txn_id]

        # Return whether every transaction in the batch with a
        # transaction result is valid
        return all(
            self._txn_results[sig].is_valid
            for sig in set(self._txn_results).intersection(
                (txn.header_signature for txn in batch.transactions)))

    def _get_initial_state_for_transaction(self, txn):
        # Collect contexts that this transaction depends upon
        # We assume that all prior txns in the batch are valid
        # or else this transaction wouldn't run. We assume that
        # the mechanism in next_transaction makes sure that each
        # predecessor txn has a result. Also any explicit
        # dependencies that could have failed this txn did so.
        contexts = []
        txn_dependencies = deque()
        txn_dependencies.extend(self._txn_predecessors[txn.header_signature])
        while txn_dependencies:
            prior_txn_id = txn_dependencies.popleft()
            if self._txn_is_in_valid_batch(prior_txn_id):
                result = self._txn_results[prior_txn_id]
                if (prior_txn_id, result.context_id) not in contexts:
                    contexts.append((prior_txn_id, result.context_id))
            else:
                txn_dependencies.extend(self._txn_predecessors[prior_txn_id])

        contexts.sort(
            key=lambda x: self._index_of_txn_in_schedule(x[0]),
            reverse=True)
        return [c_id for _, c_id in contexts]

    def _index_of_txn_in_schedule(self, txn_id):
        batch = self._batches_by_txn_id[txn_id]
        index_of_batch_in_schedule = self._batches.index(batch)
        number_of_txns_in_prior_batches = 0
        for prior in self._batches[:index_of_batch_in_schedule]:
            number_of_txns_in_prior_batches += len(prior.transactions)

        txn_index, _ = next(
            (i, t)
            for i, t in enumerate(batch.transactions)
            if t.header_signature == txn_id)

        return number_of_txns_in_prior_batches + txn_index - 1

    def _can_fail_fast(self, txn_id):
        batch_id = self._batches_by_txn_id[txn_id].header_signature
        return batch_id == self._least_batch_id_wo_results

    def next_transaction(self):
        with self._condition:
            # We return the next transaction which hasn't been scheduled and
            # is not blocked by a dependency.

            next_txn = None

            no_longer_available = []

            for txn_id, txn in self._txns_available.items():
                if (self._has_predecessors(txn_id)
                        or self._is_outstanding(txn_id)):
                    continue

                header = TransactionHeader()
                header.ParseFromString(txn.header)
                deps = tuple(header.dependencies)

                if self._dependency_not_processed(deps):
                    continue

                if self._txn_failed_by_dep(deps):
                    no_longer_available.append(txn_id)
                    self._txn_results[txn_id] = \
                        TxnExecutionResult(
                            signature=txn_id,
                            is_valid=False,
                            context_id=None,
                            state_hash=None)
                    continue

                if not self._txn_is_in_valid_batch(txn_id) and \
                        self._can_fail_fast(txn_id):
                    self._txn_results[txn_id] = \
                        TxnExecutionResult(False, None, None)
                    no_longer_available.append(txn_id)
                    continue

                next_txn = txn
                break

            for txn_id in no_longer_available:
                del self._txns_available[txn_id]

            if next_txn is not None:
                bases = self._get_initial_state_for_transaction(next_txn)

                info = TxnInformation(
                    txn=next_txn,
                    state_hash=self._first_state_hash,
                    base_context_ids=bases)
                self._scheduled.append(next_txn.header_signature)
                del self._txns_available[next_txn.header_signature]
                self._scheduled_txn_info[next_txn.header_signature] = info
                return info
            return None

    def _dependency_not_processed(self, deps):
        if any(not self._all_in_batch_have_results(d)
               for d in deps
               if d in self._batches_by_txn_id):
            return True
        return False

    def _txn_failed_by_dep(self, deps):
        if any(self._any_in_batch_are_invalid(d)
               for d in deps
               if d in self._batches_by_txn_id):
            return True
        return False

    def _all_in_batch_have_results(self, txn_id):
        batch = self._batches_by_txn_id[txn_id]
        return all(
            t.header_signature in self._txn_results
            for t in list(batch.transactions))

    def _any_in_batch_are_invalid(self, txn_id):
        batch = self._batches_by_txn_id[txn_id]
        return any(not self._txn_results[t.header_signature].is_valid
                   for t in list(batch.transactions))

    def available(self):
        with self._condition:
            # We return the next transaction which hasn't been scheduled and
            # is not blocked by a dependency.

            count = 0
            for txn_id in self._txns_available:
                if not self._has_predecessors(txn_id):
                    count += 1

            return count

    def unschedule_incomplete_batches(self):
        incomplete_batches = set()
        with self._condition:
            # These transactions have never been scheduled.
            for txn_id, txn in self._txns_available.items():
                batch = self._batches_by_txn_id[txn_id]
                batch_id = batch.header_signature

                annotated_batch = self._batches_by_id[batch_id]
                if not annotated_batch.preserve:
                    incomplete_batches.add(batch_id)

            # These transactions were in flight.
            in_flight = set(self._transactions.keys()).difference(
                self._txn_results.keys())

            for txn_id in in_flight:
                batch = self._batches_by_txn_id[txn_id]
                batch_id = batch.header_signature

                annotated_batch = self._batches_by_id[batch_id]
                if not annotated_batch.preserve:
                    incomplete_batches.add(batch_id)

            # clean up the batches, including partial complete information
            for batch_id in incomplete_batches:
                annotated_batch = self._batches_by_id[batch_id]
                self._batches.remove(annotated_batch.batch)
                del self._batches_by_id[batch_id]
                for txn in annotated_batch.batch.transactions:
                    txn_id = txn.header_signature
                    del self._batches_by_txn_id[txn_id]

                    if txn_id in self._txn_results:
                        del self._txn_results[txn_id]

                    if txn_id in self._txns_available:
                        del self._txns_available[txn_id]

                    if txn_id in self._outstanding:
                        self._outstanding.remove(txn_id)

            self._condition.notify_all()

        if incomplete_batches:
            LOGGER.debug('Removed %s incomplete batches from the schedule',
                         len(incomplete_batches))

    def is_transaction_in_schedule(self, txn_signature):
        with self._condition:
            return txn_signature in self._batches_by_txn_id

    def finalize(self):
        with self._condition:
            self._final = True
            self._condition.notify_all()

    def _complete(self):
        return self._final and \
            len(self._txn_results) == len(self._batches_by_txn_id)

    def complete(self, block=True):
        with self._condition:
            if self._complete():
                return True

            if block:
                return self._condition.wait_for(self._complete)

            return False

    def __del__(self):
        self.cancel()

    def __iter__(self):
        return SchedulerIterator(self, self._condition)

    def count(self):
        with self._condition:
            return len(self._scheduled)

    def get_transaction(self, index):
        with self._condition:
            return self._scheduled_txn_info[self._scheduled[index]]

    def cancel(self):
        with self._condition:
            if not self._cancelled and not self._final:
                contexts = [
                    tr.context_id for tr in self._txn_results.values()
                    if tr.context_id
                ]
                self._squash(
                    self._first_state_hash,
                    contexts,
                    persist=False,
                    clean_up=True)
                self._cancelled = True
                self._condition.notify_all()

    def is_cancelled(self):
        with self._condition:
            return self._cancelled
示例#29
0
class _SendReceive(object):
    def __init__(self, address, futures, identity=None, dispatcher=None):
        self._dispatcher = dispatcher
        self._futures = futures
        self._address = address
        self._identity = identity
        self._event_loop = None
        self._context = None
        self._recv_queue = None
        self._socket = None
        self._condition = Condition()

    @asyncio.coroutine
    def _receive_message(self):
        """
        Internal coroutine for receiving messages
        """
        with self._condition:
            self._condition.wait_for(lambda: self._socket is not None)
        while True:
            if self._socket.getsockopt(zmq.TYPE) == zmq.ROUTER:
                identity, msg_bytes = yield from self._socket.recv_multipart()
            else:
                msg_bytes = yield from self._socket.recv()

            message = validator_pb2.Message()
            message.ParseFromString(msg_bytes)

            LOGGER.debug("receiving %s message",
                         get_enum_name(message.message_type))
            try:
                self._futures.set_result(
                    message.correlation_id,
                    future.FutureResult(message_type=message.message_type,
                                        content=message.content))
            except future.FutureCollectionKeyError:
                if self._socket.getsockopt(zmq.TYPE) == zmq.ROUTER:
                    self._dispatcher.dispatch(identity, message)
                else:
                    LOGGER.info(
                        "received a first message on the zmq dealer.")
            else:
                my_future = self._futures.get(message.correlation_id)
                LOGGER.debug("message round "
                             "trip: %s %s",
                             get_enum_name(message.message_type),
                             my_future.get_duration())
                self._futures.remove(message.correlation_id)

    @asyncio.coroutine
    def _send_message(self, identity, msg):
        LOGGER.debug("sending %s to %s",
                     get_enum_name(msg.message_type),
                     identity)

        if identity is None:
            message_bundle = [msg.SerializeToString()]
        else:
            message_bundle = [bytes(identity),
                              msg.SerializeToString()]
        yield from self._socket.send_multipart(message_bundle)

    def send_message(self, msg, identity=None):
        """
        :param msg: protobuf validator_pb2.Message
        """
        with self._condition:
            self._condition.wait_for(lambda: self._event_loop is not None)
        asyncio.run_coroutine_threadsafe(self._send_message(identity, msg),
                                         self._event_loop)

    def setup(self, socket_type):
        """
        :param socket_type: zmq.DEALER or zmq.ROUTER
        """
        self._event_loop = zmq.asyncio.ZMQEventLoop()
        asyncio.set_event_loop(self._event_loop)
        self._context = zmq.asyncio.Context()
        self._socket = self._context.socket(socket_type)
        if socket_type == zmq.DEALER:
            self._socket.identity = "{}-{}".format(
                self._identity,
                hashlib.sha512(uuid.uuid4().hex.encode()
                               ).hexdigest()[:23]).encode('ascii')
            self._socket.connect(self._address)
        elif socket_type == zmq.ROUTER:
            self._dispatcher.set_send_message(self.send_message)
            self._socket.bind(self._address)
        self._recv_queue = asyncio.Queue()
        asyncio.ensure_future(self._receive_message(), loop=self._event_loop)
        with self._condition:
            self._condition.notify_all()
        self._event_loop.run_forever()

    def stop(self):
        self._event_loop.stop()
        self._socket.close()
        self._context.term()
示例#30
0
class FileQueue(object):
    """Thread-safe queue which reads files into memory using `read_file` and
    doesn't read any more files when the total size of the cached files
    exceeds `max_size`. The size of each file (loaded by `read_file`) is given
    by `get_file_size`. Note that get() and put() are blocking operations,
    and they are fulfilled on a first-come first-served basis."""
    def __init__(self, max_size=0, read_file=None, get_file_size=None):
        self._max_size = max_size
        self.read_file = read_file if read_file else lambda x: x
        self.get_file_size = get_file_size if get_file_size else lambda x: 1

        # cur_size and queue need atomic access via mutex
        self._mutex = RLock()
        self._cur_size = 0
        self._queue = deque()
        self._closed = False

        # first process to try to add to/remove from the queue gets to do so
        # mutex to ensure that one thread doesn't try to notify the other while
        # the cv.wait_for(...) predicate is being evaluated.
        self._cv_mutex = RLock()
        self._not_full = Condition()
        self._not_empty = Condition()

    def full(self):
        with self._mutex:
            return self._cur_size >= self._max_size > 0

    def empty(self):
        with self._mutex:
            return len(self._queue) == 0

    def close(self):
        """Cancels all pending put operations."""
        with self._not_full:
            with self._not_empty:
                self._closed = True
                self._not_full.notify_all()
                self._not_empty.notify_all()

    def open(self):
        """Re-opens the queue so new put operations can be scheduled."""
        with self._not_full:
            self._closed = False

    def clear(self):
        """Close the queue and then clear it."""
        with self._not_full:
            with self._not_empty:
                with self._mutex:
                    self.close()
                    self._queue.clear()
                    self._cur_size = 0

    def put(self, path, *args, **kwargs):
        """Attempts to read the file (at `path`) and adds it to the queue.
        Blocks if the queue is full and returns `False` if the queue is closed.
        Returns whether we succeeded in adding the item to the queue."""
        def predicate():
            with self._cv_mutex:
                is_full = self.full()
                if is_full and not self._closed:
                    logger.debug("QUEUE FULL")
                return self._closed or not is_full

        logger.debug(f"TRY PUT {os.path.basename(path)}")
        with self._not_full:
            self._not_full.wait_for(predicate)
            if self._closed:
                return False
            file = self.read_file(path, *args, **kwargs)

            with self._mutex:
                self._queue.append((path, file))
                self._cur_size += self.get_file_size(file)

        # Only try to acquire self._not_empty when get() isn't evaluating its
        # wait_for predicate (indicating whether the queue is empty). If we
        # can't acquire self._not_empty, then the queue is either in the middle
        # of a successful get(), or has just started one that will succeed
        # because we just ensured the queue is non-empty. If we can acquire it,
        # notify the appropriate condition variable that the queue is non-empty.
        with self._cv_mutex:
            if self._not_empty.acquire(blocking=False):
                self._not_empty.notify()
                self._not_empty.release()

        logger.debug(f"PUT {os.path.basename(path)}")
        return True

    def get(self, expected_path=None):
        """Returns the file from the top of the queue. Blocks if the queue is
        empty. If `expected_path` is provided, a `RuntimeError` is raised if
        the path of the file returned is different from `expected_path`."""
        def predicate():
            with self._cv_mutex:
                return self._closed or not self.empty()

        if expected_path is not None:
            logger.debug(f"TRY GET {os.path.basename(expected_path)}")
        with self._not_empty:
            self._not_empty.wait_for(predicate)
            with self._mutex:
                if self.empty():
                    path, file = None, None
                else:
                    path, file = self._queue.popleft()
                    self._cur_size -= self.get_file_size(file)

        # Only try to acquire self._not_full when put() isn't evaluating its
        # wait_for predicate (indicating whether the queue is full). If we
        # can't acquire self._not_full, then the queue is either in the middle
        # of a successful put(), or has just started one that will succeed
        # because we just ensured the queue is non-full. If we can acquire it,
        # notify the appropriate condition variable that the queue is non-full.
        with self._cv_mutex:
            if self._not_full.acquire(blocking=False):
                self._not_full.notify()
                self._not_full.release()

        if path is not None:
            if expected_path is not None and path != expected_path:
                raise RuntimeError(
                    f"A race condition caused us to fetch the wrong file. "
                    f"Expected {expected_path}, but got {path} instead.")
            logger.debug(f"GET {os.path.basename(path)}")
        else:
            logger.debug(f"QUEUE EMPTY")
        return path, file
示例#31
0
class MPIProcessor(BaseProcessor):
    """
    Processor to distribute work using MPI
    """
    def __init__(self, builders):
        (self.comm, self.rank, self.size) = get_mpi()
        if not self.comm:
            raise Exception(
                "MPI not working properly, check your mpi4py installation and ensure this is running under mpi"
            )
        self.comm.barrier()
        super(MPIProcessor, self).__init__(builders)

    def process(self, builder_id):
        """
        Run the builder using MPI protocol.

        Args:
            builder_id (int): the index of the builder in the builders list
        """
        self.comm.barrier()
        if self.rank == 0:
            self.process_master(builder_id)
        else:
            self.process_worker()

    def setup_multithreading(self):
        """
        Setup structures for managing data to/from MPI Workers
        """
        self.data = deque()
        self.ranks = deque([i + 1 for i in range(self.size - 1)])
        self.task_count = BoundedSemaphore(self.builder.chunk_size)
        self.update_data_condition = Condition()

        self.run_update_targets = True
        self.update_targets_thread = Thread(target=self.update_targets)
        self.update_targets_thread.start()

    def process_master(self, builder_id):
        """
        Master process for MPI processing
        Handles Data IO to Stores and to MPI Workers
        """
        self.builder = self.builders[builder_id]
        self.builder.connect()

        cursor = self.builder.get_items()

        self.setup_pbars(cursor)
        self.setup_multithreading()
        self.put_tasks(builder_id)
        self.clean_up_workers()
        self.clean_up_data()
        self.builder.finalize(cursor)
        self.cleanup_pbars()

    def process_worker(self):
        """
        MPI Worker process
        """
        is_valid = True

        while is_valid:
            packet = self.comm.recv(source=0)
            if packet["type"] == "process":
                builder_id = packet["builder_id"]
                data = packet["data"]
                try:
                    result = self.builders[builder_id].process_item(data)
                    self.comm.send({
                        "type": "return",
                        "return": result
                    },
                                   dest=0)
                except e:
                    self.comm.send({"type": "error", "error": e})
            elif packet["type"] == "shutdown":
                is_valid = False

    def setup_pbars(self, cursor):
        """
        Sets up progress bars
        """
        total = None
        if isinstance(cursor, types.GeneratorType):
            cursor = primed(cursor)
            if hasattr(self.builder, "total"):
                total = self.builder.total
        elif hasattr(cursor, "__len__"):
            total = len(cursor)
        elif hasattr(cursor, "count"):
            total = cursor.count()

        self.get_pbar = tqdm(cursor, desc="Get Items", total=total)
        self.process_pbar = tqdm(desc="Processing Item", total=total)
        self.update_pbar = tqdm(desc="Updating Targets", total=total)

    def cleanup_pbars(self):
        """
        Cleans up the TQDM bars
        """
        self.get_pbar.close()
        self.process_pbar.close()
        self.update_pbar.close()

    def put_tasks(self, builder_id):
        """
        Submit tasks from cursor to MPI workers
        """
        # 1.) Setup thread pool
        with ThreadPoolExecutor(max_workers=self.size - 1) as executor:
            # 2.) Loop over every item wrapped in a tqdm bar
            for item in self.get_pbar:
                # 3.) Limit total number of queued tasks using a semaphore
                self.task_count.acquire()
                # 4.) Submit the item to a worker
                f = executor.submit(self.submit_item, builder_id, item)

    def submit_item(self, builder_id, data):
        """
        Thread to submit an item to MPI Workers and get data back

        """

        # 1.) Find free rank and take it
        mpi_rank = self.ranks.pop()
        # 2.) Submit the job to that rank
        self.comm.send(
            {
                "type": "process",
                "builder_id": builder_id,
                "data": data
            },
            dest=mpi_rank)
        # 3.) Periodically poll for data back
        result = None
        while not result:
            packet = self.comm.recv(source=mpi_rank)
            if packet["type"] == "return":
                result = packet["return"]
                self.task_count.release()
            elif packet["type"] == "error":
                self.logger.error(
                    "MPI Rank {} Errored on Builder ID {}:\n{}".format(
                        mpi_rank, builder_id, packet["error"]))
                self.task_count.release()
                return
            else:
                self.task_count.release()
                return  # don't know what happened here, just quit

        # 6.) Update process progress bar
        self.process_pbar.update(1)

        # 7.) Save data
        with self.update_data_condition:
            self.data.append(result)
            self.update_data_condition.notify_all()
        # 8.) Return rank
        self.ranks.append(mpi_rank)

    def clean_up_workers(self):
        """
        Sends shutdown signal to all MPI workers
        """
        for i in range(self.size - 1):
            self.comm.send({"type": "shutdown"}, dest=i + 1)

    def clean_up_data(self):
        """
        Call back to add data into a list in thread safe manner and signal other threads to add more tasks or update_targets
        """
        self.logger.debug("Cleaning up data queue")
        try:
            with self.update_data_condition:
                self.run_update_targets = False
                self.update_data_condition.notify_all()
        except Exception as e:
            self.logger.debug(
                "Problem in updating targets at end of builder run: {}".format(
                    e))

        self.update_targets_thread.join()

    def update_targets(self):
        """
        Thread to update targets periodically
        """
        while self.run_update_targets:
            with self.update_data_condition:
                self.update_data_condition.wait_for(
                    lambda: not self.run_update_targets or len(
                        self.data) > self.builder.chunk_size)
                try:
                    self.builder.update_targets(self.data)
                    self.update_pbar.update(len(self.data))
                    self.data.clear()
                except Exception as e:
                    self.logger.debug(
                        "Problem in updating targets in builder run: {}".
                        format(e))
示例#32
0
class NtTestBase(NetworkTablesInstance):
    """
        Object for managing a live pair of NT server/client
    """

    _wait_lock = None
    _testing_verbose_logging = True

    def shutdown(self):
        logger.info("shutting down %s", self.__class__.__name__)
        NetworkTablesInstance.shutdown(self)
        if self._wait_lock is not None:
            self._wait_init_listener()

    def disconnect(self):
        self._api.dispatcher.stop()

    def _init_common(self, proto_rev):
        # This resets the instance to be independent
        self.shutdown()
        self._api.dispatcher.setDefaultProtoRev(proto_rev)
        self.proto_rev = proto_rev

        if self._testing_verbose_logging:
            self.enableVerboseLogging()
        # self._wait_init()

    def _init_server(self, proto_rev, server_port=0):
        self._init_common(proto_rev)

        self.port = server_port

    def _init_client(self, proto_rev):
        self._init_common(proto_rev)

    def _wait_init(self):
        self._wait_lock = Condition()
        self._wait = 0
        self._wait_init_listener()

    def _wait_init_listener(self):
        self._api.addEntryListener(
            "",
            self._wait_cb,
            NetworkTablesInstance.NotifyFlags.NEW
            | NetworkTablesInstance.NotifyFlags.UPDATE
            | NetworkTablesInstance.NotifyFlags.DELETE
            | NetworkTablesInstance.NotifyFlags.FLAGS,
        )

    def _wait_cb(self, *args):
        with self._wait_lock:
            self._wait += 1
            # logger.info('Wait callback, got: %s', args)
            self._wait_lock.notify()

    @contextmanager
    def expect_changes(self, count):
        """Use this on the *other* instance that you're making
        changes on, to wait for the changes to propagate to the
        other instance"""

        if self._wait_lock is None:
            self._wait_init()

        with self._wait_lock:
            self._wait = 0

        logger.info("Begin actions")
        yield
        logger.info("Waiting for %s changes", count)

        with self._wait_lock:
            result, msg = (
                self._wait_lock.wait_for(lambda: self._wait == count, 4),
                "Timeout waiting for %s changes (got %s)" % (count, self._wait),
            )
            logger.info("expect_changes: %s %s", result, msg)
            assert result, msg
示例#33
0
class ThreadPool:
    def __init__(self, n_threads=cpu_count(), local_state={}):
        self.threads = [ Thread(target=lambda i=i: self.thread_main(i, deepcopy(local_state)))
                for i in range(n_threads) ]

        self.queue = []
        self.results = []
        self.last_req_no = -1
        self.last_finished_no = -1
        self.done_at_no = -1
        self.lock = Lock()
        self.thread_cv = Condition(self.lock)
        self.iter_cv = Condition(self.lock)
        self.exception = None

        for thread in self.threads:
            thread.start()

    def init_thread(self, local_state):
        pass

    def cleanup_thread(self, local_state):
        pass

    def execute_task(self, local_state, task):
        pass

    def thread_main(self, i, local_state):
        local_state["thread_no"] = i
        self.init_thread(local_state)
        try:
            while True:
                with self.lock:
                    self.thread_cv.wait_for(lambda: self.queue)
                    task = self.queue.pop(0)
                result = self.execute_task(local_state, task)
                with self.lock:
                    self.results.append(result)
                    self.last_finished_no += 1
                    self.iter_cv.notify()
        except ExitThread:
            pass
        except BaseException as e:
            with self.lock:
                self.exception = e
        finally:
            self.cleanup_thread(local_state)

    def defer(self, task):
        with self.lock:
            self.done_at_no = -1
            self.last_req_no += 1
            self.queue.append(task)
            self.thread_cv.notify()

    def __iter__(self):
        with self.lock:
            while self.last_finished_no <= self.done_at_no:
                self.iter_cv.wait_for(lambda: self.last_finished_no >= self.done_at_no
                        or self.results or self.exception)
                if self.exception:
                    e = self.exception
                    self.exception = None
                    raise e
                elif self.results:
                    yield self.results.pop(0)
                else:
                    break
        raise StopIteration()

    def done(self):
        with self.lock:
            self.done_at_no = self.last_req_no

    def destroy(self):
        for thread in self.threads:
            # raise ExitThread in every thread
            ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread.ident),
                ctypes.py_object(ExitThread))
        with self.lock:
            # Wake up all waiting threads to handle exception
            self.thread_cv.notify_all()
        for thread in self.threads:
            thread.join()
        if self.exception:
            e = self.exception
            self.exception = None
            raise e

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.destroy()
示例#34
0
class TestStateObserver(ExecutionStateObserver):
    __test__ = False  # To tell pytest it isn't a test class

    def __init__(self):
        self._events: List[Tuple[datetime, JobInfo, ExecutionState,
                                 ExecutionError]] = []
        self.completion_lock = Condition()

    def state_update(self, job_info: JobInfo):
        self._events.append(
            (datetime.now(), job_info, job_info.state, job_info.exec_error))
        log.info("event=[state_changed] job_info=[{}]".format(job_info))
        self._release_state_waiter()

    def last_job(self) -> JobInfo:
        """
        :return: job of the last event
        """
        return self._events[-1][1]

    def last_state(self, job_id) -> ExecutionState:
        """
        :return: last state of the specified job
        """
        return next(e[2] for e in reversed(self._events)
                    if e[1].job_id == job_id)

    def exec_state(self, event_idx: int) -> ExecutionState:
        """
        :param event_idx: event index
        :return: execution state of the event on given index
        """
        return self._events[event_idx][2]

    def exec_error(self, event_idx: int) -> ExecutionError:
        """
        :param event_idx: event index
        :return: execution state of the event on given index
        """
        return self._events[event_idx][3]

    def _release_state_waiter(self):
        with self.completion_lock:
            self.completion_lock.notify(
            )  # Support only one-to-one thread sync to keep things simple

    def wait_for_state(self,
                       exec_state: ExecutionState,
                       timeout: float = 1) -> bool:
        """
        Wait for receiving notification with the specified state

        :param exec_state: Waits for the state specified by this parameter
        :param timeout: Waiting interval in seconds
        :return: True when specified state received False when timed out
        """
        return self._wait_for_state_condition(
            lambda: exec_state in (e[2] for e in self._events), timeout)

    def wait_for_terminal_state(self, timeout: float = 1) -> bool:
        """
        Wait for receiving notification with a terminal state

        :param timeout: Waiting interval in seconds
        :return: True when terminal state received False when timed out
        """
        return self._wait_for_state_condition(
            lambda: any((e for e in self._events if e[2].is_terminal())),
            timeout)

    def _wait_for_state_condition(self, state_condition: Callable[[], bool],
                                  timeout: float):
        with self.completion_lock:
            return self.completion_lock.wait_for(state_condition, timeout)
示例#35
0
class NtTestBase(NetworkTablesInstance):
    """
        Object for managing a live pair of NT server/client
    """

    _wait_lock = None

    def shutdown(self):
        logger.info("shutting down %s", self.__class__.__name__)
        NetworkTablesInstance.shutdown(self)
        if self._wait_lock is not None:
            self._wait_init_listener()

    def disconnect(self):
        self._api.dispatcher.stop()

    def _init_common(self, proto_rev):
        # This resets the instance to be independent
        self.shutdown()
        self._api.dispatcher.setDefaultProtoRev(proto_rev)
        self.proto_rev = proto_rev

        self.enableVerboseLogging()
        # self._wait_init()

    def _init_server(self, proto_rev, server_port=0):
        self._init_common(proto_rev)

        self.port = server_port

    def _init_client(self, proto_rev):
        self._init_common(proto_rev)

    def _wait_init(self):
        self._wait_lock = Condition()
        self._wait = 0
        self._wait_init_listener()

    def _wait_init_listener(self):
        self._api.addEntryListener(
            "",
            self._wait_cb,
            NetworkTablesInstance.NotifyFlags.NEW
            | NetworkTablesInstance.NotifyFlags.UPDATE
            | NetworkTablesInstance.NotifyFlags.DELETE
            | NetworkTablesInstance.NotifyFlags.FLAGS,
        )

    def _wait_cb(self, *args):
        with self._wait_lock:
            self._wait += 1
            # logger.info('Wait callback, got: %s', args)
            self._wait_lock.notify()

    @contextmanager
    def expect_changes(self, count):
        """Use this on the *other* instance that you're making
        changes on, to wait for the changes to propagate to the
        other instance"""

        if self._wait_lock is None:
            self._wait_init()

        with self._wait_lock:
            self._wait = 0

        logger.info("Begin actions")
        yield
        logger.info("Waiting for %s changes", count)

        with self._wait_lock:
            result, msg = (
                self._wait_lock.wait_for(lambda: self._wait == count, 4),
                "Timeout waiting for %s changes (got %s)" %
                (count, self._wait),
            )
            logger.info("expect_changes: %s %s", result, msg)
            assert result, msg
示例#36
0
class Pool(object):
    """
    Connection pool for pymysql.

    The initialization parameters are as follows:
    :param host: Host of MySQL server
    :param port: Port of MySQL server
    :param user: User of MySQL server
    :param password: Password of MySQL server
    :param unix_socket: Optionally, you can use a unix socket rather than TCP/IP.
    :param db: Database of MySQL server
    :param charset: Charset of MySQL server
    :param cursorclass: Class of MySQL Cursor
    :param autocommit: auto commit mode
    :param min_size: Minimum size of connection pool
    :param max_size: Maximum size of connection pool
    :param timeout: Watting time in the multi-thread environment
    :param interval: Statistical cycle time
    :param stati_mun: Statistical frequency
    :param multiple: Regulation standard
    :param counter: Counter
    :param accumulation: Statiscal result
    """

    def __init__(self,
                 host="localhost",
                 port=3306,
                 user=None,
                 password=None,
                 unix_socket=None,
                 db=None,
                 charset="utf8",
                 cursorclass=pymysql.cursors.DictCursor,
                 autocommit=False,
                 min_size=1,
                 max_size=3,
                 timeout=10.0,
                 interval=600.0,
                 stati_num=3,
                 multiple=4,
                 counter=0,
                 accumulation=0):
        self.host = host
        self.port = port
        self.user = user
        self.password = password
        self.db = db
        self.charset = charset
        self.cursorclass = cursorclass
        self.autocommit = autocommit

        self.min_size = min_size
        self.max_size = max_size
        self.current_size = 0
        self.timeout = timeout

        self.unuse_list = set()
        self.inuse_list = set()
        self.lock = Lock()
        self.cond = Condition(self.lock)

        self.interval = interval
        self.stati_num = stati_num
        self.multiple = multiple
        self.counter = 0
        self.accumulation = 0

        self.unix_socket=unix_socket

    def create_conn(self):
        """Create mysql connection by pymysql and to add unuse_list"""
        c = pymysql.connect(
            host=self.host,
            port=self.port,
            user=self.user,
            password=self.password,
            db=self.db,
            charset=self.charset,
            cursorclass=self.cursorclass,
            autocommit=self.autocommit,
            unix_socket=self.unix_socket
        )
        self.unuse_list.add(c)

    def _start(self):
        """Start thread for resize pool"""
        t = Thread(target=resize_pool, args=(self.interval, self.stati_num,
                                             self.multiple, self.counter,
                                             self.accumulation, self))
        t.start()

    def _init_pool(self):
        """Initial minimum size of pool"""
        assert (self.min_size <= self.max_size)
        for _ in range(self.min_size):
            self.create_conn()

    def init(self):
        self._init_pool()
        self._start()

    def _wait(self):
        """Waiting condition"""
        return len(self.unuse_list) > 0

    def get_conn(self):
        with self.cond:
            # Lack of resources and wait
            if len(self.unuse_list) <= 0 and \
                    self.current_size >= self.max_size:
                # note: TimeoutError mean release operation exception
                # or max_size much less than concurrence
                self.cond.wait_for(self._wait, self.timeout)
                if len(self.unuse_list) <= 0:
                    raise TimeoutError
            # Lack of resources but can created
            if len(self.unuse_list) <= 0 and \
                    self.current_size < self.max_size:
                self.create_conn()

            self.current_size += 1
            c = self.unuse_list.pop()
            self.inuse_list.add(c)
            return c

    def release(self, c):
        """Release connection from inuse_list to unuse_list"""
        with self.cond:
            self.current_size -= 1
            self.inuse_list.remove(c)
            self.unuse_list.add(c)
            self.cond.notify_all()

    def destroy(self):
        """Destroy pool"""
        for _ in range(len(self.unuse_list)):
            c = self.unuse_list.pop()
            c.close()
        for _ in range(len(self.inuse_list)):
            c = self.inuse_list.pop()
            c.close()
示例#37
0
class SerialScheduler(Scheduler):
    """Serial scheduler which returns transactions in the natural order.

    This scheduler will schedule one transaction at a time (only one may be
    unapplied), in the exact order provided as batches were added to the
    scheduler.

    This scheduler is intended to be used for comparison to more complex
    schedulers - for tests related to performance, correctness, etc.
    """
    def __init__(self, squash_handler, first_state_hash, always_persist):
        self._txn_queue = deque()
        self._scheduled_transactions = []
        self._batch_statuses = {}
        self._txn_to_batch = {}
        self._batch_by_id = {}
        self._txn_results = {}
        self._in_progress_transaction = None
        self._final = False
        self._cancelled = False
        self._previous_context_id = None
        self._previous_valid_batch_c_id = None
        self._squash = squash_handler
        self._condition = Condition()
        # contains all txn.signatures where txn is
        # last in it's associated batch
        self._last_in_batch = []
        self._previous_state_hash = first_state_hash
        # The state hashes here are the ones added in add_batch, and
        # are the state hashes that correspond with block boundaries.
        self._required_state_hashes = {}
        self._already_calculated = False
        self._always_persist = always_persist

    def __del__(self):
        self.cancel()

    def __iter__(self):
        return SchedulerIterator(self, self._condition)

    def set_transaction_execution_result(self,
                                         txn_signature,
                                         is_valid,
                                         context_id,
                                         state_changes=None,
                                         events=None,
                                         data=None,
                                         error_message="",
                                         error_data=b""):
        with self._condition:
            if (self._in_progress_transaction is None
                    or self._in_progress_transaction != txn_signature):
                return

            self._in_progress_transaction = None

            if txn_signature not in self._txn_to_batch:
                raise ValueError(
                    "transaction not in any batches: {}".format(txn_signature))

            if txn_signature not in self._txn_results:
                self._txn_results[txn_signature] = TxnExecutionResult(
                    signature=txn_signature,
                    is_valid=is_valid,
                    context_id=context_id if is_valid else None,
                    state_hash=self._previous_state_hash if is_valid else None,
                    state_changes=state_changes,
                    events=events,
                    data=data,
                    error_message=error_message,
                    error_data=error_data)

            batch_signature = self._txn_to_batch[txn_signature]
            if is_valid:
                self._previous_context_id = context_id

            else:
                # txn is invalid, preemptively fail the batch
                self._batch_statuses[batch_signature] = \
                    BatchExecutionResult(is_valid=False, state_hash=None)
            if txn_signature in self._last_in_batch:
                if batch_signature not in self._batch_statuses:
                    # because of the else clause above, txn is valid here
                    self._previous_valid_batch_c_id = self._previous_context_id
                    state_hash = self._calculate_state_root_if_required(
                        batch_id=batch_signature)
                    self._batch_statuses[batch_signature] = \
                        BatchExecutionResult(is_valid=True,
                                             state_hash=state_hash)
                else:
                    self._previous_context_id = self._previous_valid_batch_c_id

            self._condition.notify_all()

    def add_batch(self, batch, state_hash=None, required=False):
        with self._condition:
            if self._final:
                raise SchedulerError("Scheduler is finalized. Cannot take"
                                     " new batches")
            preserve = required
            if not required:
                # If this is the first non-required batch, it is preserved for
                # the schedule to be completed (i.e. no empty schedules in the
                # event of unschedule_incomplete_batches being called before
                # the first batch is completed).
                preserve = _first(
                    filterfalse(lambda sb: sb.required,
                                self._batch_by_id.values())) is None

            batch_signature = batch.header_signature
            self._batch_by_id[batch_signature] = \
                _AnnotatedBatch(batch, required=required, preserve=preserve)

            if state_hash is not None:
                self._required_state_hashes[batch_signature] = state_hash
            batch_length = len(batch.transactions)
            for idx, txn in enumerate(batch.transactions):
                if idx == batch_length - 1:
                    self._last_in_batch.append(txn.header_signature)
                self._txn_to_batch[txn.header_signature] = batch_signature
                self._txn_queue.append(txn)
            self._condition.notify_all()

    def get_batch_execution_result(self, batch_signature):
        with self._condition:
            return self._batch_statuses.get(batch_signature)

    def get_transaction_execution_results(self, batch_signature):
        with self._condition:
            batch_status = self._batch_statuses.get(batch_signature)
            if batch_status is None:
                return None

            annotated_batch = self._batch_by_id.get(batch_signature)
            if annotated_batch is None:
                return None

            results = []
            for txn in annotated_batch.batch.transactions:
                result = self._txn_results.get(txn.header_signature)
                if result is not None:
                    results.append(result)
            return results

    def count(self):
        with self._condition:
            return len(self._scheduled_transactions)

    def get_transaction(self, index):
        with self._condition:
            return self._scheduled_transactions[index]

    def _get_dependencies(self, transaction):
        header = TransactionHeader()
        header.ParseFromString(transaction.header)
        return list(header.dependencies)

    def _set_batch_result(self, txn_id, valid, state_hash):
        if txn_id not in self._txn_to_batch:
            # An incomplete transaction in progress will have been removed
            return

        batch_id = self._txn_to_batch[txn_id]
        self._batch_statuses[batch_id] = BatchExecutionResult(
            is_valid=valid, state_hash=state_hash)
        batch = self._batch_by_id[batch_id].batch
        for txn in batch.transactions:
            if txn.header_signature not in self._txn_results:
                self._txn_results[txn.header_signature] = TxnExecutionResult(
                    txn.header_signature, is_valid=False)

    def _get_batch_result(self, txn_id):
        batch_id = self._txn_to_batch[txn_id]
        return self._batch_statuses[batch_id]

    def _dep_is_known(self, txn_id):
        return txn_id in self._txn_to_batch

    def _in_invalid_batch(self, txn_id):
        if self._txn_to_batch[txn_id] in self._batch_statuses:
            dependency_result = self._get_batch_result(txn_id)
            return not dependency_result.is_valid
        return False

    def _handle_fail_fast(self, txn):
        self._set_batch_result(txn.header_signature, False, None)
        self._check_change_last_good_context_id(txn)

    def _check_change_last_good_context_id(self, txn):
        if txn.header_signature in self._last_in_batch:
            self._previous_context_id = self._previous_valid_batch_c_id

    def next_transaction(self):
        with self._condition:
            if self._in_progress_transaction is not None:
                return None

            txn = None
            while txn is None:
                try:
                    txn = self._txn_queue.popleft()
                except IndexError:
                    if self._final:
                        self._condition.notify_all()
                        raise StopIteration()
                    return None
                # Handle this transaction being invalid based on a
                # dependency.
                if any(
                        self._dep_is_known(d) and self._in_invalid_batch(d)
                        for d in self._get_dependencies(txn)):
                    self._set_batch_result(txn.header_signature, False, None)
                    self._check_change_last_good_context_id(txn=txn)
                    txn = None
                    continue
                # Handle fail fast.
                if self._in_invalid_batch(txn.header_signature):
                    self._handle_fail_fast(txn)
                    txn = None

            self._in_progress_transaction = txn.header_signature
            base_contexts = [] if self._previous_context_id is None \
                else [self._previous_context_id]
            txn_info = TxnInformation(txn=txn,
                                      state_hash=self._previous_state_hash,
                                      base_context_ids=base_contexts)
            self._scheduled_transactions.append(txn_info)
            return txn_info

    def unschedule_incomplete_batches(self):
        inprogress_batch_id = None
        with self._condition:
            # remove the in-progress transaction's batch
            if self._in_progress_transaction is not None:
                batch_id = self._txn_to_batch[self._in_progress_transaction]
                annotated_batch = self._batch_by_id[batch_id]

                # if the batch is preserve or there are no completed batches,
                # keep it in the schedule
                if not annotated_batch.preserve:
                    self._in_progress_transaction = None
                else:
                    inprogress_batch_id = batch_id

            def in_schedule(entry):
                (batch_id, annotated_batch) = entry
                return batch_id in self._batch_statuses or \
                    annotated_batch.preserve or batch_id == inprogress_batch_id

            incomplete_batches = list(
                filterfalse(in_schedule, self._batch_by_id.items()))

            # clean up the batches, including partial complete information
            for batch_id, annotated_batch in incomplete_batches:
                for txn in annotated_batch.batch.transactions:
                    txn_id = txn.header_signature
                    if txn_id in self._txn_results:
                        del self._txn_results[txn_id]

                    if txn in self._txn_queue:
                        self._txn_queue.remove(txn)

                    del self._txn_to_batch[txn_id]

                self._last_in_batch.remove(
                    annotated_batch.batch.transactions[-1].header_signature)

                del self._batch_by_id[batch_id]

            self._condition.notify_all()

        if incomplete_batches:
            LOGGER.debug('Removed %s incomplete batches from the schedule',
                         len(incomplete_batches))

    def is_transaction_in_schedule(self, txn_signature):
        with self._condition:
            return txn_signature in self._txn_to_batch

    def finalize(self):
        with self._condition:
            self._final = True
            self._condition.notify_all()

    def _compute_merkle_root(self, required_state_root):
        """Computes the merkle root of the state changes in the context
        corresponding with _last_valid_batch_c_id as applied to
        _previous_state_hash.

        Args:
            required_state_root (str): The merkle root that these txns
                should equal.

        Returns:
            state_hash (str): The merkle root calculated from the previous
                state hash and the state changes from the context_id
        """

        state_hash = None
        if self._previous_valid_batch_c_id is not None:
            publishing_or_genesis = self._always_persist or \
                required_state_root is None
            state_hash = self._squash(
                state_root=self._previous_state_hash,
                context_ids=[self._previous_valid_batch_c_id],
                persist=self._always_persist,
                clean_up=publishing_or_genesis)
            if self._always_persist is True:
                return state_hash
            if state_hash == required_state_root:
                self._squash(state_root=self._previous_state_hash,
                             context_ids=[self._previous_valid_batch_c_id],
                             persist=True,
                             clean_up=True)
        return state_hash

    def _calculate_state_root_if_not_already_done(self):
        if not self._already_calculated:
            if not self._last_in_batch:
                return
            last_txn_signature = self._last_in_batch[-1]
            batch_id = self._txn_to_batch[last_txn_signature]
            required_state_hash = self._required_state_hashes.get(batch_id)

            state_hash = self._compute_merkle_root(required_state_hash)
            self._already_calculated = True
            for t_id in self._last_in_batch[::-1]:
                b_id = self._txn_to_batch[t_id]
                if self._batch_statuses[b_id].is_valid:
                    self._batch_statuses[b_id].state_hash = state_hash
                    # found the last valid batch, so break out
                    break

    def _calculate_state_root_if_required(self, batch_id):
        required_state_hash = self._required_state_hashes.get(batch_id)
        state_hash = None
        if required_state_hash is not None:
            state_hash = self._compute_merkle_root(required_state_hash)
            self._already_calculated = True
        return state_hash

    def _complete(self):
        return self._final and \
            len(self._txn_results) == len(self._txn_to_batch)

    def complete(self, block):
        with self._condition:
            if not self._final:
                return False
            if self._complete():
                self._calculate_state_root_if_not_already_done()
                return True
            if block:
                self._condition.wait_for(self._complete)
                self._calculate_state_root_if_not_already_done()
                return True
            return False

    def cancel(self):
        with self._condition:
            if not self._cancelled and not self._final \
                    and self._previous_context_id:
                self._squash(state_root=self._previous_state_hash,
                             context_ids=[self._previous_context_id],
                             persist=False,
                             clean_up=True)
                self._cancelled = True
                self._condition.notify_all()

    def is_cancelled(self):
        with self._condition:
            return self._cancelled
示例#38
0
class _SendReceiveThread(Thread):
    """
    Internal thread to Stream class that runs the asyncio event loop.
    """

    def __init__(self, url, futures):
        super(_SendReceiveThread, self).__init__()
        self._futures = futures
        self._url = url

        self._event_loop = None
        self._sock = None
        self._recv_queue = None
        self._send_queue = None
        self._context = None

        self._condition = Condition()

    @asyncio.coroutine
    def _receive_message(self):
        """
        internal coroutine that receives messages and puts
        them on the recv_queue
        """
        with self._condition:
            self._condition.wait_for(lambda: self._sock is not None)
        while True:
            msg_bytes = yield from self._sock.recv()
            message = validator_pb2.Message()
            message.ParseFromString(msg_bytes)
            try:
                self._futures.set_result(
                    message.correlation_id,
                    FutureResult(message_type=message.message_type,
                                 content=message.content))
                self._futures.remove(message.correlation_id)
            except FutureCollectionKeyError:
                # if we are getting an initial message, not a response
                self._recv_queue.put_nowait(message)

    @asyncio.coroutine
    def _send_message(self):
        """
        internal coroutine that sends messages from the send_queue
        """
        with self._condition:
            self._condition.wait_for(lambda: self._send_queue is not None
                                     and self._sock is not None)
        while True:
            msg = yield from self._send_queue.get()
            yield from self._sock.send_multipart([msg.SerializeToString()])

    @asyncio.coroutine
    def _put_message(self, message):
        """
        puts a message on the send_queue. Not to be accessed directly.
        :param message: protobuf generated validator_pb2.Message
        """
        with self._condition:
            self._condition.wait_for(lambda: self._send_queue is not None)
        self._send_queue.put_nowait(message)

    @asyncio.coroutine
    def _get_message(self):
        """
        get a message from the recv_queue. Not to be accessed directly.
        """
        with self._condition:
            self._condition.wait_for(lambda: self._recv_queue is not None)
        msg = yield from self._recv_queue.get()

        return msg

    def put_message(self, message):
        """
        :param message: protobuf generated validator_pb2.Message
        """
        with self._condition:
            self._condition.wait_for(lambda: self._event_loop is not None)
        asyncio.run_coroutine_threadsafe(self._put_message(message),
                                         self._event_loop)

    def get_message(self):
        """
        :return message: protobuf generated validator_pb2.Message
        """
        with self._condition:
            self._condition.wait_for(lambda: self._event_loop is not None)
        return asyncio.run_coroutine_threadsafe(self._get_message(),
                                                self._event_loop).result()

    def _exit_tasks(self):
        for task in asyncio.Task.all_tasks(self._event_loop):
            task.cancel()

    def shutdown(self):
        self._exit_tasks()
        self._event_loop.call_soon_threadsafe(self._event_loop.stop)
        self._sock.close()
        self._context.destroy()

    def run(self):
        self._event_loop = zmq.asyncio.ZMQEventLoop()
        asyncio.set_event_loop(self._event_loop)
        self._context = zmq.asyncio.Context()
        self._sock = self._context.socket(zmq.DEALER)
        self._sock.identity = "{}-{}".format(self.__class__.__name__,
                                             os.getpid()).encode('ascii')
        self._sock.connect('tcp://' + self._url)
        self._send_queue = asyncio.Queue(loop=self._event_loop)
        self._recv_queue = asyncio.Queue(loop=self._event_loop)
        with self._condition:
            self._condition.notify_all()
        asyncio.ensure_future(self._send_message(), loop=self._event_loop)
        asyncio.ensure_future(self._receive_message(), loop=self._event_loop)
        self._event_loop.run_forever()
示例#39
0
class TaskExecutor(object):
    def __init__(self, balancer, index):
        self.balancer = balancer
        self.index = index
        self.task = None
        self.proc = None
        self.pid = None
        self.conn = None
        self.state = WorkerState.STARTING
        self.key = str(uuid.uuid4())
        self.result = AsyncResult()
        self.exiting = False
        self.killed = False
        self.thread = gevent.spawn(self.executor)
        self.cv = Condition()
        self.status_lock = RLock()

    def checkin(self, conn):
        with self.cv:
            self.balancer.logger.debug(
                'Check-in of worker #{0} (key {1})'.format(
                    self.index, self.key))
            self.conn = conn
            self.state = WorkerState.IDLE
            self.cv.notify_all()

    def put_progress(self, progress):
        st = TaskStatus(None)
        st.__setstate__(progress)
        self.task.set_state(progress=st)

    def put_status(self, status):
        with self.cv:
            # Try to collect rusage at this point, when process is still alive
            try:
                kinfo = self.balancer.dispatcher.threaded(
                    bsd.kinfo_getproc, self.pid)
                self.task.rusage = kinfo.rusage
            except LookupError:
                pass

            if status['status'] == 'ROLLBACK':
                self.task.set_state(TaskState.ROLLBACK)

            if status['status'] == 'FINISHED':
                self.result.set(status['result'])

            if status['status'] == 'FAILED':
                error = status['error']

                if error['type'] in ERROR_TYPES:
                    cls = ERROR_TYPES[error['type']]
                    exc = cls(code=error['code'],
                              message=error['message'],
                              stacktrace=error['stacktrace'],
                              extra=error.get('extra'))
                else:
                    exc = OtherException(
                        code=error['code'],
                        message=error['message'],
                        stacktrace=error['stacktrace'],
                        type=error['type'],
                        extra=error.get('extra'),
                    )

                self.result.set_exception(exc)

    def put_warning(self, warning):
        self.task.add_warning(warning)

    def run(self, task):
        def match_file(module, f):
            name, ext = os.path.splitext(f)
            return module == name and ext in ['.py', '.pyc', '.so']

        with self.cv:
            self.cv.wait_for(lambda: self.state == WorkerState.ASSIGNED)
            self.result = AsyncResult()
            self.task = task
            self.task.set_state(TaskState.EXECUTING)
            self.state = WorkerState.EXECUTING
            self.cv.notify_all()

        self.balancer.logger.debug('Actually starting task {0}'.format(
            task.id))

        filename = None
        module_name = inspect.getmodule(task.clazz).__name__
        for dir in self.balancer.dispatcher.plugin_dirs:
            found = False
            try:
                for root, _, files in os.walk(dir):
                    file = first_or_default(
                        lambda f: match_file(module_name, f), files)
                    if file:
                        filename = os.path.join(root, file)
                        found = True
                        break

                if found:
                    break
            except OSError:
                continue

        try:
            self.conn.call_sync(
                'taskproxy.run', {
                    'id': task.id,
                    'user': task.user,
                    'class': task.clazz.__name__,
                    'filename': filename,
                    'args': task.args,
                    'debugger': task.debugger,
                    'environment': task.environment,
                    'hooks': task.hooks,
                })
        except RpcException as e:
            self.balancer.logger.warning(
                'Cannot start task {0} on executor #{1}: {2}'.format(
                    task.id, self.index, str(e)))

            self.balancer.logger.warning(
                'Killing unresponsive task executor #{0} (pid {1})'.format(
                    self.index, self.proc.pid))

            self.terminate()

        try:
            self.result.get()
        except BaseException as e:
            if isinstance(e, OtherException):
                self.balancer.dispatcher.report_error(
                    'Task {0} raised invalid exception'.format(self.task.name),
                    e)

            if isinstance(e, TaskAbortException):
                self.task.set_state(TaskState.ABORTED,
                                    TaskStatus(0, 'aborted'))
            else:
                self.task.error = serialize_error(e)
                self.task.set_state(
                    TaskState.FAILED,
                    TaskStatus(0,
                               str(e),
                               extra={"stacktrace": traceback.format_exc()}))

            with self.cv:
                self.task.ended.set()

                if self.state == WorkerState.EXECUTING:
                    self.state = WorkerState.IDLE
                    self.cv.notify_all()

            self.balancer.task_exited(self.task)
            return

        with self.cv:
            self.task.result = self.result.value
            self.task.set_state(TaskState.FINISHED, TaskStatus(100, ''))
            self.task.ended.set()
            if self.state == WorkerState.EXECUTING:
                self.state = WorkerState.IDLE
                self.cv.notify_all()

        self.balancer.task_exited(self.task)

    def abort(self):
        self.balancer.logger.info("Trying to abort task #{0}".format(
            self.task.id))
        # Try to abort via RPC. If this fails, kill process
        try:
            # If task supports abort protocol we don't need to worry about subtasks - it's task
            # responsibility to kill them
            self.conn.call_sync('taskproxy.abort')
        except RpcException as err:
            self.balancer.logger.warning(
                "Failed to abort task #{0} gracefully: {1}".format(
                    self.task.id, str(err)))
            self.balancer.logger.warning("Killing process {0}".format(
                self.pid))
            self.killed = True
            self.terminate()

            # Now kill all the subtasks
            for subtask in filter(lambda t: t.parent is self.task,
                                  self.balancer.task_list):
                self.balancer.logger.warning(
                    "Aborting subtask {0} because parent task {1} died".format(
                        subtask.id, self.task.id))
                self.balancer.abort(subtask.id)

    def terminate(self):
        try:
            self.proc.terminate()
        except OSError:
            self.balancer.logger.warning(
                'Executor process with PID {0} already dead'.format(
                    self.proc.pid))

    def executor(self):
        while not self.exiting:
            try:
                self.proc = Popen([TASKWORKER_PATH, self.key],
                                  close_fds=True,
                                  preexec_fn=os.setpgrp,
                                  stdout=subprocess.PIPE,
                                  stderr=subprocess.STDOUT)

                self.pid = self.proc.pid
                self.balancer.logger.debug(
                    'Started executor #{0} as PID {1}'.format(
                        self.index, self.pid))
            except OSError:
                self.result.set_exception(
                    TaskException(errno.EFAULT, 'Cannot spawn task executor'))
                self.balancer.logger.error(
                    'Cannot spawn task executor #{0}'.format(self.index))
                return

            for line in self.proc.stdout:
                line = line.decode('utf8')
                self.balancer.logger.debug('Executor #{0}: {1}'.format(
                    self.index, line.strip()))
                if self.task:
                    self.task.output += line

            self.proc.wait()

            with self.cv:
                self.state = WorkerState.STARTING
                self.cv.notify_all()

            if self.proc.returncode == -signal.SIGTERM:
                self.balancer.logger.info(
                    'Executor process with PID {0} was terminated gracefully'.
                    format(self.proc.pid))
            else:
                self.balancer.logger.error(
                    'Executor process with PID {0} died abruptly with exit code {1}'
                    .format(self.proc.pid, self.proc.returncode))

            if self.killed:
                self.result.set_exception(
                    TaskException(errno.EFAULT, 'Task killed'))
            else:
                self.result.set_exception(
                    TaskException(errno.EFAULT, 'Task executor died'))
            gevent.sleep(1)

    def die(self):
        self.exiting = True
        if self.proc:
            self.terminate()
示例#40
0
class Hasher(Operator):
    """
    Calculates the md5 hash of files coming in and outputs them.
    Hashes are calculated on separate threads.
    Directories are skipped.

    Override the process_file method to change the hashing algorithm.
    """

    # Bytes to read at a time while calculating md5 hash
    BytesToRead = 1048576

    def __init__(self,
                 max_workers: int = 10,
                 progress: Optional[HasherProgressReporter] = None):
        """
        max_workers:
            The number of threads we should use while hashing files.
        """
        Operator.__init__(self)
        self.max_workers = max_workers
        self.output_queue = None
        self.futures = {}
        self.lock = None
        self.finish_condition = None
        self.progress = progress
        self.executor = None

    def process(self, input_queue: Queue, output_queue: Queue) -> None:
        """
        input_queue:
            A Queue of File(s).

        output_queue:
            A Queue of File(s) that have their content_hash calculated.
        """
        self.executor = ThreadPoolExecutor(max_workers=self.max_workers)

        self.lock = Lock()
        self.finish_condition = Condition(self.lock)
        self.output_queue = output_queue
        self.futures = {}

        file = input_queue.get()
        while not isinstance(file, TerminateOperand):
            if file.is_directory:
                output_queue.put(file)
            else:
                self.submit_file(self.executor, file)

            file = input_queue.get()

        # Put terminator back on queue
        input_queue.put(file)

        with self.finish_condition:
            self.finish_condition.wait_for(self.is_done)

        self.executor.shutdown()

    def is_done(self):
        """
        Returns True if there are no more files to hash.
        """
        return len(self.futures.keys()) == 0

    def submit_file(self, executor: ThreadPoolExecutor, file: File):
        """
        Submits the file to have its hash calculated. On done, removes itself from the pending/hashing files.
        """
        self.report_progress(file, HasherProgressReporter.State.Started)
        future = executor.submit(self.process_file, file)

        with self.lock:
            self.futures[file.path] = future

        future.add_done_callback(self.finish_file)

    def finish_file(self, future: Future):
        """
        Finishes processing the hashed file, removes it from the pending/hashing files list,
        and puts it on the output queue.
        """
        file, exception = future.result()

        with self.lock:
            del self.futures[file.path]

        self.output_queue.put(file)

        if exception is not None:
            self.report_progress(file, HasherProgressReporter.State.Failed,
                                 exception)
        else:
            self.report_progress(file, HasherProgressReporter.State.Finished)

        with self.finish_condition:
            self.finish_condition.notify()

    def process_file(self, file: File) -> Tuple[File, Optional[BaseException]]:
        """
        Calculates the hash of the file, assumed to not be a directory.
        This method calculates the md5 hash. On failure, the hash is set to the empty string.

        Override this method to calculate a different hash.
        """

        hash_md5 = hashlib.md5()
        try:
            with open(file.path, "rb") as f:
                for chunk in iter(lambda: f.read(Hasher.BytesToRead), b""):
                    hash_md5.update(chunk)

            file.content_hash = hash_md5.hexdigest()
        except OSError as ex:
            file.content_hash = ''
            return file, ex

        return file, None

    def report_progress(self,
                        file: File,
                        state: HasherProgressReporter.State,
                        message: Optional[BaseException] = None):
        if self.progress is not None:
            self.progress.submit_file(file, state, message)
示例#41
0
class _SendReceiveThread(Thread):
    """
    Internal thread to Stream class that runs the asyncio event loop.
    """
    def __init__(self, url, futures, ready_event):
        """constructor for background thread

        :param url (str): the address to connect to the validator on
        :param futures (FutureCollection): The Futures associated with
                messages sent through Stream.send
        :param ready_event (threading.Event): used to notify waiting/asking
               classes that the background thread of Stream is ready after
               a disconnect event.
        """
        super(_SendReceiveThread, self).__init__()
        self._futures = futures
        self._url = url

        self._event_loop = None
        self._sock = None
        self._monitor_sock = None
        self._monitor_fd = None
        self._recv_queue = None
        self._send_queue = None
        self._context = None
        self._ready_event = ready_event
        self._condition = Condition()

    @asyncio.coroutine
    def _receive_message(self):
        """
        internal coroutine that receives messages and puts
        them on the recv_queue
        """
        while True:
            if not self._ready_event.is_set():
                break
            msg_bytes = yield from self._sock.recv()
            message = validator_pb2.Message()
            message.ParseFromString(msg_bytes)
            try:
                self._futures.set_result(
                    message.correlation_id,
                    FutureResult(message_type=message.message_type,
                                 content=message.content))
                self._futures.remove(message.correlation_id)
            except FutureCollectionKeyError:
                # if we are getting an initial message, not a response
                if not self._ready_event.is_set():
                    break
                self._recv_queue.put_nowait(message)

    @asyncio.coroutine
    def _send_message(self):
        """
        internal coroutine that sends messages from the send_queue
        """
        while True:
            if not self._ready_event.is_set():
                break
            msg = yield from self._send_queue.get()
            yield from self._sock.send_multipart([msg.SerializeToString()])

    @asyncio.coroutine
    def _put_message(self, message):
        """
        Puts a message on the send_queue. Not to be accessed directly.
        :param message: protobuf generated validator_pb2.Message
        """
        self._send_queue.put_nowait(message)

    @asyncio.coroutine
    def _get_message(self):
        """
        Gets a message from the recv_queue. Not to be accessed directly.
        """
        with self._condition:
            self._condition.wait_for(lambda: self._recv_queue is not None)
        msg = yield from self._recv_queue.get()

        return msg

    @asyncio.coroutine
    def _monitor_disconnects(self):
        """Monitors the client socket for disconnects
        """
        yield from self._monitor_sock.recv_multipart()
        self._sock.disable_monitor()
        self._monitor_sock.disconnect(self._monitor_fd)
        self._monitor_sock.close(linger=0)
        self._monitor_sock = None
        self._sock.disconnect(self._url)
        self._ready_event.clear()
        LOGGER.debug("monitor socket received disconnect event")
        for future in self._futures.future_values():
            future.set_result(FutureError())
        for task in asyncio.Task.all_tasks(self._event_loop):
            task.cancel()
        self._event_loop.stop()
        self._send_queue = None
        self._recv_queue = None

    def put_message(self, message):
        """
        :param message: protobuf generated validator_pb2.Message
        """
        if not self._ready_event.is_set():
            return
        with self._condition:
            self._condition.wait_for(lambda: self._event_loop is not None and
                                     self._send_queue is not None)
        asyncio.run_coroutine_threadsafe(self._put_message(message),
                                         self._event_loop)

    def get_message(self):
        """
        :return message: concurrent.futures.Future
        """
        with self._condition:
            self._condition.wait_for(lambda: self._event_loop is not None)
        return asyncio.run_coroutine_threadsafe(self._get_message(),
                                                self._event_loop)

    def _tasks_yet_to_be_done(self):
        """Gathers all the tasks (pending coroutines and futures)
        and provides a future that is done when all tasks are done.
        :return: concurrent.futures.Future
        """
        return asyncio.gather(*asyncio.Task.all_tasks(self._event_loop))

    def shutdown(self):
        """Schedules a callback to be run when all of the tasks
         have completed.
        """
        future = self._tasks_yet_to_be_done()
        future.add_done_callback(self._done_callback)

    def _done_callback(self, future):
        """Stops the event loop, closes the socket, and destroys the context

        :param future: concurrent.futures.Future not used
        """
        self._event_loop.call_soon_threadsafe(self._event_loop.stop)
        self._sock.close(linger=0)
        self._monitor_sock.close(linger=0)
        self._context.destroy()

    def run(self):
        first_time = True
        while True:
            if self._event_loop is None:
                self._event_loop = zmq.asyncio.ZMQEventLoop()
                asyncio.set_event_loop(self._event_loop)
            if self._context is None:
                self._context = zmq.asyncio.Context()
            if self._sock is None:
                self._sock = self._context.socket(zmq.DEALER)
            self._sock.identity = _generate_id()[0:16].encode('ascii')
            self._sock.connect(self._url)
            self._monitor_fd = "inproc://monitor.s-{}".format(
                _generate_id()[0:5])
            self._monitor_sock = self._sock.get_monitor_socket(
                zmq.EVENT_DISCONNECTED, addr=self._monitor_fd)
            self._send_queue = asyncio.Queue(loop=self._event_loop)
            self._recv_queue = asyncio.Queue(loop=self._event_loop)
            if first_time is False:
                self._recv_queue.put_nowait(RECONNECT_EVENT)
            with self._condition:
                self._condition.notify_all()
            asyncio.ensure_future(self._send_message(), loop=self._event_loop)
            asyncio.ensure_future(self._receive_message(),
                                  loop=self._event_loop)
            asyncio.ensure_future(self._monitor_disconnects(),
                                  loop=self._event_loop)

            self._ready_event.set()
            self._event_loop.run_forever()
            if first_time is True:
                first_time = False
class SerialScheduler(Scheduler):
    """Serial scheduler which returns transactions in the natural order.

    This scheduler will schedule one transaction at a time (only one may be
    unapplied), in the exact order provided as batches were added to the
    scheduler.

    This scheduler is intended to be used for comparison to more complex
    schedulers - for tests related to performance, correctness, etc.
    """
    def __init__(self,
                 squash_handler,
                 first_state_hash,
                 always_persist,
                 context_handlers=None):
        self._txn_queue = deque()
        self._scheduled_transactions = []
        self._batch_statuses = {}
        self._txn_to_batch = {}
        self._batch_by_id = {}
        self._txn_results = {}
        self._in_progress_transaction = None
        self._final = False
        self._cancelled = False
        self._previous_context_id = None
        self._previous_valid_batch_c_id = None
        self._squash = squash_handler
        self._merkle_root = None
        if context_handlers is not None:
            self._recompute_state_hash_handler = context_handlers[
                'recompute_state']
            self._update_state_hash = context_handlers['update_state']
            self._merkle_root = context_handlers['merkle_root']

        self._condition = Condition()
        # contains all txn.signatures where txn is
        # last in it's associated batch
        self._last_in_batch = []
        self._previous_state_hash = first_state_hash
        # The state hashes here are the ones added in add_batch, and
        # are the state hashes that correspond with block boundaries.
        self._required_state_hashes = {}
        self._already_calculated = False
        self._always_persist = always_persist
        self._state_recompute_context = {'updates': None, 'deletes': None}

    def __del__(self):
        self.cancel()

    def __iter__(self):
        return SchedulerIterator(self, self._condition)

    @property
    def previous_state_hash(self):
        return self._previous_state_hash

    def set_transaction_execution_result(self,
                                         txn_signature,
                                         is_valid,
                                         context_id,
                                         state_changes=None,
                                         events=None,
                                         data=None,
                                         error_message="",
                                         error_data=b""):
        with self._condition:
            if (self._in_progress_transaction is None
                    or self._in_progress_transaction != txn_signature):
                LOGGER.debug(
                    'Received result for %s, but was unscheduled(in_progress=%s)',
                    txn_signature[:8], self._in_progress_transaction)
                return

            self._in_progress_transaction = None

            if txn_signature not in self._txn_to_batch:
                raise ValueError(
                    "transaction not in any batches: {}".format(txn_signature))

            if txn_signature not in self._txn_results:
                LOGGER.debug(
                    'TxnExecutionResult PREV STATE=%s',
                    self._previous_state_hash[:10] if is_valid else None)
                self._txn_results[txn_signature] = TxnExecutionResult(
                    signature=txn_signature,
                    is_valid=is_valid,
                    context_id=context_id if is_valid else None,
                    state_hash=self._previous_state_hash if is_valid else None,
                    state_changes=state_changes,
                    events=events,
                    data=data,
                    error_message=error_message,
                    error_data=error_data)

            batch_signature = self._txn_to_batch[txn_signature]
            if is_valid:
                self._previous_context_id = context_id

            else:
                # txn is invalid, preemptively fail the batch
                self._batch_statuses[batch_signature] = BatchExecutionResult(
                    is_valid=False, state_hash=None)
            if txn_signature in self._last_in_batch:
                LOGGER.debug('tnx=%s last in batch', txn_signature[:8])
                if batch_signature not in self._batch_statuses:
                    # because of the else clause above, txn is valid here
                    self._previous_valid_batch_c_id = self._previous_context_id
                    state_hash = self._calculate_state_root_if_required(
                        batch_id=batch_signature)
                    LOGGER.debug(
                        'calculate_state_root_if_required -> STATE=%s',
                        state_hash[:8] if state_hash is not None else None)
                    self._batch_statuses[
                        batch_signature] = BatchExecutionResult(
                            is_valid=True, state_hash=state_hash)

                else:
                    self._previous_context_id = self._previous_valid_batch_c_id

            self._condition.notify_all()

    def add_batch(self, batch, state_hash=None, required=False):
        with self._condition:
            if self._final:
                raise SchedulerError(
                    "Scheduler is finalized. Cannot take new batches")

            preserve = required
            if not required:
                # If this is the first non-required batch, it is preserved for
                # the schedule to be completed (i.e. no empty schedules in the
                # event of unschedule_incomplete_batches being called before
                # the first batch is completed).
                preserve = _first(
                    filterfalse(lambda sb: sb.required,
                                self._batch_by_id.values())) is None

            batch_signature = batch.header_signature
            self._batch_by_id[batch_signature] = _AnnotatedBatch(
                batch, required=required, preserve=preserve)

            if state_hash is not None:
                self._required_state_hashes[batch_signature] = state_hash
            batch_length = len(batch.transactions)
            LOGGER.debug(
                "SerialScheduler::add_batch: batch=%s tnxs=%s added=%s STATE=%s",
                batch_signature[:8],
                [t.header_signature[:8] for t in batch.transactions],
                len(self._batch_by_id),
                state_hash[:10] if state_hash is not None else None)
            for idx, txn in enumerate(batch.transactions):
                if idx == batch_length - 1:
                    self._last_in_batch.append(txn.header_signature)
                self._txn_to_batch[txn.header_signature] = batch_signature
                self._txn_queue.append(txn)
            self._condition.notify_all()

    def get_batch_execution_result(self, batch_signature):
        with self._condition:
            return self._batch_statuses.get(batch_signature)

    def get_transaction_execution_results(self, batch_signature):
        with self._condition:
            batch_status = self._batch_statuses.get(batch_signature)
            if batch_status is None:
                return None

            annotated_batch = self._batch_by_id.get(batch_signature)
            if annotated_batch is None:
                return None

            results = []
            for txn in annotated_batch.batch.transactions:
                result = self._txn_results.get(txn.header_signature)
                if result is not None:
                    results.append(result)
            return results

    def count(self):
        with self._condition:
            return len(self._scheduled_transactions)

    def get_transaction(self, index):
        with self._condition:
            return self._scheduled_transactions[index]

    def _get_dependencies(self, transaction):
        header = TransactionHeader()
        header.ParseFromString(transaction.header)
        return list(header.dependencies)

    def _set_batch_result(self, txn_id, valid, state_hash):
        if txn_id not in self._txn_to_batch:
            # An incomplete transaction in progress will have been removed
            return

        batch_id = self._txn_to_batch[txn_id]
        self._batch_statuses[batch_id] = BatchExecutionResult(
            is_valid=valid, state_hash=state_hash)
        batch = self._batch_by_id[batch_id].batch
        for txn in batch.transactions:
            if txn.header_signature not in self._txn_results:
                self._txn_results[txn.header_signature] = TxnExecutionResult(
                    txn.header_signature, is_valid=False)

    def _get_batch_result(self, txn_id):
        batch_id = self._txn_to_batch[txn_id]
        return self._batch_statuses[batch_id]

    def _dep_is_known(self, txn_id):
        return txn_id in self._txn_to_batch

    def _in_invalid_batch(self, txn_id):
        if self._txn_to_batch[txn_id] in self._batch_statuses:
            dependency_result = self._get_batch_result(txn_id)
            return not dependency_result.is_valid
        return False

    def _handle_fail_fast(self, txn):
        self._set_batch_result(txn.header_signature, False, None)
        self._check_change_last_good_context_id(txn)

    def _check_change_last_good_context_id(self, txn):
        if txn.header_signature in self._last_in_batch:
            self._previous_context_id = self._previous_valid_batch_c_id

    def next_transaction(self):
        with self._condition:
            if self._in_progress_transaction is not None:
                return None

            txn = None
            while txn is None:
                try:
                    txn = self._txn_queue.popleft()
                except IndexError:
                    if self._final:
                        self._condition.notify_all()
                        raise StopIteration()
                    return None
                # Handle this transaction being invalid based on a
                # dependency.
                if any(
                        self._dep_is_known(d) and self._in_invalid_batch(d)
                        for d in self._get_dependencies(txn)):
                    self._set_batch_result(txn.header_signature, False, None)
                    self._check_change_last_good_context_id(txn=txn)
                    txn = None
                    continue
                # Handle fail fast.
                if self._in_invalid_batch(txn.header_signature):
                    self._handle_fail_fast(txn)
                    txn = None

            self._in_progress_transaction = txn.header_signature
            base_contexts = [] if self._previous_context_id is None else [
                self._previous_context_id
            ]
            # for DAG we should use real merkle root
            real_state_hash = self._merkle_root() if self._merkle_root else ''
            LOGGER.debug('next_transaction: tnx=%s PREV STATE=%s~%s \n',
                         txn.header_signature[:8],
                         self._previous_state_hash[:8], real_state_hash[:8])
            txn_info = TxnInformation(
                txn=txn,
                state_hash=self._previous_state_hash
                if real_state_hash == '' else real_state_hash,
                base_context_ids=base_contexts)
            self._scheduled_transactions.append(txn_info)
            return txn_info

    def unschedule_incomplete_batches(self):
        inprogress_batch_id = None
        with self._condition:
            # remove the in-progress transaction's batch
            if self._in_progress_transaction is not None:
                batch_id = self._txn_to_batch[self._in_progress_transaction]
                annotated_batch = self._batch_by_id[batch_id]

                # if the batch is preserve or there are no completed batches,
                # keep it in the schedule
                if not annotated_batch.preserve:
                    LOGGER.debug('unschedule_incomplete_batches tnx=%s\n',
                                 self._in_progress_transaction)
                    self._in_progress_transaction = None
                else:
                    inprogress_batch_id = batch_id

            def in_schedule(entry):
                (batch_id, annotated_batch) = entry
                return batch_id in self._batch_statuses or annotated_batch.preserve or batch_id == inprogress_batch_id

            incomplete_batches = list(
                filterfalse(in_schedule, self._batch_by_id.items()))

            # clean up the batches, including partial complete information
            for batch_id, annotated_batch in incomplete_batches:
                for txn in annotated_batch.batch.transactions:
                    txn_id = txn.header_signature
                    if txn_id in self._txn_results:
                        del self._txn_results[txn_id]

                    if txn in self._txn_queue:
                        self._txn_queue.remove(txn)

                    del self._txn_to_batch[txn_id]

                self._last_in_batch.remove(
                    annotated_batch.batch.transactions[-1].header_signature)

                del self._batch_by_id[batch_id]

            self._condition.notify_all()

        if incomplete_batches:
            LOGGER.debug('Removed %s incomplete batches=%s from the schedule',
                         len(incomplete_batches), [
                             batch[1][0].header_signature[:8]
                             for batch in incomplete_batches
                         ])

    def num_batches(self):
        with self._condition:
            return len(self._batch_by_id)

    def check_incomplete_batches(self):
        with self._condition:
            LOGGER.debug('Found incomplete batches=%s statuses=%s tnx=%s',
                         len(self._batch_by_id), len(self._batch_statuses),
                         self._in_progress_transaction)
            return len(self._batch_statuses) < len(self._batch_by_id)

        inprogress_batch_id = None
        with self._condition:
            # remove the in-progress transaction's batch
            if self._in_progress_transaction is not None:
                batch_id = self._txn_to_batch[self._in_progress_transaction]
                annotated_batch = self._batch_by_id[batch_id]

                # if the batch is preserve or there are no completed batches,
                # keep it in the schedule
                inprogress_batch_id = batch_id

            def in_schedule(entry):
                (batch_id, annotated_batch) = entry
                return batch_id in self._batch_statuses or True or batch_id == inprogress_batch_id  # annotated_batch.preserve

            incomplete_batches = list(
                filterfalse(in_schedule, self._batch_by_id.items()))

            self._condition.notify_all()

        if incomplete_batches:
            LOGGER.debug('Found %s incomplete batches=%s',
                         len(incomplete_batches), [
                             batch[1][0].header_signature[:8]
                             for batch in incomplete_batches
                         ])
        else:
            LOGGER.debug(
                'Not Found incomplete batches=%s statuses=%s in progress=%s[%s]',
                len(self._batch_by_id), len(self._batch_statuses),
                inprogress_batch_id, self._in_progress_transaction)
        return len(incomplete_batches)

    def finalize(self):
        with self._condition:
            self._final = True
            self._condition.notify_all()

    def _compute_merkle_root(self, required_state_root):
        """Computes the merkle root of the state changes in the context
        corresponding with _last_valid_batch_c_id as applied to
        _previous_state_hash.

        Args:
            required_state_root (str): The merkle root that these txns
                should equal.

        Returns:
            state_hash (str): The merkle root calculated from the previous
                state hash and the state changes from the context_id
        """
        """
        LOGGER.debug('Compute merkle root: PREV=%s STATE=%s always=%s previous_valid_batch_c_id=%s',self._previous_state_hash[:10] if self._previous_state_hash is not None else None,
                      required_state_root[:10] if required_state_root is not None else None,self._always_persist,
                      self._previous_valid_batch_c_id)
        """
        state_hash = None
        if self._previous_valid_batch_c_id is not None:
            publishing_or_genesis = self._always_persist or required_state_root is None
            # FIXME for pool T-PROC
            LOGGER.debug(
                '_compute_merkle_root: _previous_state_hash=%s _previous_valid_batch_c_id=%s',
                self._previous_state_hash[:8],
                self._previous_valid_batch_c_id[:8])
            state_hash, updates, deletes = self._squash(
                state_root=self._previous_state_hash,
                context_ids=[self._previous_valid_batch_c_id],
                persist=self._always_persist,
                clean_up=publishing_or_genesis)
            LOGGER.debug(
                '_compute_merkle_root: publishing_or_genesis=%s state_hash=%s~%s',
                publishing_or_genesis, state_hash[:8],
                required_state_root[:8] if required_state_root else None)
            # save recomputing context
            self._state_recompute_context['updates'] = updates
            self._state_recompute_context['deletes'] = deletes
            if self._always_persist is True:
                return state_hash
            if state_hash == required_state_root or required_state_root == 'arbitration':
                # if new state correct fix it
                # for external cluster block don't control state
                self._squash(state_root=self._previous_state_hash,
                             context_ids=[self._previous_valid_batch_c_id],
                             persist=True,
                             clean_up=True)
        return state_hash

    def recompute_merkle_root(self, previous_state_hash, context):
        """ RE Computes the merkle root of the state changes in the context
        corresponding with _last_valid_batch_c_id as applied to
        _previous_state_hash.
        Use for DAG in chain controller context - for recomputing new root state using actual previous root state 

        Args:
            required_state_root (str): The merkle root that these txns
                should equal.

        Returns:
            state_hash (str): The merkle root calculated from the previous
                state hash and the state changes from the context_id
        """

        state_hash = self._recompute_state_hash_handler(
            previous_state_hash, context)
        LOGGER.debug('RECOMPUTE merkle root: state_hash=%s', state_hash[:10])
        return state_hash

    def update_state_hash(self, old, new):
        # for DAG correct state hash
        LOGGER.debug('update_state_hash: STATE=%s->%s', old[:10], new[:10])
        self._update_state_hash(old, new)

    def _calculate_state_root_if_not_already_done(self):
        if not self._already_calculated:
            if not self._last_in_batch:
                return
            last_txn_signature = self._last_in_batch[-1]
            batch_id = self._txn_to_batch[last_txn_signature]
            required_state_hash = self._required_state_hashes.get(batch_id)
            #LOGGER.debug('Calculate_state_root_if_not_already_done: ...')
            state_hash = self._compute_merkle_root(required_state_hash)
            #LOGGER.debug('Calculate_state_root_if_not_already_done: REQUIRED=%s STATE=%s',required_state_hash[:10] if required_state_hash is not None else None,state_hash[:10] if state_hash is not None else None)
            self._already_calculated = True
            for t_id in self._last_in_batch[::-1]:
                b_id = self._txn_to_batch[t_id]
                if self._batch_statuses[b_id].is_valid:
                    self._batch_statuses[b_id].state_hash = state_hash
                    # found the last valid batch, so break out
                    break

    def _calculate_state_root_if_required(self, batch_id):
        required_state_hash = self._required_state_hashes.get(batch_id)
        LOGGER.debug(
            '_calculate_state_root_if_required: required_state_hash=%s',
            required_state_hash[:8] if required_state_hash else None)
        state_hash = None
        if required_state_hash is not None:
            # not None when we send state_hash argument into add_batch()
            state_hash = self._compute_merkle_root(required_state_hash)
            self._already_calculated = True
        return state_hash

    def get_state_hash_context(self):
        # for DAG only - return context for recompute new state
        return self._state_recompute_context

    def _complete(self):
        return self._final and len(self._txn_results) == len(
            self._txn_to_batch)

    def complete(self, block):
        with self._condition:
            LOGGER.debug('complete...')
            if not self._final:
                return False
            if self._complete():
                self._calculate_state_root_if_not_already_done()
                return True
            if block:
                LOGGER.debug('complete wait_for tnx=%s~%s.',
                             len(self._txn_results), len(self._txn_to_batch))
                self._condition.wait_for(self._complete)
                self._calculate_state_root_if_not_already_done()
                return True
            return False

    def cancel(self):
        with self._condition:
            if not self._cancelled and not self._final \
                    and self._previous_context_id:
                self._squash(state_root=self._previous_state_hash,
                             context_ids=[self._previous_context_id],
                             persist=False,
                             clean_up=True)
                self._cancelled = True
                self._condition.notify_all()

    def is_cancelled(self):
        with self._condition:
            return self._cancelled
示例#43
0
文件: main.py 项目: lukegb/middleware
class Job(object):
    def __init__(self, context):
        self.context = context
        self.anonymous = False
        self.disabled = False
        self.one_shot = False
        self.logger = None
        self.id = None
        self.label = None
        self.parent = None
        self.provides = set()
        self.requires = set()
        self.state = JobState.UNKNOWN
        self.program = None
        self.program_arguments = []
        self.pid = None
        self.pgid = None
        self.sid = None
        self.plist = None
        self.started_at = None
        self.exited_at = None
        self.keep_alive = False
        self.supports_checkin = False
        self.throttle_interval = 0
        self.exit_timeout = 10
        self.stdout_fd = None
        self.stdout_path = None
        self.stderr_fd = None
        self.stderr_path = None
        self.run_at_load = False
        self.user = None
        self.group = None
        self.umask = None
        self.last_exit_code = None
        self.failure_reason = None
        self.status_message = None
        self.environment = {}
        self.respawns = 0
        self.cv = Condition()

    @property
    def children(self):
        return (j for j in self.context.jobs if j.parent is self)

    def load(self, plist):
        self.state = JobState.STOPPED
        self.id = plist.get('ID', str(uuid.uuid4()))
        self.label = plist.get('Label')
        self.program = plist.get('Program')
        self.requires = set(plist.get('Requires', []))
        self.provides = set(plist.get('Provides', []))
        self.program_arguments = plist.get('ProgramArguments', [])
        self.stdout_path = plist.get('StandardOutPath')
        self.stderr_path = plist.get('StandardErrorPath')
        self.disabled = bool(plist.get('Disabled', False))
        self.run_at_load = bool(plist.get('RunAtLoad', False))
        self.keep_alive = bool(plist.get('KeepAlive', False))
        self.one_shot = bool(plist.get('OneShot', False))
        self.supports_checkin = bool(plist.get('SupportsCheckin', False))
        self.throttle_interval = int(plist.get('ThrottleInterval', 0))
        self.environment = plist.get('EnvironmentVariables', {})
        self.user = plist.get('UserName')
        self.group = plist.get('GroupName')
        self.umask = plist.get('Umask')
        self.logger = logging.getLogger('Job:{0}'.format(self.label))

        if first_or_default(lambda j: j.label == self.label,
                            self.context.jobs.values()):
            raise RpcException(
                errno.EEXIST,
                'Job with label {0} already exists'.format(self.label))

        if not self.program:
            self.program = self.program_arguments[0]

        if self.stdout_path:
            self.stdout_fd = os.open(self.stdout_path,
                                     os.O_WRONLY | os.O_APPEND)

        if self.stderr_path:
            self.stderr_fd = os.open(self.stderr_path,
                                     os.O_WRONLY | os.O_APPEND)

        if self.run_at_load:
            self.start()

    def load_anonymous(self, parent, pid):
        try:
            proc = bsd.kinfo_getproc(pid)
            command = proc.command
        except (LookupError, ProcessLookupError):
            # Exited too quickly, but let's add it anyway - it will be removed in next event
            command = 'unknown'

        with self.cv:
            self.parent = parent
            self.id = str(uuid.uuid4())
            self.pid = pid
            self.label = 'anonymous.{0}@{1}'.format(command, self.pid)
            self.logger = logging.getLogger('Job:{0}'.format(self.label))
            self.anonymous = True
            self.state = JobState.RUNNING
            self.cv.notify_all()

    def unload(self):
        self.logger.info('Unloading job')
        del self.context.jobs[self.id]

    def start(self):
        with self.cv:
            if self.state in (JobState.STARTING, JobState.RUNNING):
                return

            if not self.requires <= self.context.provides:
                return

            self.logger.info('Starting job')

            pid = os.fork()
            if pid == 0:
                os.kill(os.getpid(), signal.SIGSTOP)

                if not self.stdout_fd and not self.stderr_fd:
                    self.stdout_fd = self.stderr_fd = os.open(
                        '/var/tmp/{0}.{1}.log'.format(self.label, os.getpid()),
                        os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)

                os.dup2(os.open('/dev/null', os.O_RDONLY), sys.stdin.fileno())
                os.dup2(self.stdout_fd, sys.stdout.fileno())
                os.dup2(self.stderr_fd, sys.stderr.fileno())

                if self.user:
                    user = pwd.getpwnam(self.user)
                    os.setuid(user.pw_uid)

                if self.user:
                    group = grp.getgrnam(self.group)
                    os.setgid(group.gr_gid)

                if not self.program_arguments:
                    self.program_arguments = [self.program]

                bsd.closefrom(3)
                os.setsid()
                env = BASE_ENV.copy()
                env.update(self.environment)
                try:
                    os.execvpe(self.program, self.program_arguments, env)
                except:
                    os._exit(254)

            self.logger.debug('Started as PID {0}'.format(pid))
            self.pid = pid
            self.context.track_pid(self.pid)
            self.set_state(JobState.STARTING)

        os.waitpid(self.pid, os.WUNTRACED)
        os.kill(self.pid, signal.SIGCONT)

    def stop(self):
        with self.cv:
            if self.state == JobState.STOPPED:
                return

            self.logger.info('Stopping job')
            self.set_state(JobState.STOPPING)

            if not self.pid:
                self.set_state(JobState.STOPPED)
                return

            try:
                os.kill(self.pid, signal.SIGTERM)
            except ProcessLookupError:
                # Already dead
                self.set_state(JobState.STOPPED)

            if not self.cv.wait_for(lambda: self.state == JobState.STOPPED,
                                    self.exit_timeout):
                os.killpg(self.pgid, signal.SIGKILL)

            if not self.cv.wait_for(lambda: self.state == JobState.STOPPED,
                                    self.exit_timeout):
                self.logger.error('Unkillable process {0}'.format(self.pid))

    def send_signal(self, signo):
        if not self.pid:
            return

        os.kill(self.pid, signo)

    def checkin(self):
        with self.cv:
            self.logger.info('Service check-in')
            if self.supports_checkin:
                self.set_state(JobState.RUNNING)
                self.context.provide(self.provides)

    def push_status(self, status):
        with self.cv:
            self.status_message = status
            self.cv.notify_all()

        if self.label == 'org.freenas.dispatcher':
            self.context.init_dispatcher()

        self.context.emit_event(
            'serviced.job.status', {
                'ID': self.id,
                'Label': self.label,
                'Reason': self.failure_reason,
                'Anonymous': self.anonymous,
                'Message': self.status_message
            })

    def pid_event(self, ev):
        if ev.fflags & select.KQ_NOTE_EXEC:
            self.pid_exec(ev)

        if ev.fflags & select.KQ_NOTE_EXIT:
            self.pid_exit(ev)

    def pid_exec(self, ev):
        try:
            proc = bsd.kinfo_getproc(self.pid)
            argv = list(proc.argv)
            command = proc.command
        except (LookupError, ProcessLookupError):
            # Exited too quickly, exit info will be catched in another event
            return

        self.logger.debug('Job did exec() into {0}'.format(argv))

        if self.anonymous:
            # Update label for anonymous jobs
            self.label = 'anonymous.{0}@{1}'.format(command, self.pid)
            self.logger = logging.getLogger('Job:{0}'.format(self.label))

        if self.state == JobState.STARTING:
            with self.cv:
                try:
                    self.sid = os.getsid(self.pid)
                    self.pgid = os.getpgid(self.pid)
                except ProcessLookupError:
                    # Exited too quickly after exec()
                    return

                if not self.supports_checkin:
                    self.set_state(JobState.RUNNING)
                    self.context.provide(self.provides)

    def pid_exit(self, ev):
        if not self.parent:
            # We need to reap direct children
            try:
                os.waitpid(self.pid, 0)
            except BaseException as err:
                self.logger.debug('waitpid() error: {0}'.format(err))

        with self.cv:
            self.logger.info('Job has exited with code {0}'.format(ev.data))
            self.pid = None
            self.last_exit_code = ev.data

            if self.state == JobState.STOPPING:
                self.set_state(JobState.STOPPED)
            else:
                if self.one_shot and self.last_exit_code == 0:
                    self.set_state(JobState.ENDED)
                else:
                    self.failure_reason = 'Process died with exit code {0}'.format(
                        self.last_exit_code)
                    self.set_state(JobState.ERROR)

            if self.anonymous:
                del self.context.jobs[self.id]

    def set_state(self, new_state):
        # Must run locked
        if self.state != JobState.RUNNING and new_state == JobState.RUNNING:
            self.context.emit_event('serviced.job.started', {
                'ID': self.id,
                'Label': self.label,
                'Anonymous': self.anonymous
            })

        if self.state != JobState.STOPPED and new_state == JobState.STOPPED:
            self.context.emit_event('serviced.job.stopped', {
                'ID': self.id,
                'Label': self.label,
                'Anonymous': self.anonymous
            })

        if self.state != JobState.ERROR and new_state == JobState.ERROR:
            self.context.emit_event(
                'serviced.job.error', {
                    'ID': self.id,
                    'Label': self.label,
                    'Reason': self.failure_reason,
                    'Anonymous': self.anonymous
                })

        self.state = new_state
        self.cv.notify_all()

    def __getstate__(self):
        ret = {
            'ID': self.id,
            'ParentID': self.parent.id if self.parent else None,
            'Label': self.label,
            'Program': self.program,
            'ProgramArguments': self.program_arguments,
            'Provides': list(self.provides),
            'Requires': list(self.requires),
            'RunAtLoad': self.run_at_load,
            'KeepAlive': self.keep_alive,
            'State': self.state.name,
            'LastExitStatus': self.last_exit_code,
            'PID': self.pid
        }

        if self.failure_reason:
            ret['FailureReason'] = self.failure_reason

        if self.stdout_path:
            ret['StandardOutPath'] = self.stdout_path

        if self.stdout_path:
            ret['StandardErrorPath'] = self.stderr_path

        if self.environment:
            ret['EnvironmentVariables'] = self.environment

        return ret
class _SendReceive(object):
    def __init__(self,
                 connection,
                 address,
                 futures,
                 connections,
                 zmq_identity=None,
                 dispatcher=None,
                 secured=False,
                 server_public_key=None,
                 server_private_key=None,
                 heartbeat=False,
                 heartbeat_interval=10):
        """
        Constructor for _SendReceive.

        Args:
            connection (str): A locally unique identifier for this
                thread's connection. Used to identify the connection
                in the dispatcher for transmitting responses.
            futures (future.FutureCollection): A Map of correlation ids to
                futures
            connections (dict): A dictinary that uses a sha512 hash as the keys
                and either an OutboundConnection or string identiy as values.
            zmq_identity (bytes): Used to idenitfy the dealer socket
            address (str): The endpoint to bind or connect to.
            dispatcher (dispatcher.Dispather): Used to handle messages in a
                coordinated way.s
            secured (bool): Whether or not to start the socket in
                secure mode -- using zmq auth.
            server_public_key (bytes): A public key to use in verifying
                server identity as part of the zmq auth handshake.
            server_private_key (bytes): A private key corresponding to
                server_public_key used by the server socket to sign
                messages are part of the zmq auth handshake.
            heartbeat (bool): Whether or not to send ping messages.
            heartbeat_interval (int): Number of seconds between ping
                messages on an otherwise quiet connection.
        """
        self._connection = connection
        self._dispatcher = dispatcher
        self._futures = futures
        self._address = address
        self._zmq_identity = zmq_identity
        self._secured = secured
        self._server_public_key = server_public_key
        self._server_private_key = server_private_key
        self._heartbeat = heartbeat
        self._heartbeat_interval = heartbeat_interval

        self._event_loop = None
        self._context = None
        self._recv_queue = None
        self._socket = None
        self._condition = Condition()

        self._connected_identities = {}
        self._connections = connections

    @property
    def connection(self):
        return self._connection

    @asyncio.coroutine
    def _send_heartbeat(self):
        with self._condition:
            self._condition.wait_for(lambda: self._socket is not None)

        ping = PingRequest()

        while True:
            if self._socket.getsockopt(zmq.TYPE) == zmq.ROUTER:
                expired = [
                    ident for ident in self._connected_identities
                    if time.time() - self._connected_identities[ident] >
                    self._heartbeat_interval
                ]
                for zmq_identity in expired:
                    message = validator_pb2.Message(
                        correlation_id=_generate_id(),
                        content=ping.SerializeToString(),
                        message_type=validator_pb2.Message.NETWORK_PING)
                    fut = future.Future(message.correlation_id,
                                        message.content,
                                        has_callback=False)
                    self._futures.put(fut)
                    yield from self._send_message(zmq_identity, message)
            yield from asyncio.sleep(self._heartbeat_interval)

    def _received_from_identity(self, zmq_identity):
        self._connected_identities[zmq_identity] = time.time()
        connection_id = hashlib.sha512(zmq_identity).hexdigest()
        if connection_id not in self._connections:
            self._connections[connection_id] = ("ZMQ_Identity", zmq_identity)

    @asyncio.coroutine
    def _receive_message(self):
        """
        Internal coroutine for receiving messages
        """
        zmq_identity = None
        with self._condition:
            self._condition.wait_for(lambda: self._socket is not None)
        while True:
            if self._socket.getsockopt(zmq.TYPE) == zmq.ROUTER:
                zmq_identity, msg_bytes = \
                    yield from self._socket.recv_multipart()
                self._received_from_identity(zmq_identity)
            else:
                msg_bytes = yield from self._socket.recv()

            message = validator_pb2.Message()
            message.ParseFromString(msg_bytes)
            LOGGER.debug("%s receiving %s message: %s bytes", self._connection,
                         get_enum_name(message.message_type),
                         sys.getsizeof(msg_bytes))

            try:
                self._futures.set_result(
                    message.correlation_id,
                    future.FutureResult(message_type=message.message_type,
                                        content=message.content))
            except future.FutureCollectionKeyError:
                if zmq_identity is not None:
                    connection_id = hashlib.sha512(zmq_identity).hexdigest()
                else:
                    connection_id = \
                        hashlib.sha512(self._connection.encode()).hexdigest()
                self._dispatcher.dispatch(self._connection, message,
                                          connection_id)
            else:
                my_future = self._futures.get(message.correlation_id)

                LOGGER.debug("message round "
                             "trip: %s %s",
                             get_enum_name(message.message_type),
                             my_future.get_duration())

                self._futures.remove(message.correlation_id)

    @asyncio.coroutine
    def _send_message(self, identity, msg):
        LOGGER.debug("%s sending %s to %s", self._connection,
                     get_enum_name(msg.message_type),
                     identity if identity else self._address)

        if identity is None:
            message_bundle = [msg.SerializeToString()]
        else:
            message_bundle = [bytes(identity), msg.SerializeToString()]
        yield from self._socket.send_multipart(message_bundle)

    def send_message(self, msg, connection_id=None):
        """
        :param msg: protobuf validator_pb2.Message
        """
        zmq_identity = None
        if connection_id is not None:
            connection_type, connection = self._connections.get(connection_id)
            if connection_type == "ZMQ_Identity":
                zmq_identity = connection

        with self._condition:
            self._condition.wait_for(lambda: self._event_loop is not None)
        asyncio.run_coroutine_threadsafe(self._send_message(zmq_identity, msg),
                                         self._event_loop)

    def setup(self, socket_type):
        """
        :param socket_type: zmq.DEALER or zmq.ROUTER
        """
        if self._secured:
            if self._server_public_key is None or \
                    self._server_private_key is None:
                raise LocalConfigurationError("Attempting to start socket "
                                              "in secure mode, but complete "
                                              "server keys were not provided")

        self._event_loop = zmq.asyncio.ZMQEventLoop()
        asyncio.set_event_loop(self._event_loop)
        self._context = zmq.asyncio.Context()
        self._socket = self._context.socket(socket_type)

        if socket_type == zmq.DEALER:
            self._socket.identity = "{}-{}".format(
                self._zmq_identity,
                hashlib.sha512(uuid.uuid4().hex.encode()).hexdigest()
                [:23]).encode('ascii')

            if self._secured:
                # Generate ephemeral certificates for this connection
                self._socket.curve_publickey, self._socket.curve_secretkey = \
                    zmq.curve_keypair()

                self._socket.curve_serverkey = self._server_public_key

            self._dispatcher.add_send_message(self._connection,
                                              self.send_message)
            self._socket.connect(self._address)
        elif socket_type == zmq.ROUTER:
            if self._secured:
                auth = AsyncioAuthenticator(self._context)
                auth.start()
                auth.configure_curve(domain='*',
                                     location=zmq.auth.CURVE_ALLOW_ANY)

                self._socket.curve_secretkey = self._server_private_key
                self._socket.curve_publickey = self._server_public_key
                self._socket.curve_server = True

            self._dispatcher.add_send_message(self._connection,
                                              self.send_message)
            self._socket.bind(self._address)

        self._recv_queue = asyncio.Queue()

        asyncio.ensure_future(self._receive_message(), loop=self._event_loop)

        if self._heartbeat:
            asyncio.ensure_future(self._send_heartbeat(),
                                  loop=self._event_loop)

        with self._condition:
            self._condition.notify_all()
        self._event_loop.run_forever()

    def stop(self):
        self._dispatcher.remove_send_message(self._connection)
        self._event_loop.stop()
        self._socket.close()
        self._context.term()
示例#45
0
class ComputingNode:
    """
    This is a Aml Node that wait for jobs from Main Nodes, calculate the solution and send them.

    It has an AML DDS Computing Node that implement the communication between nodes.

    This is a Mock over the real AML Computing Node, and the solution to the
    jobs is merely a human untestable routine of modifying a string to upper case.
    """
    def __init__(self, name, domain=0, store_in_file=True):
        """Create a default ComputingNode."""
        logger.construct(f'Creating ComputingNode {name}.')

        # Internal variables
        self.name_ = name
        self.time_range_ms_ = (2000, 7000)
        self.store_in_file_ = store_in_file
        self.jobs_processed_ = []

        # DDS variables
        self.dds_computing_node_ = AmlDdsComputingNode(
            name=name,
            job_process_callback=self._job_process_callback_dds_type,
            domain=domain)
        self.dds_computing_node_.init()

        # Stop variables
        self.stop_ = False
        self.cv_stop_ = Condition()

    def __del__(self):
        """TODO comment."""
        logger.construct(f'Destroying ComputingNode {self.name_}.')
        self.stop()

        if self.store_in_file_:
            self._save_jobs_in_file()

    def _save_jobs_in_file(self, file_name: str = ''):
        """TODO comment."""
        if file_name == '':
            # Check result folder exists
            checkFolders(RESULTS_FOLDER)
            file_name = f'{RESULTS_FOLDER}/solved_jobs_{"".join(self.name_.split())}.aml'

        logger.debug(
            f'Storing jobs history from node {self.name_} in file {file_name}')

        # Open file and write down solution
        with open(f'{file_name}', 'w') as file:

            # Write down pending jobs
            file.write('JOBS ANSWERED\n')
            for job, solution_job in self.jobs_processed_:
                file.write(f'{job} : {solution_job}\n')

    def stop(self):
        """
        Stop this node.

        Stop its internal DDS entities.
        Set variable stop as true.
        Awake threads waiting for stop.
        """
        logger.construct(f'Stopping ComputingNode {self.name_}.')

        # Stop DDS module
        self.dds_computing_node_.stop()

        # Set entity as stopped
        self.cv_stop_.acquire()
        self.stop_ = True
        self.cv_stop_.notify_all()
        self.cv_stop_.release()

    def run(self):
        """TODO comment."""
        logger.debug(f'Run ComputingNode {self.name_}.')

        # Thread to generate jobs randomly
        job_calulator_thread = \
            Thread(target=self._calculate_job_solution_routine)
        job_calulator_thread.start()

        # Wait to stop
        self.cv_stop_.acquire()
        self.cv_stop_.wait_for(predicate=lambda: self.stop_)
        self.cv_stop_.release()

        # Wait for all threads
        job_calulator_thread.join()

    def _calculate_job_solution_routine(self, seed=4321):
        """
        Request for random jobs.

        Create random jobs to send them to Computing nodes.
        The time between each node is also random between the range given.

        Once stop is set, it will not stop until the time range has elapsed.

        :param seed: seed for random generator
        :param time_range: range time elapsed in milliseconds to create a new job
        """
        while not self.stop_:

            try:
                # Try to answer to a new job
                logger.info(f'{self.name_} waiting to process a job.')
                self.dds_computing_node_.process_job()

            except StopException:
                logger.info(
                    f'{self.name_} stopped while calculating a solution.')
                break

            except TimeoutException:
                logger.info(f'{self.name_} timeout waiting for client.')
                continue

    def _job_process_callback_dds_type(self, job_data):
        """TODO comment."""
        return self._job_process_callback(
            Job.from_dds_data_type(job_data)).to_dds_data_type()

    final_sentences_ = [
        'with me. ', 'alone. ', 'very well. ', 'badly.', 'like a boss.'
    ]

    def _job_process_callback(self, job: Job):
        """TODO comment."""
        # Sleep to simulate long calculation
        sleep_time = random.randint(self.time_range_ms_[0],
                                    self.time_range_ms_[1]) / 1000
        logger.user(
            f'{self.name_} calculating result for job {job.index} (approximately {sleep_time} ms).'
        )
        time.sleep(sleep_time)

        # Generate solution
        logger.user(f'{self.name_} finish calculating job {job.index}.')

        solution = ComputingNode.__random_solution_generator(job)

        # Storing processed job
        self.jobs_processed_.append((job, solution))

        return solution

    def __random_solution_generator(job: Job) -> str:
        return JobSolution(
            job.index,
            job.data + random.choice(ComputingNode.final_sentences_))
示例#46
0
文件: main.py 项目: erinix/middleware
class Job(object):
    def __init__(self, context):
        self.context = context
        self.anonymous = False
        self.disabled = False
        self.one_shot = False
        self.logger = None
        self.id = None
        self.label = None
        self.parent = None
        self.provides = set()
        self.requires = set()
        self.state = JobState.UNKNOWN
        self.program = None
        self.program_arguments = []
        self.pid = None
        self.pgid = None
        self.sid = None
        self.plist = None
        self.started_at = None
        self.exited_at = None
        self.keep_alive = False
        self.supports_checkin = False
        self.throttle_interval = 0
        self.exit_timeout = 10
        self.stdout_fd = None
        self.stdout_path = None
        self.stderr_fd = None
        self.stderr_path = None
        self.run_at_load = False
        self.user = None
        self.group = None
        self.umask = None
        self.last_exit_code = None
        self.failure_reason = None
        self.status_message = None
        self.environment = {}
        self.respawns = 0
        self.cv = Condition()

    @property
    def children(self):
        return (j for j in self.context.jobs if j.parent is self)

    def load(self, plist):
        self.state = JobState.STOPPED
        self.id = plist.get('ID', str(uuid.uuid4()))
        self.label = plist.get('Label')
        self.program = plist.get('Program')
        self.requires = set(plist.get('Requires', []))
        self.provides = set(plist.get('Provides', []))
        self.program_arguments = plist.get('ProgramArguments', [])
        self.stdout_path = plist.get('StandardOutPath')
        self.stderr_path = plist.get('StandardErrorPath')
        self.disabled = bool(plist.get('Disabled', False))
        self.run_at_load = bool(plist.get('RunAtLoad', False))
        self.keep_alive = bool(plist.get('KeepAlive', False))
        self.one_shot = bool(plist.get('OneShot', False))
        self.supports_checkin = bool(plist.get('SupportsCheckin', False))
        self.throttle_interval = int(plist.get('ThrottleInterval', 0))
        self.environment = plist.get('EnvironmentVariables', {})
        self.user = plist.get('UserName')
        self.group = plist.get('GroupName')
        self.umask = plist.get('Umask')
        self.logger = logging.getLogger('Job:{0}'.format(self.label))

        if first_or_default(lambda j: j.label == self.label, self.context.jobs.values()):
            raise RpcException(errno.EEXIST, 'Job with label {0} already exists'.format(self.label))

        if not self.program:
            self.program = self.program_arguments[0]

        if self.stdout_path:
            self.stdout_fd = os.open(self.stdout_path, os.O_WRONLY | os.O_APPEND)

        if self.stderr_path:
            self.stderr_fd = os.open(self.stderr_path, os.O_WRONLY | os.O_APPEND)

        if self.run_at_load:
            self.start()

    def load_anonymous(self, parent, pid):
        try:
            proc = bsd.kinfo_getproc(pid)
            command = proc.command
        except (LookupError, ProcessLookupError):
            # Exited too quickly, but let's add it anyway - it will be removed in next event
            command = 'unknown'

        with self.cv:
            self.parent = parent
            self.id = str(uuid.uuid4())
            self.pid = pid
            self.label = 'anonymous.{0}@{1}'.format(command, self.pid)
            self.logger = logging.getLogger('Job:{0}'.format(self.label))
            self.anonymous = True
            self.state = JobState.RUNNING
            self.cv.notify_all()

    def unload(self):
        self.logger.info('Unloading job')
        del self.context.jobs[self.id]

    def start(self):
        with self.cv:
            if self.state in (JobState.STARTING, JobState.RUNNING):
                return

            if not self.requires <= self.context.provides:
                return

            self.logger.info('Starting job')

            pid = os.fork()
            if pid == 0:
                os.kill(os.getpid(), signal.SIGSTOP)

                if not self.stdout_fd and not self.stderr_fd:
                    self.stdout_fd = self.stderr_fd = os.open('/var/tmp/{0}.{1}.log'.format(
                        self.label, os.getpid()),
                        os.O_WRONLY | os.O_CREAT | os.O_TRUNC,
                        0o600
                    )

                os.dup2(os.open('/dev/null', os.O_RDONLY), sys.stdin.fileno())
                os.dup2(self.stdout_fd, sys.stdout.fileno())
                os.dup2(self.stderr_fd, sys.stderr.fileno())

                if self.user:
                    user = pwd.getpwnam(self.user)
                    os.setuid(user.pw_uid)

                if self.user:
                    group = grp.getgrnam(self.group)
                    os.setgid(group.gr_gid)

                if not self.program_arguments:
                    self.program_arguments = [self.program]

                bsd.closefrom(3)
                os.setsid()
                env = BASE_ENV.copy()
                env.update(self.environment)
                try:
                    os.execvpe(self.program, self.program_arguments, env)
                except:
                    os._exit(254)

            self.logger.debug('Started as PID {0}'.format(pid))
            self.pid = pid
            self.context.track_pid(self.pid)
            self.set_state(JobState.STARTING)

        os.waitpid(self.pid, os.WUNTRACED)
        os.kill(self.pid, signal.SIGCONT)

    def stop(self):
        with self.cv:
            if self.state == JobState.STOPPED:
                return

            self.logger.info('Stopping job')
            self.set_state(JobState.STOPPING)

            if not self.pid:
                self.set_state(JobState.STOPPED)
                return

            try:
                os.kill(self.pid, signal.SIGTERM)
            except ProcessLookupError:
                # Already dead
                self.set_state(JobState.STOPPED)

            if not self.cv.wait_for(lambda: self.state == JobState.STOPPED, self.exit_timeout):
                os.killpg(self.pgid, signal.SIGKILL)

            if not self.cv.wait_for(lambda: self.state == JobState.STOPPED, self.exit_timeout):
                self.logger.error('Unkillable process {0}'.format(self.pid))

    def send_signal(self, signo):
        if not self.pid:
            return

        os.kill(self.pid, signo)

    def checkin(self):
        with self.cv:
            self.logger.info('Service check-in')
            if self.supports_checkin:
                self.set_state(JobState.RUNNING)
                self.context.provide(self.provides)

    def push_status(self, status):
        with self.cv:
            self.status_message = status
            self.cv.notify_all()

        if self.label == 'org.freenas.dispatcher':
            self.context.init_dispatcher()

        self.context.emit_event('serviced.job.status', {
            'ID': self.id,
            'Label': self.label,
            'Reason': self.failure_reason,
            'Anonymous': self.anonymous,
            'Message': self.status_message
        })

    def pid_event(self, ev):
        if ev.fflags & select.KQ_NOTE_EXEC:
            self.pid_exec(ev)

        if ev.fflags & select.KQ_NOTE_EXIT:
            self.pid_exit(ev)

    def pid_exec(self, ev):
        try:
            proc = bsd.kinfo_getproc(self.pid)
            argv = list(proc.argv)
            command = proc.command
        except (LookupError, ProcessLookupError):
            # Exited too quickly, exit info will be catched in another event
            return

        self.logger.debug('Job did exec() into {0}'.format(argv))

        if self.anonymous:
            # Update label for anonymous jobs
            self.label = 'anonymous.{0}@{1}'.format(command, self.pid)
            self.logger = logging.getLogger('Job:{0}'.format(self.label))

        if self.state == JobState.STARTING:
            with self.cv:
                try:
                    self.sid = os.getsid(self.pid)
                    self.pgid = os.getpgid(self.pid)
                except ProcessLookupError:
                    # Exited too quickly after exec()
                    return

                if not self.supports_checkin:
                    self.set_state(JobState.RUNNING)
                    self.context.provide(self.provides)

    def pid_exit(self, ev):
        if not self.parent:
            # We need to reap direct children
            try:
                os.waitpid(self.pid, 0)
            except BaseException as err:
                self.logger.debug('waitpid() error: {0}'.format(err))

        with self.cv:
            self.logger.info('Job has exited with code {0}'.format(ev.data))
            self.pid = None
            self.last_exit_code = ev.data

            if self.state == JobState.STOPPING:
                self.set_state(JobState.STOPPED)
            else:
                if self.one_shot and self.last_exit_code == 0:
                    self.set_state(JobState.ENDED)
                else:
                    self.failure_reason = 'Process died with exit code {0}'.format(self.last_exit_code)
                    self.set_state(JobState.ERROR)

            if self.anonymous:
                del self.context.jobs[self.id]

    def set_state(self, new_state):
        # Must run locked
        if self.state != JobState.RUNNING and new_state == JobState.RUNNING:
            self.context.emit_event('serviced.job.started', {
                'ID': self.id,
                'Label': self.label,
                'Anonymous': self.anonymous
            })

        if self.state != JobState.STOPPED and new_state == JobState.STOPPED:
            self.context.emit_event('serviced.job.stopped', {
                'ID': self.id,
                'Label': self.label,
                'Anonymous': self.anonymous
            })

        if self.state != JobState.ERROR and new_state == JobState.ERROR:
            self.context.emit_event('serviced.job.error', {
                'ID': self.id,
                'Label': self.label,
                'Reason': self.failure_reason,
                'Anonymous': self.anonymous
            })

        self.state = new_state
        self.cv.notify_all()

    def __getstate__(self):
        ret = {
            'ID': self.id,
            'ParentID': self.parent.id if self.parent else None,
            'Label': self.label,
            'Program': self.program,
            'ProgramArguments': self.program_arguments,
            'Provides': list(self.provides),
            'Requires': list(self.requires),
            'RunAtLoad': self.run_at_load,
            'KeepAlive': self.keep_alive,
            'State': self.state.name,
            'LastExitStatus': self.last_exit_code,
            'PID': self.pid
        }

        if self.failure_reason:
            ret['FailureReason'] = self.failure_reason

        if self.stdout_path:
            ret['StandardOutPath'] = self.stdout_path

        if self.stdout_path:
            ret['StandardErrorPath'] = self.stderr_path

        if self.environment:
            ret['EnvironmentVariables'] = self.environment

        return ret
示例#47
0
class _SendReceiveThread(Thread):
    """
    Internal thread to Stream class that runs the asyncio event loop.
    """
    def __init__(self, url, futures):
        super(_SendReceiveThread, self).__init__()
        self._futures = futures
        self._url = url

        self._event_loop = None
        self._sock = None
        self._recv_queue = None
        self._send_queue = None
        self._context = None

        self._condition = Condition()

    @asyncio.coroutine
    def _receive_message(self):
        """
        internal coroutine that receives messages and puts
        them on the recv_queue
        """
        with self._condition:
            self._condition.wait_for(lambda: self._sock is not None)
        while True:
            msg_bytes = yield from self._sock.recv()
            message_list = validator_pb2.MessageList()
            message_list.ParseFromString(msg_bytes)
            for message in message_list.messages:
                try:
                    self._futures.set_result(
                        message.correlation_id,
                        FutureResult(message_type=message.message_type,
                                     content=message.content))
                except FutureCollectionKeyError:
                    # if we are getting an initial message, not a response
                    self._recv_queue.put_nowait(message)

    @asyncio.coroutine
    def _send_message(self):
        """
        internal coroutine that sends messages from the send_queue
        """
        with self._condition:
            self._condition.wait_for(lambda: self._send_queue is not None and
                                     self._sock is not None)
        while True:
            msg = yield from self._send_queue.get()
            yield from self._sock.send_multipart([msg.SerializeToString()])

    @asyncio.coroutine
    def _put_message(self, message):
        """
        puts a message on the send_queue. Not to be accessed directly.
        :param message: protobuf generated validator_pb2.Message
        """
        with self._condition:
            self._condition.wait_for(lambda: self._send_queue is not None)
        self._send_queue.put_nowait(message)

    @asyncio.coroutine
    def _get_message(self):
        """
        get a message from the recv_queue. Not to be accessed directly.
        """
        with self._condition:
            self._condition.wait_for(lambda: self._recv_queue is not None)
        msg = yield from self._recv_queue.get()

        return msg

    def put_message(self, message):
        """
        :param message: protobuf generated validator_pb2.Message
        """
        with self._condition:
            self._condition.wait_for(lambda: self._event_loop is not None)
        asyncio.run_coroutine_threadsafe(self._put_message(message),
                                         self._event_loop)

    def get_message(self):
        """
        :return message: protobuf generated validator_pb2.Message
        """
        with self._condition:
            self._condition.wait_for(lambda: self._event_loop is not None)
        return asyncio.run_coroutine_threadsafe(self._get_message(),
                                                self._event_loop).result()

    def _exit_tasks(self):
        for task in asyncio.Task.all_tasks(self._event_loop):
            task.cancel()

    def shutdown(self):
        self._exit_tasks()
        self._event_loop.call_soon_threadsafe(self._event_loop.stop)
        self._sock.close()
        self._context.destroy()

    def run(self):
        self._event_loop = zmq.asyncio.ZMQEventLoop()
        asyncio.set_event_loop(self._event_loop)
        self._context = zmq.asyncio.Context()
        self._sock = self._context.socket(zmq.DEALER)
        self._sock.identity = "{}-{}".format(self.__class__.__name__,
                                             os.getpid()).encode('ascii')
        self._sock.connect('tcp://' + self._url)
        self._send_queue = asyncio.Queue()
        self._recv_queue = asyncio.Queue()
        with self._condition:
            self._condition.notify_all()
        asyncio.ensure_future(self._send_message(), loop=self._event_loop)
        asyncio.ensure_future(self._receive_message(), loop=self._event_loop)
        self._event_loop.run_forever()
示例#48
0
class AntColony:
    def __init__(self, graph, lockers, lockers_dict, delivers, delivers_dict,
                 demands, num_ants, num_iterations):
        self.graph = graph
        self.lockers = lockers
        self.lockers_dict = lockers_dict
        self.delivers = delivers
        self.delivers_dict = delivers_dict
        self.demands = list(demands)
        self.ignore_locker_demand()
        self.num_ants = num_ants
        self.num_iterations = num_iterations
        self.Alpha = 0.1

        # condition var
        self.cv = Condition()
        self.create_ants()
        self.reset()

    def reset(self):
        self.iter_count = 0
        self.best_path_cost = float('inf')
        self.best_path_routes = None
        self.best_path_mat = None
        self.last_best_path_iteration = 0

    def start(self):
        self.reset()

        while self.iter_count < self.num_iterations:
            self.iteration()
            # wait until all ants finishing their jobs
            with self.cv:
                self.cv.wait_for(self.end)
                self.avg_path_cost /= len(self.ants)
                logger.info(
                    "=================Iteration {} finish=================".
                    format(self.iter_count))
                logger.info("Best path routes in iteration {} is".format(
                    self.iter_count))
                if self.best_path_routes != None:
                    for deliver in self.best_path_routes.keys():
                        logger.info("Deliver {} : {}".format(
                            deliver, self.best_path_routes[deliver]))
                    for locker in self.lockers:
                        logger.info("Locker {} scheme: {}".format(
                            locker.id,
                            self.locker_scheme(locker, self.best_path_routes)))
                    logger.info("cost : {}".format(self.best_path_cost))
                    self.global_updating_rule()
                else:
                    logger.info("Failed to find path routes.")

        # kill all ants
        for ant in self.ants:
            ant.kill()

    def end(self):
        return self.finish_ant_count == len(self.ants)

    def ignore_locker_demand(self):
        for locker in self.lockers:
            self.demands[locker.pos] = 0

    def create_ants(self):
        self.ants = []
        for i in range(0, self.num_ants):
            #ant = Ant(i, random.randint(0, self.graph.nodes_num - 1), self)
            ant = Ant(i, self)
            self.ants.append(ant)
            ant.start()

    def iteration(self):
        self.avg_path_cost = 0
        self.finish_ant_count = 0
        self.iter_count += 1
        logger.debug(
            "=================Iteration {} start=================".format(
                self.iter_count))
        for ant in self.ants:
            logger.debug("Ant {} started".format(ant.id))
            ant.begin_colony()

    def update(self, ant):
        with self.cv:
            self.finish_ant_count += 1
            self.avg_path_cost += ant.path_cost

            if ant.path_cost < self.best_path_cost:
                self.best_path_cost = ant.path_cost
                self.best_path_routes = ant.routes
                self.best_path_mat = ant.path_mat
                self.last_best_path_iteration = self.iter_count

            # release the lock
            self.cv.notify()

    def global_updating_rule(self):
        self.graph.lock.acquire()

        delta = 1.0 / self.best_path_cost

        # for i in range(0, len(self.best_path_mat)):
        #     logger.info(self.best_path_mat[i])

        for r in range(0, self.graph.nodes_num):
            for s in range(0, self.graph.nodes_num):
                if (r == s):
                    continue
                delta_rs = 0
                if self.best_path_mat[r][s] == 1:
                    delta_rs = delta
                evaporation = (1 - self.Alpha) * self.graph.tau(r, s)
                deposition = self.Alpha * delta_rs
                self.graph.update_tau(r, s, evaporation + deposition)

        self.graph.lock.release()

    def deliver_locker(self, deliver_id):
        deliver = self.delivers_dict[deliver_id]
        return self.deliver_locker(deliver)

    def deliver_locker(self, deliver):
        return self.lockers_dict[deliver.locker_id]

    def locker_scheme(self, locker, path_routes):
        capacity = 0
        for deliver_id in locker.delivers:
            if deliver_id in path_routes.keys():
                path = path_routes[deliver_id]
                for pack in path:
                    capacity += pack.capacity
        return capacity
class _ContextFuture(object):
    """Controls access to bytes set in the _result variable. The booleans
     that are flipped in set_result, based on whether the value is being set
     from the merkle tree or a direct set on the context manager are needed
     to later determine whether the value was set in that context or was
     looked up as a new address location from the merkle tree and then only
     read from, not set.

    In any context the lifecycle of a _ContextFuture can be several paths:

    Input:
    Address not in base:
      F -----> get from merkle database ----> get from the context
    Address in base:
            |---> set (F)
      F --->|
            |---> get
    Output:
      Doesn't exist ----> set address in context (F)

    Input + Output:
    Address not in base:

                             |-> set
      F |-> get from merkle -|
        |                    |-> get
        |                    |
        |                    |-> noop
        |--> set Can happen before the pre-fetch operation


                     |-> set (F) ---> get
                     |
                     |-> set (F) ----> set
                     |
    Address in base: |-> set (F)
      Doesn't exist -|
                     |-> get Future doesn't exit in context
                     |
                     |-> get ----> set (F)

    """

    def __init__(self, address, result=None, wait_for_tree=False):
        self.address = address
        self._result = result
        self._result_set_in_context = False
        self._condition = Condition()
        self._wait_for_tree = wait_for_tree
        self._tree_has_set = False
        self._read_only = False
        self._deleted = False

    def make_read_only(self):
        with self._condition:
            if self._wait_for_tree and not self._result_set_in_context:
                self._condition.wait_for(
                    lambda: self._tree_has_set or self._result_set_in_context)

            self._read_only = True

    def set_in_context(self):
        with self._condition:
            return self._result_set_in_context

    def deleted_in_context(self):
        with self._condition:
            return self._deleted

    def result(self):
        """Return the value at an address, optionally waiting until it is
        set from the context_manager, or set based on the pre-fetch mechanism.

        Returns:
            (bytes): The opaque value for an address.
        """

        if self._read_only:
            return self._result
        with self._condition:
            if self._wait_for_tree and not self._result_set_in_context:
                self._condition.wait_for(
                    lambda: self._tree_has_set or self._result_set_in_context)
            return self._result

    def set_deleted(self):
        self._result_set_in_context = False
        self._deleted = True

    def set_result(self, result, from_tree=False):
        """Set the addresses's value unless the future has been declared
        read only.

        Args:
            result (bytes): The value at an address.
            from_tree (bool): Whether the value is being set by a read from
                the merkle tree.

        Returns:
            None
        """

        if self._read_only:
            if not from_tree:
                LOGGER.warning("Tried to set address %s on a"
                               " read-only context.",
                               self.address)
            return

        with self._condition:
            if self._read_only:
                if not from_tree:
                    LOGGER.warning("Tried to set address %s on a"
                                   " read-only context.",
                                   self.address)
                return
            if from_tree:
                # If the result has not been set in the context, overwrite the
                # value with the value from the merkle tree. Otherwise, do
                # nothing.
                if not self._result_set_in_context:
                    self._result = result
                    self._tree_has_set = True
            else:
                self._result = result
                self._result_set_in_context = True
                self._deleted = False

            self._condition.notify_all()
示例#50
0
class WebSocket(object):
    def __init__(self, client: socket):
        self.__client = client

        self.__input: Deque[Union[str, bytes]] = deque()
        self.__output: Deque[Tuple[Optional[Promise], int, bytes]] = deque()
        self.__ping: Dict[bytes, Promise] = {}

        self.__lock = RLock()
        self.__cv = Condition(self.__lock)

        self.__closed = False
        self.__reason = None
        self.__code = None

    def __del__(self):
        self.close()

    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def __repr__(self):
        return f'<WebSocket object at {hex(id(self))}>'

    def accept(self,
               timeout: Optional[float] = None,
               validate: Optional[Callable[[...], bool]] = None) -> None:
        response = None
        error = None

        try:
            client = self.__client
            client.setblocking(False)

            timer = Timer(timeout)
            data = b''

            while len(data) < 0x1000:
                left = timer.left
                if left is not None and left < 0:
                    raise RequestTimeout('Operation timed out')
                r, w, x = select([client], [], [], left)
                if self.__client not in r:
                    continue
                data = client.recv(0x1000 - len(data))
                if len(data) > 4 and data[-4:] == b'\r\n\r\n':
                    break

            try:
                method, uri, version, headers = _parse_http(data.decode())
            except Exception as exc:
                raise BadRequest(*exc.args)

            connection = headers.get('Connection', None)
            upgrade = headers.get('Upgrade', None)
            key = headers.get('Sec-WebSocket-Key', None)

            if method != 'GET':
                raise MethodNotAllowed(
                    f'Expected method is GET, but got {method}')
            if version != '1.1':
                raise UpgradeRequired(
                    f'Expected HTTP version is 1.1, but got {version}')
            if connection != 'Upgrade':
                raise UpgradeRequired(
                    f'Expected Connection value is "Upgrade", but got "{connection}"'
                )
            if upgrade != 'websocket':
                raise BadRequest(
                    f'Expected Upgrade value is "websocket", but got "{upgrade}"'
                )
            if key is None:
                raise BadRequest(f'Sec-WebSocket-Key is missed')
            if validate is not None and not validate(uri, headers):
                raise BadRequest(f'Validation failed')

            response = (f'HTTP/1.1 101 WebSocket Upgrade\r\n'
                        f'Connection: Upgrade\r\n'
                        f'Upgrade: websocket\r\n'
                        f'Sec-WebSocket-Accept: {_create_accept_key(key)}\r\n'
                        f'\r\n')
        except BadRequest as exc:
            error = exc
            response = (f'HTTP/1.1 400 Bad Request\r\n' f'\r\n')
        except MethodNotAllowed as exc:
            error = exc
            response = (f'HTTP/1.1 405 Method Not Allowed\r\n' f'\r\n')
        except RequestTimeout as exc:
            error = exc
            response = (f'HTTP/1.1 408 Request Timeout\r\n' f'\r\n')
        except UpgradeRequired as exc:
            error = exc
            response = (f'HTTP/1.1 426 Upgrade Required\r\n'
                        f'Connection: Upgrade\r\n'
                        f'Upgrade: websocket\r\n'
                        f'\r\n')
        except Exception as exc:
            error = exc

        if response:
            try:
                self.__client.sendall(response.encode())
            except Exception as exc:
                if not error:
                    error = exc

        if error:
            self.__close()
            raise error
        else:
            Thread(target=self.__run).start()

    def __send_all(self, data: Union[bytes, bytearray]) -> None:
        self.__client.sendall(data)

    def __recv_all(self, n: int) -> bytes:
        data = b''
        while len(data) < n:
            data += self.__client.recv(n - len(data))
        return data

    def __send_packet(self, opcode: int, data: Union[bytes,
                                                     bytearray]) -> None:
        self.__client.setblocking(True)

        header = bytearray()
        header.append(0x80 | opcode)

        length = len(data)
        if length < 126:
            header.append(length)
        elif length < 0x10000:
            header.append(126)
            header.extend(length.to_bytes(2, 'big'))
        else:
            header.append(127)
            header.extend(length.to_bytes(8, 'big'))

        self.__send_all(header)
        self.__send_all(data)
        logger.debug(
            f'{self} sent {len(data)} bytes, '
            f'FIN=1, '
            f'OPCODE=0x{opcode:02x} ({OPCODE_NAME.get(opcode, "unknown")})')

    def __recv_packet(self) -> Tuple[int, int, bytes]:
        self.__client.setblocking(True)

        octet, = self.__recv_all(1)
        fin = octet >> 7
        opcode = octet & 0x0F

        octet, = self.__recv_all(1)
        mask = octet >> 7
        length = octet & 0x7F

        if length == 126:
            length = int.from_bytes(self.__recv_all(2), 'big')
        elif length == 127:
            length = int.from_bytes(self.__recv_all(8), 'big')

        if mask:
            key = self.__recv_all(4)
        else:
            key = None

        data = self.__recv_all(length)
        if key:
            data = bytes(data[i] ^ key[i % 4] for i in range(len(data)))

        logger.debug(
            f'{self} received {len(data)} bytes, '
            f'FIN={fin}, '
            f'OPCODE=0x{opcode:02x} ({OPCODE_NAME.get(opcode, "unknown")})')
        return fin, opcode, data

    def __send_message(self, opcode: int, data: Union[bytes,
                                                      bytearray]) -> None:
        self.__send_packet(opcode, data)

    def __recv_message(self) -> Tuple[int, bytes]:
        fin, opcode, data = self.__recv_packet()

        result_opcode = opcode
        result_data = data
        while not fin:
            fin, opcode, data = self.__recv_packet()
            assert opcode == 0
            result_data += data
        return result_opcode, result_data

    def __run(self):
        while True:
            with self.__lock:
                if self.__closed:
                    break

            r, w, x = select([self.__client], [self.__client], [], 0.05)

            with self.__lock:
                if len(self.__output) and self.__client in w:
                    try:
                        while len(self.__output):
                            promise, opcode, data = self.__output.popleft()
                            if not promise or not promise.cancelled:
                                self.__send_message(opcode, data)
                            if promise:
                                promise.set_result(None)
                    except Exception as exc:
                        logger.error(f'{self} send error', exc_info=exc)
                        self.__close()
                        continue

                if self.__client in r:
                    try:
                        opcode, data = self.__recv_message()
                    except Exception as exc:
                        logger.error(f'{self} recv error', exc_info=exc)
                        self.__close(REASON_PROTOCOL_ERROR)
                        continue

                    if opcode == OPCODE_PING:
                        self.__output.append((None, OPCODE_PONG, data))
                    elif opcode == OPCODE_PONG:
                        promise = self.__ping.pop(data, None)
                        if promise:
                            promise.set_result(None)
                    elif opcode == OPCODE_CLOSE:
                        self.__code = int.from_bytes(data[:2], 'big')
                        self.__reason = data.decode(errors='ignore')
                        self.__close()
                    elif opcode == OPCODE_BINARY:
                        self.__input.append(data)
                        self.__cv.notify_all()
                    elif opcode == OPCODE_TEXT:
                        self.__input.append(data.decode(errors='ignore'))
                        self.__cv.notify_all()
                    else:
                        self.__close(REASON_NOT_SUPPORTED)

    def __close(self, reason: Optional[int] = None) -> None:
        with self.__lock:
            if self.__closed:
                return

            if reason:
                message = REASON_MESSAGE.get(reason, '')
                self.__code = reason
                self.__message = message
                try:
                    self.__send_message(
                        OPCODE_CLOSE,
                        reason.to_bytes(2, 'big') + message.encode())
                except Exception as exc:
                    logger.warning(f'{self} close error', exc_info=exc)
                    pass

            try:
                self.__client.shutdown(SHUT_WR)
            except Exception as exc:
                logger.warning(f'{self} shutdown error', exc_info=exc)

            self.__client.close()
            self.__closed = True

            self.__cv.notify_all()

    def send(self,
             data: Union[str, bytes, bytearray],
             timeout: Optional[float] = None) -> None:
        with self.__lock:
            if self.__closed:
                raise IOError('WebSocket is closed')

            promise = Promise(self.__lock)
            if isinstance(data, str):
                opcode = OPCODE_TEXT
                data = data.encode()
            else:
                opcode = OPCODE_BINARY

            self.__output.append((promise, opcode, data))
            try:
                promise.get(timeout)
            except TimeoutError:
                promise.cancel()

    def recv(self, timeout: Optional[float] = None) -> Union[str, bytes]:
        with self.__lock:
            if not self.__cv.wait_for(
                    lambda: self.__closed or len(self.__input), timeout):
                raise TimeoutError('Operation timed out')
            if self.__input:
                return self.__input.popleft()
            raise IOError('WebSocket is closed')

    def ping(self, timeout: float) -> bool:
        from random import randint

        with self.__lock:
            if self.__closed:
                raise IOError('WebSocket is closed')

            promise = Promise(self.__lock)
            data = bytes(randint(0, 255) for i in range(4))
            self.__ping[data] = promise
            self.__output.append((None, OPCODE_PING, data))

            try:
                promise.get(timeout)
            except Exception as exc:
                logger.debug(f'{self} ping error', exc_info=exc)
                del self.__ping[data]
                return False
            else:
                return True

    def close(self) -> None:
        self.__close(REASON_NORMAL)

    @property
    def closed(self):
        return self.__closed

    @property
    def code(self):
        return self.__code

    @property
    def reason(self):
        return self.__reason
示例#51
0
class EPuckStreamer(Thread):
    '''
    Esta clase crea un servidor TCP que acepta conexiones entrantes. Las conexiones TCP enviarán información
    acerca del estado de los sensores y los actuadores del robot de forma asíncrona.
    '''


    class Client(Thread):
        def __init__(self, streamer, socket):
            super().__init__()

            self.streamer = streamer
            self.socket = socket
            self._active_lock = Lock()
            self._alive = True

            self.start()

        @property
        def alive(self):
            with self._active_lock:
                return self._alive

        @alive.setter
        def alive(self, state):
            with self._active_lock:
                self._alive = state


        def _send_data(self):
            data = self.streamer.get_data(consumer = self)

            total_sent = 0
            while total_sent < len(data):
                sent = self.socket.send(data[total_sent:])
                if sent == 0:
                    raise IOError()
                total_sent += sent

        def _run(self):
            while self.alive:
                self._send_data()

        def run(self):
            try:
                self._run()
            except Exception as e:
                pass
            finally:
                self.alive = False
                self.socket.close()

        def close(self):
            self.alive = False
            self.join()



    def __init__(self, controller, address = 'localhost', port = 19998):
        super().__init__()
        self.controller = controller
        self.epuck = self.controller.epuck
        self.address = address
        self.port = port

        self.server_socket = socket(AF_INET, SOCK_STREAM)

        self._alive_lock = Lock()
        self._alive = True

        self._data_lock = Condition()
        self._data = None

        self._banned_consumers = []
        self.start()

    @property
    def alive(self):
        with self._alive_lock:
            return self._alive

    @alive.setter
    def alive(self, state):
        with self._alive_lock:
            self._alive = state


    def run(self):
        try:
            self.server_socket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
            self.server_socket.bind((self.address, self.port))
            self.server_socket.listen(1)

            clients = []
            while self.alive:
                readable, writable, errored = select([self.server_socket], [], [], 0)
                if self.server_socket in readable:
                    client_socket, address = self.server_socket.accept()
                    client = self.Client(self, client_socket)
                    clients.append(client)

                sleep(.05)
            for client in clients:
                client.close()
        except:
            pass
        finally:
            self.server_socket.close()

    def close(self):
        self.alive = False
        self.join()



    def get_data(self, consumer):
        with self._data_lock:
            self._data_lock.wait_for(lambda: not self._data is None and not consumer in self._banned_consumers)
            self._banned_consumers.append(consumer)
            return self._data

    def broadcast(self):
        def get_sensor_data(sensor):
            return sensor.value if sensor.enabled else False

        def get_sensors_data(sensors):
            return [get_sensor_data(sensor) for sensor in sensors]

        def get_vision_sensor_data():
            if not self.epuck.vision_sensor.enabled:
                return False
            output = BytesIO()
            with output:
                image = self.epuck.vision_sensor.value.transpose(Image.FLIP_TOP_BOTTOM)
                image.save(output, format = 'jpeg')
                data = b64encode(output.getvalue()).decode()
                return data

        data = {
            # Información de sensores
            'prox_sensors' : get_sensors_data(self.epuck.prox_sensors),
            'floor_sensors' : get_sensors_data(self.epuck.floor_sensors),
            'vision_sensor' : get_vision_sensor_data(),
            'vision_sensor_params': self.epuck.vision_sensor.params,
            'light_sensor' : get_sensor_data(self.epuck.light_sensor),

            # Información de actuadores
            'leds' : self.epuck.leds.states,
            'motors' : self.epuck.motors.speeds,

            # Información del controlador
            'elapsed_time' : self.controller.elapsed_time,
            'think_time' : self.controller.think_time,
            'update_time' : self.controller.update_time,
            'steps_per_second' : self.controller.steps_per_second
        }

        data = json.dumps(data).encode()
        data = zlib.compress(struct.pack('!{}s'.format(len(data)), data))

        chunk_size = 1 << 11
        header_size = 16 + 4

        data = struct.pack('!i', len(data)) + data + bytearray(chunk_size - (len(data) + header_size) % chunk_size)

        hasher = hashlib.md5()
        hasher.update(data)
        md5sum = hasher.digest()
        data = struct.pack('!16s', md5sum) + data

        with self._data_lock:
            self._data = data
            self._banned_consumers.clear()
            self._data_lock.notify_all()
示例#52
0
class PositionController:
    """
    Controls the position of the robot using motion profile based movement.
    """
    train: wpilib.drive.DifferentialDrive

    def __init__(self):
        self.trajectories = deque()

        self.right_follower = None
        self.left_follower = None

        self.lock = Lock()
        self.cond = Condition(self.lock)
        self.thread = Thread(target=self._run, daemon=True)
        self.thread.start()

        wpilib.Resource._add_global_resource(self)

    def move_to(self, x_position, y_position, angle=0, first=False):
        """
        Generate path and set path variable

        :param x_position: The x distance.
        :param y_position: The y distance.
        :param angle: The angle difference in degrees.
        :param first: Whether or not this path should be completed next.
        """
        waypoint = pf.Waypoint(float(x_position), float(y_position),
                               radians(angle))

        info, trajectory = pf.generate([pf.Waypoint(0, 0, 0), waypoint],
                                       pf.FIT_HERMITE_CUBIC, pf.SAMPLES_HIGH,
                                       0.05, 1.7, 2.0, 60.0)

        modifier = pf.modifiers.TankModifier(trajectory).modify(0.5)
        right_trajectory = modifier.getRightTrajectory()
        left_trajectory = modifier.getLeftTrajectory()

        with self.cond:
            if self.left_follower or self.right_follower is None:
                self.right_follower = pf.followers.EncoderFollower(
                    right_trajectory)
                self.left_follower = pf.followers.EncoderFollower(
                    left_trajectory)

            if first:
                self.trajectories.appendLeft({
                    'right': right_trajectory,
                    'left': left_trajectory
                })
            else:
                self.trajectories.append({
                    'right': right_trajectory,
                    'left': left_trajectory
                })
            self.cond.notify()

    def _run(self):
        """
        Actually move the robot along the path.
        """
        while True:
            with self.cond:
                if len(self.trajectories) < 1:
                    self.cond.wait_for(lambda: len(self.trajectories) > 0)

                if self.right_follower.isFinished(
                ) and self.left_follower.isFinished():
                    trajectory = self.trajectories.popleft()
                    self.right_follower.setTrajectory(trajectory['right'])
                    self.left_follower.setTrajectory(trajectory['left'])
示例#53
0
class TaskExecutor(object):
    def __init__(self, balancer, index):
        self.balancer = balancer
        self.index = index
        self.task = None
        self.proc = None
        self.pid = None
        self.conn = None
        self.state = WorkerState.STARTING
        self.key = str(uuid.uuid4())
        self.result = AsyncResult()
        self.exiting = False
        self.killed = False
        self.thread = gevent.spawn(self.executor)
        self.cv = Condition()
        self.status_lock = RLock()

    def checkin(self, conn):
        with self.cv:
            self.balancer.logger.debug("Check-in of worker #{0} (key {1})".format(self.index, self.key))
            self.conn = conn
            self.state = WorkerState.IDLE
            self.cv.notify_all()

    def put_progress(self, progress):
        st = TaskStatus(None)
        st.__setstate__(progress)
        self.task.set_state(progress=st)

    def put_status(self, status):
        with self.cv:
            # Try to collect rusage at this point, when process is still alive
            try:
                kinfo = self.balancer.dispatcher.threaded(bsd.kinfo_getproc, self.pid)
                self.task.rusage = kinfo.rusage
            except LookupError:
                pass

            if status["status"] == "ROLLBACK":
                self.task.set_state(TaskState.ROLLBACK)

            if status["status"] == "FINISHED":
                self.result.set(status["result"])

            if status["status"] == "FAILED":
                error = status["error"]

                if error["type"] in ERROR_TYPES:
                    cls = ERROR_TYPES[error["type"]]
                    exc = cls(
                        code=error["code"],
                        message=error["message"],
                        stacktrace=error["stacktrace"],
                        extra=error.get("extra"),
                    )
                else:
                    exc = OtherException(
                        code=error["code"],
                        message=error["message"],
                        stacktrace=error["stacktrace"],
                        type=error["type"],
                        extra=error.get("extra"),
                    )

                self.result.set_exception(exc)

    def put_warning(self, warning):
        self.task.add_warning(warning)

    def run(self, task):
        def match_file(module, f):
            name, ext = os.path.splitext(f)
            return module == name and ext in [".py", ".pyc", ".so"]

        with self.cv:
            self.cv.wait_for(lambda: self.state == WorkerState.ASSIGNED)
            self.result = AsyncResult()
            self.task = task
            self.task.set_state(TaskState.EXECUTING)
            self.state = WorkerState.EXECUTING
            self.cv.notify_all()

        self.balancer.logger.debug("Actually starting task {0}".format(task.id))

        filename = None
        module_name = inspect.getmodule(task.clazz).__name__
        for dir in self.balancer.dispatcher.plugin_dirs:
            found = False
            try:
                for root, _, files in os.walk(dir):
                    file = first_or_default(lambda f: match_file(module_name, f), files)
                    if file:
                        filename = os.path.join(root, file)
                        found = True
                        break

                if found:
                    break
            except OSError:
                continue

        try:
            self.conn.call_sync(
                "taskproxy.run",
                {
                    "id": task.id,
                    "user": task.user,
                    "class": task.clazz.__name__,
                    "filename": filename,
                    "args": task.args,
                    "debugger": task.debugger,
                    "environment": task.environment,
                    "hooks": task.hooks,
                },
            )
        except RpcException as e:
            self.balancer.logger.warning(
                "Cannot start task {0} on executor #{1}: {2}".format(task.id, self.index, str(e))
            )

            self.balancer.logger.warning(
                "Killing unresponsive task executor #{0} (pid {1})".format(self.index, self.proc.pid)
            )

            self.terminate()

        try:
            self.result.get()
        except BaseException as e:
            if isinstance(e, OtherException):
                self.balancer.dispatcher.report_error("Task {0} raised invalid exception".format(self.task.name), e)

            if isinstance(e, TaskAbortException):
                self.task.set_state(TaskState.ABORTED, TaskStatus(0, "aborted"))
            else:
                self.task.error = serialize_error(e)
                self.task.set_state(
                    TaskState.FAILED, TaskStatus(0, str(e), extra={"stacktrace": traceback.format_exc()})
                )

            with self.cv:
                self.task.ended.set()

                if self.state == WorkerState.EXECUTING:
                    self.state = WorkerState.IDLE
                    self.cv.notify_all()

            self.balancer.task_exited(self.task)
            return

        with self.cv:
            self.task.result = self.result.value
            self.task.set_state(TaskState.FINISHED, TaskStatus(100, ""))
            self.task.ended.set()
            if self.state == WorkerState.EXECUTING:
                self.state = WorkerState.IDLE
                self.cv.notify_all()

        self.balancer.task_exited(self.task)

    def abort(self):
        self.balancer.logger.info("Trying to abort task #{0}".format(self.task.id))
        # Try to abort via RPC. If this fails, kill process
        try:
            # If task supports abort protocol we don't need to worry about subtasks - it's task
            # responsibility to kill them
            self.conn.call_sync("taskproxy.abort")
        except RpcException as err:
            self.balancer.logger.warning("Failed to abort task #{0} gracefully: {1}".format(self.task.id, str(err)))
            self.balancer.logger.warning("Killing process {0}".format(self.pid))
            self.killed = True
            self.terminate()

            # Now kill all the subtasks
            for subtask in filter(lambda t: t.parent is self.task, self.balancer.task_list):
                self.balancer.logger.warning(
                    "Aborting subtask {0} because parent task {1} died".format(subtask.id, self.task.id)
                )
                self.balancer.abort(subtask.id)

    def terminate(self):
        try:
            self.proc.terminate()
        except OSError:
            self.balancer.logger.warning("Executor process with PID {0} already dead".format(self.proc.pid))

    def executor(self):
        while not self.exiting:
            try:
                self.proc = Popen(
                    [TASKWORKER_PATH, self.key],
                    close_fds=True,
                    preexec_fn=os.setpgrp,
                    stdout=subprocess.PIPE,
                    stderr=subprocess.STDOUT,
                )

                self.pid = self.proc.pid
                self.balancer.logger.debug("Started executor #{0} as PID {1}".format(self.index, self.pid))
            except OSError:
                self.result.set_exception(TaskException(errno.EFAULT, "Cannot spawn task executor"))
                self.balancer.logger.error("Cannot spawn task executor #{0}".format(self.index))
                return

            for line in self.proc.stdout:
                line = line.decode("utf8")
                self.balancer.logger.debug("Executor #{0}: {1}".format(self.index, line.strip()))
                if self.task:
                    self.task.output += line

            self.proc.wait()

            with self.cv:
                self.state = WorkerState.STARTING
                self.cv.notify_all()

            if self.proc.returncode == -signal.SIGTERM:
                self.balancer.logger.info(
                    "Executor process with PID {0} was terminated gracefully".format(self.proc.pid)
                )
            else:
                self.balancer.logger.error(
                    "Executor process with PID {0} died abruptly with exit code {1}".format(
                        self.proc.pid, self.proc.returncode
                    )
                )

            if self.killed:
                self.result.set_exception(TaskException(errno.EFAULT, "Task killed"))
            else:
                self.result.set_exception(TaskException(errno.EFAULT, "Task executor died"))
            gevent.sleep(1)

    def die(self):
        self.exiting = True
        if self.proc:
            self.terminate()
示例#54
0
class ThreadedTcpClient:
    """The main difference with the TcpClient class is that this one
       will spawn a secondary thread that will be constantly reading
       from the network and putting everything on another buffer.
    """
    def __init__(self, proxy=None):
        self.connected = False
        self._proxy = proxy
        self._recreate_socket()

        # Support for multi-threading advantages and safety
        self.cancelled = Event()  # Has the read operation been cancelled?
        self.delay = 0.1  # Read delay when there was no data available
        self._lock = Lock()

        self._buffer = []
        self._read_thread = Thread(target=self._reading_thread, daemon=True)
        self._cv = Condition()  # Condition Variable

    def _recreate_socket(self):
        if self._proxy is None:
            self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        else:
            import socks
            self._socket = socks.socksocket(socket.AF_INET, socket.SOCK_STREAM)
            if type(self._proxy) is dict:
                self._socket.set_proxy(**self._proxy)
            else:  # tuple, list, etc.
                self._socket.set_proxy(*self._proxy)

    def connect(self, ip, port, timeout):
        """Connects to the specified IP and port number.
           'timeout' must be given in seconds
        """
        if not self.connected:
            self._socket.settimeout(timeout)
            self._socket.connect((ip, port))
            self._socket.setblocking(False)
            self.connected = True

    def close(self):
        """Closes the connection"""
        if self.connected:
            self._socket.shutdown(socket.SHUT_RDWR)
            self._socket.close()
            self.connected = False
            self._recreate_socket()

    def write(self, data):
        """Writes (sends) the specified bytes to the connected peer"""
        self._socket.sendall(data)

    def read(self, size, timeout=timedelta(seconds=5)):
        """Reads (receives) a whole block of 'size bytes
           from the connected peer.

           A timeout can be specified, which will cancel the operation if
           no data has been read in the specified time. If data was read
           and it's waiting for more, the timeout will NOT cancel the
           operation. Set to None for no timeout
        """
        with self._cv:
            print('wait for...')
            self._cv.wait_for(lambda: len(self._buffer) >= size, timeout=timeout.seconds)
            print('got', size)
            result, self._buffer = self._buffer[:size], self._buffer[size:]
            return result

    def _reading_thread(self):
        while True:
            partial = self._socket.recv(4096)
            if len(partial) == 0:
                self.connected = False
                raise ConnectionResetError(
                    'The server has closed the connection.')

            with self._cv:
                print('extended', len(partial))
                self._buffer.extend(partial)
                self._cv.notify()

    def cancel_read(self):
        """Cancels the read operation IF it hasn't yet
           started, raising a ReadCancelledError"""
        self.cancelled.set()