class LocalDarRepository: def __init__(self): self._context = ExitStack() self._dar_paths = [] # type: List[Path] self._files = set() self.store = PackageStore.empty() def _add_source(self, path: Path) -> None: ext = path.suffix.lower() if ext == '.daml': LOG.error('Reading metadata directly from a DAML file is not supported.') raise ValueError('Unsupported extension: .daml') elif ext == '.dalf': LOG.error('Reading metadata directly from a DALF file is not supported.') raise ValueError('Unsupported extension: .dalf') elif ext == '.dar': dar_parse_start_time = time.time() dar = self._context.enter_context(DarFile(path)) dar_package = dar.read_metadata() self._dar_paths.append(path) self.store.register_all(dar_package) dar_parse_end_time = time.time() LOG.debug('Parsed a dar in %s seconds.', dar_parse_end_time - dar_parse_start_time) else: LOG.error('Unknown extension: %s', ext) raise ValueError(f'Unknown extension: {ext}') def add_source(self, *files: Union[str, Path]) -> None: """ Add a source file (either a .daml file, .dalf file, or a .dar file). Attempts to add the same file more than once will be ignored. :param files: Files to add to the archive. """ for file in files: path = pathify(file).resolve(strict=True) if path not in self._files: self._files.add(path) self._add_source(path) def get_daml_archives(self) -> Sequence[Path]: return self._dar_paths def __enter__(self): """ Does nothing. """ self._context.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): """ Delete all managed resources in the reverse order that they were created. """ self._context.__exit__(exc_type, exc_val, exc_tb)
class SyncEventSource: _subscribers: Dict[Type[Event], List[Callable[[Event], Any]]] def __init__(self, async_event_source: EventSource, portal: Optional[BlockingPortal] = None): self.portal = portal self._async_event_source = async_event_source self._async_event_source.subscribe(self._forward_async_event) self._exit_stack = ExitStack() self._subscribers = defaultdict(list) def __enter__(self): self._exit_stack.__enter__() if not self.portal: portal_cm = start_blocking_portal() self.portal = self._exit_stack.enter_context(portal_cm) self._async_event_source.subscribe(self._forward_async_event) def __exit__(self, exc_type, exc_val, exc_tb): self._exit_stack.__exit__(exc_type, exc_val, exc_tb) async def _forward_async_event(self, event: Event) -> None: for subscriber in self._subscribers.get(type(event), ()): await run_sync_in_worker_thread(subscriber, event) def subscribe(self, callback: Callable[[Event], Any], event_types: Optional[Iterable[Type[Event]]] = None) -> None: if event_types is None: event_types = _all_event_types for event_type in event_types: existing_callbacks = self._subscribers[event_type] if callback not in existing_callbacks: existing_callbacks.append(callback) def unsubscribe( self, callback: Callable[[Event], Any], event_types: Optional[Iterable[Type[Event]]] = None) -> None: if event_types is None: event_types = _all_event_types for event_type in event_types: existing_callbacks = self._subscribers.get(event_type, []) with suppress(ValueError): existing_callbacks.remove(callback)
class ReposTestCase(AsyncTestCase): async def setUp(self): super().setUp() self.repos = [Repo(str(random.random())) for _ in range(0, 10)] self.__stack = ExitStack() self._stack = self.__stack.__enter__() self.temp_filename = self.mktempfile() self.sconfig = FileSourceConfig(filename=self.temp_filename) async with JSONSource(self.sconfig) as source: async with source() as sctx: for repo in self.repos: await sctx.update(repo) contents = json.loads(Path(self.sconfig.filename).read_text()) # Ensure there are repos in the file self.assertEqual( len(contents.get(self.sconfig.label)), len(self.repos), "ReposTestCase JSON file erroneously initialized as empty", ) def tearDown(self): super().tearDown() self.__stack.__exit__(None, None, None) def mktempfile(self): return self._stack.enter_context(non_existant_tempfile())
def __enter__(self): stack = ExitStack() self.__stack = stack if not isinstance(self.func, partial): return self func, args, kwargs = self.func.func, self.func.args, self.func.keywords stack.__enter__() # pretty sure this does nothing args = ( stack.enter_context(arg) if hasattr(arg, '__enter__') else arg for arg in args ) kwargs = { kw: stack.enter_context(arg) if hasattr(arg, '__enter__') else arg for kw, arg in kwargs.items() } return Pipe(partial(func, *args, **kwargs))
class TestOWDataSets(WidgetTest): def setUp(self): super().setUp() remote = { ("a", "b"): {"title": "K", "size": 10}, ("a", "c"): {"title": "T", "size": 20}, ("a", "d"): {"title": "Y", "size": 0}, } self.exit = ExitStack() self.exit.__enter__() self.exit.enter_context( mock.patch.object(owdatasets, "list_remote", lambda: remote) ) self.exit.enter_context( mock.patch.object(owdatasets, "list_local", lambda: {}) ) self.widget = self.create_widget( OWDataSets, stored_settings={ "selected_id": ("a", "c"), "auto_commit": False, } ) # type: OWDataSets def tearDown(self): super().tearDown() self.exit.__exit__(None, None, None) def test_init(self): if self.widget.isBlocking(): spy = QSignalSpy(self.widget.blockingStateChanged) assert spy.wait(1000) self.assertFalse(self.widget.isBlocking()) model = self.widget.view.model() self.assertEqual(model.rowCount(), 3) di = self.widget.selected_dataset() self.assertEqual((di.prefix, di.filename), ("a", "c"))
class SourceExtractor(object): """A class to extract a source package to its constituent parts""" def __init__(self, dsc_path, dsc): self.dsc_path = dsc_path self.dsc = dsc self.extracted_upstream = None self.extracted_debianised = None self.unextracted_debian_md5 = None self.upstream_tarballs = [] self.exit_stack = ExitStack() def extract(self): """Extract the package to a new temporary directory.""" raise NotImplementedError(self.extract) def __enter__(self): self.exit_stack.__enter__() self.extract() return self def __exit__(self, exc_type, exc_val, exc_tb): self.exit_stack.__exit__(exc_type, exc_val, exc_tb) return False
class patch_auth: # pylint: disable=invalid-name def __init__(self, project_id: str = "potato-dev", location: str = "moon-dark1", email: str = "*****@*****.**"): # Realistic: actual class to be accepted by clients during validation # But fake: with as few attributes as possible, any API call using the credential should fail credentials = Credentials( service_account_email=email, signer=None, token_uri="", project_id=project_id, ) managers = [ patch("google.auth.default", return_value=(credentials, project_id)), patch("gcp_pilot.base.GoogleCloudPilotAPI._set_location", return_value=location), patch("gcp_pilot.base.AppEngineBasedService._set_location", return_value=location), ] self.stack = ExitStack() for mgr in managers: self.stack.enter_context(mgr) def __enter__(self): return self.stack.__enter__() def start(self): return self.__enter__() def __exit__(self, typ, val, traceback): return self.stack.__exit__(typ, val, traceback) def stop(self): self.__exit__(None, None, None) def __call__(self, func): @wraps(func) def wrapper(*args, **kw): with self: return func(*args, **kw) return wrapper
class _ManagedPkgData(ContextDecorator): _section_re = re.compile( r'(?P<module>[A-Za-z_][A-Za-z0-9_]+[.][A-Za-z_][A-Za-z0-9_]+)' r'([:](?P<label>[A-Za-z_][A-Za-z0-9_-]))?') def __init__(self): self._stack = None self._tmp_dir = None def __enter__(self): self._stack = ExitStack() self._stack.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): self._stack.__exit__(exc_type, exc_val, exc_tb) self._tmp_dir = None self._stack = None def get_tmp_dir(self): if self._tmp_dir is None: self._tmp_dir = self._stack.enter_context( TemporaryDirectory(prefix=".histoqc_pkg_data_tmp", dir=os.getcwd())) for rsrc in {'models', 'pen', 'templates'}: package_resource_copytree('histoqc.data', rsrc, self._tmp_dir) return self._tmp_dir def inject_pkg_data_fallback(self, config: ConfigParser): """support template data packaged in module""" for section in config.sections(): m = self._section_re.match(section) if not m: continue # dispatch mf = m.group('module') s = config[section] if mf == 'HistogramModule.compareToTemplates': self._inject_HistogramModule_compareToTemplates(s) elif mf == 'ClassificationModule.byExampleWithFeatures': self._inject_ClassificationModule_byExampleWithFeatures(s) else: pass def _inject_HistogramModule_compareToTemplates(self, section): # replace example files in a compareToTemplates config section # with the histoqc package data examples if available if 'templates' in section: _templates = [] for template in map(str.strip, section['templates'].split('\n')): f_template = os.path.join(os.getcwd(), template) if not os.path.isfile(f_template): tmp = self.get_tmp_dir() f_template_pkg_data = os.path.join(tmp, template) if os.path.isfile(f_template_pkg_data): f_template = f_template_pkg_data _templates.append(f_template) section['templates'] = '\n'.join(_templates) def _inject_ClassificationModule_byExampleWithFeatures(self, section): # replace template files in a byExampleWithFeatures config section # with the histoqc package data templates if available if 'examples' in section: # mimic the code structure in ClassificationModule.byExampleWithFeatures, # which expects example templates to be specified as pairs. # each template in the pair is separated by a ':' and pairs are separated by a newline. _lines = [] for ex in section["examples"].splitlines(): ex = re.split( r'(?<!\W[A-Za-z]):(?!\\)', ex ) # workaround for windows: don't split on i.e. C:backslash _examples = [] for example in ex: f_example = os.path.join(os.getcwd(), example) if not os.path.isfile(f_example): tmp = self.get_tmp_dir() f_example_pkg_data = os.path.join(tmp, example) if os.path.isfile(f_example_pkg_data): f_example = f_example_pkg_data _examples.append(f_example) _line = ":".join(_examples) _lines.append(_line) section['examples'] = "\n".join(_lines)
class RunAsExternal: def __init__( self, task_kls, basekls="from photons_app.tasks.tasks import Task", ): self.script = ( basekls + "\n" + dedent(inspect.getsource(task_kls)) + "\n" + dedent(""" from photons_app.errors import ApplicationCancelled, ApplicationStopped, UserQuit, PhotonsAppError from photons_app.tasks.runner import Runner from photons_app.collector import Collector from photons_app import helpers as hp import asyncio import signal import sys import os def notify(): os.kill(int(sys.argv[2]), signal.SIGUSR1) collector = Collector() collector.prepare(None, {}) """) + "\n" + f"{task_kls.__name__}.create(collector).run_loop(collector=collector, notify=notify, output=sys.argv[1])" ) def __enter__(self): self.exit_stack = ExitStack() self.exit_stack.__enter__() self.fle = self.exit_stack.enter_context(hp.a_temp_file()) self.fle.write(self.script.encode()) self.fle.flush() self.out = self.exit_stack.enter_context(hp.a_temp_file()) self.out.close() os.remove(self.out.name) return self.run def __exit__(self, exc_typ, exc, tb): if hasattr(self, "exit_stack"): self.exit_stack.__exit__(exc_typ, exc, tb) async def run( self, sig=None, expected_stdout=None, expected_stderr=None, expected_output=None, extra_argv=None, *, code, ): fut = hp.create_future() def ready(signum, frame): hp.get_event_loop().call_soon_threadsafe(fut.set_result, True) signal.signal(signal.SIGUSR1, ready) pipe = subprocess.PIPE cmd = [ sys.executable, self.fle.name, self.out.name, str(os.getpid()), *[str(a) for a in (extra_argv or [])], ] p = subprocess.Popen( cmd, stdout=pipe, stderr=pipe, ) await fut if sig is not None: p.send_signal(sig) p.wait(timeout=1) got_stdout = p.stdout.read().decode() got_stderr = p.stderr.read().decode().split("\n") redacted_stderr = [] in_tb = False last_line = None for line in got_stderr: if last_line == "During handling of the above exception, another exception occurred:": redacted_stderr = redacted_stderr[:-5] last_line = line continue if in_tb and not line.startswith(" "): in_tb = False if line.startswith("Traceback"): in_tb = True redacted_stderr.append(line) redacted_stderr.append(" <REDACTED>") if not in_tb: redacted_stderr.append(line) last_line = line got_stderr = "\n".join(redacted_stderr) print("=== STDOUT:") print(got_stdout) print("=== STDERR:") print(got_stderr) def assertOutput(out, regex): if regex is None: return regex = dedent(regex).strip() out = out.strip() regex = regex.strip() assert len(out.split("\n")) == len(regex.split("\n")) pytest.helpers.assertRegex(f"(?m){regex}", out) if expected_output is not None: got = None if os.path.exists(self.out.name): with open(self.out.name) as o: got = o.read().strip() assert got == expected_output.strip() assertOutput(got_stdout.strip(), expected_stdout) assertOutput(got_stderr.strip(), expected_stderr) assert p.returncode == code return got_stdout, got_stderr
class _MeshManager: def __init__(self, report=None): self.context_stack = ExitStack() if report is not None: self._report = report self._entered = False self._overrides = {} @staticmethod def add_progress_presteps(report): report.progress_add_step("Applying Blender Mods") def _build_prop_dict(self, bstruct): props = {} for i in bstruct.bl_rna.properties: ident = i.identifier if ident == "rna_type": continue props[ident] = getattr(bstruct, ident) if getattr( i, "array_length", 0) == 0 else tuple(getattr(bstruct, ident)) return props def __enter__(self): assert self._entered is False, "_MeshManager is not reentrant" self._entered = True self.context_stack.__enter__() scene = bpy.context.scene self._report.progress_advance() self._report.progress_range = len(scene.objects) # Some modifiers like "Array" will procedurally generate new geometry that will impact # lightmap generation. The Blender Internal renderer does not seem to be smart enough to # take this into account. Thus, we temporarily apply modifiers to ALL meshes (even ones that # are not exported) such that we can generate proper lighting. mesh_type = bpy.types.Mesh for i in scene.objects: if isinstance(i.data, mesh_type) and i.is_modified( scene, "RENDER"): # Remember, storing actual pointers to the Blender objects can cause bad things to # happen because Blender's memory management SUCKS! self._overrides[i.name] = { "mesh": i.data.name, "modifiers": [] } i.data = i.to_mesh(scene, True, "RENDER", calc_tessface=False) # If the modifiers are left on the object, the lightmap bake can break under some # situations. Therefore, we now cache the modifiers and clear them away... if i.plasma_object.enabled: cache_mods = self._overrides[i.name]["modifiers"] for mod in i.modifiers: cache_mods.append(self._build_prop_dict(mod)) i.modifiers.clear() self._report.progress_increment() return self def __exit__(self, *exc_info): try: self.context_stack.__exit__(*exc_info) finally: data_bos, data_meshes = bpy.data.objects, bpy.data.meshes for obj_name, override in self._overrides.items(): bo = data_bos.get(obj_name) # Reapply the old mesh trash_mesh, bo.data = bo.data, data_meshes.get( override["mesh"]) data_meshes.remove(trash_mesh) # If modifiers were removed, reapply them now unless they're read-only. readonly_attributes = { ("DECIMATE", "face_count"), } for cached_mod in override["modifiers"]: mod = bo.modifiers.new(cached_mod["name"], cached_mod["type"]) for key, value in cached_mod.items(): if key in { "name", "type" } or (cached_mod["type"], key) in readonly_attributes: continue setattr(mod, key, value) self._entered = False def is_collapsed(self, bo) -> bool: return bo.name in self._overrides
class Application(object): def __init__(self): super(Application, self).__init__() self._test_loader = Loader() self.set_report_stream(sys.stderr) self._argv = None self._reset_parser() self._positional_args = None self._parsed_args = None self._reporter = None self.test_loader = Loader() self.session = None self._working_directory = None self._interrupted = False self._exit_code = 0 self._prelude_warning_records = [] def set_working_directory(self, path): self._working_directory = path @property def exit_code(self): return self._exit_code def set_exit_code(self, exit_code): self._exit_code = exit_code @property def interrupted(self): return self._interrupted @property def positional_args(self): return self._positional_args @property def parsed_args(self): return self._parsed_args def enable_interactive(self): self.arg_parser.add_argument( '-i', '--interactive', help='Enter an interactive shell', action="store_true", default=False) def _reset_parser(self): self.arg_parser = cli_utils.SlashArgumentParser() def set_argv(self, argv): self._argv = list(argv) def _get_argv(self): if self._argv is None: return sys.argv[1:] return self._argv[:] def set_report_stream(self, stream): if stream is not None: self._report_stream = stream self._default_reporter = ConsoleReporter(level=logbook.ERROR, stream=stream) self._console_handler = ConsoleHandler(stream=stream, level=logbook.ERROR) def set_reporter(self, reporter): self._reporter = reporter def get_reporter(self): returned = self._reporter if returned is None: returned = ConsoleReporter( level=config.root.log.console_level, stream=self._report_stream) return returned def __enter__(self): self._exit_stack = ExitStack() self._exit_stack.__enter__() try: self._exit_stack.enter_context(self._prelude_logging_context()) self._exit_stack.enter_context(self._prelude_warning_context()) self._exit_stack.enter_context(self._sigterm_context()) with dessert.rewrite_assertions_context(): site.load(working_directory=self._working_directory) cli_utils.configure_arg_parser_by_plugins(self.arg_parser) cli_utils.configure_arg_parser_by_config(self.arg_parser) argv = cli_utils.add_pending_plugins_from_commandline(self._get_argv()) self._parsed_args, self._positional_args = self.arg_parser.parse_known_args(argv) self._exit_stack.enter_context( cli_utils.get_modified_configuration_from_args_context(self.arg_parser, self._parsed_args) ) self.session = Session(reporter=self.get_reporter(), console_stream=self._report_stream) trigger_hook.configure() # pylint: disable=no-member plugins.manager.configure_for_parallel_mode() plugins.manager.activate_pending_plugins() cli_utils.configure_plugins_from_args(self._parsed_args) self._exit_stack.enter_context(self.session) self._emit_prelude_logs() self._emit_prelude_warnings() return self except: self._emit_prelude_logs() self.__exit__(*sys.exc_info()) raise def __exit__(self, exc_type, exc_value, exc_tb): exc_info = (exc_type, exc_value, exc_tb) try: debug_if_needed(exc_info) except Exception as e: # pylint: disable=broad-except _logger.error("Failed to debug_if_needed: {!r}", e, exc_info=True, extra={'capture': False}) if exc_value is not None: self._exit_code = exc_value.code if isinstance(exc_value, SystemExit) else 1 if should_inhibit_unhandled_exception_traceback(exc_value): self.get_reporter().report_error_message(str(exc_value)) elif isinstance(exc_value, Exception): _logger.error('Unexpected error occurred', exc_info=exc_info, extra={'capture': False}) self.get_reporter().report_error_message('Unexpected error: {}'.format(exc_value)) if isinstance(exc_value, exceptions.INTERRUPTION_EXCEPTIONS): self._interrupted = True if exc_type is not None: trigger_hook.result_summary() # pylint: disable=no-member self._exit_stack.__exit__(exc_type, exc_value, exc_tb) self._exit_stack = None self._reset_parser() trigger_hook.app_quit() # pylint: disable=no-member return True def _capture_native_warning(self, message, category, filename, lineno, file=None, line=None): # pylint: disable=unused-argument self._prelude_warning_records.append(RecordedWarning.from_native_warning(message, category, filename, lineno)) def _prelude_logging_context(self): self._prelude_log_handler = log.RetainedLogHandler(bubble=True, level=logbook.TRACE) return self._prelude_log_handler.applicationbound() def _prelude_warning_context(self): capture_all_warnings() return warning_callback_context(self._capture_native_warning) def _emit_prelude_warnings(self): if self.session is not None: for warning in self._prelude_warning_records: if not self.session.warnings.warning_should_be_filtered(warning): self.session.warnings.add(warning) def _emit_prelude_logs(self): self._prelude_log_handler.disable() handler = None if self.session is not None: handler = self.session.logging.session_log_handler if handler is None: handler = self._console_handler self._prelude_log_handler.flush_to_handler(handler) @contextmanager def _sigterm_context(self): def handle_sigterm(*_): with handling_exceptions(): raise TerminatedException('Terminated by signal') prev = signal.signal(signal.SIGTERM, handle_sigterm) try: yield finally: try: signal.signal(signal.SIGTERM, prev) except TypeError as e: #workaround for a strange issue on app cleanup. See https://bugs.python.org/issue23548 if 'signal handler must be signal.SIG_IGN' not in str(e): raise
class AbstractWebcamFilterApp(ABC): def __init__(self, args: argparse.Namespace): self.args = args self.bodypix_model = None self.output_sink = None self.image_source = None self.image_iterator = None self.timer = LoggingTimer() self.masks: List[np.ndarray] = [] self.exit_stack = ExitStack() self.bodypix_result_cache_time = None self.bodypix_result_cache = None @abstractmethod def get_output_image(self, image_array: np.ndarray) -> np.ndarray: pass def get_mask(self, *args, **kwargs): return get_mask( *args, masks=self.masks, timer=self.timer, args=self.args, **kwargs ) def get_bodypix_result(self, image_array: np.ndarray) -> BodyPixResultWrapper: assert self.bodypix_model is not None current_time = time() if ( self.bodypix_result_cache is not None and current_time < self.bodypix_result_cache_time + self.args.mask_cache_time ): return self.bodypix_result_cache self.bodypix_result_cache = self.bodypix_model.predict_single(image_array) self.bodypix_result_cache_time = current_time return self.bodypix_result_cache def __enter__(self): self.exit_stack.__enter__() self.bodypix_model = load_bodypix_model(self.args) self.output_sink = self.exit_stack.enter_context(get_output_sink(self.args)) self.image_source = self.exit_stack.enter_context(get_image_source_for_args(self.args)) self.image_iterator = iter(self.image_source) return self def __exit__(self, *args, **kwargs): self.exit_stack.__exit__(*args, **kwargs) def next_frame(self): self.timer.on_frame_start(initial_step_name='in') try: image_array = next(self.image_iterator) except StopIteration: return False self.timer.on_step_start('model') output_image = self.get_output_image(image_array) self.timer.on_step_start('out') self.output_sink(output_image) self.timer.on_frame_end() return True def run(self): try: self.timer.start() while self.next_frame(): pass if self.args.show_output: LOGGER.info('waiting for window to be closed') while not self.output_sink.is_closed: sleep(0.5) except KeyboardInterrupt: LOGGER.info('exiting')
class Monitor: """ A process monitor. """ def __init__(self, config: Config): self._stack = ExitStack() self._dir = Path(config.output_dir) / "monitor_logs" self._dir.mkdir(parents=True, exist_ok=True) self._verbose = config.verbose def __enter__(self): self._stack.__enter__() return self def __exit__(self, exc_type, exc_value, traceback): self._stack.__exit__(exc_type, exc_value, traceback) def get_args(self, proc): return proc.args def check_running(self, proc) -> bool: return proc.poll() is None def wait(self, proc): proc.wait() def wait_all(self, procs): for proc in procs: self.wait(proc) def get_return_code(self, proc): return proc.returncode def check_success(self, proc): proc.wait() if proc.returncode != 0: raise RuntimeError( f"Process failed with status {proc.returncode}, args: {proc.args}" ) def check_success_all(self, procs): for proc in procs: self.check_success(proc) def _spawn(self, command, name, bg=False, check=True, stdout_override=None, stderr_override=None) -> subprocess.Popen: if self._verbose: logging.info(f"Monitor - spawning {command}. " + ("[bg]" if bg else "") + ("[chk]" if check else "")) if bg or not self._verbose: stdin = subprocess.DEVNULL # Don't create temp files if we are overriding command_str = ' '.join(map(str, command)) + "\n" if stdout_override == None: f = NamedTemporaryFile(prefix=f"{name}.", suffix=".stdout", dir=self._dir, delete=False) f.write(command_str.encode()) stdout = f if stderr_override == None: f = NamedTemporaryFile(prefix=f"{name}.", suffix=".stderr", dir=self._dir, delete=False) f.write(command_str.encode()) stderr = f else: stdin = None stdout = None stderr = None if stdout_override != None: stdout = stdout_override if stderr_override != None: stderr = stderr_override proc = subprocess.Popen(command, stdin=stdin, stdout=stdout, stderr=stderr) if bg: self._stack.enter_context(proc) else: if check: self.check_success(proc) else: proc.wait() return proc def spawn(self, command, bg=False, check=True, stdout_override=None, stderr_override=None) -> subprocess.Popen: """ Run a command. """ name = PurePath(command[0]).name return self._spawn(command, name=name, bg=bg, check=check, stdout_override=stdout_override, stderr_override=stderr_override) def ssh_spawn(self, address, command, bg=False, check=True, stdout_override=None, stderr_override=None) -> subprocess.Popen: """ Run a command over ssh. """ cmd = ["ssh", "-o", "StrictHostKeyChecking=accept-new", "-tt"] if not self._verbose: cmd.append("-q") cmd += [address, "--"] cmd += [shlex.quote(arg) for arg in command] name = PurePath(command[0]).name return self._spawn(cmd, name=f"{address}.{name}", bg=bg, check=check, stdout_override=stdout_override, stderr_override=stderr_override) def ssh_spawn_all(self, addresses, command, bg=False, check=True) -> List[subprocess.Popen]: """ Run ssh commands for all addresses, returning an array of procs """ each_bg = bg if len(addresses) > 1: each_bg = True procs = [] for addr in addresses: procs.append(self.ssh_spawn(addr, command, bg=each_bg, check=check)) if not bg: if check: self.check_success_all(procs) else: self.wait_all(procs) return procs
class FormData(aiohttpFormData): """FormData used to upload files to remote""" def __init__( self, paths: Optional[List[str]] = None, logger: 'JinaLogger' = None, complete: bool = False, ) -> None: super().__init__() self._logger = logger self._complete = complete self._cur_dir = os.getcwd() self.paths = paths self._stack = ExitStack() def add(self, path: Path): """add a field to Form :param path: filepath """ self.add_field( name='files', value=self._stack.enter_context( open( complete_path(path, extra_search_paths=[self._cur_dir]) if self._complete else path, 'rb', )), filename=path.name, ) @property def fields(self) -> List[Any]: """all fields in current Form :return: list of fields """ return self._fields @property def filenames(self) -> List[str]: """all filenames in current Form :return: list of filenames """ return [os.path.basename(f[-1].name) for f in self.fields] def __len__(self): return len(self._fields) def __enter__(self): self._stack.__enter__() if not self.paths: return self tmpdir = self._stack.enter_context(TemporaryDirectory()) self._stack.enter_context(change_cwd(tmpdir)) for path in map(Path, self.paths): try: filename = path.name if path.is_file(): self.add(path) elif path.is_dir(): make_archive(base_name=filename, format='zip', root_dir=path) self.add(Path(tmpdir) / f'{filename}.zip') except TypeError: self._logger.error(f'invalid path passed {path}') continue self._logger.info(( f'{len(self)} file(s) ready to be uploaded: {", ".join(self.filenames)}' if len(self) > 0 else 'No file to be uploaded')) return self def __exit__(self, *args, **kwargs): self._stack.__exit__(*args, **kwargs)
class LocalDarRepository: def __init__(self): warnings.warn( "LocalDarRepository is deprecated; use PackageLookup instead", DeprecationWarning, stacklevel=2, ) self._context = ExitStack() self._dar_paths = [] # type: List[Path] self._files = set() with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) from ..model.types_store import PackageStore self.store = PackageStore.empty() def _add_source(self, path: Path) -> None: ext = path.suffix.lower() if ext == ".daml": LOG.error( "Reading metadata directly from a DAML file is not supported.") raise ValueError("Unsupported extension: .daml") elif ext == ".dalf": LOG.error( "Reading metadata directly from a DALF file is not supported.") raise ValueError("Unsupported extension: .dalf") elif ext == ".dar": dar_parse_start_time = time.time() with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) from ..util.dar import DarFile dar = self._context.enter_context(DarFile(path)) dar_package = dar.read_metadata() self._dar_paths.append(path) self.store.register_all(dar_package) dar_parse_end_time = time.time() LOG.debug("Parsed a dar in %s seconds.", dar_parse_end_time - dar_parse_start_time) else: LOG.error("Unknown extension: %s", ext) raise ValueError(f"Unknown extension: {ext}") def add_source(self, *files: Union[str, PathLike, Path]) -> None: """ Add a source file (either a .daml file, .dalf file, or a .dar file). Attempts to add the same file more than once will be ignored. :param files: Files to add to the archive. """ for file in files: path = Path(fspath(file)).resolve(strict=True) if path not in self._files: self._files.add(path) self._add_source(path) def get_daml_archives(self) -> Sequence[Path]: return self._dar_paths def __enter__(self): """ Does nothing. """ self._context.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): """ Delete all managed resources in the reverse order that they were created. """ self._context.__exit__(exc_type, exc_val, exc_tb)
class LocalDarRepository: def __init__(self): self._context = ExitStack() self._dar_paths = [] # type: List[Path] self._files = set() self.store = PackageStore.empty() def _add_source(self, path: Path) -> None: ext = path.suffix.lower() if ext == '.daml': # dar_create_start_time = time.time() dar_paths = self._context.enter_context(TemporaryDar(path)) self.add_source(*dar_paths) # LOG.info('Compiled a dar in % seconds.', time.time() - dar_create_start_time) elif ext == '.dalf': dalf_package = parse_dalf(path.read_bytes()) self._dar_paths.append(path) self.store.register_all(dalf_package) elif ext == '.dar': # dar_parse_start_time = time.time() dar = self._context.enter_context(DarFile(path)) dar_package = dar.read_metadata() self._dar_paths.append(path) self.store.register_all(dar_package) # dar_parse_end_time = time.time() # LOG.info('Parsed a dar in %s seconds.', dar_parse_end_time - dar_parse_start_time) else: LOG.error('Unknown extension: %s', ext) raise ValueError(f'Unknown extension: {ext}') def add_source(self, *files: Union[str, Path]) -> None: """ Add a source file (either a .daml file, .dalf file, or a .dar file). Attempts to add the same file more than once will be ignored. :param files: Files to add to the archive. """ for file in files: path = pathify(file).resolve(strict=True) if path not in self._files: self._files.add(path) self._add_source(path) # stores = list() # files = list(files) # while files: # file = files.pop(0) # ext = pathify(file).suffix.lower() # # if ext == '.daml': # # dar_create_start_time = time.time() # dar_paths = self._context.enter_context(TemporaryDar(file)) # #LOG.info('Compiled a dar in % seconds.', time.time() - dar_create_start_time) # files.extend(dar_paths) # elif ext == '.dalf': # with open(file, 'rb') as f: # contents = f.read() # stores.append(parse_dalf('pkg0', contents)) # elif ext == '.dar': # #dar_parse_start_time = time.time() # dar = self._context.enter_context(DarFile(file)) # stores.append(dar.read_metadata()) # #dar_parse_end_time = time.time() # #LOG.info('Parsed a dar in %s seconds.', dar_parse_end_time - dar_parse_start_time) # else: # LOG.error('Unknown extension: %s', ext) # raise ValueError(f'Unknown extension: {ext}') # # reduce(lambda store, other_store: store.register_all(other_store), stores, self.store) def get_daml_archives(self) -> Sequence[Path]: return self._dar_paths def __enter__(self): """ Does nothing. """ self._context.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): """ Delete all managed resources in the reverse order that they were created. """ self._context.__exit__(exc_type, exc_val, exc_tb)
class SocketIOCalculator(Calculator): implemented_properties = ['energy', 'free_energy', 'forces', 'stress'] supported_changes = {'positions', 'cell'} def __init__(self, calc=None, port=None, unixsocket=None, timeout=None, log=None, *, launch_client=None): """Initialize socket I/O calculator. This calculator launches a server which passes atomic coordinates and unit cells to an external code via a socket, and receives energy, forces, and stress in return. ASE integrates this with the Quantum Espresso, FHI-aims and Siesta calculators. This works with any external code that supports running as a client over the i-PI protocol. Parameters: calc: calculator or None If calc is not None, a client process will be launched using calc.command, and the input file will be generated using ``calc.write_input()``. Otherwise only the server will run, and it is up to the user to launch a compliant client process. port: integer port number for socket. Should normally be between 1025 and 65535. Typical ports for are 31415 (default) or 3141. unixsocket: str or None if not None, ignore host and port, creating instead a unix socket using this name prefixed with ``/tmp/ipi_``. The socket is deleted when the calculator is closed. timeout: float >= 0 or None timeout for connection, by default infinite. See documentation of Python sockets. For longer jobs it is recommended to set a timeout in case of undetected client-side failure. log: file object or None (default) logfile for communication over socket. For debugging or the curious. In order to correctly close the sockets, it is recommended to use this class within a with-block: >>> with SocketIOCalculator(...) as calc: ... atoms.calc = calc ... atoms.get_forces() ... atoms.rattle() ... atoms.get_forces() It is also possible to call calc.close() after use. This is best done in a finally-block.""" self._exitstack = ExitStack() Calculator.__init__(self) if calc is not None: if launch_client is not None: raise ValueError('Cannot pass both calc and launch_client') launch_client = FileIOSocketClientLauncher(calc) self.launch_client = launch_client #self.calc = calc self.timeout = timeout self.server = None if isinstance(log, str): self.log = self._exitstack.enter_context(open(log, 'w')) else: self.log = log # We only hold these so we can pass them on to the server. # They may both be None as stored here. self._port = port self._unixsocket = unixsocket # If there is a calculator, we will launch in calculate() because # we are responsible for executing the external process, too, and # should do so before blocking. Without a calculator we want to # block immediately: if self.launch_client is None: self.server = self.launch_server() def todict(self): d = {'type': 'calculator', 'name': 'socket-driver'} #if self.calc is not None: # d['calc'] = self.calc.todict() return d def launch_server(self): return self._exitstack.enter_context( SocketServer( #launch_client=launch_client, port=self._port, unixsocket=self._unixsocket, timeout=self.timeout, log=self.log, )) def calculate(self, atoms=None, properties=['energy'], system_changes=all_changes): bad = [ change for change in system_changes if change not in self.supported_changes ] # First time calculate() is called, system_changes will be # all_changes. After that, only positions and cell may change. if self.atoms is not None and any(bad): raise PropertyNotImplementedError( 'Cannot change {} through IPI protocol. ' 'Please create new socket calculator.'.format( bad if len(bad) > 1 else bad[0])) self.atoms = atoms.copy() if self.server is None: self.server = self.launch_server() proc = self.launch_client(atoms, properties, port=self._port, unixsocket=self._unixsocket) self.server.proc = proc # XXX nasty hack results = self.server.calculate(atoms) results['free_energy'] = results['energy'] virial = results.pop('virial') if self.atoms.cell.rank == 3 and any(self.atoms.pbc): vol = atoms.get_volume() results['stress'] = -full_3x3_to_voigt_6_stress(virial) / vol self.results.update(results) def close(self): try: self.server = None finally: self._exitstack.close() def __enter__(self): self._exitstack.__enter__() return self def __exit__(self, type, value, traceback): self.close()
class HorovodStrategy(ParallelStrategy): """Plugin for Horovod distributed training integration.""" distributed_backend = _StrategyType.HOROVOD def __init__( self, accelerator: Optional[ "pl.accelerators.accelerator.Accelerator"] = None, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ): super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, ) rank_zero_only.rank = self.global_rank self._exit_stack: Optional[ExitStack] = None @property def global_rank(self) -> int: return hvd.rank() @property def local_rank(self) -> int: return hvd.local_rank() @property def world_size(self) -> int: return hvd.size() @property def root_device(self): return self.parallel_devices[self.local_rank] @property def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs def setup(self, trainer: "pl.Trainer") -> None: self.model_to_device() super().setup(trainer) self._exit_stack = ExitStack() self._exit_stack.__enter__() if not self.lightning_module.trainer.training: # no need to setup optimizers return def _unpack_lightning_optimizer(opt): return opt._optimizer if isinstance(opt, LightningOptimizer) else opt optimizers = self.optimizers optimizers = [_unpack_lightning_optimizer(opt) for opt in optimizers] # Horovod: scale the learning rate by the number of workers to account for # increased total batch size for optimizer in optimizers: for param_group in optimizer.param_groups: param_group["lr"] *= self.world_size # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR lr_scheduler_configs = self.lr_schedulers for config in lr_scheduler_configs: scheduler = config.scheduler scheduler.base_lrs = [ lr * self.world_size for lr in scheduler.base_lrs ] # Horovod: broadcast parameters & optimizer state to ensure consistent initialization hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0) for optimizer in optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0) self.optimizers = self._wrap_optimizers(optimizers) for optimizer in self.optimizers: # Synchronization will be performed explicitly following backward() self._exit_stack.enter_context(optimizer.skip_synchronize()) def barrier(self, *args, **kwargs): if distributed_available(): self.join() def broadcast(self, obj: object, src: int = 0) -> object: obj = hvd.broadcast_object(obj, src) return obj def model_to_device(self): if self.on_gpu: # this can potentially be removed after #8312. Not done due to lack of horovod testing torch.cuda.set_device(self.root_device) self.model.to(self.root_device) def join(self): if self.on_gpu: hvd.join(self.local_rank) else: hvd.join() def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): """Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to 'mean'/'avg'. Can also be a string 'sum' to calculate the sum during reduction. Return: reduced value, except when the input was not a tensor the output remains is unchanged """ if group is not None: raise ValueError( "Horovod does not support allreduce using a subcommunicator at this time. Unset `group`." ) if reduce_op in (None, "avg", "mean"): reduce_op = hvd.Average elif reduce_op in ("sum", ReduceOp.SUM): reduce_op = hvd.Sum else: raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") # sync all processes before reduction self.join() return hvd.allreduce(tensor, op=reduce_op) def all_gather(self, result: torch.Tensor, group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False) -> torch.Tensor: if group is not None and group != dist_group.WORLD: raise ValueError( "Horovod does not support allgather using a subcommunicator at this time. Unset `group`." ) if len(result.shape) == 0: # Convert scalars to single dimension tensors result = result.reshape(1) # sync and gather all self.join() return hvd.allgather(result) def post_backward(self, closure_loss: torch.Tensor) -> None: # synchronize all horovod optimizers. for optimizer in self.lightning_module.trainer.optimizers: optimizer.synchronize() def _wrap_optimizers( self, optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]: """Wraps optimizers to perform gradient aggregation via allreduce.""" return [ hvd.DistributedOptimizer( opt, named_parameters=self._filter_named_parameters( self.lightning_module, opt)) if "horovod" not in str(opt.__class__) else opt for opt in optimizers ] @staticmethod def _filter_named_parameters( model: nn.Module, optimizer: Optimizer) -> List[Tuple[str, nn.Parameter]]: opt_params = { p for group in optimizer.param_groups for p in group.get("params", []) } return [(name, p) for name, p in model.named_parameters() if p in opt_params] def teardown(self) -> None: super().teardown() self._exit_stack.__exit__(None, None, None) self._exit_stack = None # Make sure all workers have finished training before returning to the user self.join() if self.on_gpu: # GPU teardown self.lightning_module.cpu() # clean up memory torch.cuda.empty_cache()
class _TestUser(object): def __init__(self, test_client, runestone_db_tools, username, password, course_name, # True if the course is free (no payment required); False otherwise. is_free=True): self.test_client = test_client self.runestone_db_tools = runestone_db_tools self.username = username self.first_name = 'test' self.last_name = 'user' self.email = self.username + '@foo.com' self.password = password self.course_name = course_name self.is_free = is_free def __enter__(self): # Registration doesn't work unless we're logged out. self.test_client.logout() # Now, post the registration. self.test_client.validate('default/user/register', 'Support Runestone Interactive' if self.is_free else 'Payment Amount', data=dict( username=self.username, first_name=self.first_name, last_name=self.last_name, # The e-mail address must be unique. email=self.email, password=self.password, password_two=self.password, # Note that ``course_id`` is (on the form) actually a course name. course_id=self.course_name, accept_tcp='on', donate='0', _next='/runestone/default/index', _formname='register', ) ) # Schedule this user for deletion. self.exit_stack_object = ExitStack() self.exit_stack = self.exit_stack_object.__enter__() self.exit_stack.callback(self._delete_user) # Record IDs db = self.runestone_db_tools.db self.course_id = db(db.courses.course_name == self.course_name).select(db.courses.id).first().id self.user_id = db(db.auth_user.username == self.username).select(db.auth_user.id).first().id return self # Clean up on exit by invoking all ``__exit__`` methods. def __exit__(self, exc_type, exc_value, traceback): self.exit_stack_object.__exit__(exc_type, exc_value, traceback) # Delete the user created by entering this context manager. TODO: This doesn't delete all the chapter progress tracking stuff. def _delete_user(self): db = self.runestone_db_tools.db # Delete the course this user registered for. db(( db.user_courses.course_id == self.course_id) & (db.user_courses.user_id == self.user_id) ).delete() # Delete the user. db(db.auth_user.username == self.username).delete() db.commit() def login(self): self.test_client.post('default/user/login', data=dict( username=self.username, password=self.password, _formname='login', )) def make_instructor(self, course_id=None): # If ``course_id`` isn't specified, use this user's ``course_id``. course_id = course_id or self.course_id return self.runestone_db_tools.make_instructor(self.user_id, course_id) # A context manager to update this user's profile. If a course was added, it returns that course's ID; otherwise, it returns None. @contextmanager def update_profile(self, # This parameter is passed to ``test_client.validate``. expected_string=None, # An updated username, or ``None`` to use ``self.username``. username=None, # An updated first name, or ``None`` to use ``self.first_name``. first_name=None, # An updated last name, or ``None`` to use ``self.last_name``. last_name=None, # An updated email, or ``None`` to use ``self.email``. email=None, # An updated last name, or ``None`` to use ``self.course_name``. course_name=None, section='', # A shortcut for specifying the ``expected_string``, which only applies if ``expected_string`` is not set. Use ``None`` if a course will not be added, ``True`` if the added course is free, or ``False`` if the added course is paid. is_free=None, # The value of the ``accept_tcp`` checkbox; provide an empty string to leave unchecked. The default value leaves it checked. accept_tcp='on'): if expected_string is None: if is_free is None: expected_string = 'Course Selection' else: expected_string = 'Support Runestone Interactive' \ if is_free else 'Payment Amount' username = username or self.username first_name = first_name or self.first_name last_name = last_name or self.last_name email = email or self.email course_name = course_name or self.course_name db = self.runestone_db_tools.db # Determine if we're adding a course. If so, delete it at the end of the test. To determine if a course is being added, the course must exist, but not be in the user's list of courses. course = db(db.courses.course_name == course_name).select(db.courses.id).first() delete_at_end = course and not db((db.user_courses.user_id == self.user_id) & (db.user_courses.course_id == course.id)).select(db.user_courses.id).first() # Perform the update. try: self.test_client.validate('default/user/profile', expected_string, data=dict( username=username, first_name=first_name, last_name=last_name, email=email, # Though the field is ``course_id``, it's really the course name. course_id=course_name, accept_tcp=accept_tcp, section=section, _next='/runestone/default/index', id=str(self.user_id), _formname='auth_user/' + str(self.user_id), ) ) yield course.id if delete_at_end else None finally: if delete_at_end: db = self.runestone_db_tools.db db((db.user_courses.user_id == self.user_id) & (db.user_courses.course_id == course.id)).delete() db.commit() # Call this after registering for a new course or adding a new course via ``update_profile`` to pay for the course. @contextmanager def make_payment(self, # The `Stripe test tokens <https://stripe.com/docs/testing#cards>`_ to use for payment. stripe_token, # The course ID of the course to pay for. None specifies ``self.course_id``. course_id=None): course_id = course_id or self.course_id # Get the signature from the HTML of the payment page. self.test_client.validate('default/payment') match = re.search('<input type="hidden" name="signature" value="([^ ]*)" />', self.test_client.text) signature = match.group(1) try: html = self.test_client.validate('default/payment', data=dict(stripeToken=stripe_token, signature=signature) ) assert ('Thank you for your payment' in html) or ('Payment failed' in html) yield None finally: db = self.runestone_db_tools.db db((db.user_courses.course_id == course_id) & (db.user_courses.user_id == self.user_id) ).delete() # Try to delete the payment try: db((db.user_courses.user_id == self.user_id) & (db.user_courses.course_id == course_id) & (db.user_courses.id == db.payments.user_courses_id)) \ .delete() except: pass db.commit() @contextmanager def hsblog(self, **kwargs): try: yield json.loads(self.test_client.validate('ajax/hsblog', data=kwargs)) finally: # Try to remove this hsblog entry. event = kwargs.get('event') div_id = kwargs.get('div_id') course = kwargs.get('course') db = self.runestone_db_tools.db criteria = ((db.useinfo.sid == self.username) & (db.useinfo.act == kwargs.get('act', '')) & (db.useinfo.div_id == div_id) & (db.useinfo.event == event) & (db.useinfo.course_id == course) ) useinfo_row = db(criteria).select(db.useinfo.id, orderby=db.useinfo.id).last() if useinfo_row: del db.useinfo[useinfo_row.id]