Esempio n. 1
0
def _attach(
    attach_id: Optional[str] = None,
    run_id: Optional[str] = None,
) -> Union[Run, RunDisabled, None]:
    """Attach to a run currently executing in another process/thread.

    Arguments:
        attach_id: (str, optional) The id of the run or an attach identifier
            that maps to a run.
        run_id: (str, optional) The id of the run to attach to.
    """
    attach_id = attach_id or run_id
    if attach_id is None:
        raise UsageError("attach_id or run_id must be specified")
    wandb._assert_is_user_process()

    _wl = wandb_setup._setup()

    _set_logger(_wl._get_logger())
    if logger is None:
        raise UsageError("logger is not initialized")

    manager = _wl._get_manager()
    if manager:
        manager._inform_attach(attach_id=attach_id)

    settings: Settings = copy.copy(_wl._settings)
    settings.update(run_id=attach_id, source=Source.INIT)

    # TODO: consolidate this codepath with wandb.init()
    backend = Backend(settings=settings, manager=manager)
    backend.ensure_launched()
    backend.server_connect()
    logger.info("attach backend started and connected")

    run = Run(settings=settings)
    run._set_library(_wl)
    run._set_backend(backend)
    backend._hack_set_run(run)
    assert backend.interface

    resp = backend.interface.communicate_attach(attach_id)
    if not resp:
        raise UsageError("problem")
    if resp and resp.error and resp.error.message:
        raise UsageError("bad: {}".format(resp.error.message))
    run._set_run_obj(resp.run)
    return run
Esempio n. 2
0
 def no_retry_4xx(e):
     if not isinstance(e, requests.HTTPError):
         return True
     if not (e.response.status_code >= 400
             and e.response.status_code < 500):
         return True
     body = json.loads(e.response.content)
     raise UsageError(body["errors"][0]["message"])
Esempio n. 3
0
    def prompt_api_key(self):
        key, status = self._prompt_api_key()
        if status == ApiKeyStatus.NOTTY:
            directive = ("wandb login [your_api_key]"
                         if self._settings._cli_only_mode else
                         "wandb.login(key=[your_api_key])")
            raise UsageError("api_key not configured (no-tty). call " +
                             directive)

        self.update_session(key, status=status)
        self._key = key
Esempio n. 4
0
 def prompt_api_key(self):
     api = Api(self._settings)
     key = apikey.prompt_api_key(
         self._settings,
         api=api,
         no_offline=self._settings.force,
         no_create=self._settings.force,
     )
     if key is False:
         raise UsageError("api_key not configured (no-tty).  Run wandb login")
     self.update_session(key)
     self._key = key
Esempio n. 5
0
def downsample(values: Sequence, target_length: int) -> list:
    """Downsamples 1d values to target_length, including start and end.

    Algorithm just rounds index down.

    Values can be any sequence, including a generator.
    """
    if not target_length > 1:
        raise UsageError("target_length must be > 1")
    values = list(values)
    if len(values) < target_length:
        return values
    ratio = float(len(values) - 1) / (target_length - 1)
    result = []
    for i in range(target_length):
        result.append(values[int(i * ratio)])
    return result
    def prompt_api_key(self):
        api = Api(self._settings)
        key = apikey.prompt_api_key(
            self._settings,
            api=api,
            no_offline=self._settings.force,
            no_create=self._settings.force,
        )
        if key is False:
            directive = ("wandb login [your_api_key]"
                         if self._settings._cli_only_mode else
                         "wandb.login(key=[your_api_key])")
            raise UsageError("api_key not configured (no-tty). call " +
                             directive)

        self.update_session(key)
        self._key = key
def parse_config(params, exclude=None, include=None):
    if exclude and include:
        raise UsageError("Expected at most only one of exclude or include")
    if isinstance(params, six.string_types):
        params = config_util.dict_from_config_file(params, must_exist=True)
    params = _to_dict(params)
    if include:
        params = {
            key: value
            for key, value in six.iteritems(params) if key in include
        }
    if exclude:
        params = {
            key: value
            for key, value in six.iteritems(params) if key not in exclude
        }
    return params
Esempio n. 8
0
    def init(self):  # noqa: C901
        assert logger
        logger.info("calling init triggers")
        trigger.call("on_init", **self.kwargs)
        s = self.settings
        sweep_config = self.sweep_config
        config = self.config
        logger.info(
            "wandb.init called with sweep_config: {}\nconfig: {}".format(
                sweep_config, config
            )
        )
        if s._noop:
            return self._make_run_disabled()
        if s.reinit or (s._jupyter and s.reinit is not False):
            if len(self._wl._global_run_stack) > 0:
                if len(self._wl._global_run_stack) > 1:
                    wandb.termwarn(
                        "If you want to track multiple runs concurrently in wandb you should use multi-processing not threads"  # noqa: E501
                    )

                last_id = self._wl._global_run_stack[-1]._run_id
                logger.info(
                    "re-initializing run, found existing run on stack: {}".format(
                        last_id
                    )
                )
                jupyter = (
                    s._jupyter
                    and not s._silent
                    and ipython._get_python_type() == "jupyter"
                )
                if jupyter:
                    ipython.display_html(
                        "Finishing last run (ID:{}) before initializing another...".format(
                            last_id
                        )
                    )

                self._wl._global_run_stack[-1].finish()

                if jupyter:
                    ipython.display_html(
                        "...Successfully finished last run (ID:{}). Initializing new run:<br/><br/>".format(
                            last_id
                        )
                    )
        elif isinstance(wandb.run, Run):
            logger.info("wandb.init() called when a run is still active")
            return wandb.run

        logger.info("starting backend")

        backend = Backend(settings=s)
        backend.ensure_launched()
        backend.server_connect()
        logger.info("backend started and connected")
        # Make sure we are logged in
        # wandb_login._login(_backend=backend, _settings=self.settings)

        # resuming needs access to the server, check server_status()?

        run = Run(config=config, settings=s, sweep_config=sweep_config)

        # probe the active start method
        active_start_method = None
        if s.start_method == "thread":
            active_start_method = s.start_method
        else:
            get_start_fn = getattr(backend._multiprocessing, "get_start_method", None)
            active_start_method = get_start_fn() if get_start_fn else None

        # Populate intial telemetry
        with telemetry.context(run=run) as tel:
            tel.cli_version = wandb.__version__
            tel.python_version = platform.python_version()
            hf_version = _huggingface_version()
            if hf_version:
                tel.huggingface_version = hf_version
            if s._jupyter:
                tel.env.jupyter = True
            if s._kaggle:
                tel.env.kaggle = True
            if s._windows:
                tel.env.windows = True
            run._telemetry_imports(tel.imports_init)

            if active_start_method == "spawn":
                tel.env.start_spawn = True
            elif active_start_method == "fork":
                tel.env.start_fork = True
            elif active_start_method == "forkserver":
                tel.env.start_forkserver = True
            elif active_start_method == "thread":
                tel.env.start_thread = True

        logger.info("updated telemetry")

        run._set_library(self._wl)
        run._set_backend(backend)
        run._set_reporter(self._reporter)
        run._set_teardown_hooks(self._teardown_hooks)
        # TODO: pass mode to backend
        # run_synced = None

        backend._hack_set_run(run)
        backend.interface.publish_header()

        if s._offline:
            with telemetry.context(run=run) as tel:
                tel.feature.offline = True
            run_proto = backend.interface._make_run(run)
            backend.interface._publish_run(run_proto)
            run._set_run_obj_offline(run_proto)
        else:
            logger.info("communicating current version")
            ret = backend.interface.communicate_check_version(
                current_version=wandb.__version__
            )
            if ret:
                logger.info("got version response {}".format(ret))
                if ret.upgrade_message:
                    run._set_upgraded_version_message(ret.upgrade_message)
                if ret.delete_message:
                    run._set_deleted_version_message(ret.delete_message)
                if ret.yank_message:
                    run._set_yanked_version_message(ret.yank_message)
            run._on_init()
            logger.info("communicating run to backend with 30 second timeout")
            ret = backend.interface.communicate_run(run, timeout=30)
            error_message = None
            if not ret:
                logger.error("backend process timed out")
                error_message = "Error communicating with wandb process"
                if active_start_method != "fork":
                    error_message += "\ntry: wandb.init(settings=wandb.Settings(start_method='fork'))"
                    error_message += "\nor:  wandb.init(settings=wandb.Settings(start_method='thread'))"
                    error_message += "\nFor more info see: https://docs.wandb.ai/library/init#init-start-error"
            if ret and ret.error:
                error_message = ret.error.message
            if error_message:
                logger.error("encountered error: {}".format(error_message))
                # Shutdown the backend and get rid of the logger
                # we don't need to do console cleanup at this point
                backend.cleanup()
                self.teardown()
                raise UsageError(error_message)
            if ret.run.resumed:
                logger.info("run resumed")
                with telemetry.context(run=run) as tel:
                    tel.feature.resumed = True
            run._set_run_obj(ret.run)

        logger.info("starting run threads in backend")
        # initiate run (stats and metadata probing)
        run_obj = run._run_obj or run._run_obj_offline
        _ = backend.interface.communicate_run_start(run_obj)

        self._wl._global_run_stack.append(run)
        self.run = run
        self.backend = backend
        module.set_global(
            run=run,
            config=run.config,
            log=run.log,
            summary=run.summary,
            save=run.save,
            use_artifact=run.use_artifact,
            log_artifact=run.log_artifact,
            define_metric=run._define_metric,
            plot_table=run.plot_table,
            alert=run.alert,
        )
        self._reporter.set_context(run=run)
        run._on_start()

        run._freeze()
        logger.info("run started, returning control to user process")
        return run
Esempio n. 9
0
def torch_trace_handler():
    """Creates a trace handler for traces generated by the profiler.

     Provide as an argument to `torch.profiler.profile`:
     ```python
     torch.profiler.profile(..., on_trace_ready = wandb.profiler.torch_trace_handler())
     ```

    Calling this function ensures that profiler charts & tables can be viewed in your run dashboard
    on wandb.ai.

    Please note that `wandb.init()` must be called before this function is invoked.
    The PyTorch (torch) version must also be at least 1.9, in order to ensure stability
    of their Profiler API.

    Args:
        None

    Returns:
        None

    Raises:
        UsageError if wandb.init() hasn't been called before profiling.
        Error if torch version is less than 1.9.0.

    Examples:
    ```python
    run = wandb.init()
    run.config.id = "profile_code"

    with torch.profiler.profile(
        schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
        on_trace_ready=wandb.profiler.torch_trace_handler(),
        record_shapes=True,
        with_stack=True,
    ) as prof:
        for i, batch in enumerate(dataloader):
            if step >= 5:
                break
            train(batch)
            prof.step()
    ```
    """
    torch = wandb.util.get_module(PYTORCH_MODULE, required=True)
    torch_profiler = wandb.util.get_module(PYTORCH_PROFILER_MODULE,
                                           required=True)
    version = tuple(
        map(lambda x: int(x),
            torch.__version__.replace("+cpu", "").split(".")))

    if version < (1, 9, 0):
        raise Error(
            f"torch version must be at least 1.9 in order to use the PyTorch Profiler API.\
            \nVersion of torch currently installed: {torch.__version__}")

    try:
        logdir = os.path.join(wandb.run.dir, "pytorch_traces")  # type: ignore
        os.mkdir(logdir)
    except AttributeError:
        raise UsageError(
            "Please call `wandb.init()` before `wandb.profiler.torch_trace_handler()`"
        ) from None

    with telemetry.context() as tel:
        tel.feature.torch_profiler_trace = True

    return torch_profiler.tensorboard_trace_handler(logdir)
Esempio n. 10
0
    def init(self) -> Union[Run, RunDisabled, None]:  # noqa: C901
        if logger is None:
            raise RuntimeError("Logger not initialized")
        logger.info("calling init triggers")
        trigger.call("on_init", **self.kwargs)

        logger.info(
            f"wandb.init called with sweep_config: {self.sweep_config}\nconfig: {self.config}"
        )
        if self.settings._noop:
            return self._make_run_disabled()
        if self.settings.reinit or (self.settings._jupyter
                                    and self.settings.reinit is not False):
            if len(self._wl._global_run_stack) > 0:
                if len(self._wl._global_run_stack) > 1:
                    wandb.termwarn(
                        "If you want to track multiple runs concurrently in wandb, "
                        "you should use multi-processing not threads"  # noqa: E501
                    )

                last_id = self._wl._global_run_stack[-1]._run_id
                logger.info(
                    f"re-initializing run, found existing run on stack: {last_id}"
                )
                jupyter = (self.settings._jupyter and not self.settings.silent
                           and ipython.in_jupyter())
                if jupyter:
                    ipython.display_html(
                        f"Finishing last run (ID:{last_id}) before initializing another..."
                    )

                self._wl._global_run_stack[-1].finish()

                if jupyter:
                    ipython.display_html(
                        f"Successfully finished last run (ID:{last_id}). Initializing new run:<br/>"
                    )
        elif isinstance(wandb.run, Run):
            allow_return_run = True
            manager = self._wl._get_manager()
            if manager:
                current_pid = os.getpid()
                if current_pid != wandb.run._init_pid:
                    # We shouldn't return a stale global run if we are in a new pid
                    allow_return_run = False

            if allow_return_run:
                logger.info("wandb.init() called when a run is still active")
                with telemetry.context() as tel:
                    tel.feature.init_return_run = True
                return wandb.run

        logger.info("starting backend")

        manager = self._wl._get_manager()
        if manager:
            manager._inform_init(settings=self.settings,
                                 run_id=self.settings.run_id)

        backend = Backend(settings=self.settings, manager=manager)
        backend.ensure_launched()
        backend.server_connect()
        logger.info("backend started and connected")
        # Make sure we are logged in
        # wandb_login._login(_backend=backend, _settings=self.settings)

        # resuming needs access to the server, check server_status()?

        run = Run(config=self.config,
                  settings=self.settings,
                  sweep_config=self.sweep_config)

        # probe the active start method
        active_start_method: Optional[str] = None
        if self.settings.start_method == "thread":
            active_start_method = self.settings.start_method
        else:
            get_start_fn = getattr(backend._multiprocessing,
                                   "get_start_method", None)
            active_start_method = get_start_fn() if get_start_fn else None

        # Populate initial telemetry
        with telemetry.context(run=run) as tel:
            tel.cli_version = wandb.__version__
            tel.python_version = platform.python_version()
            hf_version = _huggingface_version()
            if hf_version:
                tel.huggingface_version = hf_version
            if self.settings._jupyter:
                tel.env.jupyter = True
            if self.settings._kaggle:
                tel.env.kaggle = True
            if self.settings._windows:
                tel.env.windows = True
            run._telemetry_imports(tel.imports_init)
            if self._use_sagemaker:
                tel.feature.sagemaker = True
            if self._set_init_config:
                tel.feature.set_init_config = True
            if self._set_init_name:
                tel.feature.set_init_name = True
            if self._set_init_id:
                tel.feature.set_init_id = True
            if self._set_init_tags:
                tel.feature.set_init_tags = True

            if self.settings.launch:
                tel.feature.launch = True

            if active_start_method == "spawn":
                tel.env.start_spawn = True
            elif active_start_method == "fork":
                tel.env.start_fork = True
            elif active_start_method == "forkserver":
                tel.env.start_forkserver = True
            elif active_start_method == "thread":
                tel.env.start_thread = True

            if manager:
                tel.feature.service = True

            tel.env.maybe_mp = _maybe_mp_process(backend)

            # fixme: detected issues with settings
            if self.settings.__dict__["_Settings__preprocessing_warnings"]:
                tel.issues.settings__preprocessing_warnings = True
            if self.settings.__dict__["_Settings__validation_warnings"]:
                tel.issues.settings__validation_warnings = True
            if self.settings.__dict__["_Settings__unexpected_args"]:
                tel.issues.settings__unexpected_args = True

        if not self.settings.label_disable:
            if self.notebook:
                run._label_probe_notebook(self.notebook)
            else:
                run._label_probe_main()

        logger.info("updated telemetry")

        run._set_library(self._wl)
        run._set_backend(backend)
        run._set_reporter(self._reporter)
        run._set_teardown_hooks(self._teardown_hooks)
        # TODO: pass mode to backend
        # run_synced = None

        backend._hack_set_run(run)
        assert backend.interface
        backend.interface.publish_header()

        # Using GitRepo() blocks & can be slow, depending on user's current git setup.
        # We don't want to block run initialization/start request, so populate run's git
        # info beforehand.
        if not self.settings.disable_git:
            run._populate_git_info()

        if self.settings._offline:
            with telemetry.context(run=run) as tel:
                tel.feature.offline = True
            run_proto = backend.interface._make_run(run)
            backend.interface._publish_run(run_proto)
            run._set_run_obj_offline(run_proto)
            if self.settings.resume:
                wandb.termwarn(
                    "`resume` will be ignored since W&B syncing is set to `offline`. "
                    f"Starting a new run with run id {run.id}.")
        else:
            logger.info("communicating run to backend with 30 second timeout")
            run_result = backend.interface.communicate_run(run, timeout=30)

            error_message: Optional[str] = None
            if not run_result:
                logger.error("backend process timed out")
                error_message = "Error communicating with wandb process"
                if active_start_method != "fork":
                    error_message += "\ntry: wandb.init(settings=wandb.Settings(start_method='fork'))"
                    error_message += "\nor:  wandb.init(settings=wandb.Settings(start_method='thread'))"
                    error_message += "\nFor more info see: https://docs.wandb.ai/library/init#init-start-error"
            elif run_result.error:
                error_message = run_result.error.message
            if error_message:
                logger.error(f"encountered error: {error_message}")

                # Shutdown the backend and get rid of the logger
                # we don't need to do console cleanup at this point
                backend.cleanup()
                self.teardown()
                raise UsageError(error_message)
            assert run_result and run_result.run
            if run_result.run.resumed:
                logger.info("run resumed")
                with telemetry.context(run=run) as tel:
                    tel.feature.resumed = True
            run._set_run_obj(run_result.run)
            run._on_init()

        logger.info("starting run threads in backend")
        # initiate run (stats and metadata probing)
        run_obj = run._run_obj or run._run_obj_offline

        self.settings._apply_run_start(message_to_dict(run_obj))
        run._update_settings(self.settings)
        if manager:
            manager._inform_start(settings=self.settings,
                                  run_id=self.settings.run_id)

        assert backend.interface
        assert run_obj
        _ = backend.interface.communicate_run_start(run_obj)

        self._wl._global_run_stack.append(run)
        self.run = run
        self.backend = backend
        module.set_global(
            run=run,
            config=run.config,
            log=run.log,
            summary=run.summary,
            save=run.save,
            use_artifact=run.use_artifact,
            log_artifact=run.log_artifact,
            define_metric=run.define_metric,
            plot_table=run.plot_table,
            alert=run.alert,
            mark_preempting=run.mark_preempting,
        )
        self._reporter.set_context(run=run)
        run._on_start()

        run._freeze()
        logger.info("run started, returning control to user process")
        return run