Exemple #1
0
class LocalDarRepository:
    def __init__(self):
        self._context = ExitStack()
        self._dar_paths = []  # type: List[Path]
        self._files = set()
        self.store = PackageStore.empty()

    def _add_source(self, path: Path) -> None:
        ext = path.suffix.lower()

        if ext == '.daml':
            LOG.error('Reading metadata directly from a DAML file is not supported.')
            raise ValueError('Unsupported extension: .daml')

        elif ext == '.dalf':
            LOG.error('Reading metadata directly from a DALF file is not supported.')
            raise ValueError('Unsupported extension: .dalf')

        elif ext == '.dar':
            dar_parse_start_time = time.time()
            dar = self._context.enter_context(DarFile(path))
            dar_package = dar.read_metadata()
            self._dar_paths.append(path)
            self.store.register_all(dar_package)
            dar_parse_end_time = time.time()
            LOG.debug('Parsed a dar in %s seconds.', dar_parse_end_time - dar_parse_start_time)

        else:
            LOG.error('Unknown extension: %s', ext)
            raise ValueError(f'Unknown extension: {ext}')

    def add_source(self, *files: Union[str, Path]) -> None:
        """
        Add a source file (either a .daml file, .dalf file, or a .dar file).

        Attempts to add the same file more than once will be ignored.

        :param files: Files to add to the archive.
        """
        for file in files:
            path = pathify(file).resolve(strict=True)
            if path not in self._files:
                self._files.add(path)
                self._add_source(path)

    def get_daml_archives(self) -> Sequence[Path]:
        return self._dar_paths

    def __enter__(self):
        """
        Does nothing.
        """
        self._context.__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """
        Delete all managed resources in the reverse order that they were created.
        """
        self._context.__exit__(exc_type, exc_val, exc_tb)
Exemple #2
0
class SyncEventSource:
    _subscribers: Dict[Type[Event], List[Callable[[Event], Any]]]

    def __init__(self,
                 async_event_source: EventSource,
                 portal: Optional[BlockingPortal] = None):
        self.portal = portal
        self._async_event_source = async_event_source
        self._async_event_source.subscribe(self._forward_async_event)
        self._exit_stack = ExitStack()
        self._subscribers = defaultdict(list)

    def __enter__(self):
        self._exit_stack.__enter__()
        if not self.portal:
            portal_cm = start_blocking_portal()
            self.portal = self._exit_stack.enter_context(portal_cm)
            self._async_event_source.subscribe(self._forward_async_event)

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._exit_stack.__exit__(exc_type, exc_val, exc_tb)

    async def _forward_async_event(self, event: Event) -> None:
        for subscriber in self._subscribers.get(type(event), ()):
            await run_sync_in_worker_thread(subscriber, event)

    def subscribe(self,
                  callback: Callable[[Event], Any],
                  event_types: Optional[Iterable[Type[Event]]] = None) -> None:
        if event_types is None:
            event_types = _all_event_types

        for event_type in event_types:
            existing_callbacks = self._subscribers[event_type]
            if callback not in existing_callbacks:
                existing_callbacks.append(callback)

    def unsubscribe(
            self,
            callback: Callable[[Event], Any],
            event_types: Optional[Iterable[Type[Event]]] = None) -> None:
        if event_types is None:
            event_types = _all_event_types

        for event_type in event_types:
            existing_callbacks = self._subscribers.get(event_type, [])
            with suppress(ValueError):
                existing_callbacks.remove(callback)
class ReposTestCase(AsyncTestCase):
    async def setUp(self):
        super().setUp()
        self.repos = [Repo(str(random.random())) for _ in range(0, 10)]
        self.__stack = ExitStack()
        self._stack = self.__stack.__enter__()
        self.temp_filename = self.mktempfile()
        self.sconfig = FileSourceConfig(filename=self.temp_filename)
        async with JSONSource(self.sconfig) as source:
            async with source() as sctx:
                for repo in self.repos:
                    await sctx.update(repo)
        contents = json.loads(Path(self.sconfig.filename).read_text())
        # Ensure there are repos in the file
        self.assertEqual(
            len(contents.get(self.sconfig.label)),
            len(self.repos),
            "ReposTestCase JSON file erroneously initialized as empty",
        )

    def tearDown(self):
        super().tearDown()
        self.__stack.__exit__(None, None, None)

    def mktempfile(self):
        return self._stack.enter_context(non_existant_tempfile())
Exemple #4
0
    def __enter__(self):
        stack = ExitStack()
        self.__stack = stack

        if not isinstance(self.func, partial):
            return self

        func, args, kwargs = self.func.func, self.func.args, self.func.keywords
        stack.__enter__()  # pretty sure this does nothing
        args = (
            stack.enter_context(arg) if hasattr(arg, '__enter__') else arg
            for arg in args
        )
        kwargs = {
            kw: stack.enter_context(arg) if hasattr(arg, '__enter__') else arg
            for kw, arg in kwargs.items()
        }
        return Pipe(partial(func, *args, **kwargs))
class TestOWDataSets(WidgetTest):
    def setUp(self):
        super().setUp()
        remote = {
            ("a", "b"): {"title": "K", "size": 10},
            ("a", "c"): {"title": "T", "size": 20},
            ("a", "d"): {"title": "Y", "size": 0},
        }
        self.exit = ExitStack()
        self.exit.__enter__()
        self.exit.enter_context(
            mock.patch.object(owdatasets, "list_remote", lambda: remote)
        )
        self.exit.enter_context(
            mock.patch.object(owdatasets, "list_local", lambda: {})
        )

        self.widget = self.create_widget(
            OWDataSets, stored_settings={
                "selected_id": ("a", "c"),
                "auto_commit": False,
            }
        )  # type: OWDataSets

    def tearDown(self):
        super().tearDown()
        self.exit.__exit__(None, None, None)

    def test_init(self):
        if self.widget.isBlocking():
            spy = QSignalSpy(self.widget.blockingStateChanged)
            assert spy.wait(1000)
        self.assertFalse(self.widget.isBlocking())

        model = self.widget.view.model()
        self.assertEqual(model.rowCount(), 3)

        di = self.widget.selected_dataset()
        self.assertEqual((di.prefix, di.filename), ("a", "c"))
Exemple #6
0
class SourceExtractor(object):
    """A class to extract a source package to its constituent parts"""
    def __init__(self, dsc_path, dsc):
        self.dsc_path = dsc_path
        self.dsc = dsc
        self.extracted_upstream = None
        self.extracted_debianised = None
        self.unextracted_debian_md5 = None
        self.upstream_tarballs = []
        self.exit_stack = ExitStack()

    def extract(self):
        """Extract the package to a new temporary directory."""
        raise NotImplementedError(self.extract)

    def __enter__(self):
        self.exit_stack.__enter__()
        self.extract()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
        return False
Exemple #7
0
class patch_auth:  # pylint: disable=invalid-name
    def __init__(self,
                 project_id: str = "potato-dev",
                 location: str = "moon-dark1",
                 email: str = "*****@*****.**"):
        # Realistic: actual class to be accepted by clients during validation
        # But fake: with as few attributes as possible, any API call using the credential should fail
        credentials = Credentials(
            service_account_email=email,
            signer=None,
            token_uri="",
            project_id=project_id,
        )
        managers = [
            patch("google.auth.default",
                  return_value=(credentials, project_id)),
            patch("gcp_pilot.base.GoogleCloudPilotAPI._set_location",
                  return_value=location),
            patch("gcp_pilot.base.AppEngineBasedService._set_location",
                  return_value=location),
        ]
        self.stack = ExitStack()
        for mgr in managers:
            self.stack.enter_context(mgr)

    def __enter__(self):
        return self.stack.__enter__()

    def start(self):
        return self.__enter__()

    def __exit__(self, typ, val, traceback):
        return self.stack.__exit__(typ, val, traceback)

    def stop(self):
        self.__exit__(None, None, None)

    def __call__(self, func):
        @wraps(func)
        def wrapper(*args, **kw):
            with self:
                return func(*args, **kw)

        return wrapper
Exemple #8
0
class _ManagedPkgData(ContextDecorator):

    _section_re = re.compile(
        r'(?P<module>[A-Za-z_][A-Za-z0-9_]+[.][A-Za-z_][A-Za-z0-9_]+)'
        r'([:](?P<label>[A-Za-z_][A-Za-z0-9_-]))?')

    def __init__(self):
        self._stack = None
        self._tmp_dir = None

    def __enter__(self):
        self._stack = ExitStack()
        self._stack.__enter__()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._stack.__exit__(exc_type, exc_val, exc_tb)
        self._tmp_dir = None
        self._stack = None

    def get_tmp_dir(self):
        if self._tmp_dir is None:
            self._tmp_dir = self._stack.enter_context(
                TemporaryDirectory(prefix=".histoqc_pkg_data_tmp",
                                   dir=os.getcwd()))
            for rsrc in {'models', 'pen', 'templates'}:
                package_resource_copytree('histoqc.data', rsrc, self._tmp_dir)
        return self._tmp_dir

    def inject_pkg_data_fallback(self, config: ConfigParser):
        """support template data packaged in module"""
        for section in config.sections():
            m = self._section_re.match(section)
            if not m:
                continue

            # dispatch
            mf = m.group('module')
            s = config[section]
            if mf == 'HistogramModule.compareToTemplates':
                self._inject_HistogramModule_compareToTemplates(s)

            elif mf == 'ClassificationModule.byExampleWithFeatures':
                self._inject_ClassificationModule_byExampleWithFeatures(s)

            else:
                pass

    def _inject_HistogramModule_compareToTemplates(self, section):
        # replace example files in a compareToTemplates config section
        # with the histoqc package data examples if available
        if 'templates' in section:
            _templates = []
            for template in map(str.strip, section['templates'].split('\n')):
                f_template = os.path.join(os.getcwd(), template)

                if not os.path.isfile(f_template):
                    tmp = self.get_tmp_dir()
                    f_template_pkg_data = os.path.join(tmp, template)
                    if os.path.isfile(f_template_pkg_data):
                        f_template = f_template_pkg_data

                _templates.append(f_template)
            section['templates'] = '\n'.join(_templates)

    def _inject_ClassificationModule_byExampleWithFeatures(self, section):
        # replace template files in a byExampleWithFeatures config section
        # with the histoqc package data templates if available
        if 'examples' in section:
            # mimic the code structure in ClassificationModule.byExampleWithFeatures,
            # which expects example templates to be specified as pairs.
            # each template in the pair is separated by a ':' and pairs are separated by a newline.
            _lines = []
            for ex in section["examples"].splitlines():
                ex = re.split(
                    r'(?<!\W[A-Za-z]):(?!\\)', ex
                )  # workaround for windows: don't split on i.e. C:backslash
                _examples = []
                for example in ex:
                    f_example = os.path.join(os.getcwd(), example)

                    if not os.path.isfile(f_example):
                        tmp = self.get_tmp_dir()
                        f_example_pkg_data = os.path.join(tmp, example)
                        if os.path.isfile(f_example_pkg_data):
                            f_example = f_example_pkg_data

                    _examples.append(f_example)
                _line = ":".join(_examples)
                _lines.append(_line)
            section['examples'] = "\n".join(_lines)
Exemple #9
0
class RunAsExternal:
    def __init__(
        self,
        task_kls,
        basekls="from photons_app.tasks.tasks import Task",
    ):
        self.script = (
            basekls + "\n" + dedent(inspect.getsource(task_kls)) + "\n" +
            dedent("""
            from photons_app.errors import ApplicationCancelled, ApplicationStopped, UserQuit, PhotonsAppError
            from photons_app.tasks.runner import Runner
            from photons_app.collector import Collector
            from photons_app import helpers as hp

            import asyncio
            import signal
            import sys
            import os

            def notify():
                os.kill(int(sys.argv[2]), signal.SIGUSR1)

            collector = Collector()
            collector.prepare(None, {})
        """) + "\n" +
            f"{task_kls.__name__}.create(collector).run_loop(collector=collector, notify=notify, output=sys.argv[1])"
        )

    def __enter__(self):
        self.exit_stack = ExitStack()
        self.exit_stack.__enter__()

        self.fle = self.exit_stack.enter_context(hp.a_temp_file())
        self.fle.write(self.script.encode())
        self.fle.flush()

        self.out = self.exit_stack.enter_context(hp.a_temp_file())
        self.out.close()
        os.remove(self.out.name)

        return self.run

    def __exit__(self, exc_typ, exc, tb):
        if hasattr(self, "exit_stack"):
            self.exit_stack.__exit__(exc_typ, exc, tb)

    async def run(
        self,
        sig=None,
        expected_stdout=None,
        expected_stderr=None,
        expected_output=None,
        extra_argv=None,
        *,
        code,
    ):

        fut = hp.create_future()

        def ready(signum, frame):
            hp.get_event_loop().call_soon_threadsafe(fut.set_result, True)

        signal.signal(signal.SIGUSR1, ready)

        pipe = subprocess.PIPE
        cmd = [
            sys.executable,
            self.fle.name,
            self.out.name,
            str(os.getpid()),
            *[str(a) for a in (extra_argv or [])],
        ]

        p = subprocess.Popen(
            cmd,
            stdout=pipe,
            stderr=pipe,
        )

        await fut

        if sig is not None:
            p.send_signal(sig)

        p.wait(timeout=1)

        got_stdout = p.stdout.read().decode()
        got_stderr = p.stderr.read().decode().split("\n")

        redacted_stderr = []
        in_tb = False

        last_line = None
        for line in got_stderr:
            if last_line == "During handling of the above exception, another exception occurred:":
                redacted_stderr = redacted_stderr[:-5]
                last_line = line
                continue

            if in_tb and not line.startswith(" "):
                in_tb = False

            if line.startswith("Traceback"):
                in_tb = True
                redacted_stderr.append(line)
                redacted_stderr.append("  <REDACTED>")

            if not in_tb:
                redacted_stderr.append(line)

            last_line = line

        got_stderr = "\n".join(redacted_stderr)

        print("=== STDOUT:")
        print(got_stdout)
        print("=== STDERR:")
        print(got_stderr)

        def assertOutput(out, regex):
            if regex is None:
                return

            regex = dedent(regex).strip()

            out = out.strip()
            regex = regex.strip()
            assert len(out.split("\n")) == len(regex.split("\n"))
            pytest.helpers.assertRegex(f"(?m){regex}", out)

        if expected_output is not None:
            got = None
            if os.path.exists(self.out.name):
                with open(self.out.name) as o:
                    got = o.read().strip()
            assert got == expected_output.strip()

        assertOutput(got_stdout.strip(), expected_stdout)
        assertOutput(got_stderr.strip(), expected_stderr)

        assert p.returncode == code

        return got_stdout, got_stderr
Exemple #10
0
class _MeshManager:
    def __init__(self, report=None):
        self.context_stack = ExitStack()
        if report is not None:
            self._report = report
        self._entered = False
        self._overrides = {}

    @staticmethod
    def add_progress_presteps(report):
        report.progress_add_step("Applying Blender Mods")

    def _build_prop_dict(self, bstruct):
        props = {}
        for i in bstruct.bl_rna.properties:
            ident = i.identifier
            if ident == "rna_type":
                continue
            props[ident] = getattr(bstruct, ident) if getattr(
                i, "array_length", 0) == 0 else tuple(getattr(bstruct, ident))
        return props

    def __enter__(self):
        assert self._entered is False, "_MeshManager is not reentrant"
        self._entered = True

        self.context_stack.__enter__()

        scene = bpy.context.scene
        self._report.progress_advance()
        self._report.progress_range = len(scene.objects)

        # Some modifiers like "Array" will procedurally generate new geometry that will impact
        # lightmap generation. The Blender Internal renderer does not seem to be smart enough to
        # take this into account. Thus, we temporarily apply modifiers to ALL meshes (even ones that
        # are not exported) such that we can generate proper lighting.
        mesh_type = bpy.types.Mesh
        for i in scene.objects:
            if isinstance(i.data, mesh_type) and i.is_modified(
                    scene, "RENDER"):
                # Remember, storing actual pointers to the Blender objects can cause bad things to
                # happen because Blender's memory management SUCKS!
                self._overrides[i.name] = {
                    "mesh": i.data.name,
                    "modifiers": []
                }
                i.data = i.to_mesh(scene, True, "RENDER", calc_tessface=False)

                # If the modifiers are left on the object, the lightmap bake can break under some
                # situations. Therefore, we now cache the modifiers and clear them away...
                if i.plasma_object.enabled:
                    cache_mods = self._overrides[i.name]["modifiers"]
                    for mod in i.modifiers:
                        cache_mods.append(self._build_prop_dict(mod))
                    i.modifiers.clear()
            self._report.progress_increment()
        return self

    def __exit__(self, *exc_info):
        try:
            self.context_stack.__exit__(*exc_info)
        finally:
            data_bos, data_meshes = bpy.data.objects, bpy.data.meshes
            for obj_name, override in self._overrides.items():
                bo = data_bos.get(obj_name)

                # Reapply the old mesh
                trash_mesh, bo.data = bo.data, data_meshes.get(
                    override["mesh"])
                data_meshes.remove(trash_mesh)

                # If modifiers were removed, reapply them now unless they're read-only.
                readonly_attributes = {
                    ("DECIMATE", "face_count"),
                }
                for cached_mod in override["modifiers"]:
                    mod = bo.modifiers.new(cached_mod["name"],
                                           cached_mod["type"])
                    for key, value in cached_mod.items():
                        if key in {
                                "name", "type"
                        } or (cached_mod["type"], key) in readonly_attributes:
                            continue
                        setattr(mod, key, value)
            self._entered = False

    def is_collapsed(self, bo) -> bool:
        return bo.name in self._overrides
Exemple #11
0
class Application(object):

    def __init__(self):
        super(Application, self).__init__()
        self._test_loader = Loader()
        self.set_report_stream(sys.stderr)
        self._argv = None
        self._reset_parser()
        self._positional_args = None
        self._parsed_args = None
        self._reporter = None
        self.test_loader = Loader()
        self.session = None
        self._working_directory = None
        self._interrupted = False
        self._exit_code = 0
        self._prelude_warning_records = []

    def set_working_directory(self, path):
        self._working_directory = path

    @property
    def exit_code(self):
        return self._exit_code

    def set_exit_code(self, exit_code):
        self._exit_code = exit_code

    @property
    def interrupted(self):
        return self._interrupted

    @property
    def positional_args(self):
        return self._positional_args

    @property
    def parsed_args(self):
        return self._parsed_args

    def enable_interactive(self):
        self.arg_parser.add_argument(
            '-i', '--interactive', help='Enter an interactive shell',
            action="store_true", default=False)

    def _reset_parser(self):
        self.arg_parser = cli_utils.SlashArgumentParser()

    def set_argv(self, argv):
        self._argv = list(argv)

    def _get_argv(self):
        if self._argv is None:
            return sys.argv[1:]
        return self._argv[:]

    def set_report_stream(self, stream):
        if stream is not None:
            self._report_stream = stream
            self._default_reporter = ConsoleReporter(level=logbook.ERROR, stream=stream)
            self._console_handler = ConsoleHandler(stream=stream, level=logbook.ERROR)

    def set_reporter(self, reporter):
        self._reporter = reporter

    def get_reporter(self):
        returned = self._reporter
        if returned is None:
            returned = ConsoleReporter(
                level=config.root.log.console_level,
                stream=self._report_stream)

        return returned

    def __enter__(self):
        self._exit_stack = ExitStack()
        self._exit_stack.__enter__()
        try:
            self._exit_stack.enter_context(self._prelude_logging_context())
            self._exit_stack.enter_context(self._prelude_warning_context())
            self._exit_stack.enter_context(self._sigterm_context())
            with dessert.rewrite_assertions_context():
                site.load(working_directory=self._working_directory)

            cli_utils.configure_arg_parser_by_plugins(self.arg_parser)
            cli_utils.configure_arg_parser_by_config(self.arg_parser)
            argv = cli_utils.add_pending_plugins_from_commandline(self._get_argv())

            self._parsed_args, self._positional_args = self.arg_parser.parse_known_args(argv)

            self._exit_stack.enter_context(
                cli_utils.get_modified_configuration_from_args_context(self.arg_parser, self._parsed_args)
                )

            self.session = Session(reporter=self.get_reporter(), console_stream=self._report_stream)

            trigger_hook.configure() # pylint: disable=no-member
            plugins.manager.configure_for_parallel_mode()
            plugins.manager.activate_pending_plugins()
            cli_utils.configure_plugins_from_args(self._parsed_args)


            self._exit_stack.enter_context(self.session)
            self._emit_prelude_logs()
            self._emit_prelude_warnings()
            return self

        except:
            self._emit_prelude_logs()
            self.__exit__(*sys.exc_info())
            raise

    def __exit__(self, exc_type, exc_value, exc_tb):
        exc_info = (exc_type, exc_value, exc_tb)
        try:
            debug_if_needed(exc_info)
        except Exception as e:  # pylint: disable=broad-except
            _logger.error("Failed to debug_if_needed: {!r}", e, exc_info=True, extra={'capture': False})
        if exc_value is not None:
            self._exit_code = exc_value.code if isinstance(exc_value, SystemExit) else 1

            if should_inhibit_unhandled_exception_traceback(exc_value):
                self.get_reporter().report_error_message(str(exc_value))

            elif isinstance(exc_value, Exception):
                _logger.error('Unexpected error occurred', exc_info=exc_info, extra={'capture': False})
                self.get_reporter().report_error_message('Unexpected error: {}'.format(exc_value))

            if isinstance(exc_value, exceptions.INTERRUPTION_EXCEPTIONS):
                self._interrupted = True

        if exc_type is not None:
            trigger_hook.result_summary() # pylint: disable=no-member
        self._exit_stack.__exit__(exc_type, exc_value, exc_tb)
        self._exit_stack = None
        self._reset_parser()
        trigger_hook.app_quit()  # pylint: disable=no-member
        return True

    def _capture_native_warning(self, message, category, filename, lineno, file=None, line=None): # pylint: disable=unused-argument
        self._prelude_warning_records.append(RecordedWarning.from_native_warning(message, category, filename, lineno))

    def _prelude_logging_context(self):
        self._prelude_log_handler = log.RetainedLogHandler(bubble=True, level=logbook.TRACE)
        return self._prelude_log_handler.applicationbound()

    def _prelude_warning_context(self):
        capture_all_warnings()
        return warning_callback_context(self._capture_native_warning)

    def _emit_prelude_warnings(self):
        if self.session is not None:
            for warning in self._prelude_warning_records:
                if not self.session.warnings.warning_should_be_filtered(warning):
                    self.session.warnings.add(warning)

    def _emit_prelude_logs(self):
        self._prelude_log_handler.disable()
        handler = None
        if self.session is not None:
            handler = self.session.logging.session_log_handler
        if handler is None:
            handler = self._console_handler
        self._prelude_log_handler.flush_to_handler(handler)

    @contextmanager
    def _sigterm_context(self):
        def handle_sigterm(*_):
            with handling_exceptions():
                raise TerminatedException('Terminated by signal')

        prev = signal.signal(signal.SIGTERM, handle_sigterm)
        try:
            yield
        finally:
            try:
                signal.signal(signal.SIGTERM, prev)
            except TypeError as e:
                #workaround for a strange issue on app cleanup. See https://bugs.python.org/issue23548
                if 'signal handler must be signal.SIG_IGN' not in str(e):
                    raise
Exemple #12
0
class AbstractWebcamFilterApp(ABC):
    def __init__(self, args: argparse.Namespace):
        self.args = args
        self.bodypix_model = None
        self.output_sink = None
        self.image_source = None
        self.image_iterator = None
        self.timer = LoggingTimer()
        self.masks: List[np.ndarray] = []
        self.exit_stack = ExitStack()
        self.bodypix_result_cache_time = None
        self.bodypix_result_cache = None

    @abstractmethod
    def get_output_image(self, image_array: np.ndarray) -> np.ndarray:
        pass

    def get_mask(self, *args, **kwargs):
        return get_mask(
            *args, masks=self.masks, timer=self.timer, args=self.args, **kwargs
        )

    def get_bodypix_result(self, image_array: np.ndarray) -> BodyPixResultWrapper:
        assert self.bodypix_model is not None
        current_time = time()
        if (
            self.bodypix_result_cache is not None
            and current_time < self.bodypix_result_cache_time + self.args.mask_cache_time
        ):
            return self.bodypix_result_cache
        self.bodypix_result_cache = self.bodypix_model.predict_single(image_array)
        self.bodypix_result_cache_time = current_time
        return self.bodypix_result_cache

    def __enter__(self):
        self.exit_stack.__enter__()
        self.bodypix_model = load_bodypix_model(self.args)
        self.output_sink = self.exit_stack.enter_context(get_output_sink(self.args))
        self.image_source = self.exit_stack.enter_context(get_image_source_for_args(self.args))
        self.image_iterator = iter(self.image_source)
        return self

    def __exit__(self, *args, **kwargs):
        self.exit_stack.__exit__(*args, **kwargs)

    def next_frame(self):
        self.timer.on_frame_start(initial_step_name='in')
        try:
            image_array = next(self.image_iterator)
        except StopIteration:
            return False
        self.timer.on_step_start('model')
        output_image = self.get_output_image(image_array)
        self.timer.on_step_start('out')
        self.output_sink(output_image)
        self.timer.on_frame_end()
        return True

    def run(self):
        try:
            self.timer.start()
            while self.next_frame():
                pass
            if self.args.show_output:
                LOGGER.info('waiting for window to be closed')
                while not self.output_sink.is_closed:
                    sleep(0.5)
        except KeyboardInterrupt:
            LOGGER.info('exiting')
Exemple #13
0
class Monitor:
    """
    A process monitor.
    """
    def __init__(self, config: Config):
        self._stack = ExitStack()

        self._dir = Path(config.output_dir) / "monitor_logs"
        self._dir.mkdir(parents=True, exist_ok=True)

        self._verbose = config.verbose

    def __enter__(self):
        self._stack.__enter__()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self._stack.__exit__(exc_type, exc_value, traceback)

    def get_args(self, proc):
        return proc.args

    def check_running(self, proc) -> bool:
        return proc.poll() is None

    def wait(self, proc):
        proc.wait()

    def wait_all(self, procs):
        for proc in procs:
            self.wait(proc)

    def get_return_code(self, proc):
        return proc.returncode

    def check_success(self, proc):
        proc.wait()
        if proc.returncode != 0:
            raise RuntimeError(
                f"Process failed with status {proc.returncode}, args: {proc.args}"
            )

    def check_success_all(self, procs):
        for proc in procs:
            self.check_success(proc)

    def _spawn(self,
               command,
               name,
               bg=False,
               check=True,
               stdout_override=None,
               stderr_override=None) -> subprocess.Popen:
        if self._verbose:
            logging.info(f"Monitor - spawning {command}. " +
                         ("[bg]" if bg else "") + ("[chk]" if check else ""))

        if bg or not self._verbose:
            stdin = subprocess.DEVNULL
            # Don't create temp files if we are overriding
            command_str = ' '.join(map(str, command)) + "\n"
            if stdout_override == None:
                f = NamedTemporaryFile(prefix=f"{name}.",
                                       suffix=".stdout",
                                       dir=self._dir,
                                       delete=False)
                f.write(command_str.encode())
                stdout = f
            if stderr_override == None:
                f = NamedTemporaryFile(prefix=f"{name}.",
                                       suffix=".stderr",
                                       dir=self._dir,
                                       delete=False)
                f.write(command_str.encode())
                stderr = f
        else:
            stdin = None
            stdout = None
            stderr = None

        if stdout_override != None:
            stdout = stdout_override
        if stderr_override != None:
            stderr = stderr_override

        proc = subprocess.Popen(command,
                                stdin=stdin,
                                stdout=stdout,
                                stderr=stderr)

        if bg:
            self._stack.enter_context(proc)
        else:
            if check:
                self.check_success(proc)
            else:
                proc.wait()

        return proc

    def spawn(self,
              command,
              bg=False,
              check=True,
              stdout_override=None,
              stderr_override=None) -> subprocess.Popen:
        """
        Run a command.
        """

        name = PurePath(command[0]).name
        return self._spawn(command,
                           name=name,
                           bg=bg,
                           check=check,
                           stdout_override=stdout_override,
                           stderr_override=stderr_override)

    def ssh_spawn(self,
                  address,
                  command,
                  bg=False,
                  check=True,
                  stdout_override=None,
                  stderr_override=None) -> subprocess.Popen:
        """
        Run a command over ssh.
        """

        cmd = ["ssh", "-o", "StrictHostKeyChecking=accept-new", "-tt"]
        if not self._verbose:
            cmd.append("-q")
        cmd += [address, "--"]
        cmd += [shlex.quote(arg) for arg in command]

        name = PurePath(command[0]).name

        return self._spawn(cmd,
                           name=f"{address}.{name}",
                           bg=bg,
                           check=check,
                           stdout_override=stdout_override,
                           stderr_override=stderr_override)

    def ssh_spawn_all(self,
                      addresses,
                      command,
                      bg=False,
                      check=True) -> List[subprocess.Popen]:
        """
        Run ssh commands for all addresses, returning an array of procs
        """

        each_bg = bg
        if len(addresses) > 1:
            each_bg = True

        procs = []
        for addr in addresses:
            procs.append(self.ssh_spawn(addr, command, bg=each_bg,
                                        check=check))

        if not bg:
            if check:
                self.check_success_all(procs)
            else:
                self.wait_all(procs)

        return procs
Exemple #14
0
class FormData(aiohttpFormData):
    """FormData used to upload files to remote"""
    def __init__(
        self,
        paths: Optional[List[str]] = None,
        logger: 'JinaLogger' = None,
        complete: bool = False,
    ) -> None:
        super().__init__()
        self._logger = logger
        self._complete = complete
        self._cur_dir = os.getcwd()
        self.paths = paths
        self._stack = ExitStack()

    def add(self, path: Path):
        """add a field to Form

        :param path: filepath
        """
        self.add_field(
            name='files',
            value=self._stack.enter_context(
                open(
                    complete_path(path, extra_search_paths=[self._cur_dir])
                    if self._complete else path,
                    'rb',
                )),
            filename=path.name,
        )

    @property
    def fields(self) -> List[Any]:
        """all fields in current Form

        :return: list of fields
        """
        return self._fields

    @property
    def filenames(self) -> List[str]:
        """all filenames in current Form

        :return: list of filenames
        """
        return [os.path.basename(f[-1].name) for f in self.fields]

    def __len__(self):
        return len(self._fields)

    def __enter__(self):
        self._stack.__enter__()
        if not self.paths:
            return self
        tmpdir = self._stack.enter_context(TemporaryDirectory())
        self._stack.enter_context(change_cwd(tmpdir))
        for path in map(Path, self.paths):
            try:
                filename = path.name
                if path.is_file():
                    self.add(path)
                elif path.is_dir():
                    make_archive(base_name=filename,
                                 format='zip',
                                 root_dir=path)
                    self.add(Path(tmpdir) / f'{filename}.zip')
            except TypeError:
                self._logger.error(f'invalid path passed {path}')
                continue
        self._logger.info((
            f'{len(self)} file(s) ready to be uploaded: {", ".join(self.filenames)}'
            if len(self) > 0 else 'No file to be uploaded'))
        return self

    def __exit__(self, *args, **kwargs):
        self._stack.__exit__(*args, **kwargs)
Exemple #15
0
class LocalDarRepository:
    def __init__(self):
        warnings.warn(
            "LocalDarRepository is deprecated; use PackageLookup instead",
            DeprecationWarning,
            stacklevel=2,
        )
        self._context = ExitStack()
        self._dar_paths = []  # type: List[Path]
        self._files = set()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", DeprecationWarning)

            from ..model.types_store import PackageStore

            self.store = PackageStore.empty()

    def _add_source(self, path: Path) -> None:
        ext = path.suffix.lower()

        if ext == ".daml":
            LOG.error(
                "Reading metadata directly from a DAML file is not supported.")
            raise ValueError("Unsupported extension: .daml")

        elif ext == ".dalf":
            LOG.error(
                "Reading metadata directly from a DALF file is not supported.")
            raise ValueError("Unsupported extension: .dalf")

        elif ext == ".dar":
            dar_parse_start_time = time.time()
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", DeprecationWarning)

                from ..util.dar import DarFile

                dar = self._context.enter_context(DarFile(path))
            dar_package = dar.read_metadata()
            self._dar_paths.append(path)
            self.store.register_all(dar_package)
            dar_parse_end_time = time.time()
            LOG.debug("Parsed a dar in %s seconds.",
                      dar_parse_end_time - dar_parse_start_time)

        else:
            LOG.error("Unknown extension: %s", ext)
            raise ValueError(f"Unknown extension: {ext}")

    def add_source(self, *files: Union[str, PathLike, Path]) -> None:
        """
        Add a source file (either a .daml file, .dalf file, or a .dar file).

        Attempts to add the same file more than once will be ignored.

        :param files: Files to add to the archive.
        """
        for file in files:
            path = Path(fspath(file)).resolve(strict=True)
            if path not in self._files:
                self._files.add(path)
                self._add_source(path)

    def get_daml_archives(self) -> Sequence[Path]:
        return self._dar_paths

    def __enter__(self):
        """
        Does nothing.
        """
        self._context.__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """
        Delete all managed resources in the reverse order that they were created.
        """
        self._context.__exit__(exc_type, exc_val, exc_tb)
Exemple #16
0
class LocalDarRepository:
    def __init__(self):
        self._context = ExitStack()
        self._dar_paths = []  # type: List[Path]
        self._files = set()
        self.store = PackageStore.empty()

    def _add_source(self, path: Path) -> None:
        ext = path.suffix.lower()

        if ext == '.daml':
            # dar_create_start_time = time.time()
            dar_paths = self._context.enter_context(TemporaryDar(path))
            self.add_source(*dar_paths)
            # LOG.info('Compiled a dar in % seconds.', time.time() - dar_create_start_time)

        elif ext == '.dalf':
            dalf_package = parse_dalf(path.read_bytes())
            self._dar_paths.append(path)
            self.store.register_all(dalf_package)

        elif ext == '.dar':
            # dar_parse_start_time = time.time()
            dar = self._context.enter_context(DarFile(path))
            dar_package = dar.read_metadata()
            self._dar_paths.append(path)
            self.store.register_all(dar_package)
            # dar_parse_end_time = time.time()
            # LOG.info('Parsed a dar in %s seconds.', dar_parse_end_time - dar_parse_start_time)

        else:
            LOG.error('Unknown extension: %s', ext)
            raise ValueError(f'Unknown extension: {ext}')

    def add_source(self, *files: Union[str, Path]) -> None:
        """
        Add a source file (either a .daml file, .dalf file, or a .dar file).

        Attempts to add the same file more than once will be ignored.

        :param files: Files to add to the archive.
        """
        for file in files:
            path = pathify(file).resolve(strict=True)
            if path not in self._files:
                self._files.add(path)
                self._add_source(path)

        # stores = list()
        # files = list(files)
        # while files:
        #     file = files.pop(0)
        #     ext = pathify(file).suffix.lower()
        #
        #     if ext == '.daml':
        #         # dar_create_start_time = time.time()
        #         dar_paths = self._context.enter_context(TemporaryDar(file))
        #         #LOG.info('Compiled a dar in % seconds.', time.time() - dar_create_start_time)
        #         files.extend(dar_paths)
        #     elif ext == '.dalf':
        #         with open(file, 'rb') as f:
        #             contents = f.read()
        #         stores.append(parse_dalf('pkg0', contents))
        #     elif ext == '.dar':
        #         #dar_parse_start_time = time.time()
        #         dar = self._context.enter_context(DarFile(file))
        #         stores.append(dar.read_metadata())
        #         #dar_parse_end_time = time.time()
        #         #LOG.info('Parsed a dar in %s seconds.', dar_parse_end_time - dar_parse_start_time)
        #     else:
        #         LOG.error('Unknown extension: %s', ext)
        #         raise ValueError(f'Unknown extension: {ext}')
        #
        # reduce(lambda store, other_store: store.register_all(other_store), stores, self.store)

    def get_daml_archives(self) -> Sequence[Path]:
        return self._dar_paths

    def __enter__(self):
        """
        Does nothing.
        """
        self._context.__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """
        Delete all managed resources in the reverse order that they were created.
        """
        self._context.__exit__(exc_type, exc_val, exc_tb)
Exemple #17
0
class SocketIOCalculator(Calculator):
    implemented_properties = ['energy', 'free_energy', 'forces', 'stress']
    supported_changes = {'positions', 'cell'}

    def __init__(self,
                 calc=None,
                 port=None,
                 unixsocket=None,
                 timeout=None,
                 log=None,
                 *,
                 launch_client=None):
        """Initialize socket I/O calculator.

        This calculator launches a server which passes atomic
        coordinates and unit cells to an external code via a socket,
        and receives energy, forces, and stress in return.

        ASE integrates this with the Quantum Espresso, FHI-aims and
        Siesta calculators.  This works with any external code that
        supports running as a client over the i-PI protocol.

        Parameters:

        calc: calculator or None

            If calc is not None, a client process will be launched
            using calc.command, and the input file will be generated
            using ``calc.write_input()``.  Otherwise only the server will
            run, and it is up to the user to launch a compliant client
            process.

        port: integer

            port number for socket.  Should normally be between 1025
            and 65535.  Typical ports for are 31415 (default) or 3141.

        unixsocket: str or None

            if not None, ignore host and port, creating instead a
            unix socket using this name prefixed with ``/tmp/ipi_``.
            The socket is deleted when the calculator is closed.

        timeout: float >= 0 or None

            timeout for connection, by default infinite.  See
            documentation of Python sockets.  For longer jobs it is
            recommended to set a timeout in case of undetected
            client-side failure.

        log: file object or None (default)

            logfile for communication over socket.  For debugging or
            the curious.

        In order to correctly close the sockets, it is
        recommended to use this class within a with-block:

        >>> with SocketIOCalculator(...) as calc:
        ...    atoms.calc = calc
        ...    atoms.get_forces()
        ...    atoms.rattle()
        ...    atoms.get_forces()

        It is also possible to call calc.close() after
        use.  This is best done in a finally-block."""

        self._exitstack = ExitStack()
        Calculator.__init__(self)

        if calc is not None:
            if launch_client is not None:
                raise ValueError('Cannot pass both calc and launch_client')
            launch_client = FileIOSocketClientLauncher(calc)
        self.launch_client = launch_client
        #self.calc = calc
        self.timeout = timeout
        self.server = None

        if isinstance(log, str):
            self.log = self._exitstack.enter_context(open(log, 'w'))
        else:
            self.log = log

        # We only hold these so we can pass them on to the server.
        # They may both be None as stored here.
        self._port = port
        self._unixsocket = unixsocket

        # If there is a calculator, we will launch in calculate() because
        # we are responsible for executing the external process, too, and
        # should do so before blocking.  Without a calculator we want to
        # block immediately:
        if self.launch_client is None:
            self.server = self.launch_server()

    def todict(self):
        d = {'type': 'calculator', 'name': 'socket-driver'}
        #if self.calc is not None:
        #    d['calc'] = self.calc.todict()
        return d

    def launch_server(self):
        return self._exitstack.enter_context(
            SocketServer(
                #launch_client=launch_client,
                port=self._port,
                unixsocket=self._unixsocket,
                timeout=self.timeout,
                log=self.log,
            ))

    def calculate(self,
                  atoms=None,
                  properties=['energy'],
                  system_changes=all_changes):
        bad = [
            change for change in system_changes
            if change not in self.supported_changes
        ]

        # First time calculate() is called, system_changes will be
        # all_changes.  After that, only positions and cell may change.
        if self.atoms is not None and any(bad):
            raise PropertyNotImplementedError(
                'Cannot change {} through IPI protocol.  '
                'Please create new socket calculator.'.format(
                    bad if len(bad) > 1 else bad[0]))

        self.atoms = atoms.copy()

        if self.server is None:
            self.server = self.launch_server()
            proc = self.launch_client(atoms,
                                      properties,
                                      port=self._port,
                                      unixsocket=self._unixsocket)
            self.server.proc = proc  # XXX nasty hack

        results = self.server.calculate(atoms)
        results['free_energy'] = results['energy']
        virial = results.pop('virial')
        if self.atoms.cell.rank == 3 and any(self.atoms.pbc):
            vol = atoms.get_volume()
            results['stress'] = -full_3x3_to_voigt_6_stress(virial) / vol
        self.results.update(results)

    def close(self):
        try:
            self.server = None
        finally:
            self._exitstack.close()

    def __enter__(self):
        self._exitstack.__enter__()
        return self

    def __exit__(self, type, value, traceback):
        self.close()
Exemple #18
0
class HorovodStrategy(ParallelStrategy):
    """Plugin for Horovod distributed training integration."""

    distributed_backend = _StrategyType.HOROVOD

    def __init__(
        self,
        accelerator: Optional[
            "pl.accelerators.accelerator.Accelerator"] = None,
        parallel_devices: Optional[List[torch.device]] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None,
    ):
        super().__init__(
            accelerator=accelerator,
            parallel_devices=parallel_devices,
            cluster_environment=None,
            checkpoint_io=checkpoint_io,
            precision_plugin=precision_plugin,
        )
        rank_zero_only.rank = self.global_rank
        self._exit_stack: Optional[ExitStack] = None

    @property
    def global_rank(self) -> int:
        return hvd.rank()

    @property
    def local_rank(self) -> int:
        return hvd.local_rank()

    @property
    def world_size(self) -> int:
        return hvd.size()

    @property
    def root_device(self):
        return self.parallel_devices[self.local_rank]

    @property
    def distributed_sampler_kwargs(self):
        distributed_sampler_kwargs = dict(num_replicas=self.world_size,
                                          rank=self.global_rank)
        return distributed_sampler_kwargs

    def setup(self, trainer: "pl.Trainer") -> None:
        self.model_to_device()

        super().setup(trainer)

        self._exit_stack = ExitStack()
        self._exit_stack.__enter__()

        if not self.lightning_module.trainer.training:
            # no need to setup optimizers
            return

        def _unpack_lightning_optimizer(opt):
            return opt._optimizer if isinstance(opt,
                                                LightningOptimizer) else opt

        optimizers = self.optimizers
        optimizers = [_unpack_lightning_optimizer(opt) for opt in optimizers]

        # Horovod: scale the learning rate by the number of workers to account for
        # increased total batch size
        for optimizer in optimizers:
            for param_group in optimizer.param_groups:
                param_group["lr"] *= self.world_size

        # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR
        lr_scheduler_configs = self.lr_schedulers
        for config in lr_scheduler_configs:
            scheduler = config.scheduler
            scheduler.base_lrs = [
                lr * self.world_size for lr in scheduler.base_lrs
            ]

        # Horovod: broadcast parameters & optimizer state to ensure consistent initialization
        hvd.broadcast_parameters(self.lightning_module.state_dict(),
                                 root_rank=0)
        for optimizer in optimizers:
            hvd.broadcast_optimizer_state(optimizer, root_rank=0)

        self.optimizers = self._wrap_optimizers(optimizers)
        for optimizer in self.optimizers:
            # Synchronization will be performed explicitly following backward()
            self._exit_stack.enter_context(optimizer.skip_synchronize())

    def barrier(self, *args, **kwargs):
        if distributed_available():
            self.join()

    def broadcast(self, obj: object, src: int = 0) -> object:
        obj = hvd.broadcast_object(obj, src)
        return obj

    def model_to_device(self):
        if self.on_gpu:
            # this can potentially be removed after #8312. Not done due to lack of horovod testing
            torch.cuda.set_device(self.root_device)
        self.model.to(self.root_device)

    def join(self):
        if self.on_gpu:
            hvd.join(self.local_rank)
        else:
            hvd.join()

    def reduce(self,
               tensor,
               group: Optional[Any] = None,
               reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
        """Reduces a tensor from several distributed processes to one aggregated tensor.

        Args:
            tensor: the tensor to sync and reduce
            group: the process group to gather results from. Defaults to all processes (world)
            reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
                Can also be a string 'sum' to calculate the sum during reduction.

        Return:
            reduced value, except when the input was not a tensor the output remains is unchanged
        """
        if group is not None:
            raise ValueError(
                "Horovod does not support allreduce using a subcommunicator at this time. Unset `group`."
            )

        if reduce_op in (None, "avg", "mean"):
            reduce_op = hvd.Average
        elif reduce_op in ("sum", ReduceOp.SUM):
            reduce_op = hvd.Sum
        else:
            raise ValueError(f"unrecognized `reduce_op`: {reduce_op}")

        # sync all processes before reduction
        self.join()
        return hvd.allreduce(tensor, op=reduce_op)

    def all_gather(self,
                   result: torch.Tensor,
                   group: Optional[Any] = dist_group.WORLD,
                   sync_grads: bool = False) -> torch.Tensor:
        if group is not None and group != dist_group.WORLD:
            raise ValueError(
                "Horovod does not support allgather using a subcommunicator at this time. Unset `group`."
            )

        if len(result.shape) == 0:
            # Convert scalars to single dimension tensors
            result = result.reshape(1)

        # sync and gather all
        self.join()
        return hvd.allgather(result)

    def post_backward(self, closure_loss: torch.Tensor) -> None:
        # synchronize all horovod optimizers.
        for optimizer in self.lightning_module.trainer.optimizers:
            optimizer.synchronize()

    def _wrap_optimizers(
            self,
            optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]:
        """Wraps optimizers to perform gradient aggregation via allreduce."""
        return [
            hvd.DistributedOptimizer(
                opt,
                named_parameters=self._filter_named_parameters(
                    self.lightning_module, opt))
            if "horovod" not in str(opt.__class__) else opt
            for opt in optimizers
        ]

    @staticmethod
    def _filter_named_parameters(
            model: nn.Module,
            optimizer: Optimizer) -> List[Tuple[str, nn.Parameter]]:
        opt_params = {
            p
            for group in optimizer.param_groups
            for p in group.get("params", [])
        }
        return [(name, p) for name, p in model.named_parameters()
                if p in opt_params]

    def teardown(self) -> None:
        super().teardown()
        self._exit_stack.__exit__(None, None, None)
        self._exit_stack = None
        # Make sure all workers have finished training before returning to the user
        self.join()
        if self.on_gpu:
            # GPU teardown
            self.lightning_module.cpu()
            # clean up memory
            torch.cuda.empty_cache()
Exemple #19
0
class _TestUser(object):
    def __init__(self, test_client, runestone_db_tools, username, password, course_name,
        # True if the course is free (no payment required); False otherwise.
        is_free=True):

        self.test_client = test_client
        self.runestone_db_tools = runestone_db_tools
        self.username = username
        self.first_name = 'test'
        self.last_name = 'user'
        self.email = self.username + '@foo.com'
        self.password = password
        self.course_name = course_name
        self.is_free = is_free

    def __enter__(self):
        # Registration doesn't work unless we're logged out.
        self.test_client.logout()
        # Now, post the registration.
        self.test_client.validate('default/user/register',
            'Support Runestone Interactive' if self.is_free else 'Payment Amount',
            data=dict(
                username=self.username,
                first_name=self.first_name,
                last_name=self.last_name,
                # The e-mail address must be unique.
                email=self.email,
                password=self.password,
                password_two=self.password,
                # Note that ``course_id`` is (on the form) actually a course name.
                course_id=self.course_name,
                accept_tcp='on',
                donate='0',
                _next='/runestone/default/index',
                _formname='register',
            )
        )

        # Schedule this user for deletion.
        self.exit_stack_object = ExitStack()
        self.exit_stack = self.exit_stack_object.__enter__()
        self.exit_stack.callback(self._delete_user)

        # Record IDs
        db = self.runestone_db_tools.db
        self.course_id = db(db.courses.course_name == self.course_name).select(db.courses.id).first().id
        self.user_id = db(db.auth_user.username == self.username).select(db.auth_user.id).first().id

        return self

    # Clean up on exit by invoking all ``__exit__`` methods.
    def __exit__(self, exc_type, exc_value, traceback):
        self.exit_stack_object.__exit__(exc_type, exc_value, traceback)

    # Delete the user created by entering this context manager. TODO: This doesn't delete all the chapter progress tracking stuff.
    def _delete_user(self):
        db = self.runestone_db_tools.db
        # Delete the course this user registered for.
        db(( db.user_courses.course_id == self.course_id) & (db.user_courses.user_id == self.user_id) ).delete()
        # Delete the user.
        db(db.auth_user.username == self.username).delete()
        db.commit()

    def login(self):
        self.test_client.post('default/user/login', data=dict(
            username=self.username,
            password=self.password,
            _formname='login',
        ))

    def make_instructor(self, course_id=None):
        # If ``course_id`` isn't specified, use this user's ``course_id``.
        course_id = course_id or self.course_id
        return self.runestone_db_tools.make_instructor(self.user_id, course_id)

    # A context manager to update this user's profile. If a course was added, it returns that course's ID; otherwise, it returns None.
    @contextmanager
    def update_profile(self,
        # This parameter is passed to ``test_client.validate``.
        expected_string=None,
        # An updated username, or ``None`` to use ``self.username``.
        username=None,
        # An updated first name, or ``None`` to use ``self.first_name``.
        first_name=None,
        # An updated last name, or ``None`` to use ``self.last_name``.
        last_name=None,
        # An updated email, or ``None`` to use ``self.email``.
        email=None,
        # An updated last name, or ``None`` to use ``self.course_name``.
        course_name=None,
        section='',
        # A shortcut for specifying the ``expected_string``, which only applies if ``expected_string`` is not set. Use ``None`` if a course will not be added, ``True`` if the added course is free, or ``False`` if the added course is paid.
        is_free=None,
        # The value of the ``accept_tcp`` checkbox; provide an empty string to leave unchecked. The default value leaves it checked.
        accept_tcp='on'):

        if expected_string is None:
            if is_free is None:
                expected_string = 'Course Selection'
            else:
                expected_string = 'Support Runestone Interactive' \
                    if is_free else 'Payment Amount'
        username = username or self.username
        first_name = first_name or self.first_name
        last_name = last_name or self.last_name
        email = email or self.email
        course_name = course_name or self.course_name

        db = self.runestone_db_tools.db
        # Determine if we're adding a course. If so, delete it at the end of the test. To determine if a course is being added, the course must exist, but not be in the user's list of courses.
        course = db(db.courses.course_name == course_name).select(db.courses.id).first()
        delete_at_end = course and not db((db.user_courses.user_id == self.user_id) & (db.user_courses.course_id == course.id)).select(db.user_courses.id).first()

        # Perform the update.
        try:
            self.test_client.validate('default/user/profile',
                expected_string,
                data=dict(
                    username=username,
                    first_name=first_name,
                    last_name=last_name,
                    email=email,
                    # Though the field is ``course_id``, it's really the course name.
                    course_id=course_name,
                    accept_tcp=accept_tcp,
                    section=section,
                    _next='/runestone/default/index',
                    id=str(self.user_id),
                    _formname='auth_user/' + str(self.user_id),
                )
            )

            yield course.id if delete_at_end else None
        finally:
            if delete_at_end:
                db = self.runestone_db_tools.db
                db((db.user_courses.user_id == self.user_id) & (db.user_courses.course_id == course.id)).delete()
                db.commit()

    # Call this after registering for a new course or adding a new course via ``update_profile`` to pay for the course.
    @contextmanager
    def make_payment(self,
        # The `Stripe test tokens <https://stripe.com/docs/testing#cards>`_ to use for payment.
        stripe_token,
        # The course ID of the course to pay for. None specifies ``self.course_id``.
        course_id=None):

        course_id = course_id or self.course_id

        # Get the signature from the HTML of the payment page.
        self.test_client.validate('default/payment')
        match = re.search('<input type="hidden" name="signature" value="([^ ]*)" />',
                          self.test_client.text)
        signature = match.group(1)

        try:
            html = self.test_client.validate('default/payment',
                data=dict(stripeToken=stripe_token, signature=signature)
            )
            assert ('Thank you for your payment' in html) or ('Payment failed' in html)

            yield None

        finally:
            db = self.runestone_db_tools.db
            db((db.user_courses.course_id == course_id) &
               (db.user_courses.user_id == self.user_id) ).delete()
            # Try to delete the payment
            try:
                db((db.user_courses.user_id == self.user_id) &
                   (db.user_courses.course_id == course_id) &
                   (db.user_courses.id == db.payments.user_courses_id)) \
                   .delete()
            except:
                pass
            db.commit()

    @contextmanager
    def hsblog(self, **kwargs):
        try:
            yield json.loads(self.test_client.validate('ajax/hsblog',
                                                       data=kwargs))
        finally:
            # Try to remove this hsblog entry.
            event = kwargs.get('event')
            div_id = kwargs.get('div_id')
            course = kwargs.get('course')
            db = self.runestone_db_tools.db
            criteria = ((db.useinfo.sid == self.username) &
                        (db.useinfo.act == kwargs.get('act', '')) &
                        (db.useinfo.div_id == div_id) &
                        (db.useinfo.event == event) &
                        (db.useinfo.course_id == course) )
            useinfo_row = db(criteria).select(db.useinfo.id, orderby=db.useinfo.id).last()
            if useinfo_row:
                del db.useinfo[useinfo_row.id]