class Server(object): def __init__(self, port): self.context = None self.thread = None self.queue = None self.webserver = None self.port = port def __enter__(self): self.context = ExitStack() self.context.enter_context(self.event_loop_context()) self.thread = EventLoopThread([self.webserver.server()]) self.context.enter_context(self.thread) return self def __exit__(self, *enc): self.context.__exit__(*enc) def recv(self): result = self.queue.get() return result def send(self, msg): asyncio.run_coroutine_threadsafe(self.webserver.broadcast(msg), self.thread.loop) @contextmanager def event_loop_context(self): with ExitStack() as stack: stack.callback(lambda: setattr(self, "queue", None)) stack.callback(lambda: setattr(self, "webserver", None)) self.queue = Queue() self.webserver = WebsocketServer(self.queue, self.port) yield
class ReposTestCase(AsyncTestCase): async def setUp(self): super().setUp() self.repos = [Repo(str(random.random())) for _ in range(0, 10)] self.__stack = ExitStack() self._stack = self.__stack.__enter__() self.temp_filename = self.mktempfile() self.sconfig = FileSourceConfig(filename=self.temp_filename) async with JSONSource(self.sconfig) as source: async with source() as sctx: for repo in self.repos: await sctx.update(repo) contents = json.loads(Path(self.sconfig.filename).read_text()) # Ensure there are repos in the file self.assertEqual( len(contents.get(self.sconfig.label)), len(self.repos), "ReposTestCase JSON file erroneously initialized as empty", ) def tearDown(self): super().tearDown() self.__stack.__exit__(None, None, None) def mktempfile(self): return self._stack.enter_context(non_existant_tempfile())
class Server(object): def __init__(self, port): self.context = None self.thread = None self.queue = None self.webserver = None self.port = port def __enter__(self): self.context = ExitStack() self.context.enter_context(self.event_loop_context()) self.thread = EventLoopThread([self.webserver.server()]) self.context.enter_context(self.thread) return self def __exit__(self, *enc): self.context.__exit__(*enc) def recv(self): result = self.queue.get() return result def send(self, msg): asyncio.run_coroutine_threadsafe(self.webserver.broadcast(msg), self.thread.loop) @contextmanager def event_loop_context(self): with ExitStack() as stack: stack.callback(lambda: setattr(self, 'queue', None)) stack.callback(lambda: setattr(self, 'webserver', None)) self.queue = Queue() self.webserver = WebsocketServer(self.queue, self.port) yield
class LocalDarRepository: def __init__(self): self._context = ExitStack() self._dar_paths = [] # type: List[Path] self._files = set() self.store = PackageStore.empty() def _add_source(self, path: Path) -> None: ext = path.suffix.lower() if ext == '.daml': LOG.error('Reading metadata directly from a DAML file is not supported.') raise ValueError('Unsupported extension: .daml') elif ext == '.dalf': LOG.error('Reading metadata directly from a DALF file is not supported.') raise ValueError('Unsupported extension: .dalf') elif ext == '.dar': dar_parse_start_time = time.time() dar = self._context.enter_context(DarFile(path)) dar_package = dar.read_metadata() self._dar_paths.append(path) self.store.register_all(dar_package) dar_parse_end_time = time.time() LOG.debug('Parsed a dar in %s seconds.', dar_parse_end_time - dar_parse_start_time) else: LOG.error('Unknown extension: %s', ext) raise ValueError(f'Unknown extension: {ext}') def add_source(self, *files: Union[str, Path]) -> None: """ Add a source file (either a .daml file, .dalf file, or a .dar file). Attempts to add the same file more than once will be ignored. :param files: Files to add to the archive. """ for file in files: path = pathify(file).resolve(strict=True) if path not in self._files: self._files.add(path) self._add_source(path) def get_daml_archives(self) -> Sequence[Path]: return self._dar_paths def __enter__(self): """ Does nothing. """ self._context.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): """ Delete all managed resources in the reverse order that they were created. """ self._context.__exit__(exc_type, exc_val, exc_tb)
class capture_sql(ContextDecorator): """ Capture SQL executed on ALL databases listed in settings ``` with capture_sql() as capture: # do some quries capture.print_sql(with_traceback=False) ``` """ def __init__(self): self.query_contexts = {} self.wrapper = utils.CursorDebugWrapper def __enter__(self): utils.CursorDebugWrapper = CursorDebugWrapperWithTraceback self._stack = ExitStack() for db in settings.DATABASES: context = CaptureQueriesContext(connections[db]) self.query_contexts[db] = context self._stack.enter_context(context) return self def __exit__(self, *exc_details): utils.CursorDebugWrapper = self.wrapper self._stack.__exit__(*exc_details) @property def queries_by_db(self): return { db: context.captured_queries for db, context in self.query_contexts.items() } def print_sql(self, with_traceback=False, width=150): for db, queries in self.queries_by_db.items(): if not queries: continue print( f'\n------- Queries for Database "{db}" ({len(queries)}) ---------' ) for q in queries: out = f"{q['sql']} (took {q['time']})" print('\n{}'.format(indent('\n'.join(wrap(out, width)), '\t'))) if with_traceback: print('\n{}'.format(indent(''.join(q['traceback']), '\t\t'))) def print_sql_chronologically(self, with_traceback=False, width=150): all_queries = sorted([(db, query) for db, queries in self.queries_by_db.items() for query in queries], key=lambda q: q[1]['start']) for db, query in all_queries: out = f"({db}) {query['sql']} (took {query['time']})" print('\n{}'.format(indent('\n'.join(wrap(out, width)), '\t'))) if with_traceback: print('\n{}'.format(indent(''.join(query['traceback']), '\t\t')))
class VFS: def __init__(self): self.files = [] self.exit_stack = ExitStack() def __enter__(self): return self def __exit__(self, *exc_details): self.exit_stack.__exit__(*exc_details)
class MultiPage: """ Multi-page output, for formats that support it. Usage is similar to `matplotlib.backends.backend_pdf.PdfPages`:: with MultiPage(path, metadata=...) as mp: mp.savefig(fig1) mp.savefig(fig2) Note that the only other method of `PdfPages` that is implemented is `close`, and that empty files are not created -- as if the *keep_empty* argument to `PdfPages` was always False. """ def __init__(self, path_or_stream=None, format=None, *, metadata=None): self._stack = ExitStack() self._renderer = None def _make_renderer(): stream = self._stack.enter_context( cbook.open_file_cm(path_or_stream, "wb")) fmt = (format or Path(getattr(stream, "name", "")).suffix[1:] or rcParams["savefig.format"]).lower() renderer_cls = { "pdf": GraphicsContextRendererCairo._for_pdf_output, "ps": GraphicsContextRendererCairo._for_ps_output, }[fmt] self._renderer = renderer_cls(stream, 1, 1, 1) self._stack.callback(self._renderer._finish) self._renderer._set_metadata(copy.copy(metadata)) self._make_renderer = _make_renderer def savefig(self, figure, **kwargs): # FIXME[Upstream]: Not all kwargs are supported here -- but I plan to # deprecate them upstream. if self._renderer is None: self._make_renderer() figure.set_dpi(72) self._renderer._set_size(*figure.canvas.get_width_height(), kwargs.get("dpi", 72)) with _LOCK: figure.draw(self._renderer) self._renderer._show_page() def close(self): return self._stack.__exit__(None, None, None) def __enter__(self): return self def __exit__(self, *args): return self._stack.__exit__(*args)
class EventLoopThread(object): def __init__(self, servers_to_start): self.context = None self.executor = None self.loop = None self.servers_to_start = servers_to_start self.servers = [] def __enter__(self): self.context = ExitStack() self.executor = self.context.enter_context( ThreadPoolExecutor(max_workers=1)) self.context.enter_context(self.event_loop_context()) return self def __exit__(self, *enc): self.context.__exit__(*enc) self.context = None self.executor = None self.loop = None def start_loop(self, event): logger.info('starting eventloop server') loop = asyncio.new_event_loop() self.loop = loop asyncio.set_event_loop(loop) for server_starter in self.servers_to_start: server = loop.run_until_complete(server_starter) self.servers.append(server) loop.call_soon(event.set) loop.run_forever() def stop_loop(self): logger.info('stopping eventloop server') self.loop.create_task(self._close_connections()) @contextmanager def event_loop_context(self): event = Event() event.clear() self.executor.submit(self.start_loop, event) event.wait() logger.info('started eventloop') try: yield finally: self.loop.call_soon_threadsafe(self.stop_loop) logger.info('stopped eventloop') @asyncio.coroutine def _close_connections(self): for server in self.servers: server.close() yield from server.wait_closed() self.loop.stop()
class EventLoopThread(object): def __init__(self, servers_to_start): self.context = None self.executor = None self.loop = None self.servers_to_start = servers_to_start self.servers = [] def __enter__(self): self.context = ExitStack() self.executor = self.context.enter_context(ThreadPoolExecutor(max_workers=1)) self.context.enter_context(self.event_loop_context()) return self def __exit__(self, *enc): self.context.__exit__(*enc) self.context = None self.executor = None self.loop = None def start_loop(self, event): logger.info("starting eventloop server") loop = asyncio.new_event_loop() self.loop = loop asyncio.set_event_loop(loop) for server_starter in self.servers_to_start: server = loop.run_until_complete(server_starter) self.servers.append(server) loop.call_soon(event.set) loop.run_forever() def stop_loop(self): logger.info("stopping eventloop server") self.loop.create_task(self._close_connections()) @contextmanager def event_loop_context(self): event = Event() event.clear() self.executor.submit(self.start_loop, event) event.wait() logger.info("started eventloop") try: yield finally: self.loop.call_soon_threadsafe(self.stop_loop) logger.info("stopped eventloop") @asyncio.coroutine def _close_connections(self): for server in self.servers: server.close() yield from server.wait_closed() self.loop.stop()
class ContentTree(dict): def __init__(self, iterable=None): super().__init__(iterable or ()) self._es = ExitStack() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self._es.__exit__(exc_type, exc_val, exc_tb) def set_context(self, key, value): self[key] = self._es.enter_context(value)
class ContextManagers: """ Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` in the `fastcore` library. """ def __init__(self, context_managers: List[ContextManager]): self.context_managers = context_managers self.stack = ExitStack() def __enter__(self): for context_manager in self.context_managers: self.stack.enter_context(context_manager) def __exit__(self, *args, **kwargs): self.stack.__exit__(*args, **kwargs)
class SyncEventSource: _subscribers: Dict[Type[Event], List[Callable[[Event], Any]]] def __init__(self, async_event_source: EventSource, portal: Optional[BlockingPortal] = None): self.portal = portal self._async_event_source = async_event_source self._async_event_source.subscribe(self._forward_async_event) self._exit_stack = ExitStack() self._subscribers = defaultdict(list) def __enter__(self): self._exit_stack.__enter__() if not self.portal: portal_cm = start_blocking_portal() self.portal = self._exit_stack.enter_context(portal_cm) self._async_event_source.subscribe(self._forward_async_event) def __exit__(self, exc_type, exc_val, exc_tb): self._exit_stack.__exit__(exc_type, exc_val, exc_tb) async def _forward_async_event(self, event: Event) -> None: for subscriber in self._subscribers.get(type(event), ()): await run_sync_in_worker_thread(subscriber, event) def subscribe(self, callback: Callable[[Event], Any], event_types: Optional[Iterable[Type[Event]]] = None) -> None: if event_types is None: event_types = _all_event_types for event_type in event_types: existing_callbacks = self._subscribers[event_type] if callback not in existing_callbacks: existing_callbacks.append(callback) def unsubscribe( self, callback: Callable[[Event], Any], event_types: Optional[Iterable[Type[Event]]] = None) -> None: if event_types is None: event_types = _all_event_types for event_type in event_types: existing_callbacks = self._subscribers.get(event_type, []) with suppress(ValueError): existing_callbacks.remove(callback)
class _CallbackManager(Callback): """ Sequential execution of callback functions. Execute Callback functions at certain points. Args: callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list. """ def __init__(self, callbacks): self._callbacks, self._stack = [], None if isinstance(callbacks, Callback): self._callbacks.append(callbacks) elif callbacks is not None: for cb in callbacks: if not isinstance(cb, Callback): raise TypeError("%r is not an instance of %r" % (cb, Callback)) self._callbacks.append(cb) def __enter__(self): if self._stack is None: self._stack = ExitStack().__enter__() self._callbacks = [ self._stack.enter_context(cb) for cb in self._callbacks ] return self def __exit__(self, *err): return self._stack.__exit__(*err) def begin(self, run_context): """Called once before network training.""" for cb in self._callbacks: cb.begin(run_context) def epoch_begin(self, run_context): """Called before each epoch begin.""" for cb in self._callbacks: cb.epoch_begin(run_context) def epoch_end(self, run_context): """Called after each epoch finished.""" for cb in self._callbacks: cb.epoch_end(run_context) def step_begin(self, run_context): """Called before each epoch begin.""" for cb in self._callbacks: cb.step_begin(run_context) def step_end(self, run_context): """Called after each step finished.""" for cb in self._callbacks: cb.step_end(run_context) def end(self, run_context): """Called once after network training.""" for cb in self._callbacks: cb.end(run_context)
class BaseWorker(metaclass=ABCMeta): def __init__(self, *, mirrors=None): super().__init__() self.__open = 0 self.stack = ExitStack() def assert_open(self): assert self.__open def __enter__(self): self.__open += 1 if self.__open == 1: self._open() return self def _open(self): pass def __exit__(self, et, ev, tb): self.__open -= 1 if self.__open: return False else: return self.stack.__exit__(et, ev, tb)
class MockIterEntryPoints(AsyncTestCase): def iter_entry_points(self, entrypoint): for key, value in self.entrypoints[entrypoint].items(): mock = MagicMock() mock.name = key mock.load.return_value = value yield mock async def setUp(self): self.exit_stack = ExitStack().__enter__() self.exit_stack.enter_context( patch("pkg_resources.iter_entry_points", new=self.iter_entry_points)) async def tearDown(self): self.exit_stack.__exit__(None, None, None)
class BlockMetric: """Enable tracking on a block of code""" #: Trackers activated during the execution of the block of code trackers = [executions, errors, processing_time] def __init__(self, client, metric): self.client = client self.metric = metric def __enter__(self): self.stack = ExitStack() for tracker in self.trackers: self.stack.enter_context(tracker(self.client, self.metric)) def __exit__(self, exc_type, exc_val, exc_tb): self.stack.__exit__(exc_type, exc_val, exc_tb)
class VFS: files: List[File] def __init__(self) -> None: self.files = [] self.exit_stack = ExitStack() def __enter__(self) -> "VFS": return self def __exit__(self, *exc_details) -> None: # type: ignore[no-untyped-def] self.exit_stack.__exit__(*exc_details) def filter(self, path: str) -> List[File]: pattern = re.compile(path.replace("**", "*").replace( "*", ".*")) # support for both * and ** notation return [f for f in self.files if re.match(pattern, f.name)]
class WPRobotBase(object): def __init__(self): self.context = None self.devices = [] def __enter__(self): wiringpi2.wiringPiSetupGpio() self.context = ExitStack() for device in self.devices: self.context.enter_context(device) return self def __exit__(self, *exc): self.context.__exit__(*exc) def attach_device(self, device): self.devices.append(device) return device
class ReposTestCase(AsyncTestCase): async def setUp(self): super().setUp() self.repos = [Repo(str(random.random())) for _ in range(0, 10)] self._stack = ExitStack().__enter__() self.temp_filename = self.mktempfile() self.sconfig = FileSourceConfig(filename=self.temp_filename) async with JSONSource(self.sconfig) as source: async with source() as sctx: for repo in self.repos: await sctx.update(repo) contents = json.loads(Path(self.sconfig.filename).read_text()) # Ensure there are repos in the file self.assertEqual( len(contents.get(self.sconfig.label)), len(self.repos), "ReposTestCase JSON file erroneously initialized as empty", ) # TODO(p3) For some reason patching Model.load doesn't work # self._stack.enter_context(patch("dffml.model.model.Model.load", # new=model_load)) self._stack.enter_context( patch.object( ModelCMD, "arg_model", new=ModelCMD.arg_model.modify(type=model_load), ) ) self._stack.enter_context( patch("dffml.feature.feature.Feature.load", new=feature_load) ) self._stack.enter_context( patch("dffml.df.base.OperationImplementation.load", new=opimp_load) ) self._stack.enter_context( patch("dffml.df.types.Operation.load", new=op_load) ) def tearDown(self): super().tearDown() self._stack.__exit__(None, None, None) def mktempfile(self): return self._stack.enter_context(non_existant_tempfile())
class RobotBase(object): def __init__(self): self.context = None self.devices = [] def __enter__(self): GPIO.setmode(GPIO.BCM) self.context = ExitStack() for device in self.devices: self.context.enter_context(device) return self def __exit__(self, *exc): self.context.__exit__(*exc) GPIO.cleanup() def attach_device(self, device): self.devices.append(device) return device
class SentiData: def __init__(self): self._stack = ExitStack() self.distant_docs = [] self.distant_labels = [] self.unsup_docs = [] def __enter__(self): return self def __exit__(self, *exc_details): return self._stack.__exit__(*exc_details)
class TestOWDataSets(WidgetTest): def setUp(self): super().setUp() remote = { ("a", "b"): {"title": "K", "size": 10}, ("a", "c"): {"title": "T", "size": 20}, ("a", "d"): {"title": "Y", "size": 0}, } self.exit = ExitStack() self.exit.__enter__() self.exit.enter_context( mock.patch.object(owdatasets, "list_remote", lambda: remote) ) self.exit.enter_context( mock.patch.object(owdatasets, "list_local", lambda: {}) ) self.widget = self.create_widget( OWDataSets, stored_settings={ "selected_id": ("a", "c"), "auto_commit": False, } ) # type: OWDataSets def tearDown(self): super().tearDown() self.exit.__exit__(None, None, None) def test_init(self): if self.widget.isBlocking(): spy = QSignalSpy(self.widget.blockingStateChanged) assert spy.wait(1000) self.assertFalse(self.widget.isBlocking()) model = self.widget.view.model() self.assertEqual(model.rowCount(), 3) di = self.widget.selected_dataset() self.assertEqual((di.prefix, di.filename), ("a", "c"))
class SourceExtractor(object): """A class to extract a source package to its constituent parts""" def __init__(self, dsc_path, dsc): self.dsc_path = dsc_path self.dsc = dsc self.extracted_upstream = None self.extracted_debianised = None self.unextracted_debian_md5 = None self.upstream_tarballs = [] self.exit_stack = ExitStack() def extract(self): """Extract the package to a new temporary directory.""" raise NotImplementedError(self.extract) def __enter__(self): self.exit_stack.__enter__() self.extract() return self def __exit__(self, exc_type, exc_val, exc_tb): self.exit_stack.__exit__(exc_type, exc_val, exc_tb) return False
class patch_auth: # pylint: disable=invalid-name def __init__(self, project_id: str = "potato-dev", location: str = "moon-dark1", email: str = "*****@*****.**"): # Realistic: actual class to be accepted by clients during validation # But fake: with as few attributes as possible, any API call using the credential should fail credentials = Credentials( service_account_email=email, signer=None, token_uri="", project_id=project_id, ) managers = [ patch("google.auth.default", return_value=(credentials, project_id)), patch("gcp_pilot.base.GoogleCloudPilotAPI._set_location", return_value=location), patch("gcp_pilot.base.AppEngineBasedService._set_location", return_value=location), ] self.stack = ExitStack() for mgr in managers: self.stack.enter_context(mgr) def __enter__(self): return self.stack.__enter__() def start(self): return self.__enter__() def __exit__(self, typ, val, traceback): return self.stack.__exit__(typ, val, traceback) def stop(self): self.__exit__(None, None, None) def __call__(self, func): @wraps(func) def wrapper(*args, **kw): with self: return func(*args, **kw) return wrapper
def _exit(self, exc_type, exc_val, exc_tb, daemon=False): stack = ExitStack() # called last @stack.push def exit_loop(exc_type, exc_val, exc_tb): if self.client.is_closed: return self._loop.__exit__(exc_type, exc_val, exc_tb) if threading.current_thread() is threading.main_thread(): # TODO the main thread is not necessarily the last thread to finish. # Should the signal handler be removed in case it isn't? stack.push(self._signal_ctx) # called first # exit the client with the given daemon-ness, maybe leading the client to close @stack.push def exit_client(exc_type, exc_val, exc_tb): return self._call(self.client._aexit(exc_type, exc_val, exc_tb, daemon=daemon)) return stack.__exit__(exc_type, exc_val, exc_tb)
class TestDiscoverApi(BaseApiTest): def setup_method(self, test_method): super().setup_method(test_method) # XXX: This should use the ``discover`` dataset directly, but that will # require some updates to the test base classes to work correctly. self.__dataset_manager = ExitStack() for dataset_name in ["events", "transactions"]: self.__dataset_manager.enter_context(dataset_manager(dataset_name)) self.app.post = partial(self.app.post, headers={"referer": "test"}) self.project_id = self.event["project_id"] self.base_time = datetime.utcnow().replace(minute=0, second=0, microsecond=0) self.trace_id = uuid.UUID("7400045b-25c4-43b8-8591-4600aa83ad04") self.span_id = "8841662216cc598b" self.generate_event() self.generate_transaction() def teardown_method(self, test_method): self.__dataset_manager.__exit__(None, None, None) def generate_event(self): self.dataset = get_dataset("events") self.write_events([self.event]) def generate_transaction(self): self.dataset = get_dataset("transactions") processed = ( enforce_table_writer(self.dataset).get_stream_loader(). get_processor().process_message( ( 2, "insert", { "project_id": self.project_id, "event_id": uuid.uuid4().hex, "deleted": 0, "datetime": (self.base_time).strftime("%Y-%m-%dT%H:%M:%S.%fZ"), "platform": "python", "retention_days": settings.DEFAULT_RETENTION_DAYS, "data": { "received": calendar.timegm((self.base_time).timetuple()), "type": "transaction", "transaction": "/api/do_things", "start_timestamp": datetime.timestamp(self.base_time), "timestamp": datetime.timestamp(self.base_time), "tags": { # Sentry "environment": u"prød", "sentry:release": "1", "sentry:dist": "dist1", "url": "http://127.0.0.1:/query", # User "foo": "baz", "foo.bar": "qux", "os_name": "linux", }, "user": { "email": "*****@*****.**", "ip_address": "8.8.8.8", "geo": { "city": "San Francisco", "region": "CA", "country_code": "US", }, }, "contexts": { "trace": { "trace_id": self.trace_id.hex, "span_id": self.span_id, "op": "http", }, "device": { "online": True, "charging": True, "model_id": "Galaxy", }, }, "measurements": { "lcp": { "value": 32.129 }, "lcp.elementSize": { "value": 4242 }, }, "sdk": { "name": "sentry.python", "version": "0.13.4", "integrations": ["django"], }, "request": { "url": "http://127.0.0.1:/query", "headers": [ ["Accept-Encoding", "identity"], ["Content-Length", "398"], ["Host", "127.0.0.1:"], ["Referer", "tagstore.something"], ["Trace", "8fa73032d-1"], ], "data": "", "method": "POST", "env": { "SERVER_PORT": "1010", "SERVER_NAME": "snuba" }, }, "spans": [{ "op": "db", "trace_id": self.trace_id.hex, "span_id": self.span_id + "1", "parent_span_id": None, "same_process_as_parent": True, "description": "SELECT * FROM users", "data": {}, "timestamp": calendar.timegm((self.base_time).timetuple()), }], }, }, ), KafkaMessageMetadata(0, 0, self.base_time), )) self.write_processed_messages([processed]) def test_raw_data(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "selected_columns": ["type", "tags[custom_tag]", "release"], "conditions": [["type", "!=", "transaction"]], "orderby": "timestamp", "limit": 1000, }), ) data = json.loads(response.data) assert response.status_code == 200 assert len(data["data"]) == 1, data assert data["data"][0] == { "type": "error", "tags[custom_tag]": "custom_value", "release": None, } response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "selected_columns": [ "type", "trace_id", "tags[foo]", "group_id", "release", "sdk_name", "geo_city", ], "conditions": [["type", "=", "transaction"]], "orderby": "timestamp", "limit": 1, }), ) data = json.loads(response.data) assert response.status_code == 200 assert len(data["data"]) == 1, data assert data["data"][0] == { "type": "transaction", "trace_id": str(self.trace_id), "tags[foo]": "baz", "group_id": 0, "release": "1", "geo_city": "San Francisco", "sdk_name": "sentry.python", } def test_aggregations(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [["count()", None, "count"]], "groupby": ["project_id", "tags[custom_tag]"], "conditions": [["type", "!=", "transaction"]], "orderby": "count", "limit": 1000, }), ) data = json.loads(response.data) assert response.status_code == 200 assert data["data"] == [{ "count": 1, "tags[custom_tag]": "custom_value", "project_id": self.project_id, }] response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [["count()", "", "count"]], "groupby": ["project_id", "tags[foo]", "trace_id"], "conditions": [["type", "=", "transaction"]], "orderby": "count", "limit": 1000, }), ) data = json.loads(response.data) assert response.status_code == 200 assert data["data"] == [{ "count": 1, "tags[foo]": "baz", "project_id": self.project_id, "trace_id": str(self.trace_id), }] def test_handles_columns_from_other_dataset(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [ ["count()", "", "count"], ["uniq", ["group_id"], "uniq_group_id"], ["uniq", ["exception_stacks.type"], "uniq_ex_stacks"], ], "conditions": [ ["type", "=", "transaction"], ["group_id", "=", 2], ["duration", ">=", 0], ], "groupby": ["type"], "limit": 1000, }), ) data = json.loads(response.data) assert response.status_code == 200 assert data["data"] == [{ "type": "transaction", "count": 0, "uniq_group_id": 0, "uniq_ex_stacks": 0 }] response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [["uniq", ["trace_id"], "uniq_trace_id"]], "conditions": [["type", "=", "error"]], "groupby": ["type", "group_id"], "limit": 1000, }), ) data = json.loads(response.data) assert response.status_code == 200 assert data["data"] == [{ "type": "error", "group_id": self.event["group_id"], "uniq_trace_id": 0 }] def test_geo_column_condition(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [["count()", "", "count"]], "conditions": [ ["duration", ">=", 0], ["geo_country_code", "=", "MX"], ], "limit": 1000, }), ) data = json.loads(response.data) assert response.status_code == 200 assert data["data"] == [{"count": 0}] response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [["count()", "", "count"]], "conditions": [ ["duration", ">=", 0], ["geo_country_code", "=", "US"], ["geo_region", "=", "CA"], ["geo_city", "=", "San Francisco"], ], "limit": 1000, }), ) data = json.loads(response.data) assert response.status_code == 200 assert data["data"] == [{"count": 1}] def test_exception_stack_column_condition(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [["count()", "", "count"]], "conditions": [ ["exception_stacks.type", "LIKE", "Arithmetic%"], ["exception_frames.filename", "LIKE", "%.java"], ], "limit": 1000, }), ) assert response.status_code == 200 data = json.loads(response.data) assert data["data"] == [{"count": 1}] def test_exception_stack_column_boolean_condition(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [["count", None, "count"]], "debug": True, "conditions": [ [ [ "or", [ [ "equals", [ "exception_stacks.type", "'ArithmeticException'", ], ], [ "equals", [ "exception_stacks.type", "'RuntimeException'" ], ], ], ], "=", 1, ], ], "limit": 1000, }), ) assert response.status_code == 200 data = json.loads(response.data) assert data["data"] == [{"count": 1}] def test_exception_stack_column_boolean_condition_with_arrayjoin(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [["count", None, "count"]], "arrayjoin": "exception_stacks.type", "groupby": "exception_stacks.type", "debug": True, "conditions": [ [ [ "or", [ [ "equals", [ "exception_stacks.type", "'ArithmeticException'", ], ], [ "equals", [ "exception_stacks.type", "'RuntimeException'" ], ], ], ], "=", 1, ], ], "limit": 1000, }), ) assert response.status_code == 200 data = json.loads(response.data) assert data["data"] == [{ "count": 1, "exception_stacks.type": "ArithmeticException" }] def test_exception_stack_column_boolean_condition_arrayjoin_function(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "selected_columns": [[ "arrayJoin", ["exception_stacks.type"], "exception_stacks.type", ]], "aggregations": [["count", None, "count"]], "groupby": "exception_stacks.type", "debug": True, "conditions": [ [ [ "or", [ [ "equals", [ "exception_stacks.type", "'ArithmeticException'", ], ], [ "equals", [ "exception_stacks.type", "'RuntimeException'" ], ], ], ], "=", 1, ], ], "limit": 1000, }), ) assert response.status_code == 200 data = json.loads(response.data) assert data["data"] == [{ "count": 1, "exception_stacks.type": "ArithmeticException" }] def test_tags_key_boolean_condition(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "turbo": False, "consistent": False, "aggregations": [["count", None, "count"]], "conditions": [ [ [ "or", [ [ "equals", [["ifNull", ["tags[foo]", "''"]], "'baz'"], ], [ "equals", [["ifNull", ["tags[foo.bar]", "''"]], "'qux'"], ], ], ], "=", 1, ], ["project_id", "IN", [self.project_id]], ], "groupby": "tags_key", "orderby": ["-count", "tags_key"], "having": [[ "tags_key", "NOT IN", ["trace", "trace.ctx", "trace.span", "project"], ]], "project": [self.project_id], "limit": 10, }), ) assert response.status_code == 200 def test_os_fields_condition(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [["count()", "", "count"]], "conditions": [ ["duration", ">=", 0], ["contexts[os.build]", "LIKE", "x86%"], ["contexts[os.kernel_version]", "LIKE", "10.1%"], ], "limit": 1000, }), ) assert response.status_code == 200 data = json.loads(response.data) assert data["data"] == [{"count": 0}] def test_http_fields(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [["count()", "", "count"]], "conditions": [["duration", ">=", 0]], "groupby": ["http_method", "http_referer", "tags[url]"], "limit": 1000, }), ) assert response.status_code == 200 data = json.loads(response.data) assert data["data"] == [{ "http_method": "POST", "http_referer": "tagstore.something", "tags[url]": "http://127.0.0.1:/query", "count": 1, }] response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [["count()", "", "count"]], "conditions": [["group_id", ">=", 0]], "groupby": ["http_method", "http_referer", "tags[url]"], "limit": 1000, }), ) assert response.status_code == 200 data = json.loads(response.data) assert data["data"] == [{ "http_method": "POST", "http_referer": "tagstore.something", "tags[url]": "http://127.0.0.1:/query", "count": 1, }] def test_device_fields_condition(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [["count()", "", "count"]], "conditions": [ ["duration", ">=", 0], ["contexts[device.charging]", "=", "True"], ["contexts[device.model_id]", "=", "Galaxy"], ], "limit": 1000, }), ) assert response.status_code == 200 data = json.loads(response.data) assert data["data"][0]["count"] == 1 response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [["count()", "", "count"]], "conditions": [ ["type", "=", "error"], ["contexts[device.charging]", "=", "True"], ["contexts[device.model_id]", "=", "Galaxy"], ], "limit": 1000, }), ) assert response.status_code == 200 data = json.loads(response.data) assert data["data"][0]["count"] == 1 def test_device_boolean_fields_context_vs_promoted_column(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "selected_columns": ["contexts[device.charging]"], "aggregations": [["count()", "", "count"]], "conditions": [["duration", ">=", 0]], "groupby": ["contexts[device.charging]"], "limit": 1000, }), ) assert response.status_code == 200 data = json.loads(response.data) assert data["data"][0]["contexts[device.charging]"] == "True" assert data["data"][0]["count"] == 1 response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "selected_columns": ["contexts[device.charging]"], "aggregations": [["count()", "", "count"]], "conditions": [["type", "=", "error"]], "groupby": ["contexts[device.charging]"], "limit": 1000, }), ) assert response.status_code == 200 data = json.loads(response.data) assert data["data"][0]["contexts[device.charging]"] == "True" assert data["data"][0]["count"] == 1 def test_is_handled(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "selected_columns": ["exception_stacks.mechanism_handled"], "conditions": [ ["type", "=", "error"], [["notHandled", []], "=", 1], ], "limit": 5, }), ) assert response.status_code == 200 data = json.loads(response.data) assert data["data"][0]["exception_stacks.mechanism_handled"] == [0] def test_having(self): result = json.loads( self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "groupby": "primary_hash", "conditions": [["type", "!=", "transaction"]], "having": [["times_seen", "=", 1]], "aggregations": [["count()", "", "times_seen"]], }), ).data) assert len(result["data"]) == 1 def test_time(self): result = json.loads( self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "selected_columns": ["project_id"], "groupby": ["time", "project_id"], "conditions": [["type", "!=", "transaction"]], }), ).data) assert len(result["data"]) == 1 result = json.loads( self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "selected_columns": ["project_id"], "groupby": ["time", "project_id"], "conditions": [["duration", ">=", 0]], }), ).data) assert len(result["data"]) == 1 def test_transaction_group_ids(self): result = json.loads( self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "selected_columns": ["group_id"], "conditions": [ ["type", "=", "transaction"], ["duration", ">=", 0], ], }), ).data) assert result["data"][0]["group_id"] == 0 result = json.loads( self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "selected_columns": ["group_id"], "conditions": [ ["type", "=", "transaction"], ["duration", ">=", 0], ["group_id", "IN", (1, 2, 3, 4)], ], }), ).data) assert result["data"] == [] def test_contexts(self): result = json.loads( self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "conditions": [["type", "=", "error"]], "selected_columns": ["contexts[device.online]"], }), ).data) assert result["data"] == [{"contexts[device.online]": "True"}] result = json.loads( self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "conditions": [["duration", ">=", 0]], "selected_columns": ["contexts[device.online]"], }), ).data) assert result["data"] == [{"contexts[device.online]": "True"}] def test_ast_impossible_queries(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [["apdex(duration, 300)", None, "apdex_duration_300"]], "groupby": ["project_id", "tags[foo]"], "conditions": [], "orderby": "apdex_duration_300", "limit": 1000, }), ) data = json.loads(response.data) assert response.status_code == 200 assert data["data"] == [{ "apdex_duration_300": 1, "tags[foo]": "baz", "project_id": self.project_id }] def test_count_null_user_consistency(self): response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [ ["uniq", "user", "uniq_user"], ["count", None, "count"], ], "groupby": ["group_id", "user"], "conditions": [], "orderby": "uniq_user", "limit": 1000, }), ) data = json.loads(response.data) assert response.status_code == 200 assert len(data["data"]) == 1 assert data["data"][0]["uniq_user"] == 0 response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "aggregations": [ ["uniq", "user", "uniq_user"], ["count", None, "count"], ], "groupby": ["trace_id", "user_email"], "conditions": [], "orderby": "uniq_user", "limit": 1000, }), ) data = json.loads(response.data) assert response.status_code == 200 assert len(data["data"]) == 1 # Should now count '' user as Null, which is 0 assert data["data"][0]["uniq_user"] == 0 def test_individual_measurement(self) -> None: response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "selected_columns": [ "event_id", "measurements[lcp]", "measurements[lcp.elementSize]", "measurements[asd]", ], "limit": 1, }), ) data = json.loads(response.data) assert response.status_code == 200, response.data assert len(data["data"]) == 1, data assert data["data"][0]["measurements[lcp]"] == 32.129 assert data["data"][0]["measurements[lcp.elementSize]"] == 4242 assert data["data"][0]["measurements[asd]"] is None response = self.app.post( "/query", data=json.dumps({ "dataset": "discover", "project": self.project_id, "selected_columns": ["group_id", "measurements[lcp]"], "limit": 1, }), ) data = json.loads(response.data) assert response.status_code == 200, response.data assert len(data["data"]) == 1, data assert "measurements[lcp]" in data["data"][0] assert data["data"][0]["measurements[lcp]"] is None
class HorovodStrategy(ParallelStrategy): """Plugin for Horovod distributed training integration.""" distributed_backend = _StrategyType.HOROVOD def __init__( self, accelerator: Optional[ "pl.accelerators.accelerator.Accelerator"] = None, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ): super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, ) rank_zero_only.rank = self.global_rank self._exit_stack: Optional[ExitStack] = None @property def global_rank(self) -> int: return hvd.rank() @property def local_rank(self) -> int: return hvd.local_rank() @property def world_size(self) -> int: return hvd.size() @property def root_device(self): return self.parallel_devices[self.local_rank] @property def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs def setup(self, trainer: "pl.Trainer") -> None: self.model_to_device() super().setup(trainer) self._exit_stack = ExitStack() self._exit_stack.__enter__() if not self.lightning_module.trainer.training: # no need to setup optimizers return def _unpack_lightning_optimizer(opt): return opt._optimizer if isinstance(opt, LightningOptimizer) else opt optimizers = self.optimizers optimizers = [_unpack_lightning_optimizer(opt) for opt in optimizers] # Horovod: scale the learning rate by the number of workers to account for # increased total batch size for optimizer in optimizers: for param_group in optimizer.param_groups: param_group["lr"] *= self.world_size # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR lr_scheduler_configs = self.lr_schedulers for config in lr_scheduler_configs: scheduler = config.scheduler scheduler.base_lrs = [ lr * self.world_size for lr in scheduler.base_lrs ] # Horovod: broadcast parameters & optimizer state to ensure consistent initialization hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0) for optimizer in optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0) self.optimizers = self._wrap_optimizers(optimizers) for optimizer in self.optimizers: # Synchronization will be performed explicitly following backward() self._exit_stack.enter_context(optimizer.skip_synchronize()) def barrier(self, *args, **kwargs): if distributed_available(): self.join() def broadcast(self, obj: object, src: int = 0) -> object: obj = hvd.broadcast_object(obj, src) return obj def model_to_device(self): if self.on_gpu: # this can potentially be removed after #8312. Not done due to lack of horovod testing torch.cuda.set_device(self.root_device) self.model.to(self.root_device) def join(self): if self.on_gpu: hvd.join(self.local_rank) else: hvd.join() def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): """Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to 'mean'/'avg'. Can also be a string 'sum' to calculate the sum during reduction. Return: reduced value, except when the input was not a tensor the output remains is unchanged """ if group is not None: raise ValueError( "Horovod does not support allreduce using a subcommunicator at this time. Unset `group`." ) if reduce_op in (None, "avg", "mean"): reduce_op = hvd.Average elif reduce_op in ("sum", ReduceOp.SUM): reduce_op = hvd.Sum else: raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") # sync all processes before reduction self.join() return hvd.allreduce(tensor, op=reduce_op) def all_gather(self, result: torch.Tensor, group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False) -> torch.Tensor: if group is not None and group != dist_group.WORLD: raise ValueError( "Horovod does not support allgather using a subcommunicator at this time. Unset `group`." ) if len(result.shape) == 0: # Convert scalars to single dimension tensors result = result.reshape(1) # sync and gather all self.join() return hvd.allgather(result) def post_backward(self, closure_loss: torch.Tensor) -> None: # synchronize all horovod optimizers. for optimizer in self.lightning_module.trainer.optimizers: optimizer.synchronize() def _wrap_optimizers( self, optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]: """Wraps optimizers to perform gradient aggregation via allreduce.""" return [ hvd.DistributedOptimizer( opt, named_parameters=self._filter_named_parameters( self.lightning_module, opt)) if "horovod" not in str(opt.__class__) else opt for opt in optimizers ] @staticmethod def _filter_named_parameters( model: nn.Module, optimizer: Optimizer) -> List[Tuple[str, nn.Parameter]]: opt_params = { p for group in optimizer.param_groups for p in group.get("params", []) } return [(name, p) for name, p in model.named_parameters() if p in opt_params] def teardown(self) -> None: super().teardown() self._exit_stack.__exit__(None, None, None) self._exit_stack = None # Make sure all workers have finished training before returning to the user self.join() if self.on_gpu: # GPU teardown self.lightning_module.cpu() # clean up memory torch.cuda.empty_cache()
class Scope: """ A context manager that allows to register error and exit callbacks. """ _thread_locals = threading.local() @frozen class _ExitHandler: callback: Callable[[], Any] ignore_errors: bool = True def __exit__(self, exc_type, exc_value, exc_traceback): try: self.callback() except Exception: if not self.ignore_errors: raise @frozen class _ErrorHandler(_ExitHandler): def __exit__(self, exc_type, exc_value, exc_traceback): if exc_type: return super().__exit__(exc_type=exc_type, exc_value=exc_value, exc_traceback=exc_traceback) def __init__(self): self._stack = ExitStack() self.enabled = True def on_error_do(self, callback: Callable, *args, kwargs: Optional[Dict[str, Any]] = None, ignore_errors: bool = False): """ Registers a function to be called on scope exit because of an error. If ignore_errors is True, the errors from this function call will be ignored. """ self._register_callback(self._ErrorHandler, ignore_errors=ignore_errors, callback=callback, args=args, kwargs=kwargs) def on_exit_do(self, callback: Callable, *args, kwargs: Optional[Dict[str, Any]] = None, ignore_errors: bool = False): """ Registers a function to be called on scope exit. """ self._register_callback(self._ExitHandler, ignore_errors=ignore_errors, callback=callback, args=args, kwargs=kwargs) def _register_callback(self, handler_type, callback: Callable, args: Tuple[Any] = None, kwargs: Dict[str, Any] = None, ignore_errors: bool = False): if args or kwargs: callback = partial(callback, *args, **(kwargs or {})) self._stack.push(handler_type(callback, ignore_errors=ignore_errors)) def add(self, cm: ContextManager[T]) -> T: """ Enters a context manager and adds it to the exit stack. Returns: cm.__enter__() result """ return self._stack.enter_context(cm) def enable(self): self.enabled = True def disable(self): self.enabled = False def close(self): self.__exit__(None, None, None) def __enter__(self) -> Scope: return self def __exit__(self, exc_type, exc_value, exc_traceback): if not self.enabled: return self._stack.__exit__(exc_type, exc_value, exc_traceback) self._stack.pop_all() # prevent issues on repetitive calls @classmethod def current(cls) -> Scope: return cls._thread_locals.current @contextmanager def as_current(self): previous = getattr(self._thread_locals, 'current', None) self._thread_locals.current = self try: yield finally: self._thread_locals.current = previous
class Rollback: @attr.attrs class Handler: callback = attr.attrib() enabled = attr.attrib(default=True) ignore_errors = attr.attrib(default=False) def __call__(self): if self.enabled: try: self.callback() except: # pylint: disable=bare-except if not self.ignore_errors: raise def __init__(self): self._handlers = {} self._stack = ExitStack() self.enabled = True def add(self, callback, *args, name=None, enabled=True, ignore_errors=False, fwd_kwargs=None, **kwargs): if args or kwargs or fwd_kwargs: if fwd_kwargs: kwargs.update(fwd_kwargs) callback = partial(callback, *args, **kwargs) name = name or hash(callback) assert name not in self._handlers handler = self.Handler(callback, enabled=enabled, ignore_errors=ignore_errors) self._handlers[name] = handler self._stack.callback(handler) return name do = add # readability alias def enable(self, name=None): if name: self._handlers[name].enabled = True else: self.enabled = True def disable(self, name=None): if name: self._handlers[name].enabled = False else: self.enabled = False def clean(self): self.__exit__(None, None, None) def __enter__(self): return self # pylint: disable=redefined-builtin def __exit__(self, type=None, value=None, traceback=None): if type is None: return if not self.enabled: return self._stack.__exit__(type, value, traceback)
class Worker: def __init__(self, argv): self.__cached_copies = {} self.__command_wrapper_enabled = False self.__dpkg_architecture = None self.call_argv = None self.capabilities = set() self.command_wrapper = None self.argv = argv self.stack = ExitStack() self.user = '******' self.virt_process = None def __enter__(self): argv = list(map(os.path.expanduser, self.argv)) for prefix in ('autopkgtest-virt-', 'adt-virt-', ''): if shutil.which(prefix + argv[0]): argv[0] = prefix + argv[0] break else: raise WorkerError('virtualization provider %r not found' % argv[0]) logger.info('Starting worker: %r', argv) self.virt_process = subprocess.Popen(argv, stdin=subprocess.PIPE, stdout=subprocess.PIPE, universal_newlines=True) self.stack.enter_context(self.virt_process) self.stack.callback(self.virt_process.terminate) # FIXME: timed wait for a response? self.stack.callback(self.virt_process.stdin.flush) self.stack.callback(self.virt_process.stdin.write, 'quit\n') line = self.virt_process.stdout.readline() if line != 'ok\n': raise WorkerError('Virtual machine {!r} failed to start: ' '{}'.format(argv, line.strip())) self.virt_process.stdin.write('capabilities\n') self.virt_process.stdin.flush() line = self.virt_process.stdout.readline() if not line.startswith('ok '): raise WorkerError('Virtual machine {!r} failed to report ' 'capabilities: {}'.format(line.strip())) for word in line.split()[1:]: self.capabilities.add(word) if word.startswith('suggested-normal-user='******'suggested-normal-user='******'root-on-testbed' not in self.capabilities: raise WorkerError('Virtual machine {!r} does not have ' 'root-on-testbed capability: {}'.format( argv, line.strip())) if ('isolation-machine' not in self.capabilities and 'isolation-container' not in self.capabilities): raise WorkerError('Virtual machine {!r} does not have ' 'sufficient isolation: {}'.format( argv, line.strip())) self.virt_process.stdin.write('open\n') self.virt_process.stdin.flush() line = self.virt_process.stdout.readline() if not line.startswith('ok '): raise WorkerError('Failed to open virtual machine session {!r}: ' '{}'.format(argv, line)) self.scratch = line[3:].rstrip('\n') self.virt_process.stdin.write('print-execute-command\n') self.virt_process.stdin.flush() line = self.virt_process.stdout.readline() if not line.startswith('ok '): raise WorkerError('Failed to get virtual machine {!r} command ' 'wrapper: {}'.format(argv, line.strip())) wrapper_argv = line.rstrip('\n').split(None, 1)[1].split(',') self.call_argv = list(map(urllib.parse.unquote, wrapper_argv)) if not self.call_argv: raise WorkerError('Virtual machine {!r} command wrapper did not ' 'provide any arguments: {}'.format( argv, line.strip())) wrapper = '{}/vectis-command-wrapper'.format(self.scratch) self.copy_to_guest(_WRAPPER, wrapper) self.check_call(['chmod', '+x', wrapper]) self.command_wrapper = wrapper return self def call(self, argv, **kwargs): logger.info('%r', argv) return subprocess.call(self.call_argv + list(argv), **kwargs) def check_call(self, argv, **kwargs): logger.info('%r', argv) subprocess.check_call(self.call_argv + list(argv), **kwargs) def check_output(self, argv, **kwargs): logger.info('%r', argv) return subprocess.check_output(self.call_argv + list(argv), **kwargs) def copy_to_guest(self, host_path, guest_path, *, cache=False): assert host_path is not None assert guest_path is not None if cache and self.__cached_copies.get(host_path) == guest_path: return if not os.path.exists(host_path): raise WorkerError('Cannot copy host:{!r} to guest: it does ' 'not exist'.format(host_path)) self.virt_process.stdin.write('copydown {} {}\n'.format( urllib.parse.quote(host_path), urllib.parse.quote(guest_path), )) self.virt_process.stdin.flush() line = self.virt_process.stdout.readline() if line != 'ok\n': raise WorkerError('Failed to copy host:{!r} to guest:{!r}: ' '{}'.format(host_path, guest_path, line.strip())) if cache: self.__cached_copies[host_path] = guest_path def copy_to_host(self, guest_path, host_path): if self.call(['test', '-e', guest_path]) != 0: raise WorkerError('Cannot copy guest:{!r} to host: it does ' 'not exist'.format(guest_path)) self.virt_process.stdin.write('copyup {} {}\n'.format( urllib.parse.quote(guest_path), urllib.parse.quote(host_path), )) self.virt_process.stdin.flush() line = self.virt_process.stdout.readline() if line != 'ok\n': raise WorkerError('Failed to copy guest:{!r} to host:{!r}: ' '{}'.format(guest_path, host_path, line.strip())) def open_shell(self): self.virt_process.stdin.write('shell\n') self.virt_process.stdin.flush() line = self.virt_process.stdout.readline() if line != 'ok\n': logger.warning('Unable to open a shell in guest: %s', line.strip()) @property def dpkg_architecture(self): if self.__dpkg_architecture is None: self.__dpkg_architecture = self.check_output( ['dpkg', '--print-architecture'], universal_newlines=True).strip() return self.__dpkg_architecture def __exit__(self, et, ev, tb): return self.stack.__exit__(et, ev, tb) def set_up_apt(self, suite, components=()): with TemporaryDirectory() as tmp: with AtomicWriter(os.path.join(tmp, 'sources.list')) as writer: for ancestor in suite.hierarchy: if components: filtered_components = (set(components) & set(ancestor.all_components)) else: filtered_components = ancestor.components writer.write( textwrap.dedent(''' deb {mirror} {suite} {components} deb-src {mirror} {suite} {components} ''').format( components=' '.join(filtered_components), mirror=ancestor.mirror, suite=ancestor.apt_suite, )) if ancestor.apt_key is not None: self.copy_to_guest( ancestor.apt_key, '/etc/apt/trusted.gpg.d/' + os.path.basename(ancestor.apt_key)) self.copy_to_guest(os.path.join(tmp, 'sources.list'), '/etc/apt/sources.list') self.check_call([ 'env', 'DEBIAN_FRONTEND=noninteractive', 'apt-get', '-y', 'update', ])
class _TestUser(object): def __init__(self, test_client, runestone_db_tools, username, password, course_name, # True if the course is free (no payment required); False otherwise. is_free=True): self.test_client = test_client self.runestone_db_tools = runestone_db_tools self.username = username self.first_name = 'test' self.last_name = 'user' self.email = self.username + '@foo.com' self.password = password self.course_name = course_name self.is_free = is_free def __enter__(self): # Registration doesn't work unless we're logged out. self.test_client.logout() # Now, post the registration. self.test_client.validate('default/user/register', 'Support Runestone Interactive' if self.is_free else 'Payment Amount', data=dict( username=self.username, first_name=self.first_name, last_name=self.last_name, # The e-mail address must be unique. email=self.email, password=self.password, password_two=self.password, # Note that ``course_id`` is (on the form) actually a course name. course_id=self.course_name, accept_tcp='on', donate='0', _next='/runestone/default/index', _formname='register', ) ) # Schedule this user for deletion. self.exit_stack_object = ExitStack() self.exit_stack = self.exit_stack_object.__enter__() self.exit_stack.callback(self._delete_user) # Record IDs db = self.runestone_db_tools.db self.course_id = db(db.courses.course_name == self.course_name).select(db.courses.id).first().id self.user_id = db(db.auth_user.username == self.username).select(db.auth_user.id).first().id return self # Clean up on exit by invoking all ``__exit__`` methods. def __exit__(self, exc_type, exc_value, traceback): self.exit_stack_object.__exit__(exc_type, exc_value, traceback) # Delete the user created by entering this context manager. TODO: This doesn't delete all the chapter progress tracking stuff. def _delete_user(self): db = self.runestone_db_tools.db # Delete the course this user registered for. db(( db.user_courses.course_id == self.course_id) & (db.user_courses.user_id == self.user_id) ).delete() # Delete the user. db(db.auth_user.username == self.username).delete() db.commit() def login(self): self.test_client.post('default/user/login', data=dict( username=self.username, password=self.password, _formname='login', )) def make_instructor(self, course_id=None): # If ``course_id`` isn't specified, use this user's ``course_id``. course_id = course_id or self.course_id return self.runestone_db_tools.make_instructor(self.user_id, course_id) # A context manager to update this user's profile. If a course was added, it returns that course's ID; otherwise, it returns None. @contextmanager def update_profile(self, # This parameter is passed to ``test_client.validate``. expected_string=None, # An updated username, or ``None`` to use ``self.username``. username=None, # An updated first name, or ``None`` to use ``self.first_name``. first_name=None, # An updated last name, or ``None`` to use ``self.last_name``. last_name=None, # An updated email, or ``None`` to use ``self.email``. email=None, # An updated last name, or ``None`` to use ``self.course_name``. course_name=None, section='', # A shortcut for specifying the ``expected_string``, which only applies if ``expected_string`` is not set. Use ``None`` if a course will not be added, ``True`` if the added course is free, or ``False`` if the added course is paid. is_free=None, # The value of the ``accept_tcp`` checkbox; provide an empty string to leave unchecked. The default value leaves it checked. accept_tcp='on'): if expected_string is None: if is_free is None: expected_string = 'Course Selection' else: expected_string = 'Support Runestone Interactive' \ if is_free else 'Payment Amount' username = username or self.username first_name = first_name or self.first_name last_name = last_name or self.last_name email = email or self.email course_name = course_name or self.course_name db = self.runestone_db_tools.db # Determine if we're adding a course. If so, delete it at the end of the test. To determine if a course is being added, the course must exist, but not be in the user's list of courses. course = db(db.courses.course_name == course_name).select(db.courses.id).first() delete_at_end = course and not db((db.user_courses.user_id == self.user_id) & (db.user_courses.course_id == course.id)).select(db.user_courses.id).first() # Perform the update. try: self.test_client.validate('default/user/profile', expected_string, data=dict( username=username, first_name=first_name, last_name=last_name, email=email, # Though the field is ``course_id``, it's really the course name. course_id=course_name, accept_tcp=accept_tcp, section=section, _next='/runestone/default/index', id=str(self.user_id), _formname='auth_user/' + str(self.user_id), ) ) yield course.id if delete_at_end else None finally: if delete_at_end: db = self.runestone_db_tools.db db((db.user_courses.user_id == self.user_id) & (db.user_courses.course_id == course.id)).delete() db.commit() # Call this after registering for a new course or adding a new course via ``update_profile`` to pay for the course. @contextmanager def make_payment(self, # The `Stripe test tokens <https://stripe.com/docs/testing#cards>`_ to use for payment. stripe_token, # The course ID of the course to pay for. None specifies ``self.course_id``. course_id=None): course_id = course_id or self.course_id # Get the signature from the HTML of the payment page. self.test_client.validate('default/payment') match = re.search('<input type="hidden" name="signature" value="([^ ]*)" />', self.test_client.text) signature = match.group(1) try: html = self.test_client.validate('default/payment', data=dict(stripeToken=stripe_token, signature=signature) ) assert ('Thank you for your payment' in html) or ('Payment failed' in html) yield None finally: db = self.runestone_db_tools.db db((db.user_courses.course_id == course_id) & (db.user_courses.user_id == self.user_id) ).delete() # Try to delete the payment try: db((db.user_courses.user_id == self.user_id) & (db.user_courses.course_id == course_id) & (db.user_courses.id == db.payments.user_courses_id)) \ .delete() except: pass db.commit() @contextmanager def hsblog(self, **kwargs): try: yield json.loads(self.test_client.validate('ajax/hsblog', data=kwargs)) finally: # Try to remove this hsblog entry. event = kwargs.get('event') div_id = kwargs.get('div_id') course = kwargs.get('course') db = self.runestone_db_tools.db criteria = ((db.useinfo.sid == self.username) & (db.useinfo.act == kwargs.get('act', '')) & (db.useinfo.div_id == div_id) & (db.useinfo.event == event) & (db.useinfo.course_id == course) ) useinfo_row = db(criteria).select(db.useinfo.id, orderby=db.useinfo.id).last() if useinfo_row: del db.useinfo[useinfo_row.id]
class KiwoomOpenApiPlusEventHandler(KiwoomOpenApiPlusEventHandlerFunctions): def __init__(self, control): self._control = control self._observer = QueueBasedIterableObserver() self._enter_count = 0 self._should_exit = False self._stack = ExitStack() self._lock = RLock() @property def control(self): return self._control @property def observer(self): return self._observer @classmethod def names(cls): names = [ name for name in dir(KiwoomOpenApiPlusEventHandlerFunctions) if name.startswith("On") ] return names def slots(self): names = self.names() slots = [getattr(self, name) for name in names] names_and_slots_implemented = [(name, slot) for name, slot in zip(names, slots) if isimplemented(slot)] return names_and_slots_implemented def connect(self): for name, slot in self.slots(): getattr(self.control, name).connect(slot) def disconnect(self): for name, slot in self.slots(): getattr(self.control, name).disconnect(slot) def on_enter(self): pass def on_exit(self, exc_type=None, exc_value=None, traceback=None): pass def add_callback(self, callback, *args, **kwargs): self._stack.callback(callback, *args, **kwargs) def enter(self): with self._lock: self.connect() self.on_enter() self._should_exit = True def exit(self, exc_type=None, exc_value=None, traceback=None): with self._lock: if self._should_exit: self.disconnect() self.on_exit(exc_type, exc_value, traceback) self._stack.__exit__(exc_type, exc_value, traceback) self._should_exit = False def stop(self): return self.observer.stop() def close(self): self.exit() self.stop() def __enter__(self): with self._lock: if self._enter_count == 0: self.enter() self._enter_count += 1 return self def __exit__(self, exc_type, exc_value, traceback): if exc_type is not None: return self.exit(exc_type, exc_value, traceback) with self._lock: if self._enter_count > 0: self._enter_count -= 1 if self._enter_count == 0: return self.exit(exc_type, exc_value, traceback) return def __iter__(self): with self: return iter(self.observer)
class ScopedServiceProvider(IServiceProvider): def __init__(self, services: ServicesMap, parent=None): super().__init__() self._services = services self._exit_stack = None self._scoped_cache = {} self._parent = parent self.__class__ = ServiceProvider # hack if parent is None: # root provider self._lock = RLock() else: self._lock = nullcontext() def _get_service_info(self, key) -> IServiceInfo: try: return self._services[key] except KeyError: pass # load missing resolver and resolve service info. resolver: IServiceInfoResolver = self._services[ Symbols.missing_resolver].get(self) return resolver.get(self, key) def _get_service_info_list(self, key) -> List[IServiceInfo]: return self._services.get_many(key) def __getitem__(self, key): _logger.debug('get service by key: %r', key) service_info = self._get_service_info(key) try: return service_info.get(self) except ServiceNotFoundError as err: raise ServiceNotFoundError(key, *err.resolve_chain) def get(self, key, d=None) -> Any: ''' get a service by key. ''' try: return self[key] except ServiceNotFoundError as err: if len(err.resolve_chain) == 1: return d raise def get_many(self, key) -> List[Any]: ''' get services by key. ### example when you registered multi services with the same key, you can get them all: ``` py provider.register_value('a', 1) provider.register_value('a', 2) assert provider.get_many('a') == [2, 1] # rev order ``` ''' _logger.debug('get services by key: %r', key) service_infos = self._get_service_info_list(key) try: return [si.get(self) for si in service_infos] except ServiceNotFoundError as err: raise ServiceNotFoundError(key, *err.resolve_chain) def enter(self, context): ''' enter the context. returns the result of the `context.__enter__()` method. ''' with self._lock: if self._exit_stack is None: self._exit_stack = ExitStack() return self._exit_stack.enter_context(context) def __enter__(self): return self def __exit__(self, *args): with self._lock: if self._exit_stack is not None: self._exit_stack.__exit__(*args) self._exit_stack = None def register_service_info(self, key, service_info: IServiceInfo): ''' register a `IServiceInfo` by key. ''' if not isinstance(service_info, IServiceInfo): raise TypeError('service_info must be instance of IServiceInfo.') _logger.debug('register %r with key %r', service_info, key) self._services[key] = service_info def register(self, key, factory, lifetime): ''' register a service factory by key. `factory` accept a function which require one or zero parameter. if the count of parameter is 1, pass a `IServiceProvider` as the argument. ''' return self.register_service_info( key, ServiceInfo(self, key, factory, lifetime)) def register_singleton(self, key, factory): ''' register a service factory by key. `factory` accept a function which require one or zero parameter. if the count of parameter is 1, pass a `IServiceProvider` as the argument. ''' return self.register(key, factory, LifeTime.singleton) def register_scoped(self, key, factory): ''' register a service factory by key. `factory` accept a function which require one or zero parameter. if the count of parameter is 1, pass a `IServiceProvider` as the argument. ''' return self.register(key, factory, LifeTime.scoped) def register_transient(self, key, factory): ''' register a service factory by key. `factory` accept a function which require one or zero parameter. if the count of parameter is 1, pass a `IServiceProvider` as the argument. ''' return self.register(key, factory, LifeTime.transient) def register_value(self, key, value): ''' register a value by key. equals `register_transient(key, lambda ioc: value)` ''' return self.register_service_info(key, ValueServiceInfo(value)) def register_group(self, key, keys: list): ''' register a grouped `key` for get other `keys`. the `keys` can be a ref and you can update it later. for example: ``` py provider.register_value('str', 'name') provider.register_value('int', 1) provider.register_group('any', ['str', 'int']) assert provider['any'] == ('name', 1) ``` equals `register_transient(key, lambda ioc: tuple(ioc[k] for k in keys))` ''' return self.register_service_info(key, GroupedServiceInfo(keys)) def register_bind(self, new_key, target_key): ''' bind `new_key` to `target_key` so you can use `new_key` as key to get value from service provider. equals `register_transient(new_key, lambda ioc: ioc[target_key])` ''' return self.register_service_info(new_key, BindedServiceInfo(target_key)) def scope(self): ''' create a scoped service provider. ''' sp = ScopedServiceProvider(self._services.scope(), self) self.enter(sp) return sp @property def builder(self): ''' get a new `ServiceProviderBuilder` wrapper for this `ServiceProvider`. ''' from .builder import ServiceProviderBuilder return ServiceProviderBuilder(self)
class RunAsExternal: def __init__( self, task_kls, basekls="from photons_app.tasks.tasks import Task", ): self.script = ( basekls + "\n" + dedent(inspect.getsource(task_kls)) + "\n" + dedent(""" from photons_app.errors import ApplicationCancelled, ApplicationStopped, UserQuit, PhotonsAppError from photons_app.tasks.runner import Runner from photons_app.collector import Collector from photons_app import helpers as hp import asyncio import signal import sys import os def notify(): os.kill(int(sys.argv[2]), signal.SIGUSR1) collector = Collector() collector.prepare(None, {}) """) + "\n" + f"{task_kls.__name__}.create(collector).run_loop(collector=collector, notify=notify, output=sys.argv[1])" ) def __enter__(self): self.exit_stack = ExitStack() self.exit_stack.__enter__() self.fle = self.exit_stack.enter_context(hp.a_temp_file()) self.fle.write(self.script.encode()) self.fle.flush() self.out = self.exit_stack.enter_context(hp.a_temp_file()) self.out.close() os.remove(self.out.name) return self.run def __exit__(self, exc_typ, exc, tb): if hasattr(self, "exit_stack"): self.exit_stack.__exit__(exc_typ, exc, tb) async def run( self, sig=None, expected_stdout=None, expected_stderr=None, expected_output=None, extra_argv=None, *, code, ): fut = hp.create_future() def ready(signum, frame): hp.get_event_loop().call_soon_threadsafe(fut.set_result, True) signal.signal(signal.SIGUSR1, ready) pipe = subprocess.PIPE cmd = [ sys.executable, self.fle.name, self.out.name, str(os.getpid()), *[str(a) for a in (extra_argv or [])], ] p = subprocess.Popen( cmd, stdout=pipe, stderr=pipe, ) await fut if sig is not None: p.send_signal(sig) p.wait(timeout=1) got_stdout = p.stdout.read().decode() got_stderr = p.stderr.read().decode().split("\n") redacted_stderr = [] in_tb = False last_line = None for line in got_stderr: if last_line == "During handling of the above exception, another exception occurred:": redacted_stderr = redacted_stderr[:-5] last_line = line continue if in_tb and not line.startswith(" "): in_tb = False if line.startswith("Traceback"): in_tb = True redacted_stderr.append(line) redacted_stderr.append(" <REDACTED>") if not in_tb: redacted_stderr.append(line) last_line = line got_stderr = "\n".join(redacted_stderr) print("=== STDOUT:") print(got_stdout) print("=== STDERR:") print(got_stderr) def assertOutput(out, regex): if regex is None: return regex = dedent(regex).strip() out = out.strip() regex = regex.strip() assert len(out.split("\n")) == len(regex.split("\n")) pytest.helpers.assertRegex(f"(?m){regex}", out) if expected_output is not None: got = None if os.path.exists(self.out.name): with open(self.out.name) as o: got = o.read().strip() assert got == expected_output.strip() assertOutput(got_stdout.strip(), expected_stdout) assertOutput(got_stderr.strip(), expected_stderr) assert p.returncode == code return got_stdout, got_stderr
class _ManagedPkgData(ContextDecorator): _section_re = re.compile( r'(?P<module>[A-Za-z_][A-Za-z0-9_]+[.][A-Za-z_][A-Za-z0-9_]+)' r'([:](?P<label>[A-Za-z_][A-Za-z0-9_-]))?') def __init__(self): self._stack = None self._tmp_dir = None def __enter__(self): self._stack = ExitStack() self._stack.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): self._stack.__exit__(exc_type, exc_val, exc_tb) self._tmp_dir = None self._stack = None def get_tmp_dir(self): if self._tmp_dir is None: self._tmp_dir = self._stack.enter_context( TemporaryDirectory(prefix=".histoqc_pkg_data_tmp", dir=os.getcwd())) for rsrc in {'models', 'pen', 'templates'}: package_resource_copytree('histoqc.data', rsrc, self._tmp_dir) return self._tmp_dir def inject_pkg_data_fallback(self, config: ConfigParser): """support template data packaged in module""" for section in config.sections(): m = self._section_re.match(section) if not m: continue # dispatch mf = m.group('module') s = config[section] if mf == 'HistogramModule.compareToTemplates': self._inject_HistogramModule_compareToTemplates(s) elif mf == 'ClassificationModule.byExampleWithFeatures': self._inject_ClassificationModule_byExampleWithFeatures(s) else: pass def _inject_HistogramModule_compareToTemplates(self, section): # replace example files in a compareToTemplates config section # with the histoqc package data examples if available if 'templates' in section: _templates = [] for template in map(str.strip, section['templates'].split('\n')): f_template = os.path.join(os.getcwd(), template) if not os.path.isfile(f_template): tmp = self.get_tmp_dir() f_template_pkg_data = os.path.join(tmp, template) if os.path.isfile(f_template_pkg_data): f_template = f_template_pkg_data _templates.append(f_template) section['templates'] = '\n'.join(_templates) def _inject_ClassificationModule_byExampleWithFeatures(self, section): # replace template files in a byExampleWithFeatures config section # with the histoqc package data templates if available if 'examples' in section: # mimic the code structure in ClassificationModule.byExampleWithFeatures, # which expects example templates to be specified as pairs. # each template in the pair is separated by a ':' and pairs are separated by a newline. _lines = [] for ex in section["examples"].splitlines(): ex = re.split( r'(?<!\W[A-Za-z]):(?!\\)', ex ) # workaround for windows: don't split on i.e. C:backslash _examples = [] for example in ex: f_example = os.path.join(os.getcwd(), example) if not os.path.isfile(f_example): tmp = self.get_tmp_dir() f_example_pkg_data = os.path.join(tmp, example) if os.path.isfile(f_example_pkg_data): f_example = f_example_pkg_data _examples.append(f_example) _line = ":".join(_examples) _lines.append(_line) section['examples'] = "\n".join(_lines)