コード例 #1
0
ファイル: iterators.py プロジェクト: 4Catalyzer/nolearn_utils
def make_buffer_for_iterator_with_thread(gen, n_workers, buffer_size):
    wait_time = 0.02
    generator_queue = Queue()
    _stop = threading.Event()

    def generator_task():
        while not _stop.is_set():
            try:
                if generator_queue.qsize() < buffer_size:
                    generator_output = next(gen)
                    generator_queue.put(generator_output)
                else:
                    time.sleep(wait_time)
            except (StopIteration, KeyboardInterrupt):
                _stop.set()
                return

    generator_threads = [threading.Thread(target=generator_task) for _ in range(n_workers)]
    for thread in generator_threads:
        thread.start()

    while not _stop.is_set() or not generator_queue.empty():
        if not generator_queue.empty():
            yield generator_queue.get()
        else:
            time.sleep(wait_time)
コード例 #2
0
def generator_to_async_generator(get_iterable):
    """
    Turn a generator or iterable into an async generator.

    This works by running the generator in a background thread.
    The new async generator will yield both `Future` objects as well
    as the original items.

    :param get_iterable: Function that returns a generator or iterable when
        called.
    """
    q = Queue()
    f = Future()
    l = RLock()
    quitting = False

    def runner():
        """
        Consume the generator in background thread.
        When items are received, they'll be pushed to the queue and the
        Future is set.
        """
        for item in get_iterable():
            with l:
                q.put(item)
                if not f.done():
                    f.set_result(None)

            # When this async generator was cancelled (closed), stop this
            # thread.
            if quitting:
                break
        with l:
            if not f.done():
                f.set_result(None)

    # Start background thread.
    done_f = run_in_executor(runner, _daemon=True)

    try:
        while not done_f.done():
            # Wait for next item(s): yield Future.
            yield From(f)

            # Items received. Yield all items so far.
            with l:
                while not q.empty():
                    yield AsyncGeneratorItem(q.get())

                f = Future()

        # Yield final items.
        while not q.empty():
            yield q.get()

    finally:
        # When this async generator is closed (GeneratorExit exception, stop
        # the background thread as well. - we don't need that anymore.)
        quitting = True
コード例 #3
0
def generator_to_async_generator(get_iterable):
    """
    Turn a generator or iterable into an async generator.

    This works by running the generator in a background thread.
    The new async generator will yield both `Future` objects as well
    as the original items.

    :param get_iterable: Function that returns a generator or iterable when
        called.
    """
    q = Queue()
    f = Future()
    l = RLock()
    quitting = False

    def runner():
        """
        Consume the generator in background thread.
        When items are received, they'll be pushed to the queue and the
        Future is set.
        """
        for item in get_iterable():
            with l:
                q.put(item)
                if not f.done():
                    f.set_result(None)

            # When this async generator was cancelled (closed), stop this
            # thread.
            if quitting:
                break
        with l:
            if not f.done():
                f.set_result(None)

    # Start background thread.
    done_f = run_in_executor(runner, _daemon=True)

    try:
        while not done_f.done():
            # Wait for next item(s): yield Future.
            yield From(f)

            # Items received. Yield all items so far.
            with l:
                while not q.empty():
                    yield AsyncGeneratorItem(q.get())

                f = Future()

        # Yield final items.
        while not q.empty():
            yield q.get()

    finally:
        # When this async generator is closed (GeneratorExit exception, stop
        # the background thread as well. - we don't need that anymore.)
        quitting = True
コード例 #4
0
    def iter_entries(handle):

        cd = pycdlib.PyCdlib()

        if hasattr(handle, 'seek') and handle.seekable():
            handle.seek(0)
            cd.open_fp(handle)
        else:
            cd.open(handle)

        rock_ridge = cd.rock_ridge is not None
        joliet = cd.joliet_vd is not None
        joliet_only = joliet and not rock_ridge

        directories = Queue()
        directories.put(cd.get_entry('/', joliet_only))

        while not directories.empty():
            directory = directories.get()

            for child in directory.children:
                if not child.is_dot() and not child.is_dotdot():
                    if child.is_dir():
                        directories.put(child)
                    yield child
コード例 #5
0
ファイル: handler.py プロジェクト: swift-nav/libsbp
    class _SBPQueueIterator(six.Iterator):
        """
        Class for upstream iterators.  Implements callable interface for adding
        messages into the queue, and iterable interface for getting them out.
        """

        def __init__(self, maxsize):
            self._queue = Queue(maxsize)
            self._broken = False

        def __iter__(self):
            return self

        def __call__(self, msg, **metadata):
            self._queue.put((msg, metadata), False)

        def breakiter(self):
            self._broken = True
            self._queue.put(None, True, 1.0)

        def __next__(self):
            if self._broken and self._queue.empty():
                raise StopIteration
            m = self._queue.get(True)
            if self._broken and m is None:
                raise StopIteration
            return m
コード例 #6
0
    class _SBPQueueIterator(six.Iterator):
        """
        Class for upstream iterators.  Implements callable interface for adding
        messages into the queue, and iterable interface for getting them out.
        """
        def __init__(self, maxsize):
            self._queue = Queue(maxsize)
            self._broken = False

        def __iter__(self):
            return self

        def __call__(self, msg, **metadata):
            self._queue.put((msg, metadata), False)

        def breakiter(self):
            self._broken = True
            self._queue.put(None, True, 1.0)

        def __next__(self):
            if self._broken and self._queue.empty():
                raise StopIteration
            m = self._queue.get(True)
            if self._broken and m is None:
                raise StopIteration
            return m
コード例 #7
0
class SimpleThreadPool:
    def __init__(self, num_threads=5):
        self._num_threads = num_threads
        self._queue = Queue(2000)
        self._lock = Lock()
        self._active = False
        self._workers = list()
        self._finished = False

    def add_task(self, func, *args, **kwargs):
        if not self._active:
            with self._lock:
                if not self._active:
                    self._active = True
                    for i in range(self._num_threads):
                        w = WorkerThread(self._queue)
                        self._workers.append(w)
                        w.start()

        self._queue.put((func, args, kwargs))

    def release(self):
        while self._queue.empty() is False:
            time.sleep(1)

    def wait_completion(self):
        self._queue.join()
        self._finished = True

    def get_result(self):
        assert self._finished
        detail = [worker.get_result() for worker in self._workers]
        succ_all = all([tp[1] == 0 for tp in detail])
        return {'success_all': succ_all, 'detail': detail}
コード例 #8
0
class ThreadManager(six.with_metaclass(Singleton, object)):
    """ ThreadManager provides thread on demand """

    NUM_THREAD = 4  # Default number of threads

    def __init__(self, num_thread=NUM_THREAD):
        """ Create num_thread Threads """

        self.queue = Queue()

        self.thread_list = []

        for i in range(num_thread):
            t = Thread(target=worker, args=(self.queue, ))
            t.setDaemon(True)
            t.start()

            self.thread_list.append(t)

    def add_task(self, func, params):
        """
        Add a task to perform
        :param func: function to call
        :param params : tuple of parameters
        """

        self.queue.put((func, params))

    def clear(self):
        """ clear pending task """

        while (not self.queue.empty()):
            self.queue.get()
コード例 #9
0
class Search(TracePosterior):
    """
    Trace and Poutine-based implementation of systematic search.

    :param callable model: Probabilistic model defined as a function.
    :param int max_tries: The maximum number of times to try completing a trace from the queue.
    """
    def __init__(self, model, max_tries=1e6):
        """
        Constructor. Default `max_tries` to something sensible - 1e6.

        :param callable model: Probabilistic model defined as a function.
        :param int max_tries: The maximum number of times to try completing a trace from the queue.
        """
        self.model = model
        self.max_tries = int(max_tries)

    def _traces(self, *args, **kwargs):
        """
        algorithm entered here
        Running until the queue is empty and collecting the marginal histogram
        is performing exact inference

        :returns: Iterator of traces from the posterior.
        :rtype: Generator[:class:`pyro.Trace`]
        """
        # currently only using the standard library queue
        self.queue = Queue()
        self.queue.put(poutine.Trace())

        p = poutine.trace(
            poutine.queue(self.model, queue=self.queue, max_tries=self.max_tries))
        while not self.queue.empty():
            tr = p.get_trace(*args, **kwargs)
            yield (tr, tr.log_pdf())
コード例 #10
0
ファイル: tasks_win.py プロジェクト: ciwei100000/mozjs-debian
def run_all_tests(tests, prefix, pb, options):
    """
    Uses scatter-gather to a thread-pool to manage children.
    """
    qTasks, qResults = Queue(), Queue()

    workers = []
    watchdogs = []
    for _ in range(options.worker_count):
        qWatch = Queue()
        watcher = Thread(target=_do_watch, args=(qWatch, options.timeout))
        watcher.setDaemon(True)
        watcher.start()
        watchdogs.append(watcher)
        worker = Thread(target=_do_work,
                        args=(qTasks, qResults, qWatch, prefix,
                              options.run_skipped, options.timeout,
                              options.show_cmd))
        worker.setDaemon(True)
        worker.start()
        workers.append(worker)

    # Insert all jobs into the queue, followed by the queue-end
    # marker, one per worker. This will not block on growing the
    # queue, only on waiting for more items in the generator. The
    # workers are already started, however, so this will process as
    # fast as we can produce tests from the filesystem.
    def _do_push(num_workers, qTasks):
        for test in tests:
            qTasks.put(test)
        for _ in range(num_workers):
            qTasks.put(EndMarker)

    pusher = Thread(target=_do_push, args=(len(workers), qTasks))
    pusher.setDaemon(True)
    pusher.start()

    # Read from the results.
    ended = 0
    delay = ProgressBar.update_granularity().total_seconds()
    while ended < len(workers):
        try:
            result = qResults.get(block=True, timeout=delay)
            if result is EndMarker:
                ended += 1
            else:
                yield result
        except Empty:
            pb.poke()

    # Cleanup and exit.
    pusher.join()
    for worker in workers:
        worker.join()
    for watcher in watchdogs:
        watcher.join()
    assert qTasks.empty(), "Send queue not drained"
    assert qResults.empty(), "Result queue not drained"
コード例 #11
0
ファイル: importer.py プロジェクト: mcanthony/Photini
 def copy_selected(self):
     if self.import_worker:
         # user has clicked while upload is still cancelling
         self.copy_button.setChecked(False)
         return
     copy_list = []
     for item in self.file_list_widget.selectedItems():
         name = item.text().split()[0]
         copy_list.append(self.file_data[name])
     if not copy_list:
         self.copy_button.setChecked(False)
         return
     # create separate thread to import images
     item_queue = Queue()
     self.import_worker = ImportWorker(self.source, item_queue)
     self.import_file.connect(self.import_worker.import_file)
     self.import_worker.thread.start()
     last_transfer = datetime.min
     last_path = None
     self.import_file.emit(copy_list.pop(0))
     while self.copy_button.isChecked():
         QtWidgets.QApplication.processEvents()
         if not self.import_worker.thread.isRunning():
             # user has closed program
             return
         if item_queue.empty():
             continue
         item, camera_file = item_queue.get()
         if item is None:
             # import failed
             self._fail()
             break
         timestamp = item['timestamp']
         dest_path = item['dest_path']
         if last_transfer < timestamp:
             last_transfer = timestamp
             last_path = dest_path
         if copy_list:
             # start fetching next file
             self.import_file.emit(copy_list[0])
         if camera_file:
             camera_file.save(dest_path)
         self.image_list.open_file(dest_path)
         if not copy_list:
             break
         copy_list.pop(0)
     if last_path:
         self.config_store.set(self.config_section, 'last_transfer',
                               last_transfer.isoformat(' '))
         self.image_list.done_opening(last_path)
     self.show_file_list()
     self.import_file.disconnect()
     self.import_worker.thread.quit()
     self.import_worker.thread.wait()
     self.import_worker = None
     self.copy_button.setChecked(False)
コード例 #12
0
class CallbackHandler(object):
    """ handles callback for event for single url with pause support """
    def __init__(self, url, subscription_id, callback, paused):
        self.url = url
        self.subscription_id = subscription_id
        self.callback = callback
        self.paused = paused
        self.event_q = Queue()

    def flush(self):
        # flush any pending events (no callback triggered on flush)
        while not self.event_q.empty():
            self.event_q.get()

    def pause(self):
        # pause callbacks
        logger.debug("pausing url: %s", self.url)
        self.paused = True

    def resume(self):
        # trigger all callbacks before setting pause to false
        logger.debug("resume (queue-size %s) url %s", self.event_q.qsize(),
                     self.url)
        while not self.event_q.empty():
            event = self.event_q.get()
            try:
                self.callback(event)
            except Exception as e:
                logger.debug("Traceback:\n%s", traceback.format_exc())
                logger.warn("failed to execute event callback: %s", e)
        self.paused = False

    def execute_callback(self, event):
        # execute callback or queue event if currently paused
        if self.paused:
            self.event_q.put(event)
        else:
            try:
                self.callback(event)
            except Exception as e:
                logger.debug("Traceback:\n%s", traceback.format_exc())
                logger.warn("failed to execute event callback: %s", e)
コード例 #13
0
 def has_cycles(self):
     '''
     Check if the graph has cycles or not.
     We will do this by traversing starting
     from any leaf node and recording
     both the edges traversed and the nodes
     discovered. From stackoverflow, if
     an unexplored edge leads to a
     previously found node then it has
     cycles.
     '''
     discovered_nodes = set()
     traversed_edges = set()
     q = Queue()
     for node in self.nodes:
         if node.is_leaf():
             start_node = node
             break
     q.put(start_node)
     while not q.empty():
         current_node = q.get()
         if DEBUG:
             print("Current Node: ", current_node)
             print("Discovered Nodes before adding Current Node: ", \
                 discovered_nodes)
         if current_node.name in discovered_nodes:
             # We have a cycle!
             if DEBUG:
                 print('Dequeued node already processed: %s', current_node)
             return True
         discovered_nodes.add(current_node.name)
         if DEBUG:
             print("Discovered Nodes after adding Current Node: ", \
                 discovered_nodes)
         for neighbour in current_node.neighbours:
             edge = [current_node.name, neighbour.name]
             # Since this is undirected and we want
             # to record the edges we have traversed
             # we will sort the edge alphabetically
             edge.sort()
             edge = tuple(edge)
             if edge not in traversed_edges:
                 # This is a new edge...
                 if neighbour.name in discovered_nodes:
                     return True
             # Now place all neighbour nodes on the q
             # and record this edge as traversed
             if neighbour.name not in discovered_nodes:
                 if DEBUG:
                     print('Enqueuing: %s' % neighbour)
                 q.put(neighbour)
             traversed_edges.add(edge)
     return False
コード例 #14
0
def _extract_features_parallel_per_sample(kind_to_df_map, settings, column_id,
                                          column_value):
    """
    Parallelize the feature extraction per kind and per sample.

    As the splitting of the dataframes per kind along column_id is quite costly, we settled for an async map in this
    function. The result objects are temporarily stored in a fifo queue from which they can be retrieved in order
    of submission.

    :param kind_to_df_map: The time series to compute the features for in our internal format
    :type kind_to_df_map: dict of pandas.DataFrame

    :param column_id: The name of the id column to group by.
    :type column_id: str
    :param column_value: The name for the column keeping the value itself.
    :type column_value: str

    :param settings: settings object that controls which features are calculated
    :type settings: tsfresh.feature_extraction.settings.FeatureExtractionSettings

    :return: The (maybe imputed) DataFrame containing extracted features.
    :rtype: pandas.DataFrame
    """
    partial_extract_features_for_one_time_series = partial(
        _extract_features_for_one_time_series,
        column_id=column_id,
        column_value=column_value,
        settings=settings)
    pool = Pool(settings.n_processes)

    # Submit map jobs per kind per sample
    results_fifo = Queue()
    for kind, df_kind in kind_to_df_map.items():
        df_grouped_by_id = df_kind.groupby(column_id)
        results_fifo.put(
            pool.map_async(partial_extract_features_for_one_time_series,
                           [(kind, df_group)
                            for _, df_group in df_grouped_by_id],
                           chunksize=settings.chunksize))

    pool.close()

    # Wait for the jobs to complete and concatenate the partial results
    dfs_per_kind = []
    while not results_fifo.empty():
        map_result = results_fifo.get()
        dfs = map_result.get()
        dfs_per_kind.append(pd.concat(dfs, axis=0).astype(np.float64))

    result = pd.concat(dfs_per_kind, axis=1).astype(np.float64)

    pool.join()
    return result
コード例 #15
0
ファイル: koji_source.py プロジェクト: JayZ12138/pushsource
    def __iter__(self):
        """Iterate over push items.

        - Yields :ref:`~pushsource.RpmPushItem` instances for RPMs
        """

        # Queue holding all requests we need to make to koji.
        # We try to fetch as much as we can early to make efficient use
        # of multicall.
        koji_queue = Queue()

        # We'll need to obtain all RPMs referenced by filename
        for rpm_filename in self._rpm:
            koji_queue.put(GetRpmCommand(ident=rpm_filename))

        # We'll need to obtain all builds from which we want modules,
        # as well as the archives from those
        for build_id in self._module_build:
            koji_queue.put(GetBuildCommand(ident=build_id, list_archives=True))

        # Put some threads to work on the queue.
        fetch_exceptions = []
        fetch_threads = [
            Thread(
                name="koji-%s-fetch-%s" % (id(self), i),
                target=self._do_fetch,
                args=(koji_queue, fetch_exceptions),
            )
            for i in range(0, self._threads)
        ]

        # Wait for all fetches to finish
        for t in fetch_threads:
            t.start()
        for t in fetch_threads:
            t.join(self._timeout)

        # Re-raise exceptions, if any.
        # If we got more than one, we're only propagating the first.
        if fetch_exceptions:
            raise fetch_exceptions[0]

        # The queue must be empty now
        assert koji_queue.empty()

        push_items_fs = self._modulemd_futures() + self._rpm_futures()

        completed_fs = futures.as_completed(push_items_fs, timeout=self._timeout)
        for f in completed_fs:
            # If an exception occurred, this is where it will be raised.
            for pushitem in f.result():
                yield pushitem
コード例 #16
0
ファイル: ftdi.py プロジェクト: nccgroup/umap2
class USBFtdiInterface(USBInterface):
    name = 'FtdiInterface'

    def __init__(self, app, phy, interface_number):
        super(USBFtdiInterface, self).__init__(
            app=app,
            phy=phy,
            interface_number=interface_number,
            interface_alternate=0,
            interface_class=USBClass.VendorSpecific,
            interface_subclass=0xff,
            interface_protocol=0xff,
            interface_string_index=0,
            endpoints=[
                USBEndpoint(
                    app=app,
                    phy=phy,
                    number=1,
                    direction=USBEndpoint.direction_out,
                    transfer_type=USBEndpoint.transfer_type_bulk,
                    sync_type=USBEndpoint.sync_type_none,
                    usage_type=USBEndpoint.usage_type_data,
                    max_packet_size=0x40,
                    interval=0,
                    handler=self.handle_data_available
                ),
                USBEndpoint(
                    app=app,
                    phy=phy,
                    number=3,
                    direction=USBEndpoint.direction_in,
                    transfer_type=USBEndpoint.transfer_type_bulk,
                    sync_type=USBEndpoint.sync_type_none,
                    usage_type=USBEndpoint.usage_type_data,
                    max_packet_size=0x40,
                    interval=0,
                    handler=self.handle_ep3_buffer_available  # at this point, we don't send data to the host
                )
            ],
        )
        self.txq = Queue()

    def handle_data_available(self, data):
        self.debug('received string (%d): %s' % (len(data), data))
        reply = b'\x01\x00' + data
        self.txq.put(reply)

    def handle_ep3_buffer_available(self):
        if not self.txq.empty():
            self.send_on_endpoint(3, self.txq.get())
コード例 #17
0
ファイル: call_python_client.py プロジェクト: mposa/drake
    def _handle_messages_threaded(self):
        # Handles messages in a threaded fashion.
        queue = Queue()

        def producer_loop():
            # Read messages from file, and queue them for execution.
            for msg in self._read_next_message():
                queue.put(msg)
                # Check if an error occurred.
                if self._done:
                    break
            # Wait until the queue empties out to signal completion from the
            # producer's side.
            if not self._done:
                queue.join()
                self._done = True

        producer = Thread(name="Producer", target=producer_loop)
        # @note Previously, when trying to do `queue.clear()` in the consumer,
        # and `queue.join()` in the producer, there would be intermittent
        # deadlocks. By demoting the producer to a daemon, I (eric.c) have not
        # yet encountered a deadlock.
        producer.daemon = True
        producer.start()

        # Consume.
        # TODO(eric.cousineau): Trying to quit via Ctrl+C is awkward (but kinda
        # works). Is there a way to have `plt.pause` handle Ctrl+C differently?
        try:
            pause = self.scope_globals['pause']
            while not self._done:
                # Process messages.
                while not queue.empty():
                    msg = queue.get()
                    queue.task_done()
                    self._execute_message(msg)
                # Spin busy for a bit, let matplotlib (or whatever) flush its
                # event queue.
                pause(0.01)
        except KeyboardInterrupt:
            # User pressed Ctrl+C.
            self._done = True
            print("Quitting")
        except Exception as e:
            # We encountered an error, and must stop.
            self._done = True
            self._had_error = True
            traceback.print_exc(file=sys.stderr)
            sys.stderr.write("  Stopping (--stop_on_error)\n")
コード例 #18
0
ファイル: test_container.py プロジェクト: candlerb/pato
def test_thread_safe_object_creation(c):
    """
    If two threads try to fetch the object at the same time,
    only one instance should be created.
    This also tests assigning an existing function as a service.
    """
    cin = Queue()
    cout = Queue()
    def test_factory(username, password):
        cout.put("ready")
        cin.get()
        res = libtest.sample.Foo(username, password)
        cout.put("done")
        return res

    c['test_factory'] = test_factory
    c.load_yaml("""
a:
    :: <test_factory>
    username: abc
    password: xyz
""")
    def run(q):
        q.put("starting")
        q.put(c['a'])
    q1 = Queue()
    t1 = Thread(target=run, kwargs={"q":q1})
    t1.start()
    assert cout.get(True, 2) == "ready"
    assert q1.get(True, 2) == "starting"
    # Now t1 is waiting inside factory method

    q2 = Queue()
    t2 = Thread(target=run, kwargs={"q":q2})
    t2.start()
    assert q2.get(True, 2) == "starting"

    cin.put("go")
    assert cout.get(True, 2) == "done"
    t1.join(2)
    t2.join(2)
    assert cout.empty()

    res1 = q1.get(True, 2)
    res2 = q2.get(True, 2)
    # This also implies that test_factory was only called once
    # because otherwise t2 would hang waiting on cin
    assert isinstance(res1, libtest.sample.Foo)
    assert res1 is res2
コード例 #19
0
ファイル: ftdi.py プロジェクト: agdlgv/sahara_emulator
class USBFtdiInterface(USBInterface):
    name = 'FtdiInterface'

    def __init__(self, app, phy, interface_number):
        super(USBFtdiInterface, self).__init__(
            app=app,
            phy=phy,
            interface_number=interface_number,
            interface_alternate=0,
            interface_class=USBClass.VendorSpecific,
            interface_subclass=0xff,
            interface_protocol=0xff,
            interface_string_index=0,
            endpoints=[
                USBEndpoint(app=app,
                            phy=phy,
                            number=1,
                            direction=USBEndpoint.direction_out,
                            transfer_type=USBEndpoint.transfer_type_bulk,
                            sync_type=USBEndpoint.sync_type_none,
                            usage_type=USBEndpoint.usage_type_data,
                            max_packet_size=0x40,
                            interval=0,
                            handler=self.handle_data_available),
                USBEndpoint(
                    app=app,
                    phy=phy,
                    number=3,
                    direction=USBEndpoint.direction_in,
                    transfer_type=USBEndpoint.transfer_type_bulk,
                    sync_type=USBEndpoint.sync_type_none,
                    usage_type=USBEndpoint.usage_type_data,
                    max_packet_size=0x40,
                    interval=0,
                    handler=self.
                    handle_ep3_buffer_available  # at this point, we don't send data to the host
                )
            ],
        )
        self.txq = Queue()

    def handle_data_available(self, data):
        self.debug('received string (%d): %s' % (len(data), data))
        reply = b'\x01\x00' + data
        self.txq.put(reply)

    def handle_ep3_buffer_available(self):
        if not self.txq.empty():
            self.send_on_endpoint(3, self.txq.get())
コード例 #20
0
ファイル: test_poutines.py プロジェクト: zippeurfou/pyro
class QueueHandlerMixedTest(TestCase):
    def setUp(self):

        # Simple model with 1 continuous + 1 discrete + 1 continuous variable.
        def model():
            p = torch.tensor([0.5])
            loc = torch.zeros(1)
            scale = torch.ones(1)

            x = pyro.sample("x",
                            Normal(loc,
                                   scale))  # Before the discrete variable.
            y = pyro.sample("y", Bernoulli(p))
            z = pyro.sample("z", Normal(loc,
                                        scale))  # After the discrete variable.
            return dict(x=x, y=y, z=z)

        self.sites = ["x", "y", "z", "_INPUT", "_RETURN"]
        self.model = model
        self.queue = Queue()
        self.queue.put(poutine.Trace())

    def test_queue_single(self):
        f = poutine.trace(poutine.queue(self.model, queue=self.queue))
        tr = f.get_trace()
        for name in self.sites:
            assert name in tr

    def test_queue_enumerate(self):
        f = poutine.trace(poutine.queue(self.model, queue=self.queue))
        trs = []
        while not self.queue.empty():
            trs.append(f.get_trace())
        assert len(trs) == 2

        values = [{
            name: tr.nodes[name]['value'].view(-1).item()
            for name in tr.nodes.keys() if tr.nodes[name]['type'] == 'sample'
        } for tr in trs]

        expected_ys = set([0, 1])
        actual_ys = set([value["y"] for value in values])
        assert actual_ys == expected_ys

        # Check that x was sampled the same on all each paths.
        assert values[0]["x"] == values[1]["x"]

        # Check that y was sampled differently on each path.
        assert values[0]["z"] != values[1]["z"]  # Almost surely true.
コード例 #21
0
ファイル: call_python_client.py プロジェクト: psprecher/drake
    def _handle_messages_threaded(self):
        # Handles messages in a threaded fashion.
        queue = Queue()

        def producer_loop():
            # Read messages from file, and queue them for execution.
            for msg in self._read_next_message():
                queue.put(msg)
                # Check if an error occurred.
                if self._done:
                    break
            # Wait until the queue empties out to signal completion from the
            # producer's side.
            if not self._done:
                queue.join()
                self._done = True

        producer = Thread(name="Producer", target=producer_loop)
        # @note Previously, when trying to do `queue.clear()` in the consumer,
        # and `queue.join()` in the producer, there would be intermittent
        # deadlocks. By demoting the producer to a daemon, I (eric.c) have not
        # yet encountered a deadlock.
        producer.daemon = True
        producer.start()

        # Consume.
        # TODO(eric.cousineau): Trying to quit via Ctrl+C is awkward (but kinda
        # works). Is there a way to have `plt.pause` handle Ctrl+C differently?
        try:
            pause = self.scope_globals['pause']
            while not self._done:
                # Process messages.
                while not queue.empty():
                    msg = queue.get()
                    queue.task_done()
                    self._execute_message(msg)
                # Spin busy for a bit, let matplotlib (or whatever) flush its
                # event queue.
                pause(0.01)
        except KeyboardInterrupt:
            # User pressed Ctrl+C.
            self._done = True
            print("Quitting")
        except Exception as e:
            # We encountered an error, and must stop.
            self._done = True
            self._had_error = True
            traceback.print_exc(file=sys.stderr)
            sys.stderr.write("  Stopping (--stop_on_error)\n")
コード例 #22
0
ファイル: test_poutines.py プロジェクト: lewisKit/pyro
class QueueHandlerMixedTest(TestCase):

    def setUp(self):

        # Simple model with 1 continuous + 1 discrete + 1 continuous variable.
        def model():
            p = torch.tensor([0.5])
            loc = torch.zeros(1)
            scale = torch.ones(1)

            x = pyro.sample("x", Normal(loc, scale))  # Before the discrete variable.
            y = pyro.sample("y", Bernoulli(p))
            z = pyro.sample("z", Normal(loc, scale))  # After the discrete variable.
            return dict(x=x, y=y, z=z)

        self.sites = ["x", "y", "z", "_INPUT", "_RETURN"]
        self.model = model
        self.queue = Queue()
        self.queue.put(poutine.Trace())

    def test_queue_single(self):
        f = poutine.trace(poutine.queue(self.model, queue=self.queue))
        tr = f.get_trace()
        for name in self.sites:
            assert name in tr

    def test_queue_enumerate(self):
        f = poutine.trace(poutine.queue(self.model, queue=self.queue))
        trs = []
        while not self.queue.empty():
            trs.append(f.get_trace())
        assert len(trs) == 2

        values = [
            {name: tr.nodes[name]['value'].view(-1).item() for name in tr.nodes.keys()
             if tr.nodes[name]['type'] == 'sample'}
            for tr in trs
        ]

        expected_ys = set([0, 1])
        actual_ys = set([value["y"] for value in values])
        assert actual_ys == expected_ys

        # Check that x was sampled the same on all each paths.
        assert values[0]["x"] == values[1]["x"]

        # Check that y was sampled differently on each path.
        assert values[0]["z"] != values[1]["z"]  # Almost surely true.
コード例 #23
0
ファイル: audio.py プロジェクト: bkerler/sahara_emulator
class AudioStreaming(object):

    def __init__(self, app, phy, tx_ep, rx_ep):
        self.app = app
        self.phy = phy
        self.tx_ep = tx_ep
        self.rx_ep = rx_ep
        self.txq = Queue()

    def buffer_available(self):
        if self.txq.empty():
            self.phy.send_on_endpoint(self.tx_ep, b'\x00\x00\x00\x00\x00\x00\x00\x00')
        else:
            self.phy.send_on_endpoint(self.tx_ep, self.txq.get())

    def data_available(self, data):
        self.app.logger.info('[AudioStreaming] Got %#x bytes on streaming endpoint' % (len(data)))
コード例 #24
0
ファイル: audio.py プロジェクト: nccgroup/umap2
class AudioStreaming(object):

    def __init__(self, app, phy, tx_ep, rx_ep):
        self.app = app
        self.phy = phy
        self.tx_ep = tx_ep
        self.rx_ep = rx_ep
        self.txq = Queue()

    def buffer_available(self):
        if self.txq.empty():
            self.phy.send_on_endpoint(self.tx_ep, b'\x00\x00\x00\x00\x00\x00\x00\x00')
        else:
            self.phy.send_on_endpoint(self.tx_ep, self.txq.get())

    def data_available(self, data):
        self.app.logger.info('[AudioStreaming] Got %#x bytes on streaming endpoint' % (len(data)))
コード例 #25
0
def _build_droot_impact(destroy_handler):
    droot = {}   # destroyed view + nonview variables -> foundation
    impact = {}  # destroyed nonview variable -> it + all views of it
    root_destroyer = {}  # root -> destroyer apply

    for app in destroy_handler.destroyers:
        for output_idx, input_idx_list in app.op.destroy_map.items():
            if len(input_idx_list) != 1:
                raise NotImplementedError()
            input_idx = input_idx_list[0]
            input = app.inputs[input_idx]

            # Find non-view variable which is ultimatly viewed by input.
            view_i = destroy_handler.view_i
            _r = input
            while _r is not None:
                r = _r
                _r = view_i.get(r)
            input_root = r

            if input_root in droot:
                raise InconsistencyError(
                    "Multiple destroyers of %s" % input_root)
            droot[input_root] = input_root
            root_destroyer[input_root] = app

            # The code here add all the variables that are views of r into
            # an OrderedSet input_impact
            input_impact = OrderedSet()
            queue = Queue()
            queue.put(input_root)
            while not queue.empty():
                v = queue.get()
                for n in destroy_handler.view_o.get(v, []):
                    input_impact.add(n)
                    queue.put(n)

            for v in input_impact:
                assert v not in droot
                droot[v] = input_root

            impact[input_root] = input_impact
            impact[input_root].add(input_root)

    return droot, impact, root_destroyer
コード例 #26
0
ファイル: thread.py プロジェクト: new07/pypeln
class _InputQueue(object):

    def __init__(self, maxsize, total_done, pipeline_namespace, **kwargs):
        
        self.queue = Queue(maxsize = maxsize, **kwargs)
        self.lock = Lock()
        self.namespace = _get_namespace()
        self.namespace.remaining = total_done

        self.pipeline_namespace = pipeline_namespace

    def __iter__(self):

        while not self.is_done():
            x = self.get()

            if self.pipeline_namespace.error:
                return

            if not utils.is_continue(x):
                yield x

    def get(self):
        
        try:
            x = self.queue.get(timeout = utils.TIMEOUT)
        except (Empty, Full):
            return utils.CONTINUE
        
        if not utils.is_done(x):
            return x
        else:
            with self.lock:
                self.namespace.remaining -= 1
            
            return utils.CONTINUE

    def is_done(self):
        return self.namespace.remaining == 0 and self.queue.empty()

    def put(self, x):
        self.queue.put(x)

    def done(self):
        self.queue.put(utils.DONE)
コード例 #27
0
ファイル: destroyhandler.py プロジェクト: 12190143/Theano
def _build_droot_impact(destroy_handler):
    droot = {}   # destroyed view + nonview variables -> foundation
    impact = {}  # destroyed nonview variable -> it + all views of it
    root_destroyer = {}  # root -> destroyer apply

    for app in destroy_handler.destroyers:
        for output_idx, input_idx_list in app.op.destroy_map.items():
            if len(input_idx_list) != 1:
                raise NotImplementedError()
            input_idx = input_idx_list[0]
            input = app.inputs[input_idx]

            # Find non-view variable which is ultimatly viewed by input.
            view_i = destroy_handler.view_i
            _r = input
            while _r is not None:
                r = _r
                _r = view_i.get(r)
            input_root = r

            if input_root in droot:
                raise InconsistencyError(
                    "Multiple destroyers of %s" % input_root)
            droot[input_root] = input_root
            root_destroyer[input_root] = app

            # The code here add all the variables that are views of r into
            # an OrderedSet input_impact
            input_impact = OrderedSet()
            queue = Queue()
            queue.put(input_root)
            while not queue.empty():
                v = queue.get()
                for n in destroy_handler.view_o.get(v, []):
                    input_impact.add(n)
                    queue.put(n)

            for v in input_impact:
                assert v not in droot
                droot[v] = input_root

            impact[input_root] = input_impact
            impact[input_root].add(input_root)

    return droot, impact, root_destroyer
コード例 #28
0
class SimpleThreadPool:
    def __init__(self, num_threads=3):
        self._num_threads = num_threads
        self._queue = Queue(2000)
        self._lock = Lock()
        self._active = False
        self._workers = []
        self._finished = False

    def add_task(self, func, *args, **kwargs):
        if not self._active:
            with self._lock:
                if not self._active:
                    self._workers = []
                    self._active = True
                    for i in range(self._num_threads):
                        w = WorkerThread(self._queue)
                        self._workers.append(w)
                        w.start()

        self._queue.put((func, args, kwargs))

    def release(self):
        while self._queue.empty() is False:
            time.sleep(1)

    def wait_completion(self):
        self._queue.join()
        self._finished = True
        # 已经结束的任务, 需要将线程都退出, 防止卡死
        for i in range(self._num_threads):
            self._queue.put((None, None, None))

        self._active = False

    def complete(self):
        self._finished = True

    def get_result(self):
        assert self._finished
        detail = [worker.get_result() for worker in self._workers]
        succ_num = sum([tp[0] for tp in detail])
        fail_num = sum([tp[1] for tp in detail])
        return {'success_num': succ_num, 'fail_num': fail_num}
コード例 #29
0
def remote2local_sync_delete(src, dst, **kwargs):
    """
    下载sync时携带--delete,删除本地存在而cos上不存在的对象
    """
    q = Queue()
    q.put([dst['Path'], src['Path']])
    success_num = 0
    fail_num = 0
    # BFS上传文件夹
    try:
        while (not q.empty()):
            [local_path, cos_path] = q.get()
            local_path = to_unicode(local_path)
            cos_path = to_unicode(cos_path)
            if cos_path.endswith('/') is False:
                cos_path += "/"
            if local_path.endswith('/') is False:
                local_path += "/"
            cos_path = cos_path.lstrip('/')
            # 当前目录下的文件列表
            dirlist = os.listdir(local_path)
            for filename in dirlist:
                filepath = os.path.join(local_path, filename)
                if os.path.isdir(filepath):
                    q.put([filepath, cos_path + filename])
                else:
                    try:
                        src['Client'].head_object(Bucket=src['Bucket'],
                                                  Key=cos_path + filename)
                    except CosServiceError as e:
                        if e.get_status_code() == 404:
                            try:
                                os.remove(filepath)
                                logger.info(
                                    u"Delete {file}".format(file=filepath))
                                success_num += 1
                            except Exception:
                                logger.info(u"Delete {file} fail".format(
                                    file=filepath))
                                fail_num += 1
    except Exception as e:
        logger.warn(e)
        return [-1, 0, 0]
    return [0, success_num, fail_num]
コード例 #30
0
ファイル: gadgetfs_phy.py プロジェクト: webstorage119/umap2
class InEpThread(EndpointThread):
    def __init__(self, phy, ep):
        super(InEpThread, self).__init__(phy, ep)
        self.queue = Queue()

    def send(self, data):
        self.queue.put(data)

    def handling_write(self):
        return not self.queue.empty()

    def io_op(self):
        '''
         Fetch data from send queue and write to endpoint
        '''
        try:
            data = self.queue.get(True, 0.1)
            os.write(self.ep.fd, data)
        except Empty:
            pass
コード例 #31
0
ファイル: gadgetfs_phy.py プロジェクト: nccgroup/umap2
class InEpThread(EndpointThread):

    def __init__(self, phy, ep):
        super(InEpThread, self).__init__(phy, ep)
        self.queue = Queue()

    def send(self, data):
        self.queue.put(data)

    def handling_write(self):
        return not self.queue.empty()

    def io_op(self):
        '''
         Fetch data from send queue and write to endpoint
        '''
        try:
            data = self.queue.get(True, 0.1)
            os.write(self.ep.fd, data)
        except Empty:
            pass
コード例 #32
0
ファイル: amqp.py プロジェクト: kuldat/anypubsub
class AmqpSubscriber(Subscriber):
    def __init__(self, amqp_chan, exchanges):
        self.channel = amqp_chan
        self.messages = Queue(maxsize=0)
        qname, _, _ = self.channel.queue_declare()
        for exchange in exchanges:
            self.channel.queue_bind(qname, exchange)
        self.channel.basic_consume(queue=qname, callback=self.callback)

    def callback(self, msg):
        self.channel.basic_ack(msg.delivery_tag)
        self.messages.put_nowait(msg.body)

    def __iter__(self):
        return self

    def next(self):
        while self.messages.empty():
            self.channel.wait()
        return self.messages.get_nowait()

    __next__ = next   # PY3
コード例 #33
0
class AmqpSubscriber(Subscriber):
    def __init__(self, amqp_chan, exchanges):
        self.channel = amqp_chan
        self.messages = Queue(maxsize=0)
        qname, _, _ = self.channel.queue_declare()
        for exchange in exchanges:
            self.channel.queue_bind(qname, exchange)
        self.channel.basic_consume(queue=qname, callback=self.callback)

    def callback(self, msg):
        self.channel.basic_ack(msg.delivery_tag)
        self.messages.put_nowait(msg.body)

    def __iter__(self):
        return self

    def next(self):
        while self.messages.empty():
            self.channel.wait()
        return self.messages.get_nowait()

    __next__ = next  # PY3
コード例 #34
0
ファイル: py.py プロジェクト: pombredanne/syn-1
def hangwatch(timeout, func, *args, **kwargs):
    def target(queue):
        try:
            func(*args, **kwargs)
        except Exception as e:
            queue.put(sys.exc_info())
            queue.put(e)
            sys.exit()

    q = Queue()
    thread = threading.Thread(target=target, args=(q, ))

    thread.start()
    thread.join(timeout)
    if thread.is_alive():
        raise RuntimeError(
            'Operation did not terminate within {} seconds'.format(timeout))

    if not q.empty():
        info = q.get(block=False)
        e = q.get(block=False)
        eprint(''.join(traceback.format_exception(*info)))
        raise e
コード例 #35
0
ファイル: py.py プロジェクト: mbodenhamer/syn
def hangwatch(timeout, func, *args, **kwargs):
    def target(queue):
        try:
            func(*args, **kwargs)
        except Exception as e:
            queue.put(sys.exc_info())
            queue.put(e)
            sys.exit()

    q = Queue()
    thread = threading.Thread(target=target, args = (q,))
    
    thread.start()
    thread.join(timeout)
    if thread.is_alive():
        raise RuntimeError('Operation did not terminate within {} seconds'
                           .format(timeout))

    if not q.empty():
        info = q.get(block=False)
        e = q.get(block=False)
        eprint(''.join(traceback.format_exception(*info)))
        raise e
コード例 #36
0
ファイル: threadPool.py プロジェクト: nirs/vdsm
class ThreadPool:

    """Flexible thread pool class.  Creates a pool of threads, then
    accepts tasks that will be dispatched to the next available
    thread."""

    log = logging.getLogger("storage.ThreadPool")

    def __init__(self, name, numThreads, waitTimeout=3, maxTasks=100):

        """Initialize the thread pool with numThreads workers."""

        self.log.debug(
            "Enter - name: %s, numThreads: %s, waitTimeout: %s, " "maxTasks: %s",
            name,
            numThreads,
            waitTimeout,
            maxTasks,
        )
        self._name = name
        self._count = itertools.count()
        self.__threads = []
        self._taskThread = {}
        self.__resizeLock = threading.Condition(threading.Lock())
        self.__runningTasksLock = threading.Condition(threading.Lock())
        self.__tasks = Queue(maxTasks)
        self.__isJoining = False
        self.__runningTasks = 0
        self.__waitTimeout = waitTimeout
        self.setThreadCount(numThreads)

    def setRunningTask(self, addTask):

        """ Internal method to increase or decrease a counter of current
        executing tasks."""

        self.__runningTasksLock.acquire()
        try:
            if addTask:
                self.__runningTasks += 1
            else:
                self.__runningTasks -= 1
            self.log.debug("Number of running tasks: %s", self.__runningTasks)
        finally:
            self.__runningTasksLock.release()

    def getRunningTasks(self):
        return self.__runningTasks

    def setThreadCount(self, newNumThreads):

        """ External method to set the current pool size.  Acquires
        the resizing lock, then calls the internal version to do real
        work."""

        # Can't change the thread count if we're shutting down the pool!
        if self.__isJoining:
            return False

        self.__resizeLock.acquire()
        try:
            self.__setThreadCountNolock(newNumThreads)
        finally:
            self.__resizeLock.release()
        return True

    def __setThreadCountNolock(self, newNumThreads):

        """Set the current pool size, spawning or terminating threads
        if necessary.  Internal use only; assumes the resizing lock is
        held."""

        # If we need to grow the pool, do so
        while newNumThreads > len(self.__threads):
            name = "%s/%d" % (self._name, next(self._count))
            newThread = WorkerThread(self, name)
            self.__threads.append(newThread)
            newThread.start()
        # If we need to shrink the pool, do so
        while newNumThreads < len(self.__threads):
            self.__threads[0].goAway()
            del self.__threads[0]

    def getThreadCount(self):

        """Return the number of threads in the pool."""

        self.__resizeLock.acquire()
        try:
            return len(self.__threads)
        finally:
            self.__resizeLock.release()

    def queueTask(self, id, task, args=None, taskCallback=None):

        """Insert a task into the queue.  task must be callable;
        args and taskCallback can be None."""

        if self.__isJoining:
            return False
        if not callable(task):
            return False

        self.__tasks.put((id, task, args, taskCallback))

        return True

    def getNextTask(self):

        """ Retrieve the next task from the task queue.  For use
        only by WorkerThread objects contained in the pool."""
        id = None
        cmd = None
        args = None
        callback = None

        try:
            id, cmd, args, callback = self.__tasks.get(True, self.__waitTimeout)
        except Empty:
            pass

        return id, cmd, args, callback

    def stopThread(self):
        return self.__tasks.put((None, None, None, None))

    def joinAll(self, waitForTasks=True, waitForThreads=True):

        """ Clear the task queue and terminate all pooled threads,
        optionally allowing the tasks and threads to finish."""

        # Mark the pool as joining to prevent any more task queuing
        self.__isJoining = True

        # Wait for tasks to finish
        if waitForTasks:
            while not self.__tasks.empty():
                sleep(0.1)

        # Tell all the threads to quit
        self.__resizeLock.acquire()
        try:
            # Wait until all threads have exited
            if waitForThreads:
                for t in self.__threads:
                    t.goAway()
                for t in self.__threads:
                    t.join()
                    #                    print t,"joined"
                    del t
            self.__setThreadCountNolock(0)
            self.__isJoining = True

            # Reset the pool for potential reuse
            self.__isJoining = False
        finally:
            self.__resizeLock.release()
コード例 #37
0
ファイル: log.py プロジェクト: H4dr1en/trains
class TaskHandler(BufferingHandler):
    __flush_max_history_seconds = 30.
    __wait_for_flush_timeout = 10.
    __max_event_size = 1024 * 1024
    __once = False
    __offline_filename = 'log.jsonl'

    @property
    def task_id(self):
        return self._task_id

    @task_id.setter
    def task_id(self, value):
        self._task_id = value

    def __init__(self, task, capacity=buffer_capacity, connect_logger=True):
        super(TaskHandler, self).__init__(capacity)
        self.task_id = task.id
        self.session = task.session
        self.last_timestamp = 0
        self.counter = 1
        self._last_event = None
        self._exit_event = None
        self._queue = None
        self._thread = None
        self._pending = 0
        self._offline_log_filename = None
        self._connect_logger = connect_logger
        if task.is_offline():
            offline_folder = Path(task.get_offline_mode_folder())
            offline_folder.mkdir(parents=True, exist_ok=True)
            self._offline_log_filename = offline_folder / self.__offline_filename

    def shouldFlush(self, record):
        """
        Should the handler flush its buffer

        Returns true if the buffer is up to capacity. This method can be
        overridden to implement custom flushing strategies.
        """
        if self._task_id is None:
            return False

        # if we need to add handlers to the base_logger,
        # it will not automatically create stream one when first used, so we must manually configure it.
        if self._connect_logger and not TaskHandler.__once:
            base_logger = getLogger()
            if len(base_logger.handlers) == 1 and isinstance(
                    base_logger.handlers[0], TaskHandler):
                if record.name != 'console' and not record.name.startswith(
                        'clearml.'):
                    base_logger.removeHandler(self)
                    basicConfig()
                    base_logger.addHandler(self)
                    TaskHandler.__once = True
            else:
                TaskHandler.__once = True

        # if we passed the max buffer
        if len(self.buffer) >= self.capacity:
            return True

        # if the first entry in the log was too long ago.
        # noinspection PyBroadException
        try:
            if len(self.buffer) and (time.time() - self.buffer[0].created
                                     ) > self.__flush_max_history_seconds:
                return True
        except Exception:
            pass

        return False

    def _record_to_event(self, record):
        # type: (LogRecord) -> events.TaskLogEvent
        if self._task_id is None:
            return None
        timestamp = int(record.created * 1000)
        if timestamp == self.last_timestamp:
            timestamp += self.counter
            self.counter += 1
        else:
            self.last_timestamp = timestamp
            self.counter = 1

        # ignore backspaces (they are often used)
        full_msg = record.getMessage().replace('\x08', '')

        return_events = []
        while full_msg:
            msg = full_msg[:self.__max_event_size]
            full_msg = full_msg[self.__max_event_size:]
            # unite all records in a single second
            if self._last_event and timestamp - self._last_event.timestamp < 1000 and \
                    len(self._last_event.msg) + len(msg) < self.__max_event_size and \
                    record.levelname.lower() == str(self._last_event.level):
                # ignore backspaces (they are often used)
                self._last_event.msg += '\n' + msg
                continue

            # if we have a previous event and it timed out, return it.
            new_event = events.TaskLogEvent(task=self.task_id,
                                            timestamp=timestamp,
                                            level=record.levelname.lower(),
                                            worker=self.session.worker,
                                            msg=msg)
            if self._last_event:
                return_events.append(self._last_event)

            self._last_event = new_event

        return return_events

    def flush(self):
        if self._task_id is None:
            return

        if not self.buffer:
            return

        buffer = None
        self.acquire()
        if self.buffer:
            buffer = self.buffer
            self.buffer = []
        self.release()

        if not buffer:
            return

        # noinspection PyBroadException
        try:
            record_events = [
                r for record in buffer for r in self._record_to_event(record)
            ] + [self._last_event]
            self._last_event = None
            batch_requests = events.AddBatchRequest(
                requests=[events.AddRequest(e) for e in record_events if e])
        except Exception:
            self.__log_stderr(
                "WARNING: clearml.log - Failed logging task to backend ({:d} lines)"
                .format(len(buffer)))
            batch_requests = None

        if batch_requests and batch_requests.requests:
            self._pending += 1
            self._add_to_queue(batch_requests)

    def _create_thread_queue(self):
        if self._queue:
            return

        self._queue = Queue()
        self._exit_event = Event()
        self._exit_event.clear()
        # multiple workers could be supported as well
        self._thread = Thread(target=self._daemon)
        self._thread.daemon = True
        self._thread.start()

    def _add_to_queue(self, request):
        self._create_thread_queue()
        self._queue.put(request)

    def close(self, wait=False):
        # self.__log_stderr('Closing {} wait={}'.format(os.getpid(), wait))
        # flush pending logs
        if not self._task_id:
            return
        # avoid deadlocks just skip the lock, we are shutting down anyway
        self.lock = None

        self.flush()
        # shut down the TaskHandler, from this point onwards. No events will be logged
        _thread = self._thread
        self._thread = None
        if self._queue:
            self._exit_event.set()
            self._queue.put(None)
        self._task_id = None

        if wait and _thread:
            # noinspection PyBroadException
            try:
                timeout = 1. if self._queue.empty(
                ) else self.__wait_for_flush_timeout
                _thread.join(timeout=timeout)
                if not self._queue.empty():
                    self.__log_stderr(
                        'Flush timeout {}s exceeded, dropping last {} lines'.
                        format(timeout, self._queue.qsize()))
                # self.__log_stderr('Closing {} wait done'.format(os.getpid()))
            except Exception:
                pass
        # call super and remove the handler
        super(TaskHandler, self).close()

    def _send_events(self, a_request):
        try:
            self._pending -= 1

            if self._offline_log_filename:
                with open(self._offline_log_filename.as_posix(), 'at') as f:
                    f.write(
                        json.dumps([b.to_dict()
                                    for b in a_request.requests]) + '\n')
                return

            # if self._thread is None:
            #     self.__log_stderr('Task.close() flushing remaining logs ({})'.format(self._pending))
            res = self.session.send(a_request)
            if res and not res.ok():
                self.__log_stderr(
                    "failed logging task to backend ({:d} lines, {})".format(
                        len(a_request.requests), str(res.meta)),
                    level=WARNING)
        except MaxRequestSizeError:
            self.__log_stderr(
                "failed logging task to backend ({:d} lines) log size exceeded limit"
                .format(len(a_request.requests)),
                level=WARNING)
        except Exception as ex:
            self.__log_stderr(
                "Retrying, failed logging task to backend ({:d} lines): {}".
                format(len(a_request.requests), ex))
            # we should push ourselves back into the thread pool
            if self._queue:
                self._pending += 1
                self._queue.put(a_request)

    def _daemon(self):
        # multiple daemons are supported
        leave = self._exit_event.wait(0)
        request = True
        while not leave or request:
            # pull from queue
            request = None
            if self._queue:
                # noinspection PyBroadException
                try:
                    request = self._queue.get(block=not leave)
                except Exception:
                    pass
            if request:
                self._send_events(request)
            leave = self._exit_event.wait(0)
        # self.__log_stderr('leaving {}'.format(os.getpid()))

    @staticmethod
    def __log_stderr(msg, level=INFO):
        # output directly to stderr, make sure we do not catch it.
        write = sys.stderr._original_write if hasattr(
            sys.stderr, '_original_write') else sys.stderr.write
        write('{asctime} - {name} - {levelname} - {message}\n'.format(
            asctime=Formatter().formatTime(makeLogRecord({})),
            name='clearml.log',
            levelname=getLevelName(level),
            message=msg))

    @classmethod
    def report_offline_session(cls, task, folder):
        filename = Path(folder) / cls.__offline_filename
        if not filename.is_file():
            return False
        with open(filename.as_posix(), 'rt') as f:
            i = 0
            while True:
                try:
                    line = f.readline()
                    if not line:
                        break
                    list_requests = json.loads(line)
                    for r in list_requests:
                        r.pop('task', None)
                    i += 1
                except StopIteration:
                    break
                except Exception as ex:
                    warning('Failed reporting log, line {} [{}]'.format(i, ex))
                batch_requests = events.AddBatchRequest(requests=[
                    events.TaskLogEvent(task=task.id, **r)
                    for r in list_requests
                ])
                if batch_requests.requests:
                    res = task.session.send(batch_requests)
                    if res and not res.ok():
                        warning(
                            "failed logging task to backend ({:d} lines, {})".
                            format(len(batch_requests.requests),
                                   str(res.meta)))
        return True
コード例 #38
0
ファイル: server.py プロジェクト: psav/riggerlib
class Rigger(object):
    """ A Rigger event framework instance.

    The Rigger object holds all configuration and instances of plugins. By default Rigger accepts
    a configuration file name to parse, though it is perfectly acceptable to pass the configuration
    into the ``self.config`` attribute.

    Args:
        config_file: A configuration file holding all of Riggers base and plugin configuration.
    """
    def __init__(self, config_file):
        self.gdl = threading.Lock()
        self.pre_callbacks = defaultdict(dict)
        self.post_callbacks = defaultdict(dict)
        self.plugins = {}
        self.config_file = config_file
        self.squash_exceptions = False
        self.initialized = False
        self._task_list = {}
        self._queue_lock = threading.Lock()
        self._global_queue = Queue()
        self._background_queue = Queue()
        self._server_shutdown = False
        self._zmq_event_handler_shutdown = False
        self._global_queue_shutdown = False
        self._background_queue_shutdown = False

        globt = threading.Thread(target=self.process_queue, name="global_queue_processor")
        globt.start()
        bgt = threading.Thread(
            target=self.process_background_queue, name="background_queue_processor")
        bgt.start()

    def process_queue(self):
        """
        The ``process_queue`` thread manages taking events on and off of the global queue.
        Both TCP and in-object fire_hooks place events onto the global_queue and these are both
        handled by the same handler called ``process_hook``. If there is an exception during
        processing, the exception is printed and execution continues.
        """
        while not self._global_queue_shutdown:
            while not self._global_queue.empty():
                with self._queue_lock:
                    tid = self._global_queue.get()
                    obj = self._task_list[tid].json_dict
                    self._task_list[tid].status = Task.RUNNING
                try:
                    loc, glo = self.process_hook(obj['hook_name'], **obj['data'])
                    combined_dict = {}
                    combined_dict.update(glo)
                    combined_dict.update(loc)
                    self._task_list[tid].output = combined_dict
                except Exception as e:
                    self.log_message(e)
                with self._queue_lock:
                    self._global_queue.task_done()
                    self._task_list[tid].status = Task.FINISHED
                if not self._task_list[tid].json_dict.get('grab_result', None):
                    del self._task_list[tid]
            time.sleep(0.1)

    def process_background_queue(self):
        """
        The ``process_background_queue`` manages the hooks which have been backgrounded. In this
        respect the tasks that are completed are not required to continue with the test and as such
        can be forgotten about. An example of this would be some that sends an email, or tars up
        files, it has all the information it needs and the main process doesn't need to wait for it
        to complete.
        """
        while not self._background_queue_shutdown:
            while not self._background_queue.empty():
                obj = self._background_queue.get()
                try:
                    local, globals_updates = self.process_callbacks(obj['cb'], obj['kwargs'])
                    with self.gdl:
                        self.global_data = recursive_update(self.global_data, globals_updates)
                except Exception as e:
                    self.log_message(e)
                self._background_queue.task_done()
            time.sleep(0.1)

    def zmq_event_handler(self, zmq_socket_address):
        """
        The ``zmq_event_handler`` thread receives (and responds to) updates from the
        zmq socket, which is normally embedded in the web server running alongside this
        riggerlib instance, in its own process.

        """
        ctx = zmq.Context()
        zmq_socket = ctx.socket(zmq.REP)
        zmq_socket.set(zmq.RCVTIMEO, 300)
        zmq_socket.bind(zmq_socket_address)

        def zmq_reply(message, **extra):
            payload = {'message': message}
            payload.update(extra)
            zmq_socket.send_json(payload)
        bad_request = partial(zmq_reply, 'BAD REQUEST')

        while not self._zmq_event_handler_shutdown:
            try:
                json_dict = zmq_socket.recv_json()
            except zmq.Again:
                continue

            try:
                event_name = json_dict['event_name']
            except KeyError:
                bad_request()

            if event_name == 'fire_hook':
                tid = self._fire_internal_hook(json_dict)
                if tid:
                    zmq_reply('OK', tid=tid)
                else:
                    bad_request()
            elif event_name == 'task_check':
                try:
                    tid = json_dict['tid']
                    extra = {
                        "tid": tid,
                        "status": self._task_list[tid].status,
                    }
                    if json_dict['grab_result']:
                        extra["output"] = self._task_list[tid].output
                    zmq_reply('OK', **extra)
                except KeyError:
                    zmq_reply('NOT FOUND')
            elif event_name == 'task_delete':
                try:
                    tid = json_dict['tid']
                    del self._task_list[tid]
                    zmq_reply('OK', tid=tid)
                except KeyError:
                    zmq_reply('OK', tid=tid)
            elif event_name == 'shutdown':
                zmq_reply('OK')
                # We gotta initiate server stop from here and stop this thread
                self._server_shutdown = True
                break
            elif event_name == 'ping':
                zmq_reply('PONG')
            else:
                bad_request()

        zmq_socket.close()

    def read_config(self, config_file):
        """
        Reads in the config file and parses the yaml data.

        Args:
            config_file: A configuration file holding all of Riggers base and plugin configuration.

        Raises:
            IOError: If the file can not be read.
            Exception: If there is any error parsing the configuration file.
        """
        try:
            with open(config_file, "r") as stream:
                data = yaml.load(stream)
        except IOError:
            print("!!! Configuration file could not be loaded...exiting")
            sys.exit(127)
        except Exception as e:
            print(e)
            print("!!! Error parsing Configuration file")
            sys.exit(127)
        self.config = data

    def parse_config(self):
        """
        Takes the configuration data from ``self.config`` and sets up the plugin instances.
        """
        self.read_config(self.config_file)
        self.setup_plugin_instances()
        self.start_server()

    def setup_plugin_instances(self):
        """
        Sets up instances into a dict called ``self.instances`` and instantiates each
        instance of the plugin. It also sets the ``self._threaded`` option to determine
        if plugins will be processed synchronously or asynchronously.
        """
        self.instances = {}
        self._threaded = self.config.get("threaded", False)
        plugins = self.config.get("plugins", {})
        for ident, config in plugins.items():
            self.setup_instance(ident, config)

    def setup_instance(self, ident, config):
        """
        Sets up a single instance into the ``self.instances`` dict. If the instance does
        not exist, a warning is printed out.

        Args:
            ident: A plugin instance identifier.
            config: Configuration dict from the yaml.
        """
        plugin_name = config.get('plugin', {})
        if plugin_name in self.plugins:
            obj = self.plugins[plugin_name]
            if obj:
                obj_instance = obj(ident, config, self)
                self.instances[ident] = RiggerPluginInstance(ident, obj_instance, config)
        else:
            msg = "Plugin [{}] was not found, "\
                  "disabling instance [{}]".format(plugin_name, ident)
            self.log_message(msg)

    def start_server(self):
        """
        Starts the ZMQ server if the ``server_enabled`` is True in the config.
        """
        self._server_hostname = self.config.get('server_address', '127.0.0.1')
        self._server_port = self.config.get('server_port', 21212)
        self._server_enable = self.config.get('server_enabled', False)
        if self._server_enable:
            zmq_socket_address = 'tcp://{}:{}'.format(self._server_hostname, self._server_port)
            # set up reciever thread for zmq event handling
            zeh = threading.Thread(
                target=self.zmq_event_handler, args=(zmq_socket_address,), name="zmq_event_handler")
            zeh.start()
            exect = threading.Thread(target=self.await_shutdown, name="executioner")
            exect.start()

    def await_shutdown(self):
        while not self._server_shutdown:
            time.sleep(0.3)
        self.stop_server()

    def stop_server(self):
        """
        Responsible for the following:
            - stopping the zmq event handler (unless already stopped through 'terminate')
            - stopping the global queue
            - stopping the background queue
        """
        self.log_message("Shutdown initiated : {}".format(self._server_hostname))
        # The order here is important
        self._zmq_event_handler_shutdown = True
        self._global_queue.join()
        self._global_queue_shutdown = True
        self._background_queue.join()
        self._background_queue_shutdown = True
        raise SystemExit

    def fire_hook(self, hook_name, **kwargs):
        """
        Parses the hook information into a dict for passing to process_hook. This is used
        to enable both the TCP and in-object fire_hook methods to use the same process_hook
        method call.

        Args:
            hook_name: The name of the hook to fire.
            kwargs: The kwargs to pass to the hooks.

        """
        json_dict = {'hook_name': hook_name, 'data': kwargs}
        self._fire_internal_hook(json_dict)

    def _fire_internal_hook(self, json_dict):
        task = Task(json_dict)
        tid = task.tid.hexdigest()
        self._task_list[tid] = task
        if self._global_queue:
            with self._queue_lock:
                self._global_queue.put(tid)
            return tid
        else:
            return None

    def process_hook(self, hook_name, **kwargs):
        """
        Takes a hook_name and a selection of kwargs and fires off the appropriate callbacks.

        This function is the guts of Rigger and is responsible for running the callback and
        hook functions. It first loads some blank dicts to collect the updates for the local
        and global namespaces. After this, it loads the pre_callback functions along with
        the kwargs into the callback collector processor.

        The return values are then classifed into local and global dicts and updates proceed.
        After this, the plugin hooks themselves are then run using the same methodology. Their
        return values are merged with the existing dicts and then the same process happens
        for the post_callbacks.

        Note: If the instance of the plugin has been marked as a background instance, and hooks
              which are called in that instance will be backgrounded. The hook will also not
              be able to return any data to the post-hook callback, although updates to globals
              will be processed as and when the backgrounded task is completed.

        Args:
            hook_name: The name of the hook to fire.
            kwargs: The kwargs to pass to the hooks.
        """
        if not self.initialized:
            return
        kwargs_updates = {}
        globals_updates = {}
        kwargs.update({'config': self.config})

        # First fire off any pre-hook callbacks
        if self.pre_callbacks.get(hook_name):
            # print "Running pre hook callback for {}".format(hook_name)
            kwargs_updates, globals_updates = self.process_callbacks(
                self.pre_callbacks[hook_name].values(), kwargs)

            # Now we can update the kwargs passed to the real hook with the updates
            with self.gdl:
                self.global_data = recursive_update(self.global_data, globals_updates)
            kwargs = recursive_update(kwargs, kwargs_updates)

        # Now fire off each plugin hook
        event_hooks = []
        for instance_name, instance in self.instances.items():
            callbacks = instance.obj.callbacks
            enabled = instance.data.get('enabled', None)
            if callbacks.get(hook_name) and enabled:
                cb = callbacks[hook_name]
                if instance.data.get('background', False):
                    self._background_queue.put({'cb': [cb], 'kwargs': kwargs})
                elif cb['bg']:
                    self._background_queue.put({'cb': [cb], 'kwargs': kwargs})
                else:
                    event_hooks.append(cb)
        kwargs_updates, globals_updates = self.process_callbacks(event_hooks, kwargs)

        # One more update for the post_hook callback
        with self.gdl:
            self.global_data = recursive_update(self.global_data, globals_updates)
        kwargs = recursive_update(kwargs, kwargs_updates)

        # Finally any post-hook callbacks
        if self.post_callbacks.get(hook_name):
            # print "Running post hook callback for {}".format(hook_name)
            kwargs_updates, globals_updates = self.process_callbacks(
                self.post_callbacks[hook_name].values(), kwargs)
            with self.gdl:
                self.global_data = recursive_update(self.global_data, globals_updates)
            kwargs = recursive_update(kwargs, kwargs_updates)
        return kwargs, self.global_data

    def process_callbacks(self, callback_collection, kwargs):
        """
        Processes a collection of callbacks or hooks for a particular event, namely pre, hook or
        post.

        The functions are passed in as an array to ``callback_collection`` and process callbacks
        first iterates each function and ensures that each one has the correct arguments available
        to it. If not, an Exception is raised. Then, depending on whether Threading is enabled or
        not, the functions are either run sequentially, or loaded into a ThreadPool and executed
        asynchronously.

        The returned local and global updates are either collected and processed sequentially, as
        in the case of the non-threaded behaviour, or collected at the end of the
        callback_collection processing and handled there.

        Note:
            It is impossible to predict the order of the functions being run. If the order is
            important, it is advised to create a second event hook that will be fired before the
            other. Rigger has no concept of hook or callback order and is unlikely to ever have.

        Args:
            callback_collection: A list of functions to call.
            kwargs: A set of kwargs to pass to the functions.

        Returns: A tuple of local and global namespace updates.
        """
        loc_collect = {}
        glo_collect = {}
        if self._threaded:
            results_list = []
            pool = ThreadPool(10)
        for cb in callback_collection:
            required_args = [sig for sig in cb['args'] if isinstance(cb['args'][sig].default, type)]
            missing = list(set(required_args).difference(set(self.global_data.keys()))
                           .difference(set(kwargs.keys())))
            if not missing:
                new_kwargs = self.build_kwargs(cb['args'], kwargs)
                if self._threaded:
                    results_list.append(pool.apply_async(cb['func'], [], new_kwargs))
                else:
                    obtain_result = self.handle_results(cb['func'], [], new_kwargs)
                    loc_collect, glo_collect = self.handle_collects(
                        obtain_result, loc_collect, glo_collect)
            else:
                raise Exception('Function {} is missing kwargs {}'
                                .format(cb['func'].__name__, missing))

        if self._threaded:
            pool.close()
            pool.join()
            for result in results_list:
                obtain_result = self.handle_results(result.get, [], {})
                loc_collect, glo_collect = self.handle_collects(
                    obtain_result, loc_collect, glo_collect)
        return loc_collect, glo_collect

    def handle_results(self, call, args, kwargs):
        """
        Handles results and depending on configuration, squashes exceptions and logs or
        returns the obtained result.

        Args:
            call: The function call.
            args: The positional arguments.
            kwargs: The keyword arguments.

        Returns: The obtained result of the callback or hook.
        """
        try:
            obtain_result = call(*args, **kwargs)
        except:
            if self.squash_exceptions:
                obtain_result = None
                self.handle_failure(sys.exc_info())
            else:
                raise

        return obtain_result

    def handle_collects(self, result, loc_collect, glo_collect):
        """
        Handles extracting the information from the hook/callback result.

        If the hook/callback returns None, then the dicts are returned unaltered, else
        they are updated with local, global namespace updates.

        Args:
            result: The result to process.
            loc_collect: The local namespace updates collection.
            glo_collect: The global namespace updates collection.
        Returns: A tuple containing the local and global updates.
        """
        if result:
            if result[0]:
                loc_collect = recursive_update(loc_collect, result[0])
            if result[1]:
                glo_collect = recursive_update(glo_collect, result[1])
        return loc_collect, glo_collect

    def build_kwargs(self, args, kwargs):
        """
        Builds a new kwargs from a list of allowed args.

        Functions only receive a single set of kwargs, and so the global and local namespaces
        have to be collapsed. In this way, the local overrides the global namespace, hence if
        a key exists in both local and global, the local value will be passed to the function
        under the the key name and the global value will be forgotten.

        The args parameter ensures that only the expected arguments are supplied.

        Args:
            args: A list of allowed argument names
            kwargs: A dict of kwargs from the local namespace.
        Returns: A consolidated global/local namespace with local overrides.
        """
        returned_args = {}
        returned_args.update({
            name: self.global_data[name] for name in args
            if name in self.global_data})
        returned_args.update({
            name: kwargs[name] for name in args
            if name in kwargs})
        return returned_args

    def register_hook_callback(self, hook_name=None, ctype="pre", callback=None, name=None):
        """
        Registers pre and post callbacks.

        Takes a callback function and assigns it to the hook_name with an optional identifier.
        The optional identifier makes it possible to hot bind functions into hooks and to
        remove them at a later date with ``unregister_hook_callback``.

        Args:
            hook_name: The name of the event hook to respond to.
            ctype: The call back type, either ``pre`` or ``post``.
            callback: The callback function.
            name: An optional name for the callback instance binding.
        """
        if hook_name and callback:
            callback_instance = self.create_callback(callback)
            if not name:
                name = hashlib.sha1(
                    str(time.time()) + hook_name + str(callback_instance['args'])).hexdigest()
            if ctype == "pre":
                self.pre_callbacks[hook_name][name] = callback_instance
            elif ctype == "post":
                self.post_callbacks[hook_name][name] = callback_instance

    def unregister_hook_callback(self, hook_name, ctype, name):
        """
        Unregisters a pre or post callback.

        If the binding has a known name, this function allows the removal of a binding.

        Args:
            hook_name: The event hook name.
            ctype: The callback type, either ``pre`` or ``post``.
            name: An optional name for the callback instance binding.
        """
        if ctype == "pre":
            del self.pre_callbacks[hook_name][name]
        elif ctype == "post":
            del self.post_callbacks[hook_name][name]

    def register_plugin(self, cls, plugin_name=None):
        """ Registers a plugin class to a name.

        Multiple instances of the same plugin can be used in Rigger, ``self.plugins``
        stores un-initialized class defintions to be used by ``setup_instances``.

        Args:
            cls: The class.
            plugin_name: The name of the plugin.
        """
        if plugin_name in self.plugins:
            print("Plugin name already taken [{}]".format(plugin_name))
        elif plugin_name is None:
            print("Plugin name cannot be None")
        else:
            # print "Registering plugin {}".format(plugin_name)
            self.plugins[plugin_name] = cls

    def get_instance_obj(self, name):
        """
        Gets the instance object for a given ident name.

        Args:
            name: The ident name of the instance.
        Returns: The object of the instance.
        """
        if name in self.instances:
            return self.instances[name].obj
        else:
            return None

    def get_instance_data(self, name):
        """
        Gets the instance data(config) for a given ident name.

        Args:
            name: The ident name of the instance.
        Returns: The data(config) of the instance.
        """
        if name in self.instances:
            return self.instances[name].data
        else:
            return None

    def configure_plugin(self, name, *args, **kwargs):
        """
        Attempts to configure an instance, passing it the args and kwargs.

        Args:
            name: The ident name of the instance.
            args: The positional args.
            kwargs: The keyword arguments.
        """
        obj = self.get_instance_obj(name)
        if obj:
            obj.configure(*args, **kwargs)

    @staticmethod
    def create_callback(callback, bg=False):
        """
        Simple function to inspect a function and return it along with it param names wrapped
        up in a nice dict. This forms a callback object.

        Args:
            callback: The callback function.
        Returns: A dict of function and param names.
        """
        params = signature(callback).parameters
        return {
            'func': callback,
            'args': params,
            'bg': bg
        }

    def handle_failure(self, exc):
        """
        Handles an exception. It is expected that this be overidden.
        """
        self.log_message(exc)

    def log_message(self, message):
        """
        "Logs" a message. It is expected that this be overidden.
        """
        print(message)
コード例 #39
0
ファイル: nikon.py プロジェクト: komodo108/sequoia-ptpy
class Nikon(object):
    '''This class implements Nikon's PTP operations.'''
    def __init__(self, *args, **kwargs):
        logger.debug('Init Nikon')
        super(Nikon, self).__init__(*args, **kwargs)
        # TODO: expose the choice to poll or not Nikon events
        self.__no_polling = False
        self.__nikon_event_shutdown = Event()
        self.__nikon_event_proc = None

    @contextmanager
    def session(self):
        '''
        Manage Nikon session with context manager.
        '''
        # When raw device, do not perform
        if self.__no_polling:
            with super(Nikon, self).session():
                yield
            return
        # Within a normal PTP session
        with super(Nikon, self).session():
            # launch a polling thread
            self.__event_queue = Queue()
            self.__nikon_event_proc = Thread(name='NikonEvtPolling',
                                             target=self.__nikon_poll_events)
            self.__nikon_event_proc.daemon = False
            atexit.register(self._nikon_shutdown)
            self.__nikon_event_proc.start()

            try:
                yield
            finally:
                self._nikon_shutdown()

    def _shutdown(self):
        self._nikon_shutdown()
        super(Nikon, self)._shutdown()

    def _nikon_shutdown(self):
        logger.debug('Shutdown Nikon events')
        self.__nikon_event_shutdown.set()

        # Only join a running thread.
        if self.__nikon_event_proc and self.__nikon_event_proc.is_alive():
            self.__nikon_event_proc.join(2)

    def _PropertyCode(self, **product_properties):
        props = {
            'ShootingBank': 0xD010,
            'ShootingBankNameA': 0xD011,
            'ShootingBankNameB': 0xD012,
            'ShootingBankNameC': 0xD013,
            'ShootingBankNameD': 0xD014,
            'ResetBank0': 0xD015,
            'RawCompression': 0xD016,
            'WhiteBalanceAutoBias': 0xD017,
            'WhiteBalanceTungstenBias': 0xD018,
            'WhiteBalanceFluorescentBias': 0xD019,
            'WhiteBalanceDaylightBias': 0xD01A,
            'WhiteBalanceFlashBias': 0xD01B,
            'WhiteBalanceCloudyBias': 0xD01C,
            'WhiteBalanceShadeBias': 0xD01D,
            'WhiteBalanceColorTemperature': 0xD01E,
            'WhiteBalancePresetNo': 0xD01F,
            'WhiteBalancePresetName0': 0xD020,
            'WhiteBalancePresetName1': 0xD021,
            'WhiteBalancePresetName2': 0xD022,
            'WhiteBalancePresetName3': 0xD023,
            'WhiteBalancePresetName4': 0xD024,
            'WhiteBalancePresetVal0': 0xD025,
            'WhiteBalancePresetVal1': 0xD026,
            'WhiteBalancePresetVal2': 0xD027,
            'WhiteBalancePresetVal3': 0xD028,
            'WhiteBalancePresetVal4': 0xD029,
            'ImageSharpening': 0xD02A,
            'ToneCompensation': 0xD02B,
            'ColorModel': 0xD02C,
            'HueAdjustment': 0xD02D,
            'NonCPULensDataFocalLength': 0xD02E,
            'NonCPULensDataMaximumAperture': 0xD02F,
            'ShootingMode': 0xD030,
            'JPEGCompressionPolicy': 0xD031,
            'ColorSpace': 0xD032,
            'AutoDXCrop': 0xD033,
            'FlickerReduction': 0xD034,
            'RemoteMode': 0xD035,
            'VideoMode': 0xD036,
            'NikonEffectMode': 0xD037,
            'Mode': 0xD038,
            'CSMMenuBankSelect': 0xD040,
            'MenuBankNameA': 0xD041,
            'MenuBankNameB': 0xD042,
            'MenuBankNameC': 0xD043,
            'MenuBankNameD': 0xD044,
            'ResetBank': 0xD045,
            'A1AFCModePriority': 0xD048,
            'A2AFSModePriority': 0xD049,
            'A3GroupDynamicAF': 0xD04A,
            'A4AFActivation': 0xD04B,
            'FocusAreaIllumManualFocus': 0xD04C,
            'FocusAreaIllumContinuous': 0xD04D,
            'FocusAreaIllumWhenSelected': 0xD04E,
            'FocusAreaWrap': 0xD04F,
            'VerticalAFON': 0xD050,
            'AFLockOn': 0xD051,
            'FocusAreaZone': 0xD052,
            'EnableCopyright': 0xD053,
            'ISOAuto': 0xD054,
            'EVISOStep': 0xD055,
            'EVStep': 0xD056,
            'EVStepExposureComp': 0xD057,
            'ExposureCompensation': 0xD058,
            'CenterWeightArea': 0xD059,
            'ExposureBaseMatrix': 0xD05A,
            'ExposureBaseCenter': 0xD05B,
            'ExposureBaseSpot': 0xD05C,
            'LiveViewAFArea': 0xD05D,
            'AELockMode': 0xD05E,
            'AELAFLMode': 0xD05F,
            'LiveViewAFFocus': 0xD061,
            'MeterOff': 0xD062,
            'SelfTimer': 0xD063,
            'MonitorOff': 0xD064,
            'ImgConfTime': 0xD065,
            'AutoOffTimers': 0xD066,
            'AngleLevel': 0xD067,
            'D1ShootingSpeed': 0xD068,
            'D2MaximumShots': 0xD069,
            'ExposureDelayMode': 0xD06A,
            'LongExposureNoiseReduction': 0xD06B,
            'FileNumberSequence': 0xD06C,
            'ControlPanelFinderRearControl': 0xD06D,
            'ControlPanelFinderViewfinder': 0xD06E,
            'D7Illumination': 0xD06F,
            'NrHighISO': 0xD070,
            'SHSetCHGUIDDisp': 0xD071,
            'ArtistName': 0xD072,
            'NikonCopyrightInfo': 0xD073,
            'FlashSyncSpeed': 0xD074,
            'FlashShutterSpeed': 0xD075,
            'E3AAFlashMode': 0xD076,
            'E4ModelingFlash': 0xD077,
            'BracketSet': 0xD078,
            'E6ManualModeBracketing': 0xD079,
            'BracketOrder': 0xD07A,
            'E8AutoBracketSelection': 0xD07B,
            'BracketingSet': 0xD07C,
            'F1CenterButtonShootingMode': 0xD080,
            'CenterButtonPlaybackMode': 0xD081,
            'F2Multiselector': 0xD082,
            'F3PhotoInfoPlayback': 0xD083,
            'F4AssignFuncButton': 0xD084,
            'F5CustomizeCommDials': 0xD085,
            'ReverseCommandDial': 0xD086,
            'ApertureSetting': 0xD087,
            'MenusAndPlayback': 0xD088,
            'F6ButtonsAndDials': 0xD089,
            'NoCFCard': 0xD08A,
            'CenterButtonZoomRatio': 0xD08B,
            'FunctionButton2': 0xD08C,
            'AFAreaPoint': 0xD08D,
            'NormalAFOn': 0xD08E,
            'CleanImageSensor': 0xD08F,
            'ImageCommentString': 0xD090,
            'ImageCommentEnable': 0xD091,
            'ImageRotation': 0xD092,
            'ManualSetLensNo': 0xD093,
            'MovScreenSize': 0xD0A0,
            'MovVoice': 0xD0A1,
            'MovMicrophone': 0xD0A2,
            'MovFileSlot': 0xD0A3,
            'MovRecProhibitCondition': 0xD0A4,
            'ManualMovieSetting': 0xD0A6,
            'MovQuality': 0xD0A7,
            'LiveViewScreenDisplaySetting': 0xD0B2,
            'MonitorOffDelay': 0xD0B3,
            'Bracketing': 0xD0C0,
            'AutoExposureBracketStep': 0xD0C1,
            'AutoExposureBracketProgram': 0xD0C2,
            'AutoExposureBracketCount': 0xD0C3,
            'WhiteBalanceBracketStep': 0xD0C4,
            'WhiteBalanceBracketProgram': 0xD0C5,
            'LensID': 0xD0E0,
            'LensSort': 0xD0E1,
            'LensType': 0xD0E2,
            'FocalLengthMin': 0xD0E3,
            'FocalLengthMax': 0xD0E4,
            'MaxApAtMinFocalLength': 0xD0E5,
            'MaxApAtMaxFocalLength': 0xD0E6,
            'FinderISODisp': 0xD0F0,
            'AutoOffPhoto': 0xD0F2,
            'AutoOffMenu': 0xD0F3,
            'AutoOffInfo': 0xD0F4,
            'SelfTimerShootNum': 0xD0F5,
            'VignetteCtrl': 0xD0F7,
            'AutoDistortionControl': 0xD0F8,
            'SceneMode': 0xD0F9,
            'SceneMode2': 0xD0FD,
            'SelfTimerInterval': 0xD0FE,
            'NikonExposureTime': 0xD100,
            'ACPower': 0xD101,
            'WarningStatus': 0xD102,
            'MaximumShots': 0xD103,
            'AFLockStatus': 0xD104,
            'AELockStatus': 0xD105,
            'FVLockStatus': 0xD106,
            'AutofocusLCDTopMode2': 0xD107,
            'AutofocusArea': 0xD108,
            'FlexibleProgram': 0xD109,
            'LightMeter': 0xD10A,
            'RecordingMedia': 0xD10B,
            'USBSpeed': 0xD10C,
            'CCDNumber': 0xD10D,
            'CameraOrientation': 0xD10E,
            'GroupPtnType': 0xD10F,
            'FNumberLock': 0xD110,
            'ExposureApertureLock': 0xD111,
            'TVLockSetting': 0xD112,
            'AVLockSetting': 0xD113,
            'IllumSetting': 0xD114,
            'FocusPointBright': 0xD115,
            'ExternalFlashAttached': 0xD120,
            'ExternalFlashStatus': 0xD121,
            'ExternalFlashSort': 0xD122,
            'ExternalFlashMode': 0xD123,
            'ExternalFlashCompensation': 0xD124,
            'NewExternalFlashMode': 0xD125,
            'FlashExposureCompensation': 0xD126,
            'HDRMode': 0xD130,
            'HDRHighDynamic': 0xD131,
            'HDRSmoothing': 0xD132,
            'OptimizeImage': 0xD140,
            'Saturation': 0xD142,
            'BWFillerEffect': 0xD143,
            'BWSharpness': 0xD144,
            'BWContrast': 0xD145,
            'BWSettingType': 0xD146,
            'Slot2SaveMode': 0xD148,
            'RawBitMode': 0xD149,
            'ActiveDLighting': 0xD14E,
            'FlourescentType': 0xD14F,
            'TuneColourTemperature': 0xD150,
            'TunePreset0': 0xD151,
            'TunePreset1': 0xD152,
            'TunePreset2': 0xD153,
            'TunePreset3': 0xD154,
            'TunePreset4': 0xD155,
            'BeepOff': 0xD160,
            'AutofocusMode': 0xD161,
            'AFAssist': 0xD163,
            'PADVPMode': 0xD164,
            'ImageReview': 0xD165,
            'AFAreaIllumination': 0xD166,
            'NikonFlashMode': 0xD167,
            'FlashCommanderMode': 0xD168,
            'FlashSign': 0xD169,
            '_ISOAuto': 0xD16A,
            'RemoteTimeout': 0xD16B,
            'GridDisplay': 0xD16C,
            'FlashModeManualPower': 0xD16D,
            'FlashModeCommanderPower': 0xD16E,
            'AutoFP': 0xD16F,
            'DateImprintSetting': 0xD170,
            'DateCounterSelect': 0xD171,
            'DateCountData': 0xD172,
            'DateCountDisplaySetting': 0xD173,
            'RangeFinderSetting': 0xD174,
            'CSMMenu': 0xD180,
            'WarningDisplay': 0xD181,
            'BatteryCellKind': 0xD182,
            'ISOAutoHiLimit': 0xD183,
            'DynamicAFArea': 0xD184,
            'ContinuousSpeedHigh': 0xD186,
            'InfoDispSetting': 0xD187,
            'PreviewButton': 0xD189,
            'PreviewButton2': 0xD18A,
            'AEAFLockButton2': 0xD18B,
            'IndicatorDisp': 0xD18D,
            'CellKindPriority': 0xD18E,
            'BracketingFramesAndSteps': 0xD190,
            'LiveViewMode': 0xD1A0,
            'LiveViewDriveMode': 0xD1A1,
            'LiveViewStatus': 0xD1A2,
            'LiveViewImageZoomRatio': 0xD1A3,
            'LiveViewProhibitCondition': 0xD1A4,
            'MovieShutterSpeed': 0xD1A8,
            'MovieFNumber': 0xD1A9,
            'MovieISO': 0xD1AA,
            'LiveViewMovieMode': 0xD1AC,
            'ExposureDisplayStatus': 0xD1B0,
            'ExposureIndicateStatus': 0xD1B1,
            'InfoDispErrStatus': 0xD1B2,
            'ExposureIndicateLightup': 0xD1B3,
            'FlashOpen': 0xD1C0,
            'FlashCharged': 0xD1C1,
            'FlashMRepeatValue': 0xD1D0,
            'FlashMRepeatCount': 0xD1D1,
            'FlashMRepeatInterval': 0xD1D2,
            'FlashCommandChannel': 0xD1D3,
            'FlashCommandSelfMode': 0xD1D4,
            'FlashCommandSelfCompensation': 0xD1D5,
            'FlashCommandSelfValue': 0xD1D6,
            'FlashCommandAMode': 0xD1D7,
            'FlashCommandACompensation': 0xD1D8,
            'FlashCommandAValue': 0xD1D9,
            'FlashCommandBMode': 0xD1DA,
            'FlashCommandBCompensation': 0xD1DB,
            'FlashCommandBValue': 0xD1DC,
            'ApplicationMode': 0xD1F0,
            'ActiveSlot': 0xD1F2,
            'ActivePicCtrlItem': 0xD200,
            'ChangePicCtrlItem': 0xD201,
            'MovieNrHighISO': 0xD236,
            'D241': 0xD241,
            'D244': 0xD244,
            'D247': 0xD247,
            'GUID': 0xD24F,
            'D250': 0xD250,
            'D251': 0xD251,
            'ISO': 0xF002,
            'ImageCompression': 0xF009,
            'NikonImageSize': 0xF00A,
            'NikonWhiteBalance': 0xF00C,
            # TODO: Are these redundant? Or product-specific?
            '_LongExposureNoiseReduction': 0xF00D,
            'HiISONoiseReduction': 0xF00E,
            '_ActiveDLighting': 0xF00F,
            '_MovQuality': 0xF01C,
        }
        product_properties.update(props)
        return super(Nikon, self)._PropertyCode(**product_properties)

    def _OperationCode(self, **product_operations):
        return super(Nikon,
                     self)._OperationCode(GetProfileAllData=0x9006,
                                          SendProfileData=0x9007,
                                          DeleteProfile=0x9008,
                                          SetProfileData=0x9009,
                                          AdvancedTransfer=0x9010,
                                          GetFileInfoInBlock=0x9011,
                                          Capture=0x90C0,
                                          AFDrive=0x90C1,
                                          SetControlMode=0x90C2,
                                          DelImageSDRAM=0x90C3,
                                          GetLargeThumb=0x90C4,
                                          CurveDownload=0x90C5,
                                          CurveUpload=0x90C6,
                                          CheckEvents=0x90C7,
                                          DeviceReady=0x90C8,
                                          SetPreWBData=0x90C9,
                                          GetVendorPropCodes=0x90CA,
                                          AFCaptureSDRAM=0x90CB,
                                          GetPictCtrlData=0x90CC,
                                          SetPictCtrlData=0x90CD,
                                          DelCstPicCtrl=0x90CE,
                                          GetPicCtrlCapability=0x90CF,
                                          GetPreviewImg=0x9200,
                                          StartLiveView=0x9201,
                                          EndLiveView=0x9202,
                                          GetLiveViewImg=0x9203,
                                          MfDrive=0x9204,
                                          ChangeAFArea=0x9205,
                                          AFDriveCancel=0x9206,
                                          InitiateCaptureRecInMedia=0x9207,
                                          GetVendorStorageIDs=0x9209,
                                          StartMovieRecInCard=0x920A,
                                          EndMovieRec=0x920B,
                                          TerminateCapture=0x920C,
                                          GetDevicePTPIPInfo=0x90E0,
                                          GetPartialObjectHiSpeed=0x9400,
                                          GetDevicePropEx=0x9504,
                                          **product_operations)

    def _ResponseCode(self, **product_responses):
        return super(Nikon,
                     self)._ResponseCode(HardwareError=0xA001,
                                         OutOfFocus=0xA002,
                                         ChangeCameraModeFailed=0xA003,
                                         InvalidStatus=0xA004,
                                         SetPropertyNotSupported=0xA005,
                                         WbResetError=0xA006,
                                         DustReferenceError=0xA007,
                                         ShutterSpeedBulb=0xA008,
                                         MirrorUpSequence=0xA009,
                                         CameraModeNotAdjustFNumber=0xA00A,
                                         NotLiveView=0xA00B,
                                         MfDriveStepEnd=0xA00C,
                                         MfDriveStepInsufficiency=0xA00E,
                                         AdvancedTransferCancel=0xA022,
                                         BulbReleaseBusy=0xA200,
                                         **product_responses)

    def _EventCode(self, **product_events):
        return super(Nikon, self)._EventCode(ObjectAddedInSDRAM=0xC101,
                                             CaptureCompleteRecInSdram=0xC102,
                                             AdvancedTransfer=0xC103,
                                             PreviewImageAdded=0xC104,
                                             **product_events)

    def _FilesystemType(self, **product_filesystem_types):
        return super(Nikon, self)._FilesystemType(**product_filesystem_types)

    def _NikonEvent(self):
        return PrefixedArray(
            self._UInt16,
            Struct(
                'EventCode' / self._EventCode,
                'Parameter' / self._UInt32,
            ))

    def _set_endian(self, endian):
        logger.debug('Set Nikon endianness')
        super(Nikon, self)._set_endian(endian)
        self._NikonEvent = self._NikonEvent()

    # TODO: Add event queue over all transports and extensions.
    def check_events(self):
        '''Check Nikon specific event'''
        ptp = Container(OperationCode='CheckEvents',
                        SessionID=self._session,
                        TransactionID=self._transaction,
                        Parameter=[])
        response = self.recv(ptp)
        return self._parse_if_data(response, self._NikonEvent)

    # TODO: Provide a single camera agnostic command that will trigger a camera
    def capture(self):
        '''Nikon specific capture'''
        ptp = Container(OperationCode='Capture',
                        SessionID=self._session,
                        TransactionID=self._transaction,
                        Parameter=[])
        return self.mesg(ptp)

    def af_capture_sdram(self):
        '''Nikon specific autofocus and capture to SDRAM'''
        ptp = Container(OperationCode='AFCaptureSDRAM',
                        SessionID=self._session,
                        TransactionID=self._transaction,
                        Parameter=[])
        return self.mesg(ptp)

    def event(self, wait=False):
        '''Check Nikon or PTP events

        If `wait` this function is blocking. Otherwise it may return None.
        '''
        # TODO: Do something reasonable on wait=True
        evt = None
        timeout = None if wait else 0.001
        # TODO: Join queues to preserve order of Nikon and PTP events.
        if not self.__event_queue.empty():
            evt = self.__event_queue.get(block=not wait, timeout=timeout)
        else:
            evt = super(Nikon, self).event(wait=wait)

        return evt

    def __nikon_poll_events(self):
        '''Poll events, adding them to a queue.'''
        while (not self.__nikon_event_shutdown.is_set()
               and _main_thread_alive()):
            try:
                evts = self.check_events()
                if evts:
                    for evt in evts:
                        logger.debug('Event queued')
                        self.__event_queue.put(evt)
            except Exception as e:
                logger.error(e)
            sleep(3)
        self.__nikon_event_shutdown.clear()
コード例 #40
0
class USBTransport(object):
    '''Implement USB transport.'''
    def __init__(self, device=None):
        '''Instantiate the first available PTP device over USB'''
        logger.debug('Init USB')
        self.__setup_constructors()
        # If no device is specified, find all devices claiming to be Cameras
        # and get the USB endpoints for the first one that works.
        if device is None:
            logger.debug('No device provided, probing all USB devices.')
        if isinstance(device, six.string_types):
            name = device
            logger.debug(
                'Device name provided, probing all USB devices for {}.'.format(
                    name))
            device = None
        else:
            name = None
        devs = ([device] if
                (device is not None) else find_usb_cameras(name=name))

        self.__acquire_camera(devs)

        self.__event_queue = Queue()
        self.__event_shutdown = Event()
        # Locks for different end points.
        self.__inep_lock = RLock()
        self.__intep_lock = RLock()
        self.__outep_lock = RLock()
        self.__event_proc = Thread(name='EvtPolling',
                                   target=self.__poll_events)
        self.__event_proc.daemon = False
        atexit.register(self._shutdown)
        self.__event_proc.start()

    def __available_cameras(self, devs):
        for dev in devs:
            if self.__setup_device(dev):
                logger.debug('Found USB PTP device {}'.format(dev))
                yield
        else:
            message = 'No USB PTP device found.'
            logger.error(message)
            raise PTPError(message)

    def __acquire_camera(self, devs):
        '''From the cameras given, get the first one that does not fail'''

        for _ in self.__available_cameras(devs):
            try:
                if self.__dev.is_kernel_driver_active(
                        self.__intf.bInterfaceNumber):
                    try:
                        self.__dev.detach_kernel_driver(
                            self.__intf.bInterfaceNumber)
                        usb.util.claim_interface(self.__dev, self.__intf)
                    except usb.core.USBError:
                        message = ('Could not detach kernel driver. '
                                   'Maybe the camera is mounted?')
                        logger.error(message)
                logger.debug('Claiming {}'.format(repr(self.__dev)))
                usb.util.claim_interface(self.__dev, self.__intf)
            except Exception as e:
                logger.debug('{}'.format(e))
                continue
            break
        else:
            message = ('Could acquire any camera.')
            logger.error(message)
            raise PTPError(message)

    def _shutdown(self):
        logger.debug('Shutdown request')
        self.__event_shutdown.set()
        # Free USB resource on shutdown.

        # Only join a running thread.
        if self.__event_proc.is_alive():
            self.__event_proc.join(2)

        logger.debug('Release {}'.format(repr(self.__dev)))
        usb.util.release_interface(self.__dev, self.__intf)

    # Helper methods.
    # ---------------------
    def __setup_device(self, dev):
        '''Get endpoints for a device. True on success.'''
        self.__inep = None
        self.__outep = None
        self.__intep = None
        self.__cfg = None
        self.__dev = None
        self.__intf = None
        # Attempt to find the USB in, out and interrupt endpoints for a PTP
        # interface.
        for cfg in dev:
            for intf in cfg:
                if intf.bInterfaceClass == PTP_USB_CLASS:
                    for ep in intf:
                        ep_type = endpoint_type(ep.bmAttributes)
                        ep_dir = endpoint_direction(ep.bEndpointAddress)
                        if ep_type == ENDPOINT_TYPE_BULK:
                            if ep_dir == ENDPOINT_IN:
                                self.__inep = ep
                            elif ep_dir == ENDPOINT_OUT:
                                self.__outep = ep
                        elif ((ep_type == ENDPOINT_TYPE_INTR)
                              and (ep_dir == ENDPOINT_IN)):
                            self.__intep = ep
                if not (self.__inep and self.__outep and self.__intep):
                    self.__inep = None
                    self.__outep = None
                    self.__intep = None
                else:
                    logger.debug('Found {}'.format(repr(self.__inep)))
                    logger.debug('Found {}'.format(repr(self.__outep)))
                    logger.debug('Found {}'.format(repr(self.__intep)))
                    self.__cfg = cfg
                    self.__dev = dev
                    self.__intf = intf
                    return True
        return False

    def __setup_constructors(self):
        '''Set endianness and create transport-specific constructors.'''
        # Set endianness of constructors before using them.
        self._set_endian('little')

        self.__Length = Int32ul
        self.__Type = Enum(
            Int16ul,
            default=Pass,
            Undefined=0x0000,
            Command=0x0001,
            Data=0x0002,
            Response=0x0003,
            Event=0x0004,
        )
        # This is just a convenience constructor to get the size of a header.
        self.__Code = Int16ul
        self.__Header = Struct(
            'Length' / self.__Length,
            'Type' / self.__Type,
            'Code' / self.__Code,
            'TransactionID' / self._TransactionID,
        )
        # These are the actual constructors for parsing and building.
        self.__CommandHeader = Struct(
            'Length' / self.__Length,
            'Type' / self.__Type,
            'OperationCode' / self._OperationCode,
            'TransactionID' / self._TransactionID,
        )
        self.__ResponseHeader = Struct(
            'Length' / self.__Length,
            'Type' / self.__Type,
            'ResponseCode' / self._ResponseCode,
            'TransactionID' / self._TransactionID,
        )
        self.__EventHeader = Struct(
            'Length' / self.__Length,
            'Type' / self.__Type,
            'EventCode' / self._EventCode,
            'TransactionID' / self._TransactionID,
        )
        # Apparently nobody uses the SessionID field. Even though it is
        # specified in ISO15740:2013(E), no device respects it and the session
        # number is implicit over USB.
        self.__Param = Range(0, 5, self._Parameter)
        self.__CommandTransactionBase = Struct(
            Embedded(self.__CommandHeader), 'Payload' /
            Bytes(lambda ctx, h=self.__Header: ctx.Length - h.sizeof()))
        self.__CommandTransaction = ExprAdapter(
            self.__CommandTransactionBase,
            encoder=lambda obj, ctx, h=self.__Header: Container(
                Length=len(obj.Payload) + h.sizeof(), **obj),
            decoder=lambda obj, ctx: obj,
        )
        self.__ResponseTransactionBase = Struct(
            Embedded(self.__ResponseHeader), 'Payload' /
            Bytes(lambda ctx, h=self.__Header: ctx.Length - h.sizeof()))
        self.__ResponseTransaction = ExprAdapter(
            self.__ResponseTransactionBase,
            encoder=lambda obj, ctx, h=self.__Header: Container(
                Length=len(obj.Payload) + h.sizeof(), **obj),
            decoder=lambda obj, ctx: obj,
        )

    def __parse_response(self, usbdata):
        '''Helper method for parsing USB data.'''
        # Build up container with all PTP info.
        usbdata = bytearray(usbdata)
        transaction = self.__ResponseTransaction.parse(usbdata)
        response = Container(
            SessionID=self.session_id,
            TransactionID=transaction.TransactionID,
        )
        if transaction.Type == 'Response':
            response['ResponseCode'] = transaction.ResponseCode
            response['Parameter'] = self.__Param.parse(transaction.Payload)
        elif transaction.Type == 'Event':
            event = self.__EventHeader.parse(usbdata[0:self.__Header.sizeof()])
            response['EventCode'] = event.EventCode
            response['Parameter'] = self.__Param.parse(transaction.Payload)
        else:
            command = self.__CommandHeader.parse(
                usbdata[0:self.__Header.sizeof()])
            response['OperationCode'] = command.OperationCode
            response['Data'] = transaction.Payload
        return response

    def __recv(self, event=False, wait=False, raw=False):
        '''Helper method for receiving data.'''
        # TODO: clear stalls automatically
        ep = self.__intep if event else self.__inep
        lock = self.__intep_lock if event else self.__inep_lock
        with lock:
            try:
                usbdata = ep.read(ep.wMaxPacketSize, timeout=0 if wait else 5)
            except usb.core.USBError as e:
                # Ignore timeout or busy device once.
                if e.errno == 110 or e.errno == 16:
                    if event:
                        return None
                    else:
                        usbdata = ep.read(ep.wMaxPacketSize, timeout=5000)
                else:
                    raise e
            header = self.__ResponseHeader.parse(
                bytearray(usbdata[0:self.__Header.sizeof()]))
            if header.Type not in ['Response', 'Data', 'Event']:
                raise PTPError(
                    'Unexpected USB transfer type.'
                    'Expected Response, Event or Data but received {}'.format(
                        header.Type))
            while len(usbdata) < header.Length:
                usbdata += ep.read(ep.wMaxPacketSize, timeout=5000)
        if raw:
            return usbdata
        else:
            return self.__parse_response(usbdata)

    def __send(self, ptp_container, event=False):
        '''Helper method for sending data.'''
        ep = self.__intep if event else self.__outep
        lock = self.__intep_lock if event else self.__outep_lock
        transaction = self.__CommandTransaction.build(ptp_container)
        with lock:
            try:
                ep.write(transaction, timeout=1)
            except usb.core.USBError as e:
                # Ignore timeout or busy device once.
                if e.errno == 110 or e.errno == 16:
                    ep.write(transaction, timeout=5000)

    def __send_request(self, ptp_container):
        '''Send PTP request without checking answer.'''
        # Don't modify original container to keep abstraction barrier.
        ptp = Container(**ptp_container)
        # Don't send unused parameters
        try:
            while not ptp.Parameter[-1]:
                ptp.Parameter.pop()
                if len(ptp.Parameter) == 0:
                    break
        except IndexError:
            # The Parameter list is already empty.
            pass

        # Send request
        ptp['Type'] = 'Command'
        ptp['Payload'] = self.__Param.build(ptp.Parameter)
        self.__send(ptp)

    def __send_data(self, ptp_container, data):
        '''Send data without checking answer.'''
        # Don't modify original container to keep abstraction barrier.
        ptp = Container(**ptp_container)
        # Send data
        ptp['Type'] = 'Data'
        ptp['Payload'] = data
        self.__send(ptp)

    # Actual implementation
    # ---------------------
    def send(self, ptp_container, data):
        '''Transfer operation with dataphase from initiator to responder'''
        datalen = len(data)
        logger.debug('SEND {} {} bytes{}'.format(
            ptp_container.OperationCode,
            datalen,
            ' ' + str(list(map(hex, ptp_container.Parameter)))
            if ptp_container.Parameter else '',
        ))
        self.__send_request(ptp_container)
        self.__send_data(ptp_container, data)
        # Get response and sneak in implicit SessionID and missing parameters.
        response = self.__recv()
        logger.debug('SEND {} {} bytes {}{}'.format(
            ptp_container.OperationCode,
            datalen,
            response.ResponseCode,
            ' ' + str(list(map(hex, response.Parameter)))
            if ptp_container.Parameter else '',
        ))
        return response

    def recv(self, ptp_container):
        '''Transfer operation with dataphase from responder to initiator.'''
        logger.debug('RECV {}{}'.format(
            ptp_container.OperationCode,
            ' ' + str(list(map(hex, ptp_container.Parameter)))
            if ptp_container.Parameter else '',
        ))
        self.__send_request(ptp_container)
        dataphase = self.__recv()
        if hasattr(dataphase, 'Data'):
            response = self.__recv()
            if ((ptp_container.OperationCode != dataphase.OperationCode)
                    or (ptp_container.TransactionID != dataphase.TransactionID)
                    or (ptp_container.SessionID != dataphase.SessionID)
                    or (dataphase.TransactionID != response.TransactionID)
                    or (dataphase.SessionID != response.SessionID)):
                raise PTPError(
                    'Dataphase does not match with requested operation.')
            response['Data'] = dataphase.Data
        else:
            response = dataphase

        logger.debug('RECV {} {}{}{}'.format(
            ptp_container.OperationCode,
            response.ResponseCode,
            ' {} bytes'.format(len(response.Data)) if hasattr(
                response, 'Data') else '',
            ' ' + str(list(map(hex, response.Parameter)))
            if response.Parameter else '',
        ))
        return response

    def mesg(self, ptp_container):
        '''Transfer operation without dataphase.'''
        logger.debug('MESG {}{}'.format(
            ptp_container.OperationCode,
            ' ' + str(list(map(hex, ptp_container.Parameter)))
            if ptp_container.Parameter else '',
        ))
        self.__send_request(ptp_container)
        # Get response and sneak in implicit SessionID and missing parameters
        # for FullResponse.
        response = self.__recv()
        logger.debug('MESG {} {}{}'.format(
            ptp_container.OperationCode,
            response.ResponseCode,
            ' ' + str(list(map(hex, response.Parameter)))
            if response.Parameter else '',
        ))
        return response

    def event(self, wait=False):
        '''Check event.

        If `wait` this function is blocking. Otherwise it may return None.
        '''
        evt = None
        usbdata = None
        timeout = None if wait else 0.001
        if not self.__event_queue.empty():
            usbdata = self.__event_queue.get(block=not wait, timeout=timeout)
        if usbdata is not None:
            evt = self.__parse_response(usbdata)

        return evt

    def __poll_events(self):
        '''Poll events, adding them to a queue.'''
        while not self.__event_shutdown.is_set() and _main_thread_alive():
            evt = self.__recv(event=True, wait=False, raw=True)
            if evt is not None:
                logger.debug('Event queued')
                self.__event_queue.put(evt)
コード例 #41
0
ファイル: acisession.py プロジェクト: datacenter/acitoolkit
class Subscriber(threading.Thread):
    """
    Thread responsible for event subscriptions.
    Issues subscriptions, creates the websocket, and refreshes the
    subscriptions before timer expiry.  It also reissues the
    subscriptions when the APIC login is refreshed.
    """
    def __init__(self, apic):
        threading.Thread.__init__(self)
        self._apic = apic
        self._subscriptions = {}
        self._ws = None
        self._ws_url = None
        self._refresh_time = 30
        self._event_q = Queue()
        self._events = {}
        self._exit = False
        self.event_handler_thread = None

    def exit(self):
        """
        Indicate that the thread should exit.
        """
        self._exit = True

    def _send_subscription(self, url, only_new=False):
        """
        Send the subscription for the specified URL.

        :param url: URL string to issue the subscription
        """
        try:
            resp = self._apic.get(url)
        except ConnectionError:
            self._subscriptions[url] = None
            logging.error('Could not send subscription to APIC for url %s', url)
            resp = requests.Response()
            resp.status_code = 404
            resp._content = '{"error": "Could not send subscription to APIC"}'
            return resp
        if not resp.ok:
            self._subscriptions[url] = None
            logging.error('Could not send subscription to APIC for url %s', url)
            resp = requests.Response()
            resp.status_code = 404
            resp._content = '{"error": "Could not send subscription to APIC"}'
            return resp
        resp_data = json.loads(resp.text)
        if 'subscriptionId' not in resp_data:
            logging.error('Did not receive proper subscription response from APIC for url %s response: %s',
                          url, resp_data)
            resp = requests.Response()
            resp.status_code = 404
            resp._content = '{"error": "Could not send subscription to APIC"}'
            return resp
        subscription_id = resp_data['subscriptionId']
        self._subscriptions[url] = subscription_id
        if not only_new:
            while len(resp_data['imdata']):
                event = {"totalCount": "1",
                         "subscriptionId": [resp_data['subscriptionId']],
                         "imdata": [resp_data["imdata"][0]]}
                self._event_q.put(json.dumps(event))
                resp_data["imdata"].remove(resp_data["imdata"][0])
        return resp

    def refresh_subscriptions(self):
        """
        Refresh all of the subscriptions.
        """
        # Make a copy of the current subscriptions in case of changes
        # while we are refreshing
        current_subscriptions = {}
        for subscription in self._subscriptions:
            try:
                current_subscriptions[subscription] = self._subscriptions[subscription]
            except KeyError:
                logging.warning('Subscription removed while copying')

        # Refresh the subscriptions
        for subscription in current_subscriptions:
            if self._ws is not None:
                if not self._ws.connected:
                    logging.warning('Websocket not established on subscription refresh. Re-establishing websocket')
                    self._open_web_socket('https://' in subscription)
            try:
                subscription_id = self._subscriptions[subscription]
            except KeyError:
                logging.warning('Subscription has been removed while trying to refresh')
                continue
            if subscription_id is None:
                self._send_subscription(subscription)
                continue
            refresh_url = '/api/subscriptionRefresh.json?id=' + str(subscription_id)
            resp = self._apic.get(refresh_url)
            if not resp.ok:
                logging.warning('Could not refresh subscription: %s', refresh_url)
                # Try to resubscribe
                self._resubscribe()

    def _open_web_socket(self, use_secure=True):
        """
        Opens the web socket connection with the APIC.

        :param use_secure: Boolean indicating whether the web socket
                           should be secure.  Default is True.
        """
        sslopt = {}
        if use_secure:
            sslopt['cert_reqs'] = ssl.CERT_NONE
            self._ws_url = 'wss://%s/socket%s' % (self._apic.ipaddr,
                                                  self._apic.token)
        else:
            self._ws_url = 'ws://%s/socket%s' % (self._apic.ipaddr,
                                                 self._apic.token)

        kwargs = {}
        if self._ws is not None:
            if self._ws.connected:
                self._ws.close()
                self.event_handler_thread.exit()
        try:
            self._ws = create_connection(self._ws_url, sslopt=sslopt, **kwargs)
            if not self._ws.connected:
                logging.error('Unable to open websocket connection')
            self.event_handler_thread = EventHandler(self)
            self.event_handler_thread.daemon = True
            self.event_handler_thread.start()
        except WebSocketException:
            logging.error('Unable to open websocket connection due to WebSocketException')
        except socket.error:
            logging.error('Unable to open websocket connection due to Socket Error')

    def _resubscribe(self):
        """
        Reissue the subscriptions.
        Used to when the APIC login timeout occurs and a new subscription
        must be issued instead of simply a refresh.  Not meant to be called
        directly by end user applications.
        """
        self._process_event_q()
        urls = []
        for url in self._subscriptions:
            urls.append(url)
        self._subscriptions = {}
        for url in urls:
            self.subscribe(url, only_new=True)

    def _process_event_q(self):
        """
        Put the event into correct bucket based on URLs that have been
        subscribed.
        """
        if self._event_q.empty():
            return

        while not self._event_q.empty():
            event = self._event_q.get()
            orig_event = event
            try:
                event = json.loads(event)
            except ValueError:
                logging.error('Non-JSON event: %s', orig_event)
                continue
            # Find the URL for this event
            num_subscriptions = len(event['subscriptionId'])
            for i in range(0, num_subscriptions):
                url = None
                for k in self._subscriptions:
                    if self._subscriptions[k] == str(event['subscriptionId'][i]):
                        url = k
                        break
                if url not in self._events:
                    self._events[url] = []
                self._events[url].append(event)
                if num_subscriptions > 1:
                    event = copy.deepcopy(event)

    def subscribe(self, url, only_new=False):
        """
        Subscribe to a particular APIC URL.  Used internally by the
        Class and Instance subscriptions.

        :param url: URL string to send as a subscription
        """
        logging.info('Subscribing to url: %s', url)
        # Check if already subscribed.  If so, skip
        if url in self._subscriptions:
            return

        if self._ws is not None:
            if not self._ws.connected:
                self._open_web_socket('https://' in url)

        resp = self._send_subscription(url, only_new=only_new)
        return resp

    def is_subscribed(self, url):
        """
        Check if subscribed to a particular APIC URL.

        :param url: URL string to send as a subscription
        """
        return url in self._subscriptions

    def has_events(self, url):
        """
        Check if a particular APIC URL subscription has any events.
        Used internally by the Class and Instance subscriptions.

        :param url: URL string to check for pending events
        """
        self._process_event_q()
        if url not in self._events:
            return False
        result = len(self._events[url]) != 0
        return result

    def get_event_count(self, url):
        """
        Check the number of subscription events for a particular APIC URL

        :param url: URL string to check for pending events
        :returns: Interger number of events in event queue
        """
        self._process_event_q()
        if url not in self._events:
            return 0
        return len(self._events[url])

    def get_event(self, url):
        """
        Get an event for a particular APIC URL subscription.
        Used internally by the Class and Instance subscriptions.

        :param url: URL string to get pending event
        """
        self._process_event_q()
        if url not in self._events:
            raise ValueError
        event = self._events[url].pop(0)
        logging.debug('Event received %s', event)
        return event

    def unsubscribe(self, url):
        """
        Unsubscribe from a particular APIC URL.  Used internally by the
        Class and Instance subscriptions.

        :param url: URL string to unsubscribe
        """
        logging.info('Unsubscribing from url: %s', url)
        if url not in self._subscriptions:
            return
        if '&subscription=yes' in url:
            unsubscribe_url = url.split('&subscription=yes')[0] + '&subscription=no'
        elif '?subscription=yes' in url:
            unsubscribe_url = url.split('?subscription=yes')[0] + '?subscription=no'
        else:
            raise ValueError('No subscription string in URL being unsubscribed')
        resp = self._apic.get(unsubscribe_url)
        if not resp.ok:
            logging.warning('Could not unsubscribe from url: %s', unsubscribe_url)
        # Chew up any outstanding events
        while self.has_events(url):
            self.get_event(url)
        del self._subscriptions[url]
        if not self._subscriptions:
            self._ws.close(timeout=0)

    def run(self):
        while not self._exit:
            # Sleep for some interval and send subscription list
            time.sleep(self._refresh_time)
            try:
                self.refresh_subscriptions()
            except ConnectionError:
                logging.error('Could not refresh subscriptions due to ConnectionError')
コード例 #42
0
def test_instances_deployed(mock_get_paasta_api_client, mock__log):
    mock_paasta_api_client = Mock()
    mock_get_paasta_api_client.return_value = mock_paasta_api_client
    mock_paasta_api_client.service.status_instance.side_effect = \
        mock_status_instance_side_effect

    f = mark_for_deployment.instances_deployed
    e = Event()
    e.set()
    cluster_data = mark_for_deployment.ClusterData(cluster='cluster',
                                                   service='service1',
                                                   git_sha='somesha',
                                                   instances_queue=Queue())
    cluster_data.instances_queue.put('instance1')
    instances_out = Queue()
    f(cluster_data, instances_out, e)
    assert cluster_data.instances_queue.empty()
    assert instances_out.empty()

    cluster_data.instances_queue = Queue()
    cluster_data.instances_queue.put('instance1')
    cluster_data.instances_queue.put('instance2')
    instances_out = Queue()
    f(cluster_data, instances_out, e)
    assert cluster_data.instances_queue.empty()
    assert instances_out.get(block=False) == 'instance2'

    cluster_data.instances_queue = Queue()
    cluster_data.instances_queue.put('instance3')
    instances_out = Queue()
    f(cluster_data, instances_out, e)
    assert cluster_data.instances_queue.empty()
    assert instances_out.empty()

    cluster_data.instances_queue = Queue()
    cluster_data.instances_queue.put('instance4')
    instances_out = Queue()
    f(cluster_data, instances_out, e)
    assert cluster_data.instances_queue.empty()
    assert instances_out.empty()

    cluster_data.instances_queue = Queue()
    cluster_data.instances_queue.put('instance4.1')
    instances_out = Queue()
    f(cluster_data, instances_out, e)
    assert cluster_data.instances_queue.empty()
    assert instances_out.empty()

    cluster_data.instances_queue = Queue()
    cluster_data.instances_queue.put('instance5')
    cluster_data.instances_queue.put('instance1')
    instances_out = Queue()
    f(cluster_data, instances_out, e)
    assert cluster_data.instances_queue.empty()
    assert instances_out.empty()

    cluster_data.instances_queue = Queue()
    cluster_data.instances_queue.put('instance6')
    instances_out = Queue()
    f(cluster_data, instances_out, e)
    assert cluster_data.instances_queue.empty()
    assert instances_out.get(block=False) == 'instance6'

    cluster_data.instances_queue = Queue()
    cluster_data.instances_queue.put('notaninstance')
    instances_out = Queue()
    f(cluster_data, instances_out, e)
    assert cluster_data.instances_queue.empty()
    assert instances_out.get(block=False) == 'notaninstance'

    cluster_data.instances_queue = Queue()
    cluster_data.instances_queue.put('api_error')
    instances_out = Queue()
    f(cluster_data, instances_out, e)
    assert cluster_data.instances_queue.empty()
    assert instances_out.get(block=False) == 'api_error'

    cluster_data.instances_queue = Queue()
    cluster_data.instances_queue.put('instance7')
    instances_out = Queue()
    f(cluster_data, instances_out, e)
    assert cluster_data.instances_queue.empty()
    assert instances_out.empty()

    cluster_data.instances_queue = Queue()
    cluster_data.instances_queue.put('instance8')
    instances_out = Queue()
    f(cluster_data, instances_out, e)
    assert cluster_data.instances_queue.empty()
    assert instances_out.empty()
コード例 #43
0
ファイル: test_poutines.py プロジェクト: Magica-Chen/pyro
class QueuePoutineDiscreteTest(TestCase):

    def setUp(self):

        # simple Gaussian-mixture HMM
        def model():
            ps = pyro.param("ps", Variable(torch.Tensor([[0.8], [0.3]])))
            mu = pyro.param("mu", Variable(torch.Tensor([[-0.1], [0.9]])))
            sigma = Variable(torch.ones(1, 1))

            latents = [Variable(torch.ones(1))]
            observes = []
            for t in range(3):

                latents.append(
                    pyro.sample("latent_{}".format(str(t)),
                                Bernoulli(ps[latents[-1][0].long().data])))

                observes.append(
                    pyro.observe("observe_{}".format(str(t)),
                                 Normal(mu[latents[-1][0].long().data], sigma),
                                 pyro.ones(1)))
            return latents

        self.sites = ["observe_{}".format(str(t)) for t in range(3)] + \
                     ["latent_{}".format(str(t)) for t in range(3)] + \
                     ["_INPUT", "_RETURN"]
        self.model = model
        self.queue = Queue()
        self.queue.put(poutine.Trace())

    def test_queue_single(self):
        f = poutine.trace(poutine.queue(self.model, queue=self.queue))
        tr = f.get_trace()
        for name in self.sites:
            assert name in tr

    def test_queue_enumerate(self):
        f = poutine.trace(poutine.queue(self.model, queue=self.queue))
        trs = []
        while not self.queue.empty():
            trs.append(f.get_trace())
        assert len(trs) == 2 ** 3

        true_latents = set()
        for i1 in range(2):
            for i2 in range(2):
                for i3 in range(2):
                    true_latents.add((i1, i2, i3))

        tr_latents = []
        for tr in trs:
            tr_latents.append(tuple([int(tr.nodes[name]["value"].view(-1).data[0]) for name in tr
                                     if tr.nodes[name]["type"] == "sample" and
                                     not tr.nodes[name]["is_observed"]]))

        assert true_latents == set(tr_latents)

    def test_queue_max_tries(self):
        f = poutine.queue(self.model, queue=self.queue, max_tries=3)
        try:
            f()
            assert False
        except ValueError:
            self.assertTrue(True)
コード例 #44
0
class Canon(EOSPropertiesMixin, object):
    '''This class implements Canon's PTP operations.'''
    def __init__(self, *args, **kwargs):
        logger.debug('Init Canon')
        super(Canon, self).__init__(*args, **kwargs)
        # TODO: expose the choice to poll or not Canon events
        self.__no_polling = False
        self.__eos_event_shutdown = Event()
        self.__eos_event_proc = None

    @contextmanager
    def session(self):
        '''
        Manage Canon session with context manager.
        '''
        # When raw device, do not perform
        if self.__no_polling:
            with super(Canon, self).session():
                yield
            return
        # Within a normal PTP session
        with super(Canon, self).session():
            # Set up remote mode and extended event info
            self.eos_set_remote_mode(1)
            self.eos_event_mode(1)
            # And launch a polling thread
            self.__event_queue = Queue()
            self.__eos_event_proc = Thread(
                name='EOSEvtPolling',
                target=self.__eos_poll_events
            )
            self.__eos_event_proc.daemon = False
            atexit.register(self._eos_shutdown)
            self.__eos_event_proc.start()

            try:
                yield
            finally:
                self._eos_shutdown()

    def _shutdown(self):
        self._eos_shutdown()
        super(Canon, self)._shutdown()

    def _eos_shutdown(self):
        logger.debug('Shutdown EOS events request')
        self.__eos_event_shutdown.set()

        # Only join a running thread.
        if self.__eos_event_proc and self.__eos_event_proc.is_alive():
            self.__eos_event_proc.join(2)

    def _PropertyCode(self, **product_properties):
        return super(Canon, self)._PropertyCode(
            BeepMode=0xD001,
            ViewfinderMode=0xD003,
            ImageQuality=0xD006,
            CanonImageSize=0xD008,
            CanonFlashMode=0xD00A,
            TvAvSetting=0xD00C,
            MeteringMode=0xD010,
            MacroMode=0xD011,
            FocusingPoint=0xD012,
            CanonWhiteBalance=0xD013,
            ISOSpeed=0xD01C,
            Aperture=0xD01D,
            ShutterSpeed=0xD01E,
            ExpCompensation=0xD01F,
            Zoom=0xD02A,
            SizeQualityMode=0xD02C,
            FlashMemory=0xD031,
            CameraModel=0xD032,
            CameraOwner=0xD033,
            UnixTime=0xD034,
            ViewfinderOutput=0xD036,
            RealImageWidth=0xD039,
            PhotoEffect=0xD040,
            AssistLight=0xD041,
            **product_properties
        )

    def _OperationCode(self, **product_operations):
        return super(Canon, self)._OperationCode(
            GetObjectSize=0x9001,
            SetObjectArchive=0x9002,
            KeepDeviceOn=0x9003,
            LockDeviceUI=0x9004,
            UnlockDeviceUI=0x9005,
            GetObjectHandleByName=0x9006,
            InitiateReleaseControl=0x9008,
            TerminateReleaseControl=0x9009,
            TerminatePlaybackMode=0x900A,
            ViewfinderOn=0x900B,
            ViewfinderOff=0x900C,
            DoAeAfAwb=0x900D,
            GetCustomizeSpec=0x900E,
            GetCustomizeItemInfo=0x900F,
            GetCustomizeData=0x9010,
            SetCustomizeData=0x9011,
            GetCaptureStatus=0x9012,
            CheckEvent=0x9013,
            FocusLock=0x9014,
            FocusUnlock=0x9015,
            GetLocalReleaseParam=0x9016,
            SetLocalReleaseParam=0x9017,
            AskAboutPcEvf=0x9018,
            SendPartialObject=0x9019,
            InitiateCaptureInMemory=0x901A,
            GetPartialObjectEx=0x901B,
            SetObjectTime=0x901C,
            GetViewfinderImage=0x901D,
            GetObjectAttributes=0x901E,
            ChangeUSBProtocol=0x901F,
            GetChanges=0x9020,
            GetObjectInfoEx=0x9021,
            InitiateDirectTransfer=0x9022,
            TerminateDirectTransfer=0x9023,
            SendObjectInfoByPath=0x9024,
            SendObjectByPath=0x9025,
            InitiateDirectTansferEx=0x9026,
            GetAncillaryObjectHandles=0x9027,
            GetTreeInfo=0x9028,
            GetTreeSize=0x9029,
            NotifyProgress=0x902A,
            NotifyCancelAccepted=0x902B,
            GetDirectory=0x902D,
            SetPairingInfo=0x9030,
            GetPairingInfo=0x9031,
            DeletePairingInfo=0x9032,
            GetMACAddress=0x9033,
            SetDisplayMonitor=0x9034,
            PairingComplete=0x9035,
            GetWirelessMAXChannel=0x9036,
            EOSGetStorageIDs=0x9101,
            EOSGetStorageInfo=0x9102,
            EOSGetObjectInfo=0x9103,
            EOSGetObject=0x9104,
            EOSDeleteObject=0x9105,
            EOSFormatStore=0x9106,
            EOSGetPartialObject=0x9107,
            EOSGetDeviceInfoEx=0x9108,
            EOSGetObjectInfoEx=0x9109,
            EOSGetThumbEx=0x910A,
            EOSSendPartialObject=0x910B,
            EOSSetObjectAttributes=0x910C,
            EOSGetObjectTime=0x910D,
            EOSSetObjectTime=0x910E,
            EOSRemoteRelease=0x910F,
            EOSSetDevicePropValueEx=0x9110,
            EOSGetRemoteMode=0x9113,
            EOSSetRemoteMode=0x9114,
            EOSSetEventMode=0x9115,
            EOSGetEvent=0x9116,
            EOSTransferComplete=0x9117,
            EOSCancelTransfer=0x9118,
            EOSResetTransfer=0x9119,
            EOSPCHDDCapacity=0x911A,
            EOSSetUILock=0x911B,
            EOSResetUILock=0x911C,
            EOSKeepDeviceOn=0x911D,
            EOSSetNullPacketMode=0x911E,
            EOSUpdateFirmware=0x911F,
            EOSTransferCompleteDT=0x9120,
            EOSCancelTransferDT=0x9121,
            EOSSetWftProfile=0x9122,
            EOSGetWftProfile=0x9122,
            EOSSetProfileToWft=0x9124,
            EOSBulbStart=0x9125,
            EOSBulbEnd=0x9126,
            EOSRequestDevicePropValue=0x9127,
            EOSRemoteReleaseOn=0x9128,
            EOSRemoteReleaseOff=0x9129,
            EOSInitiateViewfinder=0x9151,
            EOSTerminateViewfinder=0x9152,
            EOSGetViewFinderImage=0x9153,
            EOSDoAf=0x9154,
            EOSDriveLens=0x9155,
            EOSDepthOfFieldPreview=0x9156,
            EOSClickWB=0x9157,
            EOSZoom=0x9158,
            EOSZoomPosition=0x9159,
            EOSSetLiveAfFrame=0x915a,
            EOSAfCancel=0x9160,
            EOSFAPIMessageTX=0x91FE,
            EOSFAPIMessageRX=0x91FF,
            **product_operations
        )

    def _ObjectFormatCode(self, **product_object_formats):
        return super(Canon, self)._ObjectFormatCode(
            CRW=0xB101,
            CRW3=0xB103,
            MOV=0xB104,
            **product_object_formats
        )

    def _ResponseCode(self, **product_responses):
        return super(Canon, self)._ResponseCode(
            **product_responses
        )

    def _EventCode(self, **product_events):
        return super(Canon, self)._EventCode(
            CanonDeviceInfoChanged=0xC008,
            CanonRequestObjectTransfer=0xC009,
            CameraModeChanged=0xC00C,
            **product_events
        )

    def _FilesystemType(self, **product_filesystem_types):
        return super(Canon, self)._FilesystemType(
            **product_filesystem_types
        )

    def _EOSEventCode(self):
        '''Return desired endianness for Canon EOS event codes'''
        return Enum(
            self._UInt32,
            default=Pass,
            EmptyEvent=0x0000,
            RequestGetEvent=0xC101,
            ObjectAdded=0xC181,
            ObjectRemoved=0xC182,
            RequestGetObjectInfoEx=0xC183,
            StorageStatusChanged=0xC184,
            StorageInfoChanged=0xC185,
            RequestObjectTransfer=0xC186,
            ObjectInfoChangedEx=0xC187,
            ObjectContentChanged=0xC188,
            DevicePropChanged=0xC189,
            AvailListChanged=0xC18A,
            CameraStatusChanged=0xC18B,
            WillSoonShutdown=0xC18D,
            ShutdownTimerUpdated=0xC18E,
            RequestCancelTransfer=0xC18F,
            RequestObjectTransferDT=0xC190,
            RequestCancelTransferDT=0xC191,
            StoreAdded=0xC192,
            StoreRemoved=0xC193,
            BulbExposureTime=0xC194,
            RecordingTime=0xC195,
            InnerDevelopParam=0xC196,
            RequestObjectTransferDevelop=0xC197,
            GPSLogOutputProgress=0xC198,
            GPSLogOutputComplete=0xC199,
            TouchTrans=0xC19A,
            RequestObjectTransferExInfo=0xC19B,
            PowerZoomInfoChanged=0xC19D,
            RequestPushMode=0xC19F,
            RequestObjectTransferTS=0xC1A2,
            AfResult=0xC1A3,
            CTGInfoCheckComplete=0xC1A4,
            OLCInfoChanged=0xC1A5,
            ObjectAddedEx64=0xC1A7,
            ObjectInfoChangedEx64=0xC1A8,
            RequestObjectTransfer64=0xC1A9,
            RequestObjectTransferFTP64=0xC1AB,
            ImportFailed=0xC1AF,
            BlePairing=0xC1B0,
            RequestObjectTransferFTP=0xC1F1,
            Unknown=0xFFFF,
        )

    def _EOSPropertyCode(self):
        '''Return desired endianness for Canon EOS property codes'''
        return Enum(
            self._UInt32,
            default=Pass,
            Aperture=0xD101,
            ShutterSpeed=0xD102,
            ISO=0xD103,
            ExposureCompensation=0xD104,
            ShootingMode=0xD105,
            DriveMode=0xD106,
            ExposureMeteringMode=0xD107,
            AutoFocusMode=0xD108,
            WhiteBalance=0xD109,
            ColorTemperature=0xD10A,
            WhiteBalanceAdjustA=0xD10B,
            WhiteBalanceAdjustB=0xD10C,
            WhiteBalanceXA=0xD10D,
            WhiteBalanceXB=0xD10E,
            ColorSpace=0xD10F,
            PictureStyle=0xD110,
            BatteryPower=0xD111,
            BatterySelect=0xD112,
            CameraTime=0xD113,
            AutoPowerOff=0xD114,
            Owner=0xD115,
            ModelID=0xD116,
            PTPExtensionVersion=0xD119,
            DPOFVersion=0xD11A,
            AvailableShots=0xD11B,
            CaptureDestination=0xD11C,
            BracketMode=0xD11D,
            CurrentStorage=0xD11E,
            CurrentFolder=0xD11F,
            ImageFormat=0xD120,
            ImageFormatCF=0xD121,
            ImageFormatSD=0xD122,
            ImageFormatHDD=0xD123,
            CompressionS=0xD130,
            CompressionM1=0xD131,
            CompressionM2=0xD132,
            CompressionL=0xD133,
            AEModeDial=0xD138,
            AEModeCustom=0xD139,
            MirrorUpSetting=0xD13A,
            HighlightTonePriority=0xD13B,
            AFSelectFocusArea=0xD13C,
            HDRSetting=0xD13D,
            PCWhiteBalance1=0xD140,
            PCWhiteBalance2=0xD141,
            PCWhiteBalance3=0xD142,
            PCWhiteBalance4=0xD143,
            PCWhiteBalance5=0xD144,
            MWhiteBalance=0xD145,
            MWhiteBalanceEx=0xD146,
            PictureStyleStandard=0xD150,
            PictureStylePortrait=0xD151,
            PictureStyleLandscape=0xD152,
            PictureStyleNeutral=0xD153,
            PictureStyleFaithful=0xD154,
            PictureStyleBlackWhite=0xD155,
            PictureStyleAuto=0xD156,
            PictureStyleUserSet1=0xD160,
            PictureStyleUserSet2=0xD161,
            PictureStyleUserSet3=0xD162,
            PictureStyleParam1=0xD170,
            PictureStyleParam2=0xD171,
            PictureStyleParam3=0xD172,
            HighISOSettingNoiseReduction=0xD178,
            MovieServoAF=0xD179,
            ContinuousAFValid=0xD17A,
            Attenuator=0xD17B,
            UTCTime=0xD17C,
            Timezone=0xD17D,
            Summertime=0xD17E,
            FlavorLUTParams=0xD17F,
            CustomFunc1=0xD180,
            CustomFunc2=0xD181,
            CustomFunc3=0xD182,
            CustomFunc4=0xD183,
            CustomFunc5=0xD184,
            CustomFunc6=0xD185,
            CustomFunc7=0xD186,
            CustomFunc8=0xD187,
            CustomFunc9=0xD188,
            CustomFunc10=0xD189,
            CustomFunc11=0xD18A,
            CustomFunc12=0xD18B,
            CustomFunc13=0xD18C,
            CustomFunc14=0xD18D,
            CustomFunc15=0xD18E,
            CustomFunc16=0xD18F,
            CustomFunc17=0xD190,
            CustomFunc18=0xD191,
            CustomFunc19=0xD192,
            InnerDevelop=0xD193,
            MultiAspect=0xD194,
            MovieSoundRecord=0xD195,
            MovieRecordVolume=0xD196,
            WindCut=0xD197,
            ExtenderType=0xD198,
            OLCInfoVersion=0xD199,
            CustomFuncEx=0xD1A0,
            MyMenu=0xD1A1,
            MyMenuList=0xD1A2,
            WftStatus=0xD1A3,
            WftInputTransmission=0xD1A4,
            HDDDirectoryStructure=0xD1A5,
            BatteryInfo=0xD1A6,
            AdapterInfo=0xD1A7,
            LensStatus=0xD1A8,
            QuickReviewTime=0xD1A9,
            CardExtension=0xD1AA,
            TempStatus=0xD1AB,
            ShutterCounter=0xD1AC,
            SpecialOption=0xD1AD,
            PhotoStudioMode=0xD1AE,
            SerialNumber=0xD1AF,
            EVFOutputDevice=0xD1B0,
            EVFMode=0xD1B1,
            DepthOfFieldPreview=0xD1B2,
            EVFSharpness=0xD1B3,
            EVFWBMode=0xD1B4,
            EVFClickWBCoeffs=0xD1B5,
            EVFColorTemp=0xD1B6,
            ExposureSimMode=0xD1B7,
            EVFRecordStatus=0xD1B8,
            LvAfSystem=0xD1BA,
            MovSize=0xD1BB,
            LvViewTypeSelect=0xD1BC,
            MirrorDownStatus=0xD1BD,
            MovieParam=0xD1BE,
            MirrorLockupState=0xD1BF,
            FlashChargingState=0xD1C0,
            AloMode=0xD1C1,
            FixedMovie=0xD1C2,
            OneShotRawOn=0xD1C3,
            ErrorForDisplay=0xD1C4,
            AEModeMovie=0xD1C5,
            BuiltinStroboMode=0xD1C6,
            StroboDispState=0xD1C7,
            StroboETTL2Metering=0xD1C8,
            ContinousAFMode=0xD1C9,
            MovieParam2=0xD1CA,
            StroboSettingExpComposition=0xD1CB,
            MovieParam3=0xD1CC,
            LVMedicalRotate=0xD1CF,
            Artist=0xD1D0,
            Copyright=0xD1D1,
            BracketValue=0xD1D2,
            FocusInfoEx=0xD1D3,
            DepthOfField=0xD1D4,
            Brightness=0xD1D5,
            LensAdjustParams=0xD1D6,
            EFComp=0xD1D7,
            LensName=0xD1D8,
            AEB=0xD1D9,
            StroboSetting=0xD1DA,
            StroboWirelessSetting=0xD1DB,
            StroboFiring=0xD1DC,
            LensID=0xD1DD,
            LCDBrightness=0xD1DE,
            CADarkBright=0xD1DF,
        )

    def _EOSEventRecords(self):
        '''Return desired endianness for EOS Event Records constructor'''
        return Range(
            # The dataphase can be about as long as a 32 bit unsigned int.
            0, 0xFFFFFFFF,
            self._EOSEventRecord
        )

    def _EOSEventRecord(self):
        '''Return desired endianness for a single EOS Event Record'''
        return Struct(
            'Bytes' / self._UInt32,
            Embedded(Struct(
                'EventCode' / self._EOSEventCode,
                'Record' / Switch(
                    lambda ctx: ctx.EventCode,
                    {
                        'AvailListChanged':
                        Embedded(Struct(
                            'PropertyCode' / self._EOSPropertyCode,
                            'Enumeration' / Array(
                                # TODO: Verify if this is actually an
                                # enumeration.
                                lambda ctx: ctx._._.Bytes - 12,
                                self._UInt8
                            )
                        )),
                        'DevicePropChanged':
                        Embedded(Struct(
                            'PropertyCode' / self._EOSPropertyCode,
                            'DataTypeCode' / Computed(
                                lambda ctx: self._EOSDataTypeCode[ctx.PropertyCode]
                            ),
                            'Value' / Switch(
                                lambda ctx: ctx.DataTypeCode,
                                {
                                    None: Array(
                                        lambda ctx: ctx._._.Bytes - 12,
                                        self._UInt8
                                    )
                                },
                                default=self._DataType
                            ),
                        )),
                        # TODO: 'EmptyEvent',
                        # TODO: 'RequestGetEvent',
                        # TODO: 'ObjectAdded',
                        # TODO: 'ObjectRemoved',
                        # TODO: 'RequestGetObjectInfoEx',
                        # TODO: 'StorageStatusChanged',
                        # TODO: 'StorageInfoChanged',
                        # TODO: 'RequestObjectTransfer',
                        # TODO: 'ObjectInfoChangedEx',
                        # TODO: 'ObjectContentChanged',
                        # TODO: 'DevicePropChanged',
                        # TODO: 'AvailListChanged',
                        # TODO: 'CameraStatusChanged',
                        # TODO: 'WillSoonShutdown',
                        # TODO: 'ShutdownTimerUpdated',
                        # TODO: 'RequestCancelTransfer',
                        # TODO: 'RequestObjectTransferDT',
                        # TODO: 'RequestCancelTransferDT',
                        # TODO: 'StoreAdded',
                        # TODO: 'StoreRemoved',
                        # TODO: 'BulbExposureTime',
                        # TODO: 'RecordingTime',
                        # TODO: 'InnerDevelopParam',
                        # TODO: 'RequestObjectTransferDevelop',
                        # TODO: 'GPSLogOutputProgress',
                        # TODO: 'GPSLogOutputComplete',
                        # TODO: 'TouchTrans',
                        # TODO: 'RequestObjectTransferExInfo',
                        # TODO: 'PowerZoomInfoChanged',
                        # TODO: 'RequestPushMode',
                        # TODO: 'RequestObjectTransferTS',
                        # TODO: 'AfResult',
                        # TODO: 'CTGInfoCheckComplete',
                        # TODO: 'OLCInfoChanged',
                        # TODO: 'ObjectAddedEx64',
                        # TODO: 'ObjectInfoChangedEx64',
                        # TODO: 'RequestObjectTransfer64',
                        # TODO: 'RequestObjectTransferFTP64',
                        # TODO: 'ImportFailed',
                        # TODO: 'BlePairing',
                        # TODO: 'RequestObjectTransferFTP',
                        # TODO: 'Unknown',
                    },
                    default=Array(
                        lambda ctx: ctx._.Bytes - 8,
                        self._UInt8
                    )
                )
            ))
        )

    def _EOSDeviceInfo(self):
        return Struct(
            'EventsSupported' / PrefixedArray(
                self._UInt32,
                self._EOSEventCode
            ),
            'DevicePropertiesSupported' / PrefixedArray(
                self._UInt32,
                self._EOSPropertyCode
            ),
            'TODO' / PrefixedArray(
                self._UInt32,
                self._UInt32
            ),
        )

    # TODO: Decode Canon specific events and properties.
    def _set_endian(self, endian):
        logger.debug('Set Canon endianness')
        super(Canon, self)._set_endian(endian, explicit=True)
        self._EOSPropertyCode = self._EOSPropertyCode()
        self._EOSEventCode = self._EOSEventCode()
        self._EOSImageSize = self._EOSImageSize()
        self._EOSImageType = self._EOSImageType()
        self._EOSImageCompression = self._EOSImageCompression()
        self._EOSImageFormat = self._EOSImageFormat()
        self._EOSEventRecord = self._EOSEventRecord()
        self._EOSEventRecords = self._EOSEventRecords()
        super(Canon, self)._set_endian(endian, explicit=False)


    # TODO: implement GetObjectSize
    # TODO: implement SetObjectArchive

    def keep_device_on(self):
        '''Ping non EOS camera so it stays ON'''
        ptp = Container(
            OperationCode='KeepDeviceOn',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[]
        )
        response = self.mesg(ptp)
        return response

    # TODO: implement LockDeviceUI
    # TODO: implement UnlockDeviceUI
    # TODO: implement GetObjectHandleByName
    # TODO: implement InitiateReleaseControl
    # TODO: implement TerminateReleaseControl
    # TODO: implement TerminatePlaybackMode
    # TODO: implement ViewfinderOn
    # TODO: implement ViewfinderOff
    # TODO: implement DoAeAfAwb
    # TODO: implement GetCustomizeSpec
    # TODO: implement GetCustomizeItemInfo
    # TODO: implement GetCustomizeData
    # TODO: implement SetCustomizeData
    # TODO: implement GetCaptureStatus
    # TODO: implement CheckEvent
    # TODO: implement FocusLock
    # TODO: implement FocusUnlock
    # TODO: implement GetLocalReleaseParam
    # TODO: implement SetLocalReleaseParam
    # TODO: implement AskAboutPcEvf
    # TODO: implement SendPartialObject
    # TODO: implement InitiateCaptureInMemory
    # TODO: implement GetPartialObjectEx
    # TODO: implement SetObjectTime
    # TODO: implement GetViewfinderImage
    # TODO: implement GetObjectAttributes
    # TODO: implement ChangeUSBProtocol
    # TODO: implement GetChanges
    # TODO: implement GetObjectInfoEx
    # TODO: implement InitiateDirectTransfer
    # TODO: implement TerminateDirectTransfer
    # TODO: implement SendObjectInfoByPath
    # TODO: implement SendObjectByPath
    # TODO: implement InitiateDirectTansferEx
    # TODO: implement GetAncillaryObjectHandles
    # TODO: implement GetTreeInfo
    # TODO: implement GetTreeSize
    # TODO: implement NotifyProgress
    # TODO: implement NotifyCancelAccepted
    # TODO: implement GetDirectory
    # TODO: implement SetPairingInfo
    # TODO: implement GetPairingInfo
    # TODO: implement DeletePairingInfo
    # TODO: implement GetMACAddress
    # TODO: implement SetDisplayMonitor
    # TODO: implement PairingComplete
    # TODO: implement GetWirelessMAXChannel
    # TODO: implement EOSGetStorageIDs
    # TODO: implement EOSGetStorageInfo
    # TODO: implement EOSGetObjectInfo
    # TODO: implement EOSGetObject
    # TODO: implement EOSDeleteObject
    # TODO: implement EOSFormatStore
    # TODO: implement EOSGetPartialObject
    # TODO: implement EOSGetDeviceInfoEx

    def eos_get_device_info(self):
        '''Get EOS camera device information'''
        ptp = Container(
            OperationCode='EOSGetDeviceInfoEx',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[0x00100000]
        )
        response = self.recv(ptp)
        return self._parse_if_data(response, self._EOSDeviceInfo)

    # TODO: implement EOSGetObjectInfoEx
    # TODO: implement EOSGetThumbEx
    # TODO: implement EOSSendPartialObject
    # TODO: implement EOSSetObjectAttributes
    # TODO: implement EOSGetObjectTime
    # TODO: implement EOSSetObjectTime

    def eos_remote_release(self):
        '''Release shutter remotely on EOS cameras'''
        ptp = Container(
            OperationCode='EOSRemoteRelease',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[]
        )
        response = self.mesg(ptp)
        return response

    # TODO: implement EOSSetDevicePropValueEx
    # TODO: implement EOSGetRemoteMode

    def eos_set_remote_mode(self, mode):
        '''Set remote mode on EOS cameras'''

        # TODO: Add automatic translation of remote mode codes and names.
        code = mode
        ptp = Container(
            OperationCode='EOSSetRemoteMode',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[code]
        )
        response = self.mesg(ptp)
        return response

    def eos_event_mode(self, mode):
        '''Set event mode on EOS cameras'''
        # Canon extension uses this to enrich the events returned by the camera
        # as well as allowing for polling at the convenience of the initiator.

        # TODO: Add automatic translation of event mode codes and names.
        code = mode
        ptp = Container(
            OperationCode='EOSSetEventMode',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[code]
        )
        response = self.mesg(ptp)
        return response

    def eos_get_event(self):
        '''Poll EOS camera for EOS events'''
        ptp = Container(
            OperationCode='EOSGetEvent',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[]
        )
        response = self.recv(ptp)
        return self._parse_if_data(response, self._EOSEventRecords)

    def eos_transfer_complete(self, handle):
        '''Terminate a transfer for EOS Cameras'''
        ptp = Container(
            OperationCode='EOSTransferComplete',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[handle]
        )
        response = self.mesg(ptp)
        return response

    # TODO: implement EOSCancelTransfer
    # TODO: implement EOSResetTransfer

    def eos_pc_hdd_capacity(self, todo0=0xfffffff8, todo1=0x1000, todo2=0x1):
        '''Tell EOS camera about PC hard drive capacity'''
        # TODO: Figure out what to send exactly.
        ptp = Container(
            OperationCode='EOSPCHDDCapacity',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[todo0, todo1, todo2]
        )
        response = self.mesg(ptp)
        return response

    def eos_set_ui_lock(self):
        '''Lock user interface on EOS cameras'''
        ptp = Container(
            OperationCode='EOSSetUILock',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[]
        )
        response = self.mesg(ptp)
        return response

    def eos_reset_ui_lock(self):
        '''Unlock user interface on EOS cameras'''
        ptp = Container(
            OperationCode='EOSResetUILock',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[]
        )
        response = self.mesg(ptp)
        return response

    def eos_keep_device_on(self):
        '''Ping EOS camera so it stays ON'''
        ptp = Container(
            OperationCode='EOSKeepDeviceOn',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[]
        )
        response = self.mesg(ptp)
        return response

    # TODO: implement EOSSetNullPacketMode
    # TODO: implement EOSUpdateFirmware
    # TODO: implement EOSTransferCompleteDT
    # TODO: implement EOSCancelTransferDT
    # TODO: implement EOSSetWftProfile
    # TODO: implement EOSGetWftProfile
    # TODO: implement EOSSetProfileToWft

    # TODO: implement method convenience method for bulb captures
    def eos_bulb_start(self):
        '''Begin bulb capture on EOS cameras'''
        ptp = Container(
            OperationCode='EOSBulbStart',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[]
        )
        response = self.mesg(ptp)
        return response

    def eos_bulb_end(self):
        '''End bulb capture on EOS cameras'''
        ptp = Container(
            OperationCode='EOSBulbEnd',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[]
        )
        response = self.mesg(ptp)
        return response

    def eos_request_device_prop_value(self, device_property):
        '''End bulb capture on EOS cameras'''
        ptp = Container(
            OperationCode='EOSRequestDevicePropValue',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[device_property]
        )
        response = self.mesg(ptp)
        return response

    def eos_remote_release_on(self, full=False, m=False, x=0):
        '''
        Remote control shutter press for EOS cameras

        This is the equivalent of pressing the shutter button: all the way in
        if `full` or half-way otherwise.

        For Canon EOS M, there is only full press with a special argument.
        '''
        ptp = Container(
            OperationCode='EOSRemoteReleaseOn',
            SessionID=self._session,
            TransactionID=self._transaction,
            # TODO: figure out what x means.
            Parameter=[0x3 if m else (0x2 if full else 0x1), x]
        )
        response = self.mesg(ptp)
        return response

    def eos_remote_release_off(self, full=False, m=False):
        '''
        Remote control shutter release for EOS cameras

        This is the equivalent of releasing the shutter button: from all the
        way in if `full` or from half-way otherwise.

        For Canon EOS M, there is only full press with a special argument.
        '''
        ptp = Container(
            OperationCode='EOSRemoteReleaseOff',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[0x3 if m else (0x2 if full else 0x1)]
        )
        response = self.mesg(ptp)
        return response

    # TODO: implement EOSInitiateViewfinder
    # TODO: implement EOSTerminateViewfinder
    # TODO: implement EOSGetViewFinderImage

    def eos_get_viewfinder_image(self):
        '''Get viefinder image for EOS cameras'''
        ptp = Container(
            OperationCode='EOSGetViewFinderImage',
            SessionID=self._session,
            TransactionID=self._transaction,
            # TODO: Find out what this parameter does.
            Parameter=[0x00100000]
        )
        return self.recv(ptp)

    def eos_do_af(self):
        '''Perform auto-focus with AF lenses set to AF'''

        ptp = Container(
            OperationCode='EOSDoAf',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[]
        )
        response = self.mesg(ptp)
        return response

    def eos_drive_lens(self, infinity=True, step=2):
        '''
        Drive lens focus on EOS cameras with an auto-focus lens on.

        `step` lies in the interval [-3, 3]. Its sign reverses the infinity
        argument. If `infinity` is `True`, the focal plane is driven away from
        the camera with the given step.

        The magnitude of `step` is qualitatively `1` for "fine", `2` for
        "normal" and `3` for "coarse".
        '''

        if step not in range(-3, 4):
            raise ValueError(
                'The step must be within [-3, 3].'
            )

        infinity = not infinity if step < 0 else infinity
        step = -step if step < 0 else step
        instruction = 0x8000 if infinity else 0x0000
        instruction |= step

        ptp = Container(
            OperationCode='EOSDriveLens',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[instruction]
        )
        response = self.mesg(ptp)
        return response

    # TODO: implement EOSDepthOfFieldPreview
    # TODO: implement EOSClickWB
    # TODO: implement EOSZoom
    # TODO: implement EOSZoomPosition
    # TODO: implement EOSSetLiveAfFrame

    def eos_af_cancel(self):
        '''Stop driving AF on EOS cameras.'''

        ptp = Container(
            OperationCode='EOSAfCancel',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[]
        )
        response = self.mesg(ptp)
        return response

    # TODO: implement EOSFAPIMessageTX
    # TODO: implement EOSFAPIMessageRX

    def event(self, wait=False):
        '''Check Canon or PTP events

        If `wait` this function is blocking. Otherwise it may return None.
        '''
        # TODO: Do something reasonable on wait=True
        evt = None
        timeout = None if wait else 0.001
        # TODO: Join queues to preserve order of Canon and PTP events.
        if not self.__event_queue.empty():
            evt = self.__event_queue.get(block=not wait, timeout=timeout)
        else:
            evt = super(Canon, self).event(wait=wait)

        return evt

    def __eos_poll_events(self):
        '''Poll events, adding them to a queue.'''
        while not self.__eos_event_shutdown.is_set() and _main_thread_alive():
            try:
                evts = self.eos_get_event()
                if evts:
                    for evt in evts:
                        logger.debug('Event queued')
                        logger.debug(evt)
                        self.__event_queue.put(evt)
            except Exception as e:
                logger.error(e)
            sleep(0.2)
        self.__eos_event_shutdown.clear()
コード例 #45
0
ファイル: ip.py プロジェクト: Parrot-Developers/sequoia-ptpy
class IPTransport(object):
    '''Implement IP transport.'''
    def __init__(self, device=None):
        '''Instantiate the first available PTP device over IP'''
        self.__setup_constructors()
        logger.debug('Init IP')

        self.__dev = device
        if device is None:
            raise NotImplementedError(
                'IP discovery not implemented. Please provide a device.'
            )
        self.__device = device

        # Signal usable implicit session
        self.__implicit_session_open = Event()
        # Signal implicit session is shutting down
        self.__implicit_session_shutdown = Event()

        self.__check_session_lock = Lock()
        self.__transaction_lock = Lock()

        self.__event_queue = Queue()

        atexit.register(self._shutdown)

    def _shutdown(self):
        try:
            self.__close_implicit_session()
        except Exception as e:
            logger.error(e)

    @contextmanager
    def __implicit_session(self):
        '''Manage implicit sessions with responder'''
        # There is now an implicit session
        self.__check_session_lock.acquire()
        if not self.__implicit_session_open.is_set():
            try:
                self.__open_implicit_session()
                self.__check_session_lock.release()
                yield
            except Exception as e:
                logger.error(e)
                raise PTPError('Failed to open PTP/IP implicit session')
            finally:
                if self.__implicit_session_open.is_set():
                    self.__close_implicit_session()
                if self.__check_session_lock.locked():
                    self.__check_session_lock.release()
        else:
            self.__check_session_lock.release()
            yield

    def __open_implicit_session(self):
        '''Establish implicit session with responder'''

        self.__implicit_session_shutdown.clear()

        # Establish Command and Event connections
        if type(self.__device) is tuple:
            host, port = self.__device
            self.__setup_connection(host, port)
        else:
            self.__setup_connection(self.__device)

        self.__implicit_session_open.set()

        # Prepare Event and Probe threads
        self.__event_proc = Thread(
            name='EvtPolling',
            target=self.__poll_events
        )
        self.__event_proc.daemon = False

        self.__ping_pong_proc = Thread(
            name='PingPong',
            target=self.__ping_pong
        )
        self.__ping_pong_proc.daemon = False

        # Launch Event and Probe threads
        self.__event_proc.start()
        self.__ping_pong_proc.start()

    def __close_implicit_session(self):
        '''Terminate implicit session with responder'''
        self.__implicit_session_shutdown.set()

        if not self.__implicit_session_open.is_set():
            return

        # Only join running threads.
        if self.__event_proc.is_alive():
            self.__event_proc.join(2)
        if self.__ping_pong_proc.is_alive():
            self.__ping_pong_proc.join(2)

        logger.debug('Close connections for {}'.format(repr(self.__dev)))
        try:
            self.__evtcon.shutdown(socket.SHUT_RDWR)
        except socket.error as e:
            if e.errno == 107:
                pass
            else:
                raise e
        try:
            self.__cmdcon.shutdown(socket.SHUT_RDWR)
        except socket.error as e:
            if e.errno == 107:
                pass
            else:
                raise e
        self.__evtcon.close()
        self.__cmdcon.close()

        self.__implicit_session_open.clear()

    def __setup_connection(self, host=None, port=15740):
        '''Establish a PTP/IP session for a given host'''
        logger.debug(
            'Establishing PTP/IP connection with {}:{}'
            .format(host, port)
        )
        socket.setdefaulttimeout(5)
        hdrlen = self.__Header.sizeof()
        # Command Connection Establishment
        self.__cmdcon = create_connection((host, port))
        # Send InitCommand
        # TODO: Allow users to identify as an arbitrary initiator.
        init_cmd_req_payload = self.__InitCommand.build(
            Container(
                InitiatorGUID=16*[0xFF],
                InitiatorFriendlyName='PTPy',
                InitiatorProtocolVersion=Container(
                    Major=100,
                    Minor=000,
                ),
            ))
        init_cmd_req = self.__Packet.build(
            Container(
                Type='InitCommand',
                Payload=init_cmd_req_payload,
            )
        )
        actual_socket(self.__cmdcon).sendall(init_cmd_req)
        # Get ACK/NACK
        init_cmd_req_rsp = actual_socket(self.__cmdcon).recv(72)
        init_cmd_rsp_hdr = self.__Header.parse(
            init_cmd_req_rsp[0:hdrlen]
        )

        if init_cmd_rsp_hdr.Type == 'InitCommandAck':
            cmd_ack = self.__InitCommandACK.parse(init_cmd_req_rsp[hdrlen:])
            logger.debug(
                'Command connection ({}) established'
                .format(cmd_ack.ConnectionNumber)
            )
        elif init_cmd_rsp_hdr.Type == 'InitFail':
            cmd_nack = self.__InitFail.parse(init_cmd_req_rsp[hdrlen:])
            msg = 'InitCommand failed, Reason: {}'.format(
                cmd_nack
            )
            logger.error(msg)
            raise PTPError(msg)
        else:
            msg = 'Unexpected response Type to InitCommand : {}'.format(
                init_cmd_rsp_hdr.Type
            )
            logger.error(msg)
            raise PTPError(msg)

        # Event Connection Establishment
        self.__evtcon = create_connection((host, port))
        self.__evtcon.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        self.__evtcon.setsockopt(socket.IPPROTO_TCP, socket.SO_KEEPALIVE, 1)

        # Send InitEvent
        payload = self.__InitEvent.build(Container(
            ConnectionNumber=cmd_ack.ConnectionNumber,
        ))
        evt_req = self.__Packet.build(
            Container(
                Type='InitEvent',
                Payload=payload,
            )
        )
        actual_socket(self.__evtcon).sendall(evt_req)
        # Get ACK/NACK
        init_evt_req_rsp = actual_socket(self.__evtcon).recv(
            hdrlen + self.__InitFail.sizeof()
        )
        init_evt_rsp_hdr = self.__Header.parse(
            init_evt_req_rsp[0:hdrlen]
        )

        if init_evt_rsp_hdr.Type == 'InitEventAck':
            logger.debug(
                'Event connection ({}) established'
                .format(cmd_ack.ConnectionNumber)
            )
        elif init_evt_rsp_hdr.Type == 'InitFail':
            evt_nack = self.__InitFail.parse(init_evt_req_rsp[hdrlen:])
            msg = 'InitEvent failed, Reason: {}'.format(
                evt_nack
            )
            logger.error(msg)
            raise PTPError(msg)
        else:
            msg = 'Unexpected response Type to InitEvent : {}'.format(
                init_evt_rsp_hdr.Type
            )
            logger.error(msg)
            raise PTPError(msg)

    # Helper methods.
    # ---------------------
    def __setup_constructors(self):
        '''Set endianness and create transport-specific constructors.'''
        # Set endianness of constructors before using them.
        self._set_endian('little')

        self.__Length = Int32ul
        self.__Type = Enum(
            Int32ul,
            Undefined=0x00000000,
            InitCommand=0x00000001,
            InitCommandAck=0x00000002,
            InitEvent=0x00000003,
            InitEventAck=0x00000004,
            InitFail=0x00000005,
            Command=0x00000006,
            Response=0x00000007,
            Event=0x00000008,
            StartData=0x00000009,
            Data=0x0000000A,
            Cancel=0x0000000B,
            EndData=0x0000000C,
            Ping=0x0000000D,
            Pong=0x0000000E,
        )
        self.__Header = Struct(
            'Length' / self.__Length,
            'Type' / self.__Type,
        )
        self.__Param = Range(0, 5, self._Parameter)
        self.__EventParam = Range(0, 3, self._Parameter)
        self.__PacketBase = Struct(
            Embedded(self.__Header),
            'Payload' / Bytes(
                lambda ctx, h=self.__Header: ctx.Length - h.sizeof()),
        )
        self.__Packet = ExprAdapter(
            self.__PacketBase,
            encoder=lambda obj, ctx, h=self.__Header: Container(
                Length=len(obj.Payload) + h.sizeof(),
                **obj
            ),
            decoder=lambda obj, ctx: obj,
        )
        # Yet another arbitrary string type. Max-length CString utf8-encoded
        self.__PTPIPString = ExprAdapter(
            RepeatUntil(
                lambda obj, ctx, lst:
                six.unichr(obj) in '\x00' or len(lst) == 40, Int16ul
            ),
            encoder=lambda obj, ctx:
            [] if len(obj) == 0 else[ord(c) for c in six.text_type(obj)]+[0],
            decoder=lambda obj, ctx:
            u''.join(
                [six.unichr(o) for o in obj]
            ).split('\x00')[0],
        )
        # PTP/IP packets
        # Command
        self.__ProtocolVersion = Struct(
            'Major' / Int16ul,
            'Minor' / Int16ul,
        )
        self.__InitCommand = Embedded(Struct(
            'InitiatorGUID' / Array(16, Int8ul),
            'InitiatorFriendlyName' / self.__PTPIPString,
            'InitiatorProtocolVersion' / self.__ProtocolVersion,
        ))
        self.__InitCommandACK = Embedded(Struct(
            'ConnectionNumber' / Int32ul,
            'ResponderGUID' / Array(16, Int8ul),
            'ResponderFriendlyName' / self.__PTPIPString,
            'ResponderProtocolVersion' / self.__ProtocolVersion,
        ))
        # Event
        self.__InitEvent = Embedded(Struct(
            'ConnectionNumber' / Int32ul,
        ))
        # Common to Events and Command requests
        self.__Reason = Enum(
            # TODO: Verify these codes...
            Int32ul,
            Undefined=0x0000,
            RejectedInitiator=0x0001,
            Busy=0x0002,
            Unspecified=0x0003,
        )
        self.__InitFail = Embedded(Struct(
            'Reason' / self.__Reason,
        ))

        self.__DataphaseInfo = Enum(
            Int32ul,
            Undefined=0x00000000,
            In=0x00000001,
            Out=0x00000002,
        )
        self.__Command = Embedded(Struct(
            'DataphaseInfo' / self.__DataphaseInfo,
            'OperationCode' / self._OperationCode,
            'TransactionID' / self._TransactionID,
            'Parameter' / self.__Param,
        ))
        self.__Response = Embedded(Struct(
            'ResponseCode' / self._ResponseCode,
            'TransactionID' / self._TransactionID,
            'Parameter' / self.__Param,
        ))
        self.__Event = Embedded(Struct(
            'EventCode' / self._EventCode,
            'TransactionID' / self._TransactionID,
            'Parameter' / self.__EventParam,
        ))
        self.__StartData = Embedded(Struct(
            'TransactionID' / self._TransactionID,
            'TotalDataLength' / Int64ul,
        ))
        # TODO: Fix packing and unpacking dataphase data
        self.__Data = Embedded(Struct(
            'TransactionID' / self._TransactionID,
            'Data' / Bytes(
                lambda ctx:
                ctx._.Length -
                self.__Header.sizeof() -
                self._TransactionID.sizeof()
            ),
        ))
        self.__EndData = Embedded(Struct(
            'TransactionID' / self._TransactionID,
            'Data' / Bytes(
                lambda ctx:
                ctx._.Length -
                self.__Header.sizeof() -
                self._TransactionID.sizeof()
            ),
        ))
        self.__Cancel = Embedded(Struct(
            'TransactionID' / self._TransactionID,
        ))
        # Convenience construct for parsing packets

        self.__PacketPayload = Debugger(Struct(
            'Header' / Embedded(self.__Header),
            'Payload' / Embedded(Switch(
                lambda ctx: ctx.Type,
                {
                    'InitCommand': self.__InitCommand,
                    'InitCommandAck': self.__InitCommandACK,
                    'InitEvent': self.__InitEvent,
                    'InitFail': self.__InitFail,
                    'Command': self.__Command,
                    'Response': self.__Response,
                    'Event': self.__Event,
                    'StartData': self.__StartData,
                    'Data': self.__Data,
                    'EndData': self.__EndData,
                },
                default=Pass,
            ))
        ))

    def __parse_response(self, ipdata):
        '''Helper method for parsing data.'''
        # Build up container with all PTP info.
        response = self.__PacketPayload.parse(ipdata)
        # Sneak in an implicit Session ID
        response['SessionID'] = self.session_id
        return response

    def __recv(self, event=False, wait=False, raw=False):
        '''Helper method for receiving packets.'''
        hdrlen = self.__Header.sizeof()
        with self.__implicit_session():
            ip = (
                actual_socket(self.__evtcon)
                if event
                else actual_socket(self.__cmdcon)
            )
            data = bytes()
            while True:
                try:
                    ipdata = ip.recv(hdrlen)
                except socket.timeout:
                    if event:
                        return None
                    else:
                        ipdata = ip.recv(hdrlen)

                if len(ipdata) == 0 and not event:
                    raise PTPError('Command connection dropped')
                elif event:
                    return None

                # Read a single entire header
                while len(ipdata) < hdrlen:
                    ipdata += ip.recv(hdrlen - len(ipdata))
                header = self.__Header.parse(
                    ipdata[0:hdrlen]
                )
                # Read a single entire packet
                while len(ipdata) < header.Length:
                    ipdata += ip.recv(header.Length - len(ipdata))
                # Run sanity checks.
                if header.Type not in [
                        'Cancel',
                        'Data',
                        'Event',
                        'Response',
                        'StartData',
                        'EndData',
                ]:
                    raise PTPError(
                        'Unexpected PTP/IP packet type {}'
                        .format(header.Type)
                    )
                if header.Type not in ['StartData', 'Data', 'EndData']:
                    break
                else:
                    response = self.__parse_response(ipdata)

                if header.Type == 'StartData':
                    expected = response.TotalDataLength
                    current_transaction = response.TransactionID
                elif (
                        header.Type == 'Data' and
                        response.TransactionID == current_transaction
                ):
                    data += response.Data
                elif (
                        header.Type == 'EndData' and
                        response.TransactionID == current_transaction
                ):
                    data += response.Data
                    datalen = len(data)
                    if datalen != expected:
                        logger.warning(
                            '{} data than expected {}/{}'
                            .format(
                                'More' if datalen > expected else 'Less',
                                datalen,
                                expected
                            )
                        )
                    response['Data'] = data
                    response['Type'] = 'Data'
                    return response

        if raw:
            # TODO: Deal with raw Data packets??
            return ipdata
        else:
            return self.__parse_response(ipdata)

    def __send(self, ptp_container, event=False):
        '''Helper method for sending packets.'''
        packet = self.__Packet.build(ptp_container)
        ip = (
            actual_socket(self.__evtcon)
            if event
            else actual_socket(self.__cmdcon)
        )
        while ip.sendall(packet) is not None:
            logger.debug('Failed to send {} packet'.format(ptp_container.Type))

    def __send_request(self, ptp_container):
        '''Send PTP request without checking answer.'''
        # Don't modify original container to keep abstraction barrier.
        ptp = Container(**ptp_container)

        # Send unused parameters always
        ptp['Parameter'] += [0] * (5 - len(ptp.Parameter))

        # Send request
        ptp['Type'] = 'Command'
        ptp['DataphaseInfo'] = 'In'
        ptp['Payload'] = self.__Command.build(ptp)
        self.__send(ptp)

    def __send_data(self, ptp_container, data):
        '''Send data without checking answer.'''
        # Don't modify original container to keep abstraction barrier.
        ptp = Container(**ptp_container)
        # Send data
        ptp['Type'] = 'Data'
        ptp['DataphaseInfo'] = 'Out'
        ptp['Payload'] = data
        self.__send(ptp)

    # Actual implementation
    # ---------------------
    def send(self, ptp_container, data):
        '''Transfer operation with dataphase from initiator to responder'''
        logger.debug('SEND {}{}'.format(
            ptp_container.OperationCode,
            ' ' + str(list(map(hex, ptp_container.Parameter)))
            if ptp_container.Parameter else '',
        ))
        with self.__implicit_session():
            with self.__transaction_lock:
                self.__send_request(ptp_container)
                self.__send_data(ptp_container, data)
                # Get response and sneak in implicit SessionID and missing
                # parameters.
                return self.__recv()

    def recv(self, ptp_container):
        '''Transfer operation with dataphase from responder to initiator.'''
        logger.debug('RECV {}{}'.format(
            ptp_container.OperationCode,
            ' ' + str(list(map(hex, ptp_container.Parameter)))
            if ptp_container.Parameter else '',
        ))
        with self.__implicit_session():
            with self.__transaction_lock:
                self.__send_request(ptp_container)
                dataphase = self.__recv()
                if hasattr(dataphase, 'Data'):
                    response = self.__recv()
                    if (
                            (ptp_container.TransactionID != dataphase.TransactionID) or
                            (ptp_container.SessionID != dataphase.SessionID) or
                            (dataphase.TransactionID != response.TransactionID) or
                            (dataphase.SessionID != response.SessionID)
                    ):
                        raise PTPError(
                            'Dataphase does not match with requested operation'
                        )
                    response['Data'] = dataphase.Data
                    return response
                else:
                    return dataphase

    def mesg(self, ptp_container):
        '''Transfer operation without dataphase.'''
        op = ptp_container['OperationCode']
        if op == 'OpenSession':
            self.__open_implicit_session()

        with self.__implicit_session():
            with self.__transaction_lock:
                self.__send_request(ptp_container)
                # Get response and sneak in implicit SessionID and missing
                # parameters for FullResponse.
                response = self.__recv()

        rc = response['ResponseCode']
        if op == 'OpenSession':
            if rc != 'OK':
                self.__close_implicit_session()
        elif op == 'CloseSession':
            if rc == 'OK':
                self.__close_implicit_session()

        return response

    def event(self, wait=False):
        '''Check event.

        If `wait` this function is blocking. Otherwise it may return None.
        '''
        evt = None
        ipdata = None
        timeout = None if wait else 0.001
        if not self.__event_queue.empty():
            ipdata = self.__event_queue.get(block=not wait, timeout=timeout)
        if ipdata is not None:
            evt = self.__parse_response(ipdata)

        return evt

    def __poll_events(self):
        '''Poll events, adding them to a queue.'''
        logger.debug('Start')
        while (
                not self.__implicit_session_shutdown.is_set() and
                self.__implicit_session_open.is_set() and
                _main_thread_alive()
        ):
            try:
                evt = self.__recv(event=True, wait=False, raw=True)
            except OSError as e:
                if e.errno == 9 and not self.__implicit_session_open.is_set():
                    break
                else:
                    raise e
            if evt is not None:
                logger.debug('Event queued')
                self.__event_queue.put(evt)
            sleep(5e-3)
        logger.debug('Stop')

    def __ping_pong(self):
        '''Poll events, adding them to a queue.'''
        logger.debug('Start')
        last = time()
        while (
                not self.__implicit_session_shutdown.is_set() and
                self.__implicit_session_open.is_set() and
                _main_thread_alive()
        ):
            if time() - last > 10:
                logger.debug('PING')
                # TODO: implement Ping Pong
                last = time()
            sleep(0.10)
        logger.debug('Stop')
コード例 #46
0
ファイル: usb.py プロジェクト: Parrot-Developers/sequoia-ptpy
class USBTransport(object):
    '''Implement USB transport.'''
    def __init__(self, *args, **kwargs):
        device = kwargs.get('device', None)
        '''Instantiate the first available PTP device over USB'''
        logger.debug('Init USB')
        self.__setup_constructors()
        # If no device is specified, find all devices claiming to be Cameras
        # and get the USB endpoints for the first one that works.
        if device is None:
            logger.debug('No device provided, probing all USB devices.')
        if isinstance(device, six.string_types):
            name = device
            logger.debug(
                'Device name provided, probing all USB devices for {}.'
                .format(name)
            )
            device = None
        else:
            name = None
        devs = (
            [device] if (device is not None)
            else find_usb_cameras(name=name)
        )
        self.__claimed = False
        self.__acquire_camera(devs)

        self.__event_queue = Queue()
        self.__event_shutdown = Event()
        # Locks for different end points.
        self.__inep_lock = RLock()
        self.__intep_lock = RLock()
        self.__outep_lock = RLock()
        # Slightly redundant transaction lock to avoid catching other request's
        # response
        self.__transaction_lock = RLock()

        self.__event_proc = Thread(
            name='EvtPolling',
            target=self.__poll_events
        )
        self.__event_proc.daemon = False
        atexit.register(self._shutdown)
        self.__event_proc.start()

    def __available_cameras(self, devs):
        for dev in devs:
            if self.__setup_device(dev):
                logger.debug('Found USB PTP device {}'.format(dev))
                yield
        else:
            message = 'No USB PTP device found.'
            logger.error(message)
            raise PTPError(message)

    def __acquire_camera(self, devs):
        '''From the cameras given, get the first one that does not fail'''

        for _ in self.__available_cameras(devs):
            # Stop system drivers
            try:
                if self.__dev.is_kernel_driver_active(
                        self.__intf.bInterfaceNumber):
                    try:
                        self.__dev.detach_kernel_driver(
                            self.__intf.bInterfaceNumber)
                    except usb.core.USBError:
                        message = (
                            'Could not detach kernel driver. '
                            'Maybe the camera is mounted?'
                        )
                        logger.error(message)
            except NotImplementedError as e:
                logger.debug('Ignoring unimplemented function: {}'.format(e))
            # Claim camera
            try:
                logger.debug('Claiming {}'.format(repr(self.__dev)))
                usb.util.claim_interface(self.__dev, self.__intf)
                self.__claimed = True
            except Exception as e:
                logger.warn('Failed to claim PTP device: {}'.format(e))
                continue
            self.__dev.reset()
            break
        else:
            message = (
                'Could not acquire any camera.'
            )
            logger.error(message)
            raise PTPError(message)

    def _shutdown(self):
        logger.debug('Shutdown request')
        self.__event_shutdown.set()
        # Free USB resource on shutdown.

        # Only join a running thread.
        if self.__event_proc.is_alive():
            self.__event_proc.join(2)

        try:
            if self.__claimed:
                logger.debug('Release {}'.format(repr(self.__dev)))
                usb.util.release_interface(self.__dev, self.__intf)
        except Exception as e:
            logger.warn(e)

    # Helper methods.
    # ---------------------
    def __setup_device(self, dev):
        '''Get endpoints for a device. True on success.'''
        self.__inep = None
        self.__outep = None
        self.__intep = None
        self.__cfg = None
        self.__dev = None
        self.__intf = None
        # Attempt to find the USB in, out and interrupt endpoints for a PTP
        # interface.
        for cfg in dev:
            for intf in cfg:
                if intf.bInterfaceClass == PTP_USB_CLASS:
                    for ep in intf:
                        ep_type = endpoint_type(ep.bmAttributes)
                        ep_dir = endpoint_direction(ep.bEndpointAddress)
                        if ep_type == ENDPOINT_TYPE_BULK:
                            if ep_dir == ENDPOINT_IN:
                                self.__inep = ep
                            elif ep_dir == ENDPOINT_OUT:
                                self.__outep = ep
                        elif ((ep_type == ENDPOINT_TYPE_INTR) and
                                (ep_dir == ENDPOINT_IN)):
                            self.__intep = ep
                if not (self.__inep and self.__outep and self.__intep):
                    self.__inep = None
                    self.__outep = None
                    self.__intep = None
                else:
                    logger.debug('Found {}'.format(repr(self.__inep)))
                    logger.debug('Found {}'.format(repr(self.__outep)))
                    logger.debug('Found {}'.format(repr(self.__intep)))
                    self.__cfg = cfg
                    self.__dev = dev
                    self.__intf = intf
                    return True
        return False

    def __setup_constructors(self):
        '''Set endianness and create transport-specific constructors.'''
        # Set endianness of constructors before using them.
        self._set_endian('little')

        self.__Length = Int32ul
        self.__Type = Enum(
                Int16ul,
                default=Pass,
                Undefined=0x0000,
                Command=0x0001,
                Data=0x0002,
                Response=0x0003,
                Event=0x0004,
                )
        # This is just a convenience constructor to get the size of a header.
        self.__Code = Int16ul
        self.__Header = Struct(
                'Length' / self.__Length,
                'Type' / self.__Type,
                'Code' / self.__Code,
                'TransactionID' / self._TransactionID,
                )
        # These are the actual constructors for parsing and building.
        self.__CommandHeader = Struct(
                'Length' / self.__Length,
                'Type' / self.__Type,
                'OperationCode' / self._OperationCode,
                'TransactionID' / self._TransactionID,
                )
        self.__ResponseHeader = Struct(
                'Length' / self.__Length,
                'Type' / self.__Type,
                'ResponseCode' / self._ResponseCode,
                'TransactionID' / self._TransactionID,
                )
        self.__EventHeader = Struct(
                'Length' / self.__Length,
                'Type' / self.__Type,
                'EventCode' / self._EventCode,
                'TransactionID' / self._TransactionID,
                )
        # Apparently nobody uses the SessionID field. Even though it is
        # specified in ISO15740:2013(E), no device respects it and the session
        # number is implicit over USB.
        self.__Param = Range(0, 5, self._Parameter)
        self.__CommandTransactionBase = Struct(
                Embedded(self.__CommandHeader),
                'Payload' / Bytes(
                    lambda ctx, h=self.__Header: ctx.Length - h.sizeof()
                )
        )
        self.__CommandTransaction = ExprAdapter(
                self.__CommandTransactionBase,
                encoder=lambda obj, ctx, h=self.__Header: Container(
                    Length=len(obj.Payload) + h.sizeof(),
                    **obj
                    ),
                decoder=lambda obj, ctx: obj,
                )
        self.__ResponseTransactionBase = Struct(
                Embedded(self.__ResponseHeader),
                'Payload' / Bytes(
                    lambda ctx, h=self.__Header: ctx.Length - h.sizeof())
                )
        self.__ResponseTransaction = ExprAdapter(
                self.__ResponseTransactionBase,
                encoder=lambda obj, ctx, h=self.__Header: Container(
                    Length=len(obj.Payload) + h.sizeof(),
                    **obj
                    ),
                decoder=lambda obj, ctx: obj,
                )

    def __parse_response(self, usbdata):
        '''Helper method for parsing USB data.'''
        # Build up container with all PTP info.
        logger.debug('Transaction:')
        usbdata = bytearray(usbdata)
        if logger.isEnabledFor(logging.DEBUG):
            for l in hexdump(
                    six.binary_type(usbdata[:512]),
                    result='generator'
            ):
                logger.debug(l)
        transaction = self.__ResponseTransaction.parse(usbdata)
        response = Container(
            SessionID=self.session_id,
            TransactionID=transaction.TransactionID,
        )
        logger.debug('Interpreting {} transaction'.format(transaction.Type))
        if transaction.Type == 'Response':
            response['ResponseCode'] = transaction.ResponseCode
            response['Parameter'] = self.__Param.parse(transaction.Payload)
        elif transaction.Type == 'Event':
            event = self.__EventHeader.parse(
                usbdata[0:self.__Header.sizeof()]
            )
            response['EventCode'] = event.EventCode
            response['Parameter'] = self.__Param.parse(transaction.Payload)
        else:
            command = self.__CommandHeader.parse(
                usbdata[0:self.__Header.sizeof()]
            )
            response['OperationCode'] = command.OperationCode
            response['Data'] = transaction.Payload
        return response

    def __recv(self, event=False, wait=False, raw=False):
        '''Helper method for receiving data.'''
        # TODO: clear stalls automatically
        ep = self.__intep if event else self.__inep
        lock = self.__intep_lock if event else self.__inep_lock
        usbdata = array.array('B', [])
        with lock:
            tries = 0
            # Attempt to read a header
            while len(usbdata) < self.__Header.sizeof() and tries < 5:
                if tries > 0:
                    logger.debug('Data smaller than a header')
                    logger.debug(
                        'Requesting {} bytes of data'
                        .format(ep.wMaxPacketSize)
                    )
                try:
                    usbdata += ep.read(
                        ep.wMaxPacketSize
                    )
                except usb.core.USBError as e:
                    # Return None on timeout or busy for events
                    if (
                            (e.errno is None and
                             ('timeout' in e.strerror.decode() or
                              'busy' in e.strerror.decode())) or
                            (e.errno == 110 or e.errno == 16 or e.errno == 5)
                    ):
                        if event:
                            return None
                        else:
                            logger.warning('Ignored exception: {}'.format(e))
                    else:
                        logger.error(e)
                        raise e
                tries += 1
            logger.debug('Read {} bytes of data'.format(len(usbdata)))

            if len(usbdata) == 0:
                if event:
                    return None
                else:
                    raise PTPError('Empty USB read')

            if (
                    logger.isEnabledFor(logging.DEBUG) and
                    len(usbdata) < self.__Header.sizeof()
            ):
                logger.debug('Incomplete header')
                for l in hexdump(
                        six.binary_type(bytearray(usbdata)),
                        result='generator'
                ):
                    logger.debug(l)

            header = self.__ResponseHeader.parse(
                bytearray(usbdata[0:self.__Header.sizeof()])
            )
            if header.Type not in ['Response', 'Data', 'Event']:
                raise PTPError(
                    'Unexpected USB transfer type. '
                    'Expected Response, Event or Data but received {}'
                    .format(header.Type)
                )
            while len(usbdata) < header.Length:
                usbdata += ep.read(
                    min(
                        header.Length - len(usbdata),
                        # Up to 64kB
                        64 * 2**10
                    )
                )
        if raw:
            return usbdata
        else:
            return self.__parse_response(usbdata)

    def __send(self, ptp_container, event=False):
        '''Helper method for sending data.'''
        ep = self.__intep if event else self.__outep
        lock = self.__intep_lock if event else self.__outep_lock
        transaction = self.__CommandTransaction.build(ptp_container)
        with lock:
            try:
                sent = 0
                while sent < len(transaction):
                    sent = ep.write(
                        # Up to 64kB
                        transaction[sent:(sent + 64*2**10)]
                    )
            except usb.core.USBError as e:
                # Ignore timeout or busy device once.
                if (
                        (e.errno is None and
                         ('timeout' in e.strerror.decode() or
                          'busy' in e.strerror.decode())) or
                        (e.errno == 110 or e.errno == 16 or e.errno == 5)
                ):
                    logger.warning('Ignored USBError {}'.format(e.errno))
                    ep.write(transaction)

    def __send_request(self, ptp_container):
        '''Send PTP request without checking answer.'''
        # Don't modify original container to keep abstraction barrier.
        ptp = Container(**ptp_container)
        # Don't send unused parameters
        try:
            while not ptp.Parameter[-1]:
                ptp.Parameter.pop()
                if len(ptp.Parameter) == 0:
                    break
        except IndexError:
            # The Parameter list is already empty.
            pass

        # Send request
        ptp['Type'] = 'Command'
        ptp['Payload'] = self.__Param.build(ptp.Parameter)
        self.__send(ptp)

    def __send_data(self, ptp_container, data):
        '''Send data without checking answer.'''
        # Don't modify original container to keep abstraction barrier.
        ptp = Container(**ptp_container)
        # Send data
        ptp['Type'] = 'Data'
        ptp['Payload'] = data
        self.__send(ptp)

    @property
    def _dev(self):
        return None if self.__event_shutdown.is_set() else self.__dev

    @_dev.setter
    def _dev(self, value):
        raise ValueError('Read-only property')

    # Actual implementation
    # ---------------------
    def send(self, ptp_container, data):
        '''Transfer operation with dataphase from initiator to responder'''
        datalen = len(data)
        logger.debug('SEND {} {} bytes{}'.format(
            ptp_container.OperationCode,
            datalen,
            ' ' + str(list(map(hex, ptp_container.Parameter)))
            if ptp_container.Parameter else '',
        ))
        with self.__transaction_lock:
            self.__send_request(ptp_container)
            self.__send_data(ptp_container, data)
            # Get response and sneak in implicit SessionID and missing
            # parameters.
            response = self.__recv()
        logger.debug('SEND {} {} bytes {}{}'.format(
            ptp_container.OperationCode,
            datalen,
            response.ResponseCode,
            ' ' + str(list(map(hex, response.Parameter)))
            if ptp_container.Parameter else '',
        ))
        return response

    def recv(self, ptp_container):
        '''Transfer operation with dataphase from responder to initiator.'''
        logger.debug('RECV {}{}'.format(
            ptp_container.OperationCode,
            ' ' + str(list(map(hex, ptp_container.Parameter)))
            if ptp_container.Parameter else '',
        ))
        with self.__transaction_lock:
            self.__send_request(ptp_container)
            dataphase = self.__recv()
            if hasattr(dataphase, 'Data'):
                response = self.__recv()
                if not (ptp_container.SessionID ==
                        dataphase.SessionID ==
                        response.SessionID):
                    self.__dev.reset()
                    raise PTPError(
                        'Dataphase session ID missmatch: {}, {}, {}.'
                        .format(
                            ptp_container.SessionID,
                            dataphase.SessionID,
                            response.SessionID
                        )
                    )
                if not (ptp_container.TransactionID ==
                        dataphase.TransactionID ==
                        response.TransactionID):
                    self.__dev.reset()
                    raise PTPError(
                        'Dataphase transaction ID missmatch: {}, {}, {}.'
                        .format(
                            ptp_container.TransactionID,
                            dataphase.TransactionID,
                            response.TransactionID
                        )
                    )
                if not (ptp_container.OperationCode ==
                        dataphase.OperationCode):
                    self.__dev.reset()
                    raise PTPError(
                        'Dataphase operation code missmatch: {}, {}.'.
                        format(
                            ptp_container.OperationCode,
                            dataphase.OperationCode
                        )
                    )

                response['Data'] = dataphase.Data
            else:
                response = dataphase

        logger.debug('RECV {} {}{}{}'.format(
            ptp_container.OperationCode,
            response.ResponseCode,
            ' {} bytes'.format(len(response.Data))
            if hasattr(response, 'Data') else '',
            ' ' + str(list(map(hex, response.Parameter)))
            if response.Parameter else '',
        ))
        return response

    def mesg(self, ptp_container):
        '''Transfer operation without dataphase.'''
        logger.debug('MESG {}{}'.format(
            ptp_container.OperationCode,
            ' ' + str(list(map(hex, ptp_container.Parameter)))
            if ptp_container.Parameter else '',
        ))
        with self.__transaction_lock:
            self.__send_request(ptp_container)
            # Get response and sneak in implicit SessionID and missing
            # parameters for FullResponse.
            response = self.__recv()
        logger.debug('MESG {} {}{}'.format(
            ptp_container.OperationCode,
            response.ResponseCode,
            ' ' + str(list(map(hex, response.Parameter)))
            if response.Parameter else '',
        ))
        return response

    def event(self, wait=False):
        '''Check event.

        If `wait` this function is blocking. Otherwise it may return None.
        '''
        evt = None
        usbdata = None
        if wait:
            usbdata = self.__event_queue.get(block=True)
        elif not self.__event_queue.empty():
            usbdata = self.__event_queue.get(block=False)

        if usbdata is not None:
            evt = self.__parse_response(usbdata)

        return evt

    def __poll_events(self):
        '''Poll events, adding them to a queue.'''
        while not self.__event_shutdown.is_set() and _main_thread_alive():
            try:
                evt = self.__recv(event=True, wait=False, raw=True)
                if evt is not None:
                    logger.debug('Event queued')
                    self.__event_queue.put(evt)
            except usb.core.USBError as e:
                logger.error(
                    '{} polling exception: {}'.format(repr(self.__dev), e)
                )
                # check if disconnected
                if e.errno == 19:
                    break
            except Exception as e:
                logger.error(
                    '{} polling exception: {}'.format(repr(self.__dev), e)
                )
コード例 #47
0
ファイル: mass_storage.py プロジェクト: nccgroup/umap2
class ScsiDevice(USBBaseActor):
    '''
    Implementation of subset of the SCSI protocol
    '''
    name = 'ScsiDevice'

    def __init__(self, app, disk_image):
        super(ScsiDevice, self).__init__(app, None)
        self.disk_image = disk_image
        self.handlers = {
            ScsiCmds.INQUIRY: self.handle_inquiry,
            ScsiCmds.REQUEST_SENSE: self.handle_request_sense,
            ScsiCmds.TEST_UNIT_READY: self.handle_test_unit_ready,
            ScsiCmds.READ_CAPACITY_10: self.handle_read_capacity_10,
            # ScsiCmds.SEND_DIAGNOSTIC: self.handle_send_diagnostic,
            ScsiCmds.PREVENT_ALLOW_MEDIUM_REMOVAL: self.handle_prevent_allow_medium_removal,
            ScsiCmds.WRITE_10: self.handle_write_10,
            ScsiCmds.READ_10: self.handle_read_10,
            # ScsiCmds.WRITE_6: self.handle_write_6,
            # ScsiCmds.READ_6: self.handle_read_6,
            # ScsiCmds.VERIFY_10: self.handle_verify_10,
            ScsiCmds.MODE_SENSE_6: self.handle_mode_sense_6,
            ScsiCmds.MODE_SENSE_10: self.handle_mode_sense_10,
            ScsiCmds.READ_FORMAT_CAPACITIES: self.handle_read_format_capacities,
            ScsiCmds.SYNCHRONIZE_CACHE: self.handle_synchronize_cache,
            ScsiCmds.READ_CAPACITY_16: self.handle_read_capacity_16,
        }
        self.is_write_in_progress = False
        self.handle_reset()
        self.stop_event = Event()
        self.thread = Thread(target=self.handle_data_loop)
        self.thread.daemon = True
        self.thread.start()

    def handle_reset(self):
        self.debug('handling reset')
        if self.is_write_in_progress and self.write_data:
            self.disk_image.put_sector_data(self.write_base_lba, self.write_data)
        self.is_write_in_progress = False
        self.write_cbw = None
        self.write_base_lba = 0
        self.write_length = 0
        self.write_data = b''
        self.tx = Queue()
        self.rx = Queue()

    def stop(self):
        self.stop_event.set()

    def handle_data_loop(self):
        while not self.stop_event.isSet():
            if not self.rx.empty():
                data = self.rx.get()
                self.handle_data(data)
            else:
                time.sleep(0.0001)

    def handle_data(self, data):
        if self.is_write_in_progress:
            self.handle_write_data(data)
        else:
            cbw = CommandBlockWrapper(data)
            opcode = cbw.opcode
            if opcode in self.handlers:
                try:
                    resp = self.handlers[opcode](cbw)
                    if resp is not None:
                        self.tx.put(resp)
                    self.tx.put(scsi_status(cbw, ScsiCmdStatus.COMMAND_PASSED))
                except Exception as ex:
                    self.warning('exception while processing opcode %#x' % (opcode))
                    self.warning(ex)
                    self.tx.put(scsi_status(cbw, ScsiCmdStatus.COMMAND_FAILED))
            else:
                self.error('No handler for opcode %#x, return CSW with ScsiCmdStatus.COMMAND_FAILED' % (opcode))
                self.tx.put(scsi_status(cbw, ScsiCmdStatus.COMMAND_FAILED))

    def handle_write_data(self, data):
        self.write_data += data
        self.debug('Got %#x bytes of SCSI write data, written so far: %#x' % (len(data), len(self.write_data)))
        if len(self.write_data) >= self.write_length:
            self.info('Got all write data')
            # done writing
            self.disk_image.put_sector_data(self.write_base_lba, self.write_data)
            self.is_write_in_progress = False
            self.write_data = b''
            self.tx.put(scsi_status(self.write_cbw, ScsiCmdStatus.COMMAND_PASSED))

    @mutable('scsi_inquiry_response')
    def handle_inquiry(self, cbw):
        self.debug('SCSI Inquiry, data: %s' % hexlify(cbw.cb[1:]))
        peripheral = 0x00  # SBC
        RMB = 0x80  # Removable
        version = 0x00
        response_data_format = 0x01
        config = (0x00, 0x00, 0x00)
        vendor_id = b'MBYDCOR '
        product_id = b'UMAP2 DISK IMAG '
        product_revision_level = b'8.02'
        part1 = struct.pack('BBBB', peripheral, RMB, version, response_data_format)
        part2 = struct.pack('BBB', *config) + vendor_id + product_id + product_revision_level
        length = struct.pack('B', len(part2))
        response = part1 + length + part2
        return response

    @mutable('scsi_request_sense_response')
    def handle_request_sense(self, cbw):
        self.debug('SCSI Request Sense, data: %s' % hexlify(cbw.cb[1:]))
        response_code = 0x70
        valid = 0x00
        filemark = 0x06
        information = 0x00000000
        command_info = 0x00000000
        additional_sense_code = 0x3a
        additional_sens_code_qualifier = 0x00
        field_replacement_unti_code = 0x00
        sense_key_specific = b'\x00\x00\x00'

        part1 = struct.pack('<BBBI', response_code, valid, filemark, information)
        part2 = struct.pack(
            '<IBBB',
            command_info,
            additional_sense_code,
            additional_sens_code_qualifier,
            field_replacement_unti_code
        )
        part2 += sense_key_specific
        length = struct.pack('B', len(part2))
        response = part1 + length + part2
        return response

    @mutable('scsi_test_unit_ready_response')
    def handle_test_unit_ready(self, cbw):
        self.debug('SCSI Test Unit Ready, logical unit number: %02x' % (cbw.cb[1]))

    @mutable('scsi_read_capacity_10_response')
    def handle_read_capacity_10(self, cbw):
        # .. todo: is the length correct?
        self.debug('SCSI Read Capacity(10), data: %s' % hexlify(cbw.cb[1:]))
        lastlba = self.disk_image.get_sector_count()
        length = self.disk_image.block_size
        response = struct.pack('>II', lastlba, length)
        return response

    @mutable('scsi_read_capacity_16_response')
    def handle_read_capacity_16(self, cbw):
        # .. todo: is the length correct?
        self.debug('SCSI Read Capacity(16), data: %s' % hexlify(cbw.cb[1:]))
        lastlba = self.disk_image.get_sector_count()
        length = self.disk_image.block_size
        response = struct.pack('>BBQIBB', 0x9e, 0x10, lastlba, length, 0x00, 0x00)
        return response

    @mutable('scsi_send_diagnostic_response')
    def handle_send_diagnostic(self, cbw):
        raise NotImplementedError('yet...')

    @mutable('scsi_prevent_allow_medium_removal_response')
    def handle_prevent_allow_medium_removal(self, cbw):
        self.debug('SCSI Prevent/Allow Removal')

    @mutable('scsi_write_10_response')
    def handle_write_10(self, cbw):
        self.debug('SCSI Write (10), data: %s' % hexlify(cbw.cb[1:]))

        base_lba = struct.unpack('>I', cbw.cb[2:6])[0]
        num_blocks = struct.unpack('>H', cbw.cb[7:9])[0]

        self.debug('SCSI Write (10), lba %#x + %#x block(s)' % (base_lba, num_blocks))

        # save for later
        self.write_cbw = cbw
        self.write_base_lba = base_lba
        self.write_length = num_blocks * self.disk_image.block_size
        self.debug('SCSI Write (10) total expected length: %#x' % (self.write_length))
        self.is_write_in_progress = True

    def handle_read_10(self, cbw):
        base_lba, group, num_blocks = struct.unpack('>IBH', cbw.cb[2:9])
        self.debug('SCSI Read (10), lba %#x + %#x block(s)' % (base_lba, num_blocks))
        for block_num in range(num_blocks):
            data = self.disk_image.get_sector_data(base_lba + block_num)
            self.tx.put(data)

    @mutable('scsi_write_6_response')
    def handle_write_6(self, cbw):
        raise NotImplementedError('yet...')

    @mutable('scsi_read_6_response')
    def handle_read_6(self, cbw):
        raise NotImplementedError('yet...')

    @mutable('scsi_verify_10_response')
    def handle_verify_10(self, cbw):
        raise NotImplementedError('yet...')

    def _build_page0_report(self, page, data):
        report = struct.pack('BB', page, len(data))
        report += data
        return report

    def _build_subpage_report(self, page, subpage, data):
        report = struct.pack('>BBH', page | 0x40, subpage, len(data))
        report += data
        return report

    def _report_header(self, mode_type, mode_data_length):
        # Based on seagate 100293068h.pdf
        medium_type = 0x00
        flags = 0x00
        block_descriptor_len = 0x00
        if mode_type == 6:  # Table 292
            header_data = struct.pack('>3B', medium_type, flags, block_descriptor_len)
            total_len = struct.pack('B', len(header_data) + mode_data_length)
        else:  # Table 293
            longlba = 0x00
            header_data = struct.pack('>BBBBH', medium_type, flags, longlba, 0, block_descriptor_len)
            total_len = struct.pack('>H', len(header_data) + mode_data_length)
        return total_len + header_data

    def _build_page_report(self, page, subpage, data):
        if subpage is None:
            report = self._build_page0_report(page, data)
        else:
            report = self._build_subpage_report(page, subpage, data)
        return report

    def handle_scsi_mode_sense(self, mode_type, page, subpage, alloc_len, ctrl, with_header=True):
        # .. todo: implement response for unsupported pages
        self.debug('SCSI Mode Sense(%d), page %#x subpage %#x' % (mode_type, page, subpage))
        report = None
        # wish there was a switch :(
        if page == 0x1c:
            # case: informational exceptions control (table 314)
            if subpage == 0x00:
                data = struct.pack('>BBII', 0x00, 0x05, 0x00, 0x00)
                report = self._build_page_report(page, 0x00, data)
            # case: background control (table 300)
            elif subpage == 0x01:
                data = struct.pack('>BBHHHHH', 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
                report = self._build_page_report(page, subpage, data)
            elif subpage == 0xff:
                report = self.handle_scsi_mode_sense(mode_type, 0x1c, 0x00, alloc_len, ctrl, False)
                report += self.handle_scsi_mode_sense(mode_type, 0x1c, 0x01, alloc_len, ctrl, False)
        # case: all pages
        elif page == 0x3f:
            # return all pages that we got ...
            report = self.handle_scsi_mode_sense(mode_type, 0x1c, 0xff, alloc_len, ctrl, False)
        if report is None:
            # default behaviour, taken from previous implementation
            # this should probably be changed ...
            report = '\x07\x00\x00\x00\x00\x00\x00\x00'
        if with_header:
            self.debug('SCSI mode sense (%d) - adding header' % (mode_type))
            report = self._report_header(mode_type, len(report)) + report
        return report

    @mutable('scsi_mode_sense_6_response')
    def handle_mode_sense_6(self, cbw):
        # .. todo: DBD, PC
        page, subpage, alloc_len, control = struct.unpack('>4B', cbw.cb[2:6])
        page &= 0x3f
        return self.handle_scsi_mode_sense(6, page, subpage, alloc_len, control)

    @mutable('scsi_mode_sense_10_response')
    def handle_mode_sense_10(self, cbw):
        # .. todo: LLBA, DBD, PC
        page, subpage, _, _, _, alloc_len, control = struct.unpack('>5BHB', cbw.cb[2:10])
        page &= 0x3f
        return self.handle_scsi_mode_sense(10, page, subpage, alloc_len, control)

    @mutable('scsi_read_format_capacities')
    def handle_read_format_capacities(self, cbw):
        self.debug('SCSI Read Format Capacity')
        # header
        response = struct.pack('>I', 8)
        num_sectors = 0x1000
        reserved = 0x1000
        sector_size = self.disk_image.block_size
        response += struct.pack('>IHH', num_sectors, reserved, sector_size)
        return response

    @mutable('scsi_synchronize_cache_response')
    def handle_synchronize_cache(self, cbw):
        self.debug('Synchronize Cache (10)')
コード例 #48
0
ファイル: extraction.py プロジェクト: bolicc/tsfresh
def _extract_features_parallel_per_sample(kind_to_df_map,
                                          column_id, column_value,
                                          default_fc_parameters,
                                          kind_to_fc_parameters=None,
                                          chunksize=defaults.CHUNKSIZE,
                                          n_processes=defaults.N_PROCESSES, show_warnings=defaults.SHOW_WARNINGS,
                                          disable_progressbar=defaults.DISABLE_PROGRESSBAR,
                                          impute_function=defaults.IMPUTE_FUNCTION):
    """
    Parallelize the feature extraction per kind and per sample.

    As the splitting of the dataframes per kind along column_id is quite costly, we settled for an async map in this
    function. The result objects are temporarily stored in a fifo queue from which they can be retrieved in order
    of submission.

    :param kind_to_df_map: The time series to compute the features for in our internal format
    :type kind_to_df_map: dict of pandas.DataFrame

    :param column_id: The name of the id column to group by.
    :type column_id: str

    :param column_value: The name for the column keeping the value itself.
    :type column_value: str

    :param default_fc_parameters: mapping from feature calculator names to parameters. Only those names
           which are keys in this dict will be calculated. See the class:`ComprehensiveFCParameters` for
           more information.
    :type default_fc_parameters: dict

    :param kind_to_fc_parameters: mapping from kind names to objects of the same type as the ones for
            default_fc_parameters. If you put a kind as a key here, the fc_parameters
            object (which is the value), will be used instead of the default_fc_parameters.
    :type kind_to_fc_parameters: dict

    :param chunksize: The size of one chunk for the parallelisation
    :type chunksize: None or int

    :param n_processes: The number of processes to use for parallelisation.
    :type n_processes: int

    :param: show_warnings: Show warnings during the feature extraction (needed for debugging of calculators).
    :type show_warnings: bool

    :param disable_progressbar: Do not show a progressbar while doing the calculation.
    :type disable_progressbar: bool

    :param impute_function: None, if no imputing should happen or the function to call for imputing.
    :type impute_function: None or function

    :return: The (maybe imputed) DataFrame containing extracted features.
    :rtype: pandas.DataFrame
    """
    partial_extract_features_for_one_time_series = partial(_extract_features_for_one_time_series,
                                                           column_id=column_id,
                                                           column_value=column_value,
                                                           default_fc_parameters=default_fc_parameters,
                                                           kind_to_fc_parameters=kind_to_fc_parameters,
                                                           show_warnings=show_warnings)
    pool = Pool(n_processes)
    total_number_of_expected_results = 0

    # Submit map jobs per kind per sample
    results_fifo = Queue()

    for kind, df_kind in kind_to_df_map.items():
        df_grouped_by_id = df_kind.groupby(column_id)

        total_number_of_expected_results += len(df_grouped_by_id)

        if not chunksize:
            chunksize = _calculate_best_chunksize(df_grouped_by_id, n_processes)

        results_fifo.put(
            pool.imap_unordered(
                partial_extract_features_for_one_time_series,
                [(kind, df_group) for _, df_group in df_grouped_by_id],
                chunksize=chunksize
            )
        )

    pool.close()

    # Wait for the jobs to complete and concatenate the partial results
    dfs_per_kind = []

    # Do this all with a progress bar
    with tqdm(total=total_number_of_expected_results, desc="Feature Extraction",
              disable=disable_progressbar) as progress_bar:
        # We need some sort of measure, when a new result is there. So we wrap the
        # map_results into another iterable which updates the progress bar each time,
        # a new result is there
        def iterable_with_tqdm_update(queue, progress_bar):
            for element in queue:
                progress_bar.update(1)
                yield element

        result = pd.DataFrame()
        while not results_fifo.empty():
            map_result = results_fifo.get()
            dfs_kind = iterable_with_tqdm_update(map_result, progress_bar)
            df_tmp = pd.concat(dfs_kind, axis=0).astype(np.float64)

            # Impute the result if requested
            if impute_function is not None:
                impute_function(df_tmp)

            result = pd.concat([result, df_tmp], axis=1).astype(np.float64)

    pool.join()
    return result
コード例 #49
0
ファイル: process_helpers.py プロジェクト: caperna/mhctools
def run_multiple_commands_redirect_stdout(
        multiple_args_dict,
        print_commands=True,
        process_limit=0,
        polling_freq=1,
        **kwargs):
    """
    Run multiple shell commands in parallel, write each of their
    stdout output to files associated with each command.

    Parameters
    ----------
    multiple_args_dict : dict
        A dictionary whose keys are files and values are args list.
        Run each args list as a subprocess and write stdout to the
        corresponding file.

    print_commands : bool
        Print shell commands before running them.

    process_limit : int
        Limit the number of concurrent processes to this number. 0
        if there is no limit

    polling_freq : int
        Number of seconds between checking for done processes, if
        we have a process limit
    """
    assert len(multiple_args_dict) > 0
    assert all(len(args) > 0 for args in multiple_args_dict.values())
    assert all(hasattr(f, 'name') for f in multiple_args_dict.keys())
    start_time = time.time()
    processes = Queue(maxsize=process_limit)

    def add_to_queue(process):
        if print_commands:
            print(" ".join(process.args), ">",
                  process.redirect_stdout_file.name)
        processes.put(process)

    for f, args in multiple_args_dict.items():
        p = AsyncProcess(
            args,
            redirect_stdout_file=f,
            **kwargs)
        if not processes.full():
            add_to_queue(p)
        else:
            while processes.full():
                # Are there any done processes?
                to_remove = []
                for possibly_done in processes.queue:
                    if possibly_done.poll() is not None:
                        possibly_done.wait()
                        to_remove.append(possibly_done)
                # Remove them from the queue and stop checking
                if to_remove:
                    for process_to_remove in to_remove:
                        processes.queue.remove(process_to_remove)
                    break
                # Check again in a second if there weren't
                time.sleep(polling_freq)
            add_to_queue(p)

    # Wait for all the rest of the processes
    while not processes.empty():
        processes.get().wait()

    elapsed_time = time.time() - start_time
    logging.info(
        "Ran %d commands in %0.4f seconds",
        len(multiple_args_dict),
        elapsed_time)
コード例 #50
0
ファイル: smartcard.py プロジェクト: hongyunnchen/umap2
class USBSmartcardInterface(USBInterface):
    name = 'SmartcardInterface'

    def __init__(self, app, phy):
        descriptors = {
            DescriptorType.hid: self.get_icc_descriptor
        }
        self.clock_frequencies = [
            0x00003267, 0x000064ce, 0x0000c99d, 0x0001933a, 0x00032674, 0x00064ce7,
            0x000c99ce, 0x00025cd7, 0x0003f011, 0x00004334, 0x00008669, 0x00010cd1,
            0x000219a2, 0x00043345, 0x0008668a, 0x0002a00b, 0x00003073, 0x000060e6,
            0x0000c1cc, 0x00018399, 0x00030732, 0x00060e63, 0x000122b3, 0x0001e47f,
            0x00015006, 0x00009736, 0x0000fc04, 0x00002853, 0x000050a5, 0x0000a14a,
            0x00014295, 0x00028529, 0x000078f8, 0x0000493e, 0x0000927c, 0x000124f8,
            0x000249f0, 0x000493e0, 0x000927c0, 0x0001b774, 0x0002dc6c, 0x000030d4,
            0x000061a8, 0x0000c350, 0x000186a0, 0x00030d40, 0x00061a80, 0x0001e848,
            0x0000dbba, 0x00016e36, 0x0000f424, 0x00006ddd, 0x0000b71b
        ]

        self.data_rates = []

        self.clock_freq = self.clock_frequencies[0]
        self.data_rate = 0 if not self.data_rates else self.data_rates[0]

        endpoints = [
            # CCID command pipe
            USBEndpoint(
                app=app,
                phy=phy,
                number=1,
                direction=USBEndpoint.direction_out,
                transfer_type=USBEndpoint.transfer_type_bulk,
                sync_type=USBEndpoint.sync_type_none,
                usage_type=USBEndpoint.usage_type_data,
                max_packet_size=0x40,
                interval=0,
                handler=self.handle_data_available
            ),
            # CCID response pipe
            USBEndpoint(
                app=app,
                phy=phy,
                number=2,
                direction=USBEndpoint.direction_in,
                transfer_type=USBEndpoint.transfer_type_bulk,
                sync_type=USBEndpoint.sync_type_none,
                usage_type=USBEndpoint.usage_type_data,
                max_packet_size=0x40,
                interval=0,
                handler=None
            ),
            # CCID event notification pipe
            USBEndpoint(
                app=app,
                phy=phy,
                number=3,
                direction=USBEndpoint.direction_in,
                transfer_type=USBEndpoint.transfer_type_interrupt,
                sync_type=USBEndpoint.sync_type_none,
                usage_type=USBEndpoint.usage_type_data,
                max_packet_size=8,
                interval=0,
                handler=self.handle_buffer_available
            ),
        ]

        # TODO: un-hardcode string index (last arg before 'verbose')
        super(USBSmartcardInterface, self).__init__(
            app=app,
            phy=phy,
            interface_number=0,
            interface_alternate=0,
            interface_class=USBClass.SmartCard,
            interface_subclass=0,
            interface_protocol=0,
            interface_string_index=0,
            endpoints=endpoints,
            descriptors=descriptors,
            device_class=USBSmartcardClass(app, phy)
        )

        self.proto = 0
        self.abProtocolDataStructure = b'\x11\x00\x00\x0a\x00'
        self.clock_status = 0x00
        self.int_q = Queue()
        self.int_q.put(b'\x50\x03')

        self.operations = {
            PcToRdrOpcode.IccPowerOn: self.handle_PcToRdr_IccPowerOn,
            PcToRdrOpcode.IccPowerOff: self.handle_PcToRdr_IccPowerOff,
            PcToRdrOpcode.GetSlotStatus: self.handle_PcToRdr_GetSlotStatus,
            PcToRdrOpcode.XfrBlock: self.handle_PcToRdr_XfrBlock,
            PcToRdrOpcode.GetParameters: self.handle_PcToRdr_GetParameters,
            PcToRdrOpcode.ResetParameters: self.handle_PcToRdr_ResetParameters,
            PcToRdrOpcode.SetParameters: self.handle_PcToRdr_SetParameters,
            PcToRdrOpcode.Escape: self.handle_PcToRdr_Escape,
            PcToRdrOpcode.IccClock: self.handle_PcToRdr_IccClock,
            PcToRdrOpcode.T0APDU: self.handle_PcToRdr_T0APDU,
            PcToRdrOpcode.Secure: self.handle_PcToRdr_Secure,
            PcToRdrOpcode.Mechanical: self.handle_PcToRdr_Mechanical,
            PcToRdrOpcode.Abort: self.handle_PcToRdr_Abort,
            PcToRdrOpcode.SetDataRateAndClock_Frequency: self.handle_PcToRdr_SetDataRateAndClock_Frequency,
        }

    @mutable('smartcard_IccPowerOn_response')
    def handle_PcToRdr_IccPowerOn(self, slot, seq, data):
        abData = b'\x3b\x6e\x00\x00\x80\x31\x80\x66\xb0\x84\x12\x01\x6e\x01\x83\x00\x90\x00'
        # Entropia Universe Gold card
        # Taken from http://ludovic.rousseau.free.fr/softwares/pcsc-tools/smartcard_list.txt
        abData = b'\x3B\x6B\x00\x00\x00\x31\xC0\x64\xA9\xEC\x01\x00\x82\x90\x00'
        return R2P_DataBlock(
            slot=slot,
            seq=seq,
            status=0x00,
            error=0x80,
            chain_param=0x00,
            data=abData
        )

    @mutable('smartcard_IccPowerOff_response')
    def handle_PcToRdr_IccPowerOff(self, slot, seq, data):
        '''
        Check out slot number (should be as bulk OUT message)
        '''
        return R2P_SlotStatus(
            slot=slot,
            seq=seq,
            status=0x00,
            error=0x80,
            clock_status=self.clock_status
        )

    @mutable('smartcard_GetSlotStatus_response')
    def handle_PcToRdr_GetSlotStatus(self, slot, seq, data):
        return R2P_SlotStatus(
            slot=slot,
            seq=seq,
            status=0x00,
            error=0x80,
            clock_status=self.clock_status
        )

    @mutable('smartcard_XfrBlock_response')
    def handle_PcToRdr_XfrBlock(self, slot, seq, data):
        '''
        .. todo:: check the response again later,
        '''
        abData = b'\x6a\x82'
        return R2P_DataBlock(
            slot=slot,
            seq=seq,
            status=0x00,
            error=0x80,
            chain_param=0x00,
            data=abData
        )

    @mutable('smartcard_GetParameters_response')
    def handle_PcToRdr_GetParameters(self, slot, seq, data):
        return R2P_Parameters(
            slot=slot,
            seq=seq,
            status=0x00,
            error=0x80,
            proto=self.proto,
            data=self.abProtocolDataStructure
        )

    @mutable('smartcard_ResetParameters_response')
    def handle_PcToRdr_ResetParameters(self, slot, seq, data):
        return R2P_Parameters(
            slot=slot,
            seq=seq,
            status=0x00,
            error=0x80,
            proto=self.proto,
            data=self.abProtocolDataStructure
        )

    @mutable('smartcard_SetParameters_response')
    def handle_PcToRdr_SetParameters(self, slot, seq, data):
        self.proto = data[7]
        if self.proto == 0:
            self.abProtocolDataStructure = data[10:15]
        elif self.proto == 1:
            self.abProtocolDataStructure = data[10:17]
        return R2P_Parameters(
            slot=slot,
            seq=seq,
            status=0x00,
            error=0x80,
            proto=self.proto,
            data=self.abProtocolDataStructure
        )

    @mutable('smartcard_Escape_response')
    def handle_PcToRdr_Escape(self, slot, seq, data):
        '''
        .. todo:: should check the data parameter
        '''
        return R2P_Escape(
            slot=slot,
            seq=seq,
            status=0x00,
            error=0x80,
            data=b''
        )

    @mutable('smartcard_IccClock_response')
    def handle_PcToRdr_IccClock(self, slot, seq, data):
        # bClockCommand = data[7]
        return R2P_SlotStatus(
            slot=slot,
            seq=seq,
            status=0x00,
            error=0x80,
            clock_status=self.clock_status
        )

    @mutable('smartcard_T0APDU_response')
    def handle_PcToRdr_T0APDU(self, slot, seq, data):
        # bmChange, bClassGetResponse, bClassEnvelope = struct.unpack('<BBB', data[7:10])
        return R2P_SlotStatus(
            slot=slot,
            seq=seq,
            status=0x00,
            error=0x80,
            clock_status=self.clock_status
        )

    @mutable('smartcard_Secure_response')
    def handle_PcToRdr_Secure(self, slot, seq, data):
        '''
        .. todo:: to complete that, go over section 6.1.11
                  ATM unpack will raise an exception
        '''
        bBWI, wLevelParameter = struct.unpack('<BH')
        return R2P_DataBlock(
            slot=slot,
            seq=seq,
            status=0x00,
            error=0x80,
            chain_param=0x00,
            data=b''
        )

    @mutable('smartcard_Mechanical_response')
    def handle_PcToRdr_Mechanical(self, slot, seq, data):
        '''
        .. todo:: handling
        '''
        return R2P_SlotStatus(
            slot=slot,
            seq=seq,
            status=0x00,
            error=0x80,
            clock_status=self.clock_status
        )

    @mutable('smartcard_Abort_response')
    def handle_PcToRdr_Abort(self, slot, seq, data):
        '''
        .. todo:: handling
        '''
        return R2P_SlotStatus(
            slot=slot,
            seq=seq,
            status=0x00,
            error=0x80,
            clock_status=self.clock_status
        )

    @mutable('smartcard_SetDataRateAndClock_Frequency_response')
    def handle_PcToRdr_SetDataRateAndClock_Frequency(self, slot, seq, data):
        self.clock_freq, self.data_rate = struct.unpack('<II', data[10:18])
        return R2P_DataRateAndClockFrequency(
            slot=slot,
            seq=seq,
            status=0x00,
            error=0x80,
            freq=self.clock_freq,
            rate=self.data_rate
        )

    @mutable('smartcard_scd_icc_descriptor')
    def get_icc_descriptor(self, *args):
        bDescriptorType = 0x21
        bcdCCID = 0x0110
        bMaxSlotIndex = 0x00
        bVoltageSupport = 0x07
        dwProtocols = 0x00000003
        dwDefaultClock = 0x00000ea6
        dwMaximumClock = 0x00001d4c
        bNumClockSupported = len(self.clock_frequencies)
        dwDataRate = 0x00002760
        dwMaxDataRate = 0x0004c4b4
        bNumDataRatesSupported = len(self.data_rates)
        dwMaxIFSD = 0x000000fe
        dwSynchProtocols = 0x00000000
        dwMechanical = 0x00000000
        dwFeatures = 0x00010030
        dwMaxCCIDMessageLength = 0x0000010f
        bClassGetResponse = 0x00
        bClassEnvelope = 0x00
        wLcdLayout = 0x0000
        bPinSupport = 0x00
        bMaxCCIDBusySlots = 0x01

        response = struct.pack(
            '<BHBBIIIBIIBIIIIIBBHBB',
            bDescriptorType,
            bcdCCID,
            bMaxSlotIndex,
            bVoltageSupport,
            dwProtocols,
            dwDefaultClock,
            dwMaximumClock,
            bNumClockSupported,
            dwDataRate,
            dwMaxDataRate,
            bNumDataRatesSupported,
            dwMaxIFSD,
            dwSynchProtocols,
            dwMechanical,
            dwFeatures,
            dwMaxCCIDMessageLength,
            bClassGetResponse,
            bClassEnvelope,
            wLcdLayout,
            bPinSupport,
            bMaxCCIDBusySlots
        )

        response = struct.pack('B', len(response) + 1) + response
        return response

    def handle_data_available(self, data):
        self.usb_function_supported()
        opcode, length, slot, seq = struct.unpack('<BIBB', data[:7])
        if opcode in self.operations:
            handler = self.operations[opcode]
            self.session_data['bSlot'] = data[5]
            self.session_data['bSeq'] = data[6]
            self.session_data['data'] = data
            response = handler(slot, seq, data)
        else:
            self.error('Received Smartcard command not understood')
            response = b''
        if response:
            self.send_on_endpoint(2, response)

    def handle_buffer_available(self):
        if not self.int_q.empty():
            buff = self.int_q.get()
            self.debug('Sending data to host: %s' % (hexlify(buff)))
            self.send_on_endpoint(3, buff)
        else:
            self.send_on_endpoint(3, b'')
コード例 #51
0
class Nikon(object):
    '''This class implements Nikon's PTP operations.'''

    def __init__(self, *args, **kwargs):
        logger.debug('Init Nikon')
        super(Nikon, self).__init__(*args, **kwargs)
        # TODO: expose the choice to poll or not Nikon events
        self.__no_polling = False
        self.__nikon_event_shutdown = Event()
        self.__nikon_event_proc = None

    @contextmanager
    def session(self):
        '''
        Manage Nikon session with context manager.
        '''
        # When raw device, do not perform
        if self.__no_polling:
            with super(Nikon, self).session():
                yield
            return
        # Within a normal PTP session
        with super(Nikon, self).session():
            # launch a polling thread
            self.__event_queue = Queue()
            self.__nikon_event_proc = Thread(
                name='NikonEvtPolling',
                target=self.__nikon_poll_events
            )
            self.__nikon_event_proc.daemon = False
            atexit.register(self._nikon_shutdown)
            self.__nikon_event_proc.start()

            try:
                yield
            finally:
                self._nikon_shutdown()

    def _shutdown(self):
        self._nikon_shutdown()
        super(Nikon, self)._shutdown()

    def _nikon_shutdown(self):
        logger.debug('Shutdown Nikon events')
        self.__nikon_event_shutdown.set()

        # Only join a running thread.
        if self.__nikon_event_proc and self.__nikon_event_proc.is_alive():
            self.__nikon_event_proc.join(2)

    def _PropertyCode(self, **product_properties):
        props = {
            'ShootingBank': 0xD010,
            'ShootingBankNameA': 0xD011,
            'ShootingBankNameB': 0xD012,
            'ShootingBankNameC': 0xD013,
            'ShootingBankNameD': 0xD014,
            'ResetBank0': 0xD015,
            'RawCompression': 0xD016,
            'WhiteBalanceAutoBias': 0xD017,
            'WhiteBalanceTungstenBias': 0xD018,
            'WhiteBalanceFluorescentBias': 0xD019,
            'WhiteBalanceDaylightBias': 0xD01A,
            'WhiteBalanceFlashBias': 0xD01B,
            'WhiteBalanceCloudyBias': 0xD01C,
            'WhiteBalanceShadeBias': 0xD01D,
            'WhiteBalanceColorTemperature': 0xD01E,
            'WhiteBalancePresetNo': 0xD01F,
            'WhiteBalancePresetName0': 0xD020,
            'WhiteBalancePresetName1': 0xD021,
            'WhiteBalancePresetName2': 0xD022,
            'WhiteBalancePresetName3': 0xD023,
            'WhiteBalancePresetName4': 0xD024,
            'WhiteBalancePresetVal0': 0xD025,
            'WhiteBalancePresetVal1': 0xD026,
            'WhiteBalancePresetVal2': 0xD027,
            'WhiteBalancePresetVal3': 0xD028,
            'WhiteBalancePresetVal4': 0xD029,
            'ImageSharpening': 0xD02A,
            'ToneCompensation': 0xD02B,
            'ColorModel': 0xD02C,
            'HueAdjustment': 0xD02D,
            'NonCPULensDataFocalLength': 0xD02E,
            'NonCPULensDataMaximumAperture': 0xD02F,
            'ShootingMode': 0xD030,
            'JPEGCompressionPolicy': 0xD031,
            'ColorSpace': 0xD032,
            'AutoDXCrop': 0xD033,
            'FlickerReduction': 0xD034,
            'RemoteMode': 0xD035,
            'VideoMode': 0xD036,
            'NikonEffectMode': 0xD037,
            'Mode': 0xD038,
            'CSMMenuBankSelect': 0xD040,
            'MenuBankNameA': 0xD041,
            'MenuBankNameB': 0xD042,
            'MenuBankNameC': 0xD043,
            'MenuBankNameD': 0xD044,
            'ResetBank': 0xD045,
            'A1AFCModePriority': 0xD048,
            'A2AFSModePriority': 0xD049,
            'A3GroupDynamicAF': 0xD04A,
            'A4AFActivation': 0xD04B,
            'FocusAreaIllumManualFocus': 0xD04C,
            'FocusAreaIllumContinuous': 0xD04D,
            'FocusAreaIllumWhenSelected': 0xD04E,
            'FocusAreaWrap': 0xD04F,
            'VerticalAFON': 0xD050,
            'AFLockOn': 0xD051,
            'FocusAreaZone': 0xD052,
            'EnableCopyright': 0xD053,
            'ISOAuto': 0xD054,
            'EVISOStep': 0xD055,
            'EVStep': 0xD056,
            'EVStepExposureComp': 0xD057,
            'ExposureCompensation': 0xD058,
            'CenterWeightArea': 0xD059,
            'ExposureBaseMatrix': 0xD05A,
            'ExposureBaseCenter': 0xD05B,
            'ExposureBaseSpot': 0xD05C,
            'LiveViewAFArea': 0xD05D,
            'AELockMode': 0xD05E,
            'AELAFLMode': 0xD05F,
            'LiveViewAFFocus': 0xD061,
            'MeterOff': 0xD062,
            'SelfTimer': 0xD063,
            'MonitorOff': 0xD064,
            'ImgConfTime': 0xD065,
            'AutoOffTimers': 0xD066,
            'AngleLevel': 0xD067,
            'D1ShootingSpeed': 0xD068,
            'D2MaximumShots': 0xD069,
            'ExposureDelayMode': 0xD06A,
            'LongExposureNoiseReduction': 0xD06B,
            'FileNumberSequence': 0xD06C,
            'ControlPanelFinderRearControl': 0xD06D,
            'ControlPanelFinderViewfinder': 0xD06E,
            'D7Illumination': 0xD06F,
            'NrHighISO': 0xD070,
            'SHSetCHGUIDDisp': 0xD071,
            'ArtistName': 0xD072,
            'NikonCopyrightInfo': 0xD073,
            'FlashSyncSpeed': 0xD074,
            'FlashShutterSpeed': 0xD075,
            'E3AAFlashMode': 0xD076,
            'E4ModelingFlash': 0xD077,
            'BracketSet': 0xD078,
            'E6ManualModeBracketing': 0xD079,
            'BracketOrder': 0xD07A,
            'E8AutoBracketSelection': 0xD07B,
            'BracketingSet': 0xD07C,
            'F1CenterButtonShootingMode': 0xD080,
            'CenterButtonPlaybackMode': 0xD081,
            'F2Multiselector': 0xD082,
            'F3PhotoInfoPlayback': 0xD083,
            'F4AssignFuncButton': 0xD084,
            'F5CustomizeCommDials': 0xD085,
            'ReverseCommandDial': 0xD086,
            'ApertureSetting': 0xD087,
            'MenusAndPlayback': 0xD088,
            'F6ButtonsAndDials': 0xD089,
            'NoCFCard': 0xD08A,
            'CenterButtonZoomRatio': 0xD08B,
            'FunctionButton2': 0xD08C,
            'AFAreaPoint': 0xD08D,
            'NormalAFOn': 0xD08E,
            'CleanImageSensor': 0xD08F,
            'ImageCommentString': 0xD090,
            'ImageCommentEnable': 0xD091,
            'ImageRotation': 0xD092,
            'ManualSetLensNo': 0xD093,
            'MovScreenSize': 0xD0A0,
            'MovVoice': 0xD0A1,
            'MovMicrophone': 0xD0A2,
            'MovFileSlot': 0xD0A3,
            'MovRecProhibitCondition': 0xD0A4,
            'ManualMovieSetting': 0xD0A6,
            'MovQuality': 0xD0A7,
            'LiveViewScreenDisplaySetting': 0xD0B2,
            'MonitorOffDelay': 0xD0B3,
            'Bracketing': 0xD0C0,
            'AutoExposureBracketStep': 0xD0C1,
            'AutoExposureBracketProgram': 0xD0C2,
            'AutoExposureBracketCount': 0xD0C3,
            'WhiteBalanceBracketStep': 0xD0C4,
            'WhiteBalanceBracketProgram': 0xD0C5,
            'LensID': 0xD0E0,
            'LensSort': 0xD0E1,
            'LensType': 0xD0E2,
            'FocalLengthMin': 0xD0E3,
            'FocalLengthMax': 0xD0E4,
            'MaxApAtMinFocalLength': 0xD0E5,
            'MaxApAtMaxFocalLength': 0xD0E6,
            'FinderISODisp': 0xD0F0,
            'AutoOffPhoto': 0xD0F2,
            'AutoOffMenu': 0xD0F3,
            'AutoOffInfo': 0xD0F4,
            'SelfTimerShootNum': 0xD0F5,
            'VignetteCtrl': 0xD0F7,
            'AutoDistortionControl': 0xD0F8,
            'SceneMode': 0xD0F9,
            'SceneMode2': 0xD0FD,
            'SelfTimerInterval': 0xD0FE,
            'NikonExposureTime': 0xD100,
            'ACPower': 0xD101,
            'WarningStatus': 0xD102,
            'MaximumShots': 0xD103,
            'AFLockStatus': 0xD104,
            'AELockStatus': 0xD105,
            'FVLockStatus': 0xD106,
            'AutofocusLCDTopMode2': 0xD107,
            'AutofocusArea': 0xD108,
            'FlexibleProgram': 0xD109,
            'LightMeter': 0xD10A,
            'RecordingMedia': 0xD10B,
            'USBSpeed': 0xD10C,
            'CCDNumber': 0xD10D,
            'CameraOrientation': 0xD10E,
            'GroupPtnType': 0xD10F,
            'FNumberLock': 0xD110,
            'ExposureApertureLock': 0xD111,
            'TVLockSetting': 0xD112,
            'AVLockSetting': 0xD113,
            'IllumSetting': 0xD114,
            'FocusPointBright': 0xD115,
            'ExternalFlashAttached': 0xD120,
            'ExternalFlashStatus': 0xD121,
            'ExternalFlashSort': 0xD122,
            'ExternalFlashMode': 0xD123,
            'ExternalFlashCompensation': 0xD124,
            'NewExternalFlashMode': 0xD125,
            'FlashExposureCompensation': 0xD126,
            'HDRMode': 0xD130,
            'HDRHighDynamic': 0xD131,
            'HDRSmoothing': 0xD132,
            'OptimizeImage': 0xD140,
            'Saturation': 0xD142,
            'BWFillerEffect': 0xD143,
            'BWSharpness': 0xD144,
            'BWContrast': 0xD145,
            'BWSettingType': 0xD146,
            'Slot2SaveMode': 0xD148,
            'RawBitMode': 0xD149,
            'ActiveDLighting': 0xD14E,
            'FlourescentType': 0xD14F,
            'TuneColourTemperature': 0xD150,
            'TunePreset0': 0xD151,
            'TunePreset1': 0xD152,
            'TunePreset2': 0xD153,
            'TunePreset3': 0xD154,
            'TunePreset4': 0xD155,
            'BeepOff': 0xD160,
            'AutofocusMode': 0xD161,
            'AFAssist': 0xD163,
            'PADVPMode': 0xD164,
            'ImageReview': 0xD165,
            'AFAreaIllumination': 0xD166,
            'NikonFlashMode': 0xD167,
            'FlashCommanderMode': 0xD168,
            'FlashSign': 0xD169,
            '_ISOAuto': 0xD16A,
            'RemoteTimeout': 0xD16B,
            'GridDisplay': 0xD16C,
            'FlashModeManualPower': 0xD16D,
            'FlashModeCommanderPower': 0xD16E,
            'AutoFP': 0xD16F,
            'DateImprintSetting': 0xD170,
            'DateCounterSelect': 0xD171,
            'DateCountData': 0xD172,
            'DateCountDisplaySetting': 0xD173,
            'RangeFinderSetting': 0xD174,
            'CSMMenu': 0xD180,
            'WarningDisplay': 0xD181,
            'BatteryCellKind': 0xD182,
            'ISOAutoHiLimit': 0xD183,
            'DynamicAFArea': 0xD184,
            'ContinuousSpeedHigh': 0xD186,
            'InfoDispSetting': 0xD187,
            'PreviewButton': 0xD189,
            'PreviewButton2': 0xD18A,
            'AEAFLockButton2': 0xD18B,
            'IndicatorDisp': 0xD18D,
            'CellKindPriority': 0xD18E,
            'BracketingFramesAndSteps': 0xD190,
            'LiveViewMode': 0xD1A0,
            'LiveViewDriveMode': 0xD1A1,
            'LiveViewStatus': 0xD1A2,
            'LiveViewImageZoomRatio': 0xD1A3,
            'LiveViewProhibitCondition': 0xD1A4,
            'MovieShutterSpeed': 0xD1A8,
            'MovieFNumber': 0xD1A9,
            'MovieISO': 0xD1AA,
            'LiveViewMovieMode': 0xD1AC,
            'ExposureDisplayStatus': 0xD1B0,
            'ExposureIndicateStatus': 0xD1B1,
            'InfoDispErrStatus': 0xD1B2,
            'ExposureIndicateLightup': 0xD1B3,
            'FlashOpen': 0xD1C0,
            'FlashCharged': 0xD1C1,
            'FlashMRepeatValue': 0xD1D0,
            'FlashMRepeatCount': 0xD1D1,
            'FlashMRepeatInterval': 0xD1D2,
            'FlashCommandChannel': 0xD1D3,
            'FlashCommandSelfMode': 0xD1D4,
            'FlashCommandSelfCompensation': 0xD1D5,
            'FlashCommandSelfValue': 0xD1D6,
            'FlashCommandAMode': 0xD1D7,
            'FlashCommandACompensation': 0xD1D8,
            'FlashCommandAValue': 0xD1D9,
            'FlashCommandBMode': 0xD1DA,
            'FlashCommandBCompensation': 0xD1DB,
            'FlashCommandBValue': 0xD1DC,
            'ApplicationMode': 0xD1F0,
            'ActiveSlot': 0xD1F2,
            'ActivePicCtrlItem': 0xD200,
            'ChangePicCtrlItem': 0xD201,
            'MovieNrHighISO': 0xD236,
            'D241': 0xD241,
            'D244': 0xD244,
            'D247': 0xD247,
            'GUID': 0xD24F,
            'D250': 0xD250,
            'D251': 0xD251,
            'ISO': 0xF002,
            'ImageCompression': 0xF009,
            'NikonImageSize': 0xF00A,
            'NikonWhiteBalance': 0xF00C,
            # TODO: Are these redundant? Or product-specific?
            '_LongExposureNoiseReduction': 0xF00D,
            'HiISONoiseReduction': 0xF00E,
            '_ActiveDLighting': 0xF00F,
            '_MovQuality': 0xF01C,
        }
        product_properties.update(props)
        return super(Nikon, self)._PropertyCode(
            **product_properties
        )

    def _OperationCode(self, **product_operations):
        return super(Nikon, self)._OperationCode(
            GetProfileAllData=0x9006,
            SendProfileData=0x9007,
            DeleteProfile=0x9008,
            SetProfileData=0x9009,
            AdvancedTransfer=0x9010,
            GetFileInfoInBlock=0x9011,
            Capture=0x90C0,
            AFDrive=0x90C1,
            SetControlMode=0x90C2,
            DelImageSDRAM=0x90C3,
            GetLargeThumb=0x90C4,
            CurveDownload=0x90C5,
            CurveUpload=0x90C6,
            CheckEvents=0x90C7,
            DeviceReady=0x90C8,
            SetPreWBData=0x90C9,
            GetVendorPropCodes=0x90CA,
            AFCaptureSDRAM=0x90CB,
            GetPictCtrlData=0x90CC,
            SetPictCtrlData=0x90CD,
            DelCstPicCtrl=0x90CE,
            GetPicCtrlCapability=0x90CF,
            GetPreviewImg=0x9200,
            StartLiveView=0x9201,
            EndLiveView=0x9202,
            GetLiveViewImg=0x9203,
            MfDrive=0x9204,
            ChangeAFArea=0x9205,
            AFDriveCancel=0x9206,
            InitiateCaptureRecInMedia=0x9207,
            GetVendorStorageIDs=0x9209,
            StartMovieRecInCard=0x920A,
            EndMovieRec=0x920B,
            TerminateCapture=0x920C,
            GetDevicePTPIPInfo=0x90E0,
            GetPartialObjectHiSpeed=0x9400,
            GetDevicePropEx=0x9504,
            **product_operations
        )

    def _ResponseCode(self, **product_responses):
        return super(Nikon, self)._ResponseCode(
            HardwareError=0xA001,
            OutOfFocus=0xA002,
            ChangeCameraModeFailed=0xA003,
            InvalidStatus=0xA004,
            SetPropertyNotSupported=0xA005,
            WbResetError=0xA006,
            DustReferenceError=0xA007,
            ShutterSpeedBulb=0xA008,
            MirrorUpSequence=0xA009,
            CameraModeNotAdjustFNumber=0xA00A,
            NotLiveView=0xA00B,
            MfDriveStepEnd=0xA00C,
            MfDriveStepInsufficiency=0xA00E,
            AdvancedTransferCancel=0xA022,
            **product_responses
        )

    def _EventCode(self, **product_events):
        return super(Nikon, self)._EventCode(
            ObjectAddedInSDRAM=0xC101,
            CaptureCompleteRecInSdram=0xC102,
            AdvancedTransfer=0xC103,
            PreviewImageAdded=0xC104,
            **product_events
        )

    def _FilesystemType(self, **product_filesystem_types):
        return super(Nikon, self)._FilesystemType(
            **product_filesystem_types
        )

    def _NikonEvent(self):
        return PrefixedArray(
            self._UInt16,
            Struct(
                'EventCode' / self._EventCode,
                'Parameter' / self._UInt32,
            )
        )

    def _set_endian(self, endian):
        logger.debug('Set Nikon endianness')
        super(Nikon, self)._set_endian(endian)
        self._NikonEvent = self._NikonEvent()

    # TODO: Add event queue over all transports and extensions.
    def check_events(self):
        '''Check Nikon specific event'''
        ptp = Container(
            OperationCode='CheckEvents',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[]
        )
        response = self.recv(ptp)
        return self._parse_if_data(response, self._NikonEvent)

    # TODO: Provide a single camera agnostic command that will trigger a camera
    def capture(self):
        '''Nikon specific capture'''
        ptp = Container(
            OperationCode='Capture',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[]
        )
        return self.mesg(ptp)

    def af_capture_sdram(self):
        '''Nikon specific autofocus and capture to SDRAM'''
        ptp = Container(
            OperationCode='AFCaptureSDRAM',
            SessionID=self._session,
            TransactionID=self._transaction,
            Parameter=[]
        )
        return self.mesg(ptp)

    def event(self, wait=False):
        '''Check Nikon or PTP events

        If `wait` this function is blocking. Otherwise it may return None.
        '''
        # TODO: Do something reasonable on wait=True
        evt = None
        timeout = None if wait else 0.001
        # TODO: Join queues to preserve order of Nikon and PTP events.
        if not self.__event_queue.empty():
            evt = self.__event_queue.get(block=not wait, timeout=timeout)
        else:
            evt = super(Nikon, self).event(wait=wait)

        return evt

    def __nikon_poll_events(self):
        '''Poll events, adding them to a queue.'''
        while (not self.__nikon_event_shutdown.is_set() and
               _main_thread_alive()):
            try:
                evts = self.check_events()
                if evts:
                    for evt in evts:
                        logger.debug('Event queued')
                        self.__event_queue.put(evt)
            except Exception as e:
                logger.error(e)
            sleep(3)
        self.__nikon_event_shutdown.clear()
コード例 #52
0
ファイル: test_poutines.py プロジェクト: lewisKit/pyro
class QueueHandlerDiscreteTest(TestCase):

    def setUp(self):

        # simple Gaussian-mixture HMM
        def model():
            probs = pyro.param("probs", torch.tensor([[0.8], [0.3]]))
            loc = pyro.param("loc", torch.tensor([[-0.1], [0.9]]))
            scale = torch.ones(1, 1)

            latents = [torch.ones(1)]
            observes = []
            for t in range(3):

                latents.append(
                    pyro.sample("latent_{}".format(str(t)),
                                Bernoulli(probs[latents[-1][0].long().data])))

                observes.append(
                    pyro.sample("observe_{}".format(str(t)),
                                Normal(loc[latents[-1][0].long().data], scale),
                                obs=torch.ones(1)))
            return latents

        self.sites = ["observe_{}".format(str(t)) for t in range(3)] + \
                     ["latent_{}".format(str(t)) for t in range(3)] + \
                     ["_INPUT", "_RETURN"]
        self.model = model
        self.queue = Queue()
        self.queue.put(poutine.Trace())

    def test_queue_single(self):
        f = poutine.trace(poutine.queue(self.model, queue=self.queue))
        tr = f.get_trace()
        for name in self.sites:
            assert name in tr

    def test_queue_enumerate(self):
        f = poutine.trace(poutine.queue(self.model, queue=self.queue))
        trs = []
        while not self.queue.empty():
            trs.append(f.get_trace())
        assert len(trs) == 2 ** 3

        true_latents = set()
        for i1 in range(2):
            for i2 in range(2):
                for i3 in range(2):
                    true_latents.add((i1, i2, i3))

        tr_latents = []
        for tr in trs:
            tr_latents.append(tuple([int(tr.nodes[name]["value"].view(-1).item()) for name in tr
                                     if tr.nodes[name]["type"] == "sample" and
                                     not tr.nodes[name]["is_observed"]]))

        assert true_latents == set(tr_latents)

    def test_queue_max_tries(self):
        f = poutine.queue(self.model, queue=self.queue, max_tries=3)
        with pytest.raises(ValueError):
            f()
コード例 #53
0
class TestFetcherProcessor(unittest.TestCase):

    @classmethod
    def setUpClass(self):
        self.projectdb = ProjectDB([os.path.join(os.path.dirname(__file__), 'data_fetcher_processor_handler.py')])
        self.fetcher = Fetcher(None, None, async=False)
        self.status_queue = Queue()
        self.newtask_queue = Queue()
        self.result_queue = Queue()
        self.httpbin_thread = utils.run_in_subprocess(httpbin.app.run, port=14887, passthrough_errors=False)
        self.httpbin = 'http://127.0.0.1:14887'
        self.proxy_thread = subprocess.Popen(['pyproxy', '--username=binux',
                                              '--password=123456', '--port=14830',
                                              '--debug'], close_fds=True)
        self.proxy = '127.0.0.1:14830'
        self.processor = Processor(projectdb=self.projectdb,
                                   inqueue=None,
                                   status_queue=self.status_queue,
                                   newtask_queue=self.newtask_queue,
                                   result_queue=self.result_queue)
        self.project_name = 'data_fetcher_processor_handler'
        time.sleep(0.5)

    @classmethod
    def tearDownClass(self):
        self.proxy_thread.terminate()
        self.proxy_thread.wait()
        self.httpbin_thread.terminate()
        self.httpbin_thread.join()

    def crawl(self, url=None, track=None, **kwargs):
        if url is None and kwargs.get('callback'):
            url = dataurl.encode(utils.text(kwargs.get('callback')))

        project_data = self.processor.project_manager.get(self.project_name)
        assert project_data, "can't find project: %s" % self.project_name
        instance = project_data['instance']
        instance._reset()
        task = instance.crawl(url, **kwargs)
        if isinstance(task, list):
            task = task[0]
        task['track'] = track
        result = self.fetcher.fetch(task)
        self.processor.on_task(task, result)

        status = None
        while not self.status_queue.empty():
            status = self.status_queue.get()
        newtasks = []
        while not self.newtask_queue.empty():
            newtasks = self.newtask_queue.get()
        result = None
        while not self.result_queue.empty():
            _, result = self.result_queue.get()
        return status, newtasks, result

    def status_ok(self, status, type):
        if not status:
            return False
        return status.get('track', {}).get(type, {}).get('ok', False)

    def assertStatusOk(self, status):
        self.assertTrue(self.status_ok(status, 'fetch'), status.get('track', {}).get('fetch'))
        self.assertTrue(self.status_ok(status, 'process'), status.get('track', {}).get('process'))

    def __getattr__(self, name):
        return name

    def test_10_not_status(self):
        status, newtasks, result = self.crawl(callback=self.not_send_status)

        self.assertIsNone(status)
        self.assertEqual(len(newtasks), 1, newtasks)
        self.assertEqual(result, 'not_send_status')

    def test_20_url_deduplicated(self):
        status, newtasks, result = self.crawl(callback=self.url_deduplicated)

        self.assertStatusOk(status)
        self.assertIsNone(status['track']['fetch']['error'])
        self.assertIsNone(status['track']['fetch']['content'])
        self.assertFalse(status['track']['fetch']['headers'])
        self.assertFalse(status['track']['process']['logs'])
        self.assertEqual(len(newtasks), 2, newtasks)
        self.assertIsNone(result)

    def test_30_catch_status_code_error(self):
        status, newtasks, result = self.crawl(self.httpbin+'/status/418', callback=self.json)

        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertFalse(self.status_ok(status, 'process'))
        self.assertIn('HTTP 418', status['track']['fetch']['error'])
        self.assertTrue(status['track']['fetch']['content'], '')
        self.assertTrue(status['track']['fetch']['headers'])
        self.assertTrue(status['track']['process']['logs'])
        self.assertIn('HTTPError: HTTP 418', status['track']['process']['logs'])
        self.assertFalse(newtasks)


        status, newtasks, result = self.crawl(self.httpbin+'/status/400', callback=self.catch_http_error)

        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertTrue(self.status_ok(status, 'process'))
        self.assertEqual(len(newtasks), 1, newtasks)
        self.assertEqual(result, 400)

        status, newtasks, result = self.crawl(self.httpbin+'/status/500', callback=self.catch_http_error)
        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertTrue(self.status_ok(status, 'process'))
        self.assertEqual(len(newtasks), 1, newtasks)
        self.assertEqual(result, 500)

        status, newtasks, result = self.crawl(self.httpbin+'/status/302',
                                              allow_redirects=False,
                                              callback=self.catch_http_error)
        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertTrue(self.status_ok(status, 'process'))
        self.assertEqual(len(newtasks), 1, newtasks)
        self.assertEqual(result, 302)

    def test_40_method(self):
        status, newtasks, result = self.crawl(self.httpbin+'/delete', method='DELETE', callback=self.json)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)

        status, newtasks, result = self.crawl(self.httpbin+'/get', method='DELETE', callback=self.catch_http_error)

        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertTrue(self.status_ok(status, 'process'))
        self.assertTrue(newtasks)
        self.assertEqual(result, 405)

    def test_50_params(self):
        status, newtasks, result = self.crawl(self.httpbin+'/get', params={
            'roy': 'binux',
            u'中文': '.',
        }, callback=self.json)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result['args'], {'roy': 'binux', u'中文': '.'})

    def test_60_data(self):
        status, newtasks, result = self.crawl(self.httpbin+'/post', data={
            'roy': 'binux',
            u'中文': '.',
        }, callback=self.json)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result['form'], {'roy': 'binux', u'中文': '.'})

    def test_70_redirect(self):
        status, newtasks, result = self.crawl(self.httpbin+'/redirect-to?url=/get', callback=self.json)

        self.assertStatusOk(status)
        self.assertEqual(status['track']['fetch']['redirect_url'], self.httpbin+'/get')
        self.assertFalse(newtasks)

    def test_80_redirect_too_many(self):
        status, newtasks, result = self.crawl(self.httpbin+'/redirect/10', callback=self.json)

        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertFalse(self.status_ok(status, 'process'))
        self.assertFalse(newtasks)
        self.assertEqual(status['track']['fetch']['status_code'], 599)
        self.assertIn('redirects followed', status['track']['fetch']['error'])

    def test_90_files(self):
        status, newtasks, result = self.crawl(self.httpbin+'/put', method='PUT',
                                              files={os.path.basename(__file__): open(__file__).read()},
                                              callback=self.json)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertIn(os.path.basename(__file__), result['files'])

    def test_a100_files_with_data(self):
        status, newtasks, result = self.crawl(self.httpbin+'/put', method='PUT',
                                              files={os.path.basename(__file__): open(__file__).read()},
                                              data={
                                                  'roy': 'binux',
                                                  #'中文': '.', # FIXME: not work
                                              },
                                              callback=self.json)
        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result['form'], {'roy': 'binux'})
        self.assertIn(os.path.basename(__file__), result['files'])

    def test_a110_headers(self):
        status, newtasks, result = self.crawl(self.httpbin+'/get',
                                              headers={
                                                  'a': 'b',
                                                  'C-d': 'e-F',
                                              }, callback=self.json)
        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result['headers'].get('A'), 'b')
        self.assertEqual(result['headers'].get('C-D'), 'e-F')

    def test_a115_user_agent(self):
        status, newtasks, result = self.crawl(self.httpbin+'/get',
                                              user_agent='binux', callback=self.json)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result['headers'].get('User-Agent'), 'binux')


    def test_a120_cookies(self):
        status, newtasks, result = self.crawl(self.httpbin+'/get',
                                              cookies={
                                                  'a': 'b',
                                                  'C-d': 'e-F'
                                              }, callback=self.json)
        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertIn('a=b', result['headers'].get('Cookie'))
        self.assertIn('C-d=e-F', result['headers'].get('Cookie'))

    def test_a130_cookies_with_headers(self):
        status, newtasks, result = self.crawl(self.httpbin+'/get',
                                              headers={
                                                  'Cookie': 'g=h; I=j',
                                              },
                                              cookies={
                                                  'a': 'b',
                                                  'C-d': 'e-F'
                                              }, callback=self.json)
        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertIn('g=h', result['headers'].get('Cookie'))
        self.assertIn('I=j', result['headers'].get('Cookie'))
        self.assertIn('a=b', result['headers'].get('Cookie'))
        self.assertIn('C-d=e-F', result['headers'].get('Cookie'))

    def test_a140_response_cookie(self):
        status, newtasks, result = self.crawl(self.httpbin+'/cookies/set?k1=v1&k2=v2',
                                              callback=self.cookies)
        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result, {'k1': 'v1', 'k2': 'v2'})

    def test_a145_redirect_cookie(self):
        status, newtasks, result = self.crawl(self.httpbin+'/cookies/set?k1=v1&k2=v2',
                                              callback=self.json)
        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result['cookies'], {'k1': 'v1', 'k2': 'v2'})

    def test_a150_timeout(self):
        status, newtasks, result = self.crawl(self.httpbin+'/delay/2', timeout=1, callback=self.json)

        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertFalse(self.status_ok(status, 'process'))
        self.assertFalse(newtasks)
        self.assertEqual(int(status['track']['fetch']['time']), 1)

    def test_a160_etag(self):
        status, newtasks, result = self.crawl(self.httpbin+'/cache', etag='abc', callback=self.json)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertFalse(result)

    def test_a170_last_modified(self):
        status, newtasks, result = self.crawl(self.httpbin+'/cache', last_modified='0', callback=self.json)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertFalse(result)

    def test_a180_save(self):
        status, newtasks, result = self.crawl(callback=self.get_save,
                                              save={'roy': 'binux', u'中文': 'value'})

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result, {'roy': 'binux', u'中文': 'value'})

    def test_a190_taskid(self):
        status, newtasks, result = self.crawl(callback=self.get_save,
                                              taskid='binux-taskid')

        self.assertStatusOk(status)
        self.assertEqual(status['taskid'], 'binux-taskid')
        self.assertFalse(newtasks)
        self.assertFalse(result)

    def test_a200_no_proxy(self):
        old_proxy = self.fetcher.proxy
        self.fetcher.proxy = self.proxy
        status, newtasks, result = self.crawl(self.httpbin+'/get',
                                              params={
                                                  'test': 'a200'
                                              }, proxy=False, callback=self.json)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.fetcher.proxy = old_proxy

    def test_a210_proxy_failed(self):
        old_proxy = self.fetcher.proxy
        self.fetcher.proxy = self.proxy
        status, newtasks, result = self.crawl(self.httpbin+'/get',
                                              params={
                                                  'test': 'a210'
                                              }, callback=self.catch_http_error)

        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertTrue(self.status_ok(status, 'process'))
        self.assertEqual(len(newtasks), 1, newtasks)
        self.assertEqual(result, 403)
        self.fetcher.proxy = old_proxy

    def test_a220_proxy_ok(self):
        old_proxy = self.fetcher.proxy
        self.fetcher.proxy = self.proxy
        status, newtasks, result = self.crawl(self.httpbin+'/get',
                                              params={
                                                  'test': 'a220',
                                                  'username': '******',
                                                  'password': '******',
                                              }, callback=self.catch_http_error)

        self.assertStatusOk(status)
        self.assertEqual(result, 200)
        self.fetcher.proxy = old_proxy

    def test_a230_proxy_parameter_fail(self):
        status, newtasks, result = self.crawl(self.httpbin+'/get',
                                              params={
                                                  'test': 'a230',
                                              }, proxy=self.proxy,
                                              callback=self.catch_http_error)

        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertTrue(self.status_ok(status, 'process'))
        self.assertEqual(result, 403)

    def test_a240_proxy_parameter_ok(self):
        status, newtasks, result = self.crawl(self.httpbin+'/post',
                                              method='POST',
                                              data={
                                                  'test': 'a240',
                                                  'username': '******',
                                                  'password': '******',
                                              }, proxy=self.proxy,
                                              callback=self.catch_http_error)

        self.assertStatusOk(status)
        self.assertEqual(result, 200)

    def test_a250_proxy_userpass(self):
        status, newtasks, result = self.crawl(self.httpbin+'/post',
                                              method='POST',
                                              data={
                                                  'test': 'a250',
                                              }, proxy='binux:123456@'+self.proxy,
                                              callback=self.catch_http_error)

        self.assertStatusOk(status)
        self.assertEqual(result, 200)

    def test_a260_process_save(self):
        status, newtasks, result = self.crawl(callback=self.set_process_save)

        self.assertStatusOk(status)
        self.assertIn('roy', status['track']['save'])
        self.assertEqual(status['track']['save']['roy'], 'binux')

        status, newtasks, result = self.crawl(callback=self.get_process_save,
                                              track=status['track'])

        self.assertStatusOk(status)
        self.assertIn('roy', result)
        self.assertEqual(result['roy'], 'binux')


    def test_zzz_links(self):
        status, newtasks, result = self.crawl(self.httpbin+'/links/10/0', callback=self.links)

        self.assertStatusOk(status)
        self.assertEqual(len(newtasks), 9, newtasks)
        self.assertFalse(result)

    def test_zzz_html(self):
        status, newtasks, result = self.crawl(self.httpbin+'/html', callback=self.html)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result, 'Herman Melville - Moby-Dick')

    def test_zzz_etag_enabled(self):
        status, newtasks, result = self.crawl(self.httpbin+'/cache', callback=self.json)
        self.assertStatusOk(status)
        self.assertTrue(result)

        status, newtasks, result = self.crawl(self.httpbin+'/cache',
                                              track=status['track'], callback=self.json)
        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertFalse(result)

    def test_zzz_etag_not_working(self):
        status, newtasks, result = self.crawl(self.httpbin+'/cache', callback=self.json)
        self.assertStatusOk(status)
        self.assertTrue(result)

        status['track']['process']['ok'] = False
        status, newtasks, result = self.crawl(self.httpbin+'/cache',
                                              track=status['track'], callback=self.json)
        self.assertStatusOk(status)
        self.assertTrue(result)

    def test_zzz_unexpected_crawl_argument(self):
        with self.assertRaisesRegexp(TypeError, "unexpected keyword argument"):
            self.crawl(self.httpbin+'/cache', cookie={}, callback=self.json)

    def test_zzz_curl_get(self):
        status, newtasks, result = self.crawl("curl '"+self.httpbin+'''/get' -H 'DNT: 1' -H 'Accept-Encoding: gzip, deflate, sdch' -H 'Accept-Language: en,zh-CN;q=0.8,zh;q=0.6' -H 'User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_2) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2272.17 Safari/537.36' -H 'Binux-Header: Binux-Value' -H 'Accept: */*' -H 'Cookie: _gauges_unique_year=1; _gauges_unique=1; _ga=GA1.2.415471573.1419316591' -H 'Connection: keep-alive' --compressed''', callback=self.json)
        self.assertStatusOk(status)
        self.assertTrue(result)

        self.assertTrue(result['headers'].get('Binux-Header'), 'Binux-Value')

    def test_zzz_curl_post(self):
        status, newtasks, result = self.crawl("curl '"+self.httpbin+'''/post' -H 'Origin: chrome-extension://hgmloofddffdnphfgcellkdfbfbjeloo' -H 'Accept-Encoding: gzip, deflate' -H 'Accept-Language: en,zh-CN;q=0.8,zh;q=0.6' -H 'User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_2) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2272.17 Safari/537.36' -H 'Content-Type: application/x-www-form-urlencoded' -H 'Accept: */*' -H 'Cookie: _gauges_unique_year=1; _gauges_unique=1; _ga=GA1.2.415471573.1419316591' -H 'Connection: keep-alive' -H 'DNT: 1' --data 'Binux-Key=%E4%B8%AD%E6%96%87+value' --compressed''', callback=self.json)
        self.assertStatusOk(status)
        self.assertTrue(result)

        self.assertTrue(result['form'].get('Binux-Key'), '中文 value')

    def test_zzz_curl_put(self):
        status, newtasks, result = self.crawl("curl '"+self.httpbin+'''/put' -X PUT -H 'Origin: chrome-extension://hgmloofddffdnphfgcellkdfbfbjeloo' -H 'Accept-Encoding: gzip, deflate, sdch' -H 'Accept-Language: en,zh-CN;q=0.8,zh;q=0.6' -H 'User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_2) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2272.17 Safari/537.36' -H 'Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryYlkgyaA7SRGOQYUG' -H 'Accept: */*' -H 'Cookie: _gauges_unique_year=1; _gauges_unique=1; _ga=GA1.2.415471573.1419316591' -H 'Connection: keep-alive' -H 'DNT: 1' --data-binary $'------WebKitFormBoundaryYlkgyaA7SRGOQYUG\r\nContent-Disposition: form-data; name="Binux-Key"\r\n\r\n%E4%B8%AD%E6%96%87+value\r\n------WebKitFormBoundaryYlkgyaA7SRGOQYUG\r\nContent-Disposition: form-data; name="fileUpload1"; filename="1"\r\nContent-Type: application/octet-stream\r\n\r\n\r\n------WebKitFormBoundaryYlkgyaA7SRGOQYUG--\r\n' --compressed''', callback=self.json)
        self.assertStatusOk(status)
        self.assertTrue(result)

        self.assertIn('fileUpload1', result['files'], result)

    def test_zzz_curl_no_url(self):
        with self.assertRaisesRegexp(TypeError, 'no URL'):
            status, newtasks, result = self.crawl(
                '''curl -X PUT -H 'Origin: chrome-extension://hgmloofddffdnphfgcellkdfbfbjeloo' --compressed''',
                callback=self.json)

    def test_zzz_curl_bad_option(self):
        with self.assertRaisesRegexp(TypeError, 'Unknow curl option'):
            status, newtasks, result = self.crawl(
                '''curl '%s/put' -X PUT -H 'Origin: chrome-extension://hgmloofddffdnphfgcellkdfbfbjeloo' -v''' % self.httpbin,
                callback=self.json)

        with self.assertRaisesRegexp(TypeError, 'Unknow curl option'):
            status, newtasks, result = self.crawl(
                '''curl '%s/put' -X PUT -v -H 'Origin: chrome-extension://hgmloofddffdnphfgcellkdfbfbjeloo' ''' % self.httpbin,
                callback=self.json)


    def test_zzz_robots_txt(self):
        status, newtasks, result = self.crawl(self.httpbin+'/deny', robots_txt=True, callback=self.catch_http_error)

        self.assertEqual(result, 403)


    def test_zzz_connect_timeout(self):
        start_time = time.time()
        status, newtasks, result = self.crawl('http://1.1.1.1/', connect_timeout=5, callback=self.catch_http_error)
        end_time = time.time()
        self.assertTrue(5 <= end_time - start_time <= 6)
コード例 #54
0
class TestFetcherProcessor(unittest.TestCase):

    @classmethod
    def setUpClass(self):
        self.projectdb = ProjectDB([os.path.join(os.path.dirname(__file__), 'data_fetcher_processor_handler.py')])
        self.fetcher = Fetcher(None, None, async=False)
        self.status_queue = Queue()
        self.newtask_queue = Queue()
        self.result_queue = Queue()
        self.httpbin_thread = utils.run_in_subprocess(httpbin.app.run, port=14887)
        self.httpbin = 'http://127.0.0.1:14887'
        self.proxy_thread = subprocess.Popen(['pyproxy', '--username=binux',
                                              '--password=123456', '--port=14830',
                                              '--debug'], close_fds=True)
        self.proxy = '127.0.0.1:14830'
        self.processor = Processor(projectdb=self.projectdb,
                                   inqueue=None,
                                   status_queue=self.status_queue,
                                   newtask_queue=self.newtask_queue,
                                   result_queue=self.result_queue)
        self.project_name = 'data_fetcher_processor_handler'
        time.sleep(0.5)

    @classmethod
    def tearDownClass(self):
        self.proxy_thread.terminate()
        self.proxy_thread.wait()
        self.httpbin_thread.terminate()
        self.httpbin_thread.join()

    def crawl(self, url=None, track=None, **kwargs):
        if url is None and kwargs.get('callback'):
            url = dataurl.encode(utils.text(kwargs.get('callback')))

        project_data = self.processor.project_manager.get(self.project_name)
        assert project_data, "can't find project: %s" % self.project_name
        instance = project_data['instance']
        instance._reset()
        task = instance.crawl(url, **kwargs)
        if isinstance(task, list):
            task = task[0]
        task['track'] = track
        result = self.fetcher.fetch(task)
        self.processor.on_task(task, result)

        status = None
        while not self.status_queue.empty():
            status = self.status_queue.get()
        newtasks = []
        while not self.newtask_queue.empty():
            newtasks = self.newtask_queue.get()
        result = None
        while not self.result_queue.empty():
            _, result = self.result_queue.get()
        return status, newtasks, result

    def status_ok(self, status, type):
        if not status:
            return False
        return status.get('track', {}).get(type, {}).get('ok', False)

    def assertStatusOk(self, status):
        self.assertTrue(self.status_ok(status, 'fetch'), status.get('track', {}).get('fetch'))
        self.assertTrue(self.status_ok(status, 'process'), status.get('track', {}).get('process'))

    def __getattr__(self, name):
        return name

    def test_10_not_status(self):
        status, newtasks, result = self.crawl(callback=self.not_send_status)

        self.assertIsNone(status)
        self.assertEqual(len(newtasks), 1, newtasks)
        self.assertEqual(result, 'not_send_status')

    def test_20_url_deduplicated(self):
        status, newtasks, result = self.crawl(callback=self.url_deduplicated)

        self.assertStatusOk(status)
        self.assertIsNone(status['track']['fetch']['error'])
        self.assertIsNone(status['track']['fetch']['content'])
        self.assertFalse(status['track']['fetch']['headers'])
        self.assertFalse(status['track']['process']['logs'])
        self.assertEqual(len(newtasks), 2, newtasks)
        self.assertIsNone(result)

    def test_30_catch_status_code_error(self):
        status, newtasks, result = self.crawl(self.httpbin+'/status/418', callback=self.json)

        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertFalse(self.status_ok(status, 'process'))
        self.assertIn('HTTP 418', status['track']['fetch']['error'])
        self.assertTrue(status['track']['fetch']['content'], '')
        self.assertTrue(status['track']['fetch']['headers'])
        self.assertTrue(status['track']['process']['logs'])
        self.assertIn('HTTPError: HTTP 418', status['track']['process']['logs'])
        self.assertFalse(newtasks)


        status, newtasks, result = self.crawl(self.httpbin+'/status/400', callback=self.catch_http_error)

        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertTrue(self.status_ok(status, 'process'))
        self.assertEqual(len(newtasks), 1, newtasks)
        self.assertEqual(result, 400)

        status, newtasks, result = self.crawl(self.httpbin+'/status/500', callback=self.catch_http_error)
        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertTrue(self.status_ok(status, 'process'))
        self.assertEqual(len(newtasks), 1, newtasks)
        self.assertEqual(result, 500)

        status, newtasks, result = self.crawl(self.httpbin+'/status/302',
                                              allow_redirects=False,
                                              callback=self.catch_http_error)
        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertTrue(self.status_ok(status, 'process'))
        self.assertEqual(len(newtasks), 1, newtasks)
        self.assertEqual(result, 302)

    def test_40_method(self):
        status, newtasks, result = self.crawl(self.httpbin+'/delete', method='DELETE', callback=self.json)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)

        status, newtasks, result = self.crawl(self.httpbin+'/get', method='DELETE', callback=self.catch_http_error)

        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertTrue(self.status_ok(status, 'process'))
        self.assertTrue(newtasks)
        self.assertEqual(result, 405)

    def test_50_params(self):
        status, newtasks, result = self.crawl(self.httpbin+'/get', params={
            'roy': 'binux',
            u'中文': '.',
        }, callback=self.json)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result['args'], {'roy': 'binux', u'中文': '.'})

    def test_60_data(self):
        status, newtasks, result = self.crawl(self.httpbin+'/post', data={
            'roy': 'binux',
            u'中文': '.',
        }, callback=self.json)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result['form'], {'roy': 'binux', u'中文': '.'})

    def test_70_redirect(self):
        status, newtasks, result = self.crawl(self.httpbin+'/redirect-to?url=/get', callback=self.json)

        self.assertStatusOk(status)
        self.assertEqual(status['track']['fetch']['redirect_url'], self.httpbin+'/get')
        self.assertFalse(newtasks)

    def test_80_redirect_too_many(self):
        status, newtasks, result = self.crawl(self.httpbin+'/redirect/10', callback=self.json)

        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertFalse(self.status_ok(status, 'process'))
        self.assertFalse(newtasks)
        self.assertEqual(status['track']['fetch']['status_code'], 599)
        self.assertIn('redirects followed', status['track']['fetch']['error'])

    def test_90_files(self):
        status, newtasks, result = self.crawl(self.httpbin+'/put', method='PUT',
                                              files={os.path.basename(__file__): open(__file__).read()},
                                              callback=self.json)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertIn(os.path.basename(__file__), result['files'])

    def test_a100_files_with_data(self):
        status, newtasks, result = self.crawl(self.httpbin+'/put', method='PUT',
                                              files={os.path.basename(__file__): open(__file__).read()},
                                              data={
                                                  'roy': 'binux',
                                                  #'中文': '.', # FIXME: not work
                                              },
                                              callback=self.json)
        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result['form'], {'roy': 'binux'})
        self.assertIn(os.path.basename(__file__), result['files'])

    def test_a110_headers(self):
        status, newtasks, result = self.crawl(self.httpbin+'/get',
                                              headers={
                                                  'a': 'b',
                                                  'C-d': 'e-F',
                                              }, callback=self.json)
        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result['headers'].get('A'), 'b')
        self.assertEqual(result['headers'].get('C-D'), 'e-F')

    def test_a120_cookies(self):
        status, newtasks, result = self.crawl(self.httpbin+'/get',
                                              cookies={
                                                  'a': 'b',
                                                  'C-d': 'e-F'
                                              }, callback=self.json)
        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertIn('a=b', result['headers'].get('Cookie'))
        self.assertIn('C-d=e-F', result['headers'].get('Cookie'))

    def test_a130_cookies_with_headers(self):
        status, newtasks, result = self.crawl(self.httpbin+'/get',
                                              headers={
                                                  'Cookie': 'g=h; I=j',
                                              },
                                              cookies={
                                                  'a': 'b',
                                                  'C-d': 'e-F'
                                              }, callback=self.json)
        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertIn('g=h', result['headers'].get('Cookie'))
        self.assertIn('I=j', result['headers'].get('Cookie'))
        self.assertIn('a=b', result['headers'].get('Cookie'))
        self.assertIn('C-d=e-F', result['headers'].get('Cookie'))

    def test_a140_response_cookie(self):
        status, newtasks, result = self.crawl(self.httpbin+'/cookies/set?k1=v1&k2=v2',
                                              callback=self.cookies)
        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result, {'k1': 'v1', 'k2': 'v2'})

    def test_a145_redirect_cookie(self):
        status, newtasks, result = self.crawl(self.httpbin+'/cookies/set?k1=v1&k2=v2',
                                              callback=self.json)
        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result['cookies'], {'k1': 'v1', 'k2': 'v2'})

    def test_a150_timeout(self):
        status, newtasks, result = self.crawl(self.httpbin+'/delay/2', timeout=1, callback=self.json)

        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertFalse(self.status_ok(status, 'process'))
        self.assertFalse(newtasks)
        self.assertEqual(int(status['track']['fetch']['time']), 1)

    def test_a160_etag(self):
        status, newtasks, result = self.crawl(self.httpbin+'/cache', etag='abc', callback=self.json)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertFalse(result)

    def test_a170_last_modifed(self):
        status, newtasks, result = self.crawl(self.httpbin+'/cache', last_modifed='0', callback=self.json)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertFalse(result)

    def test_a180_save(self):
        status, newtasks, result = self.crawl(callback=self.get_save,
                                              save={'roy': 'binux', u'中文': 'value'})

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result, {'roy': 'binux', u'中文': 'value'})

    def test_a190_taskid(self):
        status, newtasks, result = self.crawl(callback=self.get_save,
                                              taskid='binux-taskid')

        self.assertStatusOk(status)
        self.assertEqual(status['taskid'], 'binux-taskid')
        self.assertFalse(newtasks)
        self.assertFalse(result)

    def test_a200_no_proxy(self):
        old_proxy = self.fetcher.proxy
        self.fetcher.proxy = self.proxy
        status, newtasks, result = self.crawl(self.httpbin+'/get',
                                              params={
                                                  'test': 'a200'
                                              }, proxy=False, callback=self.json)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.fetcher.proxy = old_proxy

    def test_a210_proxy_failed(self):
        old_proxy = self.fetcher.proxy
        self.fetcher.proxy = self.proxy
        status, newtasks, result = self.crawl(self.httpbin+'/get',
                                              params={
                                                  'test': 'a210'
                                              }, callback=self.catch_http_error)

        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertTrue(self.status_ok(status, 'process'))
        self.assertEqual(len(newtasks), 1, newtasks)
        self.assertEqual(result, 403)
        self.fetcher.proxy = old_proxy

    def test_a220_proxy_ok(self):
        old_proxy = self.fetcher.proxy
        self.fetcher.proxy = self.proxy
        status, newtasks, result = self.crawl(self.httpbin+'/get',
                                              params={
                                                  'test': 'a220',
                                                  'username': '******',
                                                  'password': '******',
                                              }, callback=self.catch_http_error)

        self.assertStatusOk(status)
        self.assertEqual(result, 200)
        self.fetcher.proxy = old_proxy

    def test_a230_proxy_parameter_fail(self):
        status, newtasks, result = self.crawl(self.httpbin+'/get',
                                              params={
                                                  'test': 'a230',
                                              }, proxy=self.proxy,
                                              callback=self.catch_http_error)

        self.assertFalse(self.status_ok(status, 'fetch'))
        self.assertTrue(self.status_ok(status, 'process'))
        self.assertEqual(result, 403)

    def test_a240_proxy_parameter_ok(self):
        status, newtasks, result = self.crawl(self.httpbin+'/post',
                                              method='POST',
                                              data={
                                                  'test': 'a240',
                                                  'username': '******',
                                                  'password': '******',
                                              }, proxy=self.proxy,
                                              callback=self.catch_http_error)

        self.assertStatusOk(status)
        self.assertEqual(result, 200)

    def test_a250_proxy_userpass(self):
        status, newtasks, result = self.crawl(self.httpbin+'/post',
                                              method='POST',
                                              data={
                                                  'test': 'a250',
                                              }, proxy='binux:123456@'+self.proxy,
                                              callback=self.catch_http_error)

        self.assertStatusOk(status)
        self.assertEqual(result, 200)

    def test_a260_process_save(self):
        status, newtasks, result = self.crawl(callback=self.set_process_save)

        self.assertStatusOk(status)
        self.assertIn('roy', status['track']['save'])
        self.assertEqual(status['track']['save']['roy'], 'binux')

        status, newtasks, result = self.crawl(callback=self.get_process_save,
                                              track=status['track'])

        self.assertStatusOk(status)
        self.assertIn('roy', result)
        self.assertEqual(result['roy'], 'binux')


    def test_zzz_links(self):
        status, newtasks, result = self.crawl(self.httpbin+'/links/10/0', callback=self.links)

        self.assertStatusOk(status)
        self.assertEqual(len(newtasks), 9, newtasks)
        self.assertFalse(result)

    def test_zzz_html(self):
        status, newtasks, result = self.crawl(self.httpbin+'/html', callback=self.html)

        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertEqual(result, 'Herman Melville - Moby-Dick')

    def test_zzz_etag_enabled(self):
        status, newtasks, result = self.crawl(self.httpbin+'/cache', callback=self.json)
        self.assertStatusOk(status)
        self.assertTrue(result)

        status, newtasks, result = self.crawl(self.httpbin+'/cache',
                                              track=status['track'], callback=self.json)
        self.assertStatusOk(status)
        self.assertFalse(newtasks)
        self.assertFalse(result)

    def test_zzz_etag_not_working(self):
        status, newtasks, result = self.crawl(self.httpbin+'/cache', callback=self.json)
        self.assertStatusOk(status)
        self.assertTrue(result)

        status['track']['process']['ok'] = False
        status, newtasks, result = self.crawl(self.httpbin+'/cache',
                                              track=status['track'], callback=self.json)
        self.assertStatusOk(status)
        self.assertTrue(result)

    def test_zzz_unexpected_crawl_argument(self):
        with self.assertRaisesRegexp(TypeError, "unexpected keyword argument"):
            self.crawl(self.httpbin+'/cache', cookie={}, callback=self.json)

    def test_zzz_curl_get(self):
        status, newtasks, result = self.crawl("curl '"+self.httpbin+'''/get' -H 'DNT: 1' -H 'Accept-Encoding: gzip, deflate, sdch' -H 'Accept-Language: en,zh-CN;q=0.8,zh;q=0.6' -H 'User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_2) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2272.17 Safari/537.36' -H 'Binux-Header: Binux-Value' -H 'Accept: */*' -H 'Cookie: _gauges_unique_year=1; _gauges_unique=1; _ga=GA1.2.415471573.1419316591' -H 'Connection: keep-alive' --compressed''', callback=self.json)
        self.assertStatusOk(status)
        self.assertTrue(result)

        self.assertTrue(result['headers'].get('Binux-Header'), 'Binux-Value')

    def test_zzz_curl_post(self):
        status, newtasks, result = self.crawl("curl '"+self.httpbin+'''/post' -H 'Origin: chrome-extension://hgmloofddffdnphfgcellkdfbfbjeloo' -H 'Accept-Encoding: gzip, deflate' -H 'Accept-Language: en,zh-CN;q=0.8,zh;q=0.6' -H 'User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_2) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2272.17 Safari/537.36' -H 'Content-Type: application/x-www-form-urlencoded' -H 'Accept: */*' -H 'Cookie: _gauges_unique_year=1; _gauges_unique=1; _ga=GA1.2.415471573.1419316591' -H 'Connection: keep-alive' -H 'DNT: 1' --data 'Binux-Key=%E4%B8%AD%E6%96%87+value' --compressed''', callback=self.json)
        self.assertStatusOk(status)
        self.assertTrue(result)

        self.assertTrue(result['form'].get('Binux-Key'), '中文 value')

    def test_zzz_curl_put(self):
        status, newtasks, result = self.crawl("curl '"+self.httpbin+'''/put' -X PUT -H 'Origin: chrome-extension://hgmloofddffdnphfgcellkdfbfbjeloo' -H 'Accept-Encoding: gzip, deflate, sdch' -H 'Accept-Language: en,zh-CN;q=0.8,zh;q=0.6' -H 'User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_2) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2272.17 Safari/537.36' -H 'Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryYlkgyaA7SRGOQYUG' -H 'Accept: */*' -H 'Cookie: _gauges_unique_year=1; _gauges_unique=1; _ga=GA1.2.415471573.1419316591' -H 'Connection: keep-alive' -H 'DNT: 1' --data-binary $'------WebKitFormBoundaryYlkgyaA7SRGOQYUG\r\nContent-Disposition: form-data; name="Binux-Key"\r\n\r\n%E4%B8%AD%E6%96%87+value\r\n------WebKitFormBoundaryYlkgyaA7SRGOQYUG\r\nContent-Disposition: form-data; name="fileUpload1"; filename="1"\r\nContent-Type: application/octet-stream\r\n\r\n\r\n------WebKitFormBoundaryYlkgyaA7SRGOQYUG--\r\n' --compressed''', callback=self.json)
        self.assertStatusOk(status)
        self.assertTrue(result)

        self.assertIn('fileUpload1', result['files'], result)

    def test_zzz_curl_no_url(self):
        with self.assertRaisesRegexp(TypeError, 'no URL'):
            status, newtasks, result = self.crawl(
                '''curl -X PUT -H 'Origin: chrome-extension://hgmloofddffdnphfgcellkdfbfbjeloo' --compressed''',
                callback=self.json)

    def test_zzz_curl_bad_option(self):
        with self.assertRaisesRegexp(TypeError, 'Unknow curl option'):
            status, newtasks, result = self.crawl(
                '''curl '%s/put' -X PUT -H 'Origin: chrome-extension://hgmloofddffdnphfgcellkdfbfbjeloo' -v''' % self.httpbin,
                callback=self.json)

        with self.assertRaisesRegexp(TypeError, 'Unknow curl option'):
            status, newtasks, result = self.crawl(
                '''curl '%s/put' -X PUT -v -H 'Origin: chrome-extension://hgmloofddffdnphfgcellkdfbfbjeloo' ''' % self.httpbin,
                callback=self.json)


    def test_zzz_robots_txt(self):
        status, newtasks, result = self.crawl(self.httpbin+'/deny', robots_txt=True, callback=self.catch_http_error)

        self.assertEqual(result, 403)
コード例 #55
0
class Subscriber(threading.Thread):
    """
    Thread responsible for event subscriptions.
    Issues subscriptions, creates the websocket, and refreshes the
    subscriptions before timer expiry.  It also reissues the
    subscriptions when the APIC login is refreshed.
    """
    def __init__(self, apic):
        threading.Thread.__init__(self)
        self._apic = apic
        self._subscriptions = {}
        self._ws = None
        self._ws_url = None
        self._refresh_time = 30
        self._event_q = Queue()
        self._events = {}
        self._exit = False
        self.event_handler_thread = None

    def exit(self):
        """
        Indicate that the thread should exit.
        """
        self._exit = True

    def _send_subscription(self, url, only_new=False):
        """
        Send the subscription for the specified URL.

        :param url: URL string to issue the subscription
        """
        try:
            resp = self._apic.get(url)
        except ConnectionError:
            self._subscriptions[url] = None
            logging.error('Could not send subscription to APIC for url %s',
                          url)
            resp = requests.Response()
            resp.status_code = 404
            resp._content = '{"error": "Could not send subscription to APIC"}'
            return resp
        if not resp.ok:
            self._subscriptions[url] = None
            logging.error('Could not send subscription to APIC for url %s',
                          url)
            resp = requests.Response()
            resp.status_code = 404
            resp._content = '{"error": "Could not send subscription to APIC"}'
            return resp
        resp_data = json.loads(resp.text)
        if 'subscriptionId' not in resp_data:
            logging.error(
                'Did not receive proper subscription response from APIC for url %s response: %s',
                url, resp_data)
            resp = requests.Response()
            resp.status_code = 404
            resp._content = '{"error": "Could not send subscription to APIC"}'
            return resp
        subscription_id = resp_data['subscriptionId']
        self._subscriptions[url] = subscription_id
        if not only_new:
            while len(resp_data['imdata']):
                event = {
                    "totalCount": "1",
                    "subscriptionId": [resp_data['subscriptionId']],
                    "imdata": [resp_data["imdata"][0]]
                }
                self._event_q.put(json.dumps(event))
                resp_data["imdata"].remove(resp_data["imdata"][0])
        return resp

    def refresh_subscriptions(self):
        """
        Refresh all of the subscriptions.
        """
        # Make a copy of the current subscriptions in case of changes
        # while we are refreshing
        current_subscriptions = {}
        for subscription in self._subscriptions:
            try:
                current_subscriptions[subscription] = self._subscriptions[
                    subscription]
            except KeyError:
                logging.warning('Subscription removed while copying')

        # Refresh the subscriptions
        for subscription in current_subscriptions:
            if self._ws is not None:
                if not self._ws.connected:
                    logging.warning(
                        'Websocket not established on subscription refresh. Re-establishing websocket'
                    )
                    self._open_web_socket('https://' in subscription)
            try:
                subscription_id = self._subscriptions[subscription]
            except KeyError:
                logging.warning(
                    'Subscription has been removed while trying to refresh')
                continue
            if subscription_id is None:
                self._send_subscription(subscription)
                continue
            refresh_url = '/api/subscriptionRefresh.json?id=' + str(
                subscription_id)
            resp = self._apic.get(refresh_url)
            if not resp.ok:
                logging.warning('Could not refresh subscription: %s',
                                refresh_url)
                # Try to resubscribe
                self._resubscribe()

    def _open_web_socket(self, use_secure=True):
        """
        Opens the web socket connection with the APIC.

        :param use_secure: Boolean indicating whether the web socket
                           should be secure.  Default is True.
        """
        sslopt = {}
        if use_secure:
            sslopt['cert_reqs'] = ssl.CERT_NONE
            self._ws_url = 'wss://%s/socket%s' % (self._apic.ipaddr,
                                                  self._apic.token)
        else:
            self._ws_url = 'ws://%s/socket%s' % (self._apic.ipaddr,
                                                 self._apic.token)

        kwargs = {}
        if self._ws is not None:
            if self._ws.connected:
                self._ws.close()
                self.event_handler_thread.exit()
        try:
            self._ws = create_connection(self._ws_url, sslopt=sslopt, **kwargs)
            if not self._ws.connected:
                logging.error('Unable to open websocket connection')
            self.event_handler_thread = EventHandler(self)
            self.event_handler_thread.daemon = True
            self.event_handler_thread.start()
        except WebSocketException:
            logging.error(
                'Unable to open websocket connection due to WebSocketException'
            )
        except socket.error:
            logging.error(
                'Unable to open websocket connection due to Socket Error')

    def _resubscribe(self):
        """
        Reissue the subscriptions.
        Used to when the APIC login timeout occurs and a new subscription
        must be issued instead of simply a refresh.  Not meant to be called
        directly by end user applications.
        """
        self._process_event_q()
        urls = []
        for url in self._subscriptions:
            urls.append(url)
        self._subscriptions = {}
        for url in urls:
            self.subscribe(url, only_new=True)

    def _process_event_q(self):
        """
        Put the event into correct bucket based on URLs that have been
        subscribed.
        """
        if self._event_q.empty():
            return

        while not self._event_q.empty():
            event = self._event_q.get()
            orig_event = event
            try:
                event = json.loads(event)
            except ValueError:
                logging.error('Non-JSON event: %s', orig_event)
                continue
            # Find the URL for this event
            num_subscriptions = len(event['subscriptionId'])
            for i in range(0, num_subscriptions):
                url = None
                for k in self._subscriptions:
                    if self._subscriptions[k] == str(
                            event['subscriptionId'][i]):
                        url = k
                        break
                if url not in self._events:
                    self._events[url] = []
                self._events[url].append(event)
                if num_subscriptions > 1:
                    event = copy.deepcopy(event)

    def subscribe(self, url, only_new=False):
        """
        Subscribe to a particular APIC URL.  Used internally by the
        Class and Instance subscriptions.

        :param url: URL string to send as a subscription
        """
        logging.info('Subscribing to url: %s', url)
        # Check if already subscribed.  If so, skip
        if url in self._subscriptions:
            return

        if self._ws is not None:
            if not self._ws.connected:
                self._open_web_socket('https://' in url)

        resp = self._send_subscription(url, only_new=only_new)
        return resp

    def is_subscribed(self, url):
        """
        Check if subscribed to a particular APIC URL.

        :param url: URL string to send as a subscription
        """
        return url in self._subscriptions

    def has_events(self, url):
        """
        Check if a particular APIC URL subscription has any events.
        Used internally by the Class and Instance subscriptions.

        :param url: URL string to check for pending events
        """
        self._process_event_q()
        if url not in self._events:
            return False
        result = len(self._events[url]) != 0
        return result

    def get_event_count(self, url):
        """
        Check the number of subscription events for a particular APIC URL

        :param url: URL string to check for pending events
        :returns: Interger number of events in event queue
        """
        self._process_event_q()
        if url not in self._events:
            return 0
        return len(self._events[url])

    def get_event(self, url):
        """
        Get an event for a particular APIC URL subscription.
        Used internally by the Class and Instance subscriptions.

        :param url: URL string to get pending event
        """
        self._process_event_q()
        if url not in self._events:
            raise ValueError
        event = self._events[url].pop(0)
        logging.debug('Event received %s', event)
        return event

    def unsubscribe(self, url):
        """
        Unsubscribe from a particular APIC URL.  Used internally by the
        Class and Instance subscriptions.

        :param url: URL string to unsubscribe
        """
        logging.info('Unsubscribing from url: %s', url)
        if url not in self._subscriptions:
            return
        if '&subscription=yes' in url:
            unsubscribe_url = url.split(
                '&subscription=yes')[0] + '&subscription=no'
        elif '?subscription=yes' in url:
            unsubscribe_url = url.split(
                '?subscription=yes')[0] + '?subscription=no'
        else:
            raise ValueError(
                'No subscription string in URL being unsubscribed')
        resp = self._apic.get(unsubscribe_url)
        if not resp.ok:
            logging.warning('Could not unsubscribe from url: %s',
                            unsubscribe_url)
        # Chew up any outstanding events
        while self.has_events(url):
            self.get_event(url)
        del self._subscriptions[url]
        if not self._subscriptions:
            self._ws.close(timeout=0)

    def run(self):
        while not self._exit:
            # Sleep for some interval and send subscription list
            time.sleep(self._refresh_time)
            try:
                self.refresh_subscriptions()
            except ConnectionError:
                logging.error(
                    'Could not refresh subscriptions due to ConnectionError')
コード例 #56
0
ファイル: mass_storage.py プロジェクト: Manouchehri/umap2
class ScsiDevice(USBBaseActor):
    '''
    Implementation of subset of the SCSI protocol
    '''
    name = 'SCSI Stack'

    def __init__(self, app, disk_image):
        super(ScsiDevice, self).__init__(app, None)
        self.disk_image = disk_image
        self.handlers = {
            ScsiCmds.INQUIRY: self.handle_inquiry,
            ScsiCmds.REQUEST_SENSE: self.handle_request_sense,
            ScsiCmds.TEST_UNIT_READY: self.handle_test_unit_ready,
            ScsiCmds.READ_CAPACITY_10: self.handle_read_capacity_10,
            # ScsiCmds.SEND_DIAGNOSTIC: self.handle_send_diagnostic,
            ScsiCmds.PREVENT_ALLOW_MEDIUM_REMOVAL: self.handle_prevent_allow_medium_removal,
            ScsiCmds.WRITE_10: self.handle_write_10,
            ScsiCmds.READ_10: self.handle_read_10,
            # ScsiCmds.WRITE_6: self.handle_write_6,
            # ScsiCmds.READ_6: self.handle_read_6,
            # ScsiCmds.VERIFY_10: self.handle_verify_10,
            ScsiCmds.MODE_SENSE_6: self.handle_mode_sense_6,
            ScsiCmds.MODE_SENSE_10: self.handle_mode_sense_10,
            ScsiCmds.READ_FORMAT_CAPACITIES: self.handle_read_format_capacities,
            ScsiCmds.SYNCHRONIZE_CACHE: self.handle_synchronize_cache,
        }
        self.tx = Queue()
        self.rx = Queue()
        self.stop_event = Event()
        self.thread = Thread(target=self.handle_data_loop)
        self.thread.daemon = True
        self.thread.start()
        self.is_write_in_progress = False
        self.write_cbw = None
        self.write_base_lba = 0
        self.write_length = 0
        self.write_data = b''

    def stop(self):
        self.stop_event.set()

    def handle_data_loop(self):
        while not self.stop_event.isSet():
            if not self.rx.empty():
                data = self.rx.get()
                self.handle_data(data)
            else:
                time.sleep(0.0001)

    def handle_data(self, data):
        if self.is_write_in_progress:
            self.handle_write_data(data)
        else:
            cbw = CommandBlockWrapper(data)
            opcode = cbw.opcode
            if opcode in self.handlers:
                try:
                    resp = self.handlers[opcode](cbw)
                    if resp is not None:
                        self.tx.put(resp)
                    self.tx.put(scsi_status(cbw, 0))
                except Exception as ex:
                    self.warning('exception while processing opcode %#x' % (opcode))
                    self.warning(ex)
                    self.tx.put(scsi_status(cbw, 2))
            else:
                raise Exception('No handler for opcode %#x' % (opcode))

    def handle_write_data(self, data):
        self.debug('got %#x bytes of SCSI write data' % (len(data)))
        self.write_data += data
        if len(self.write_data) >= self.write_length:
            # done writing
            self.disk_image.put_sector_data(self.write_base_lba, self.write_data)
            self.is_write_in_progress = False
            self.write_data = b''
            self.tx.put(scsi_status(self.write_cbw, 0))

    @mutable('scsi_inquiry_response')
    def handle_inquiry(self, cbw):
        self.debug('SCSI Inquiry, data: %s' % hexlify(cbw.cb[1:]))
        peripheral = 0x00  # SBC
        RMB = 0x80  # Removable
        version = 0x00
        response_data_format = 0x01
        config = (0x00, 0x00, 0x00)
        vendor_id = b'MBYDCOR '
        product_id = b'UMAP2 DISK IMAG '
        product_revision_level = b'8.02'
        part1 = struct.pack('BBBB', peripheral, RMB, version, response_data_format)
        part2 = struct.pack('BBB', *config) + vendor_id + product_id + product_revision_level
        length = struct.pack('B', len(part2))
        response = part1 + length + part2
        return response

    @mutable('scsi_request_sense_response')
    def handle_request_sense(self, cbw):
        self.debug('SCSI Request Sense, data: %s' % hexlify(cbw.cb[1:]))
        response_code = 0x70
        valid = 0x00
        filemark = 0x06
        information = 0x00000000
        command_info = 0x00000000
        additional_sense_code = 0x3a
        additional_sens_code_qualifier = 0x00
        field_replacement_unti_code = 0x00
        sense_key_specific = b'\x00\x00\x00'

        part1 = struct.pack('<BBBI', response_code, valid, filemark, information)
        part2 = struct.pack(
            '<IBBB',
            command_info,
            additional_sense_code,
            additional_sens_code_qualifier,
            field_replacement_unti_code
        )
        part2 += sense_key_specific
        length = struct.pack('B', len(part2))
        response = part1 + length + part2
        return response

    @mutable('scsi_test_unit_ready_response')
    def handle_test_unit_ready(self, cbw):
        self.debug('SCSI Test Unit Ready, logical unit number: %02x' % (cbw.cb[1]))

    @mutable('scsi_read_capacity_10_response')
    def handle_read_capacity_10(self, cbw):
        self.debug('SCSI Read Capacity, data: %s' % hexlify(cbw.cb[1:]))
        lastlba = self.disk_image.get_sector_count()
        logical_block_address = struct.pack('>I', lastlba)
        length = 0x00000200
        response = logical_block_address + struct.pack('>I', length)
        return response

    @mutable('scsi_send_diagnostic_response')
    def handle_send_diagnostic(self, cbw):
        raise NotImplementedError('yet...')

    @mutable('scsi_prevent_allow_medium_removal_response')
    def handle_prevent_allow_medium_removal(self, cbw):
        self.debug('SCSI Prevent/Allow Removal')

    @mutable('scsi_write_10_response')
    def handle_write_10(self, cbw):
        self.debug('SCSI Write (10), data: %s' % hexlify(cbw.cb[1:]))

        base_lba = struct.unpack('>I', cbw.cb[2:6])[0]
        num_blocks = struct.unpack('>H', cbw.cb[7:9])[0]

        self.debug('SCSI Write (10), lba %#x + %#x block(s)' % (base_lba, num_blocks))

        # save for later
        self.write_cbw = cbw
        self.write_base_lba = base_lba
        self.write_length = num_blocks * self.disk_image.block_size
        self.is_write_in_progress = True

    def handle_read_10(self, cbw):
        base_lba = struct.unpack('>I', cbw.cb[2:6])[0]
        num_blocks = struct.unpack('>H', cbw.cb[7:9])[0]
        self.debug('SCSI Read (10), lba %#x + %#x block(s)' % (base_lba, num_blocks))
        for block_num in range(num_blocks):
            data = self.disk_image.get_sector_data(base_lba + block_num)
            self.tx.put(data)

    @mutable('scsi_write_6_response')
    def handle_write_6(self, cbw):
        raise NotImplementedError('yet...')

    @mutable('scsi_read_6_response')
    def handle_read_6(self, cbw):
        raise NotImplementedError('yet...')

    @mutable('scsi_verify_10_response')
    def handle_verify_10(self, cbw):
        raise NotImplementedError('yet...')

    def handle_scsi_mode_sense(self, cbw):
        page = cbw.cb[2] & 0x3f

        self.debug('SCSI Mode Sense, page code 0x%02x' % page)

        if page == 0x1c:
            medium_type = 0x00
            device_specific_param = 0x00
            block_descriptor_len = 0x00
            mode_page_1c = b'\x1c\x06\x00\x05\x00\x00\x00\x00'
            body = struct.pack('BBB', medium_type, device_specific_param, block_descriptor_len)
            body += mode_page_1c
            length = struct.pack('<B', len(body))
            response = length + body

        elif page == 0x3f:
            length = 0x45  # .. todo:: this seems awefully wrong
            medium_type = 0x00
            device_specific_param = 0x00
            block_descriptor_len = 0x08
            mode_page = 0x00000000
            response = struct.pack('<BBBBI', length, medium_type, device_specific_param, block_descriptor_len, mode_page)
        else:
            length = 0x07
            medium_type = 0x00
            device_specific_param = 0x00
            block_descriptor_len = 0x00
            mode_page = 0x00000000
            response = struct.pack('<BBBBI', length, medium_type, device_specific_param, block_descriptor_len, mode_page)
        return response

    @mutable('scsi_mode_sense_6_response')
    def handle_mode_sense_6(self, cbw):
        return self.handle_scsi_mode_sense(cbw)

    @mutable('scsi_mode_sense_10_response')
    def handle_mode_sense_10(self, cbw):
        return self.handle_scsi_mode_sense(cbw)

    @mutable('scsi_read_format_capacities')
    def handle_read_format_capacities(self, cbw):
        self.debug('SCSI Read Format Capacity')
        # header
        response = struct.pack('>I', 8)
        num_sectors = 0x1000
        reserved = 0x1000
        sector_size = 0x200
        response += struct.pack('>IHH', num_sectors, reserved, sector_size)
        return response

    @mutable('scsi_synchronize_cache_response')
    def handle_synchronize_cache(self, cbw):
        self.debug('Synchronize Cache (10)')