Esempio n. 1
0
    def init(self):
        trigger.call("on_init", **self.kwargs)
        s = self.settings
        config = self.config

        if s._noop:
            run = RunDummy()
            module.set_global(
                run=run,
                config=run.config,
                log=run.log,
                summary=run.summary,
                save=run.save,
                restore=run.restore,
                use_artifact=run.use_artifact,
                log_artifact=run.log_artifact,
                plot_table=run.plot_table,
            )
            return run

        if s.reinit or (s._jupyter and s.reinit is not False):
            if len(self._wl._global_run_stack) > 0:
                if len(self._wl._global_run_stack) > 1:
                    wandb.termwarn(
                        "If you want to track multiple runs concurrently in wandb you should use multi-processing not threads"  # noqa: E501
                    )
                self._wl._global_run_stack[-1].finish()
        elif wandb.run:
            logger.info("wandb.init() called when a run is still active")
            return wandb.run

        use_redirect = True
        stdout_master_fd, stderr_master_fd = None, None
        stdout_slave_fd, stderr_slave_fd = None, None

        backend = Backend()
        backend.ensure_launched(
            settings=s,
            stdout_fd=stdout_master_fd,
            stderr_fd=stderr_master_fd,
            use_redirect=use_redirect,
        )
        backend.server_connect()
        # Make sure we are logged in
        # wandb_login._login(_backend=backend, _settings=self.settings)

        # resuming needs access to the server, check server_status()?

        run = Run(config=config, settings=s)
        run._set_console(
            use_redirect=use_redirect,
            stdout_slave_fd=stdout_slave_fd,
            stderr_slave_fd=stderr_slave_fd,
        )
        run._set_library(self._wl)
        run._set_backend(backend)
        run._set_reporter(self._reporter)
        run._set_teardown_hooks(self._teardown_hooks)
        # TODO: pass mode to backend
        # run_synced = None

        backend._hack_set_run(run)
        backend.interface.publish_header()

        if s._offline:
            run_proto = backend.interface._make_run(run)
            backend.interface._publish_run(run_proto)
            run._set_run_obj_offline(run_proto)
        else:
            ret = backend.interface.communicate_check_version()
            if ret:
                if ret.upgrade_message:
                    run._set_upgrade_version_message(ret.upgrade_message)
                # if yanked or deleted, warn at header and footer
                if ret.delete_message:
                    run._set_check_version_message(
                        click.style(ret.delete_message, fg="red")
                    )
                elif ret.yank_message:
                    run._set_check_version_message(
                        click.style(ret.yank_message, fg="red")
                    )
            run._on_init()
            ret = backend.interface.communicate_run(run, timeout=30)
            error_message = None
            if not ret:
                error_message = "Error communicating with backend"
            if ret and ret.error:
                error_message = ret.error.message
            if error_message:
                # Shutdown the backend and get rid of the logger
                # we don't need to do console cleanup at this point
                backend.cleanup()
                self.teardown()
                raise UsageError(error_message)
            run._set_run_obj(ret.run)

        # initiate run (stats and metadata probing)
        _ = backend.interface.communicate_run_start()

        self._wl._global_run_stack.append(run)
        self.run = run
        self.backend = backend
        module.set_global(
            run=run,
            config=run.config,
            log=run.log,
            summary=run.summary,
            save=run.save,
            restore=run.restore,
            use_artifact=run.use_artifact,
            log_artifact=run.log_artifact,
            plot_table=run.plot_table,
        )
        self._reporter.set_context(run=run)
        run._on_start()

        run._freeze()
        return run
Esempio n. 2
0
    def init(self):
        s = self.settings
        config = self.config

        if s.reinit:
            if len(self._wl._global_run_stack) > 0:
                if len(self._wl._global_run_stack) > 1:
                    wandb.termwarn(
                        "If you want to track multiple runs concurrently in wandb you should use multi-processing not threads"  # noqa: E501
                    )
                self._wl._global_run_stack[-1].join()

        if s.mode == "noop":
            # TODO(jhr): return dummy object
            return None

        console = s.console
        use_redirect = True
        stdout_master_fd, stderr_master_fd = None, None
        stdout_slave_fd, stderr_slave_fd = None, None
        if console == "iowrap":
            stdout_master_fd, stdout_slave_fd = io_wrap.wandb_pty(resize=False)
            stderr_master_fd, stderr_slave_fd = io_wrap.wandb_pty(resize=False)
        elif console == "_win32":
            # Not used right now
            stdout_master_fd, stdout_slave_fd = lib_console.win32_create_pipe()
            stderr_master_fd, stderr_slave_fd = lib_console.win32_create_pipe()

        backend = Backend(mode=s.mode)
        backend.ensure_launched(
            settings=s,
            stdout_fd=stdout_master_fd,
            stderr_fd=stderr_master_fd,
            use_redirect=use_redirect,
        )
        backend.server_connect()
        # Make sure we are logged in
        wandb.login(backend=backend)

        # resuming needs access to the server, check server_status()?

        run = RunManaged(config=config, settings=s)
        run._set_console(
            use_redirect=use_redirect,
            stdout_slave_fd=stdout_slave_fd,
            stderr_slave_fd=stderr_slave_fd,
        )
        run._set_library(self._wl)
        run._set_backend(backend)
        run._set_reporter(self._reporter)
        # TODO: pass mode to backend
        # run_synced = None

        backend._hack_set_run(run)

        if s.mode == "online":
            ret = backend.interface.send_run_sync(run, timeout=30)
            # TODO: fail on more errors, check return type
            # TODO: make the backend log stacktraces on catostrophic failure
            if ret.HasField("error"):
                # Shutdown the backend and get rid of the logger
                # we don't need to do console cleanup at this point
                backend.cleanup()
                self._wl.on_finish()
                raise UsageError(ret.error.message)
            run._set_run_obj(ret.run)
        elif s.mode in ("offline", "dryrun"):
            backend.interface.send_run(run)
        elif s.mode in ("async", "run"):
            ret = backend.interface.send_run_sync(run, timeout=10)
            # TODO: on network error, do async run save
            backend.interface.send_run(run)

        self._wl._global_run_stack.append(run)
        self.run = run
        self.backend = backend
        module.set_global(
            run=run,
            config=run.config,
            log=run.log,
            join=run.join,
            summary=run.summary,
            save=run.save,
            restore=run.restore,
            use_artifact=run.use_artifact,
            log_artifact=run.log_artifact,
        )
        self._reporter.set_context(run=run)
        run._on_start()

        return run