示例#1
0
def _user_process_finished(server, hooks, wandb_process, stdout_redirector,
                           stderr_redirector):
    global _user_process_finished_called
    if _user_process_finished_called:
        return
    _user_process_finished_called = True
    trigger.call('on_finished')
    shutdown_async_log_thread()

    stdout_redirector.restore()
    if not env.is_debug():
        stderr_redirector.restore()

    termlog()
    termlog("Waiting for W&B process to finish, PID {}".format(
        wandb_process.pid))
    server.done(hooks.exit_code)
    try:
        while wandb_process.poll() is None:
            time.sleep(0.1)
    except KeyboardInterrupt:
        pass

    if wandb_process.poll() is None:
        termlog('Killing W&B process, PID {}'.format(wandb_process.pid))
        wandb_process.kill()
示例#2
0
def _user_process_finished(server, hooks, wandb_process, stdout_redirector,
                           stderr_redirector):
    global _user_process_finished_called
    if _user_process_finished_called:
        return
    _user_process_finished_called = True
    trigger.call('on_finished')
    if run:
        run.close_files()

    stdout_redirector.restore()
    if not env.is_debug():
        stderr_redirector.restore()

    termlog()
    termlog("Waiting for W&B process to finish, PID {}".format(
        wandb_process.pid))
    server.done(hooks.exit_code)
    try:
        while wandb_process.poll() is None:
            time.sleep(0.1)
    except KeyboardInterrupt:
        termlog(
            'Sending ctrl-c to W&B process, PID {}. Press ctrl-c again to kill it.'
            .format(wandb_process.pid))

    try:
        while wandb_process.poll() is None:
            time.sleep(0.1)
    except KeyboardInterrupt:
        if wandb_process.poll() is None:
            termlog('Killing W&B process, PID {}'.format(wandb_process.pid))
            wandb_process.kill()
示例#3
0
def _fit_wrapper(self, fn, generator=None, *args, **kwargs):
    trigger.call("on_fit")
    keras = sys.modules.get("keras", None)
    tfkeras = sys.modules.get("tensorflow.python.keras", None)
    epochs = kwargs.pop("epochs", None)
    batch_size = kwargs.pop("batch_size", None)

    magic_epochs = _magic_get_config("keras.fit.epochs", None)
    if magic_epochs is not None:
        epochs = magic_epochs
    magic_batch_size = _magic_get_config("keras.fit.batch_size", None)
    if magic_batch_size is not None:
        batch_size = magic_batch_size
    callbacks = kwargs.pop("callbacks", [])

    if tb_enabled := _magic_get_config(
            "keras.fit.callbacks.tensorboard.enable", None):
        if k := getattr(self, "_keras_or_tfkeras", None):
            tb_duplicate = _magic_get_config(
                "keras.fit.callbacks.tensorboard.duplicate", None)
            tb_overwrite = _magic_get_config(
                "keras.fit.callbacks.tensorboard.overwrite", None)
            tb_present = any(
                isinstance(cb, k.callbacks.TensorBoard) for cb in callbacks)
            if tb_present and tb_overwrite:
                callbacks = [
                    cb for cb in callbacks
                    if not isinstance(cb, k.callbacks.TensorBoard)
                ]
            if tb_overwrite or tb_duplicate or not tb_present:
                tb_callback_kwargs = {"log_dir": wandb.run.dir}
                cb_args = (
                    "write_graph",
                    "histogram_freq",
                    "update_freq",
                    "write_grads",
                    "write_images",
                    "batch_size",
                )
                for cb_arg in cb_args:
                    v = _magic_get_config(
                        f'keras.fit.callbacks.tensorboard.{cb_arg}', None)
                    if v is not None:
                        tb_callback_kwargs[cb_arg] = v
                tb_callback = k.callbacks.TensorBoard(**tb_callback_kwargs)
                callbacks.append(tb_callback)
示例#4
0
    def init(self):  # noqa: C901
        trigger.call("on_init", **self.kwargs)
        s = self.settings
        config = self.config
        if s._noop:
            run = Dummy()
            run.config = wandb.wandb_sdk.wandb_config.Config()
            run.config.update(config)
            run.summary = DummyDict()
            run.log = lambda data, *_, **__: run.summary.update(data)
            run.finish = lambda *_, **__: module.unset_globals()
            run.step = 0
            run.resumed = False
            run.disabled = True
            run.id = shortuuid.uuid()
            run.name = "dummy-" + run.id
            run.dir = "/"
            module.set_global(
                run=run,
                config=run.config,
                log=run.log,
                summary=run.summary,
                save=run.save,
                use_artifact=run.use_artifact,
                log_artifact=run.log_artifact,
                plot_table=run.plot_table,
                alert=run.alert,
            )
            return run
        if s.reinit or (s._jupyter and s.reinit is not False):
            if len(self._wl._global_run_stack) > 0:
                if len(self._wl._global_run_stack) > 1:
                    wandb.termwarn(
                        "If you want to track multiple runs concurrently in wandb you should use multi-processing not threads"  # noqa: E501
                    )

                last_id = self._wl._global_run_stack[-1]._run_id
                if s._jupyter and not s._silent:
                    ipython.display_html(
                        "Finishing last run (ID:{}) before initializing another...".format(
                            last_id
                        )
                    )

                self._wl._global_run_stack[-1].finish()

                if s._jupyter and not s._silent:
                    ipython.display_html(
                        "...Successfully finished last run (ID:{}). Initializing new run:<br/><br/>".format(
                            last_id
                        )
                    )
        elif isinstance(wandb.run, Run):
            logger.info("wandb.init() called when a run is still active")
            return wandb.run

        use_redirect = True
        stdout_master_fd, stderr_master_fd = None, None
        stdout_slave_fd, stderr_slave_fd = None, None

        backend = Backend()
        backend.ensure_launched(
            settings=s,
            stdout_fd=stdout_master_fd,
            stderr_fd=stderr_master_fd,
            use_redirect=use_redirect,
        )
        backend.server_connect()
        # Make sure we are logged in
        # wandb_login._login(_backend=backend, _settings=self.settings)

        # resuming needs access to the server, check server_status()?

        run = Run(config=config, settings=s)
        run._set_console(
            use_redirect=use_redirect,
            stdout_slave_fd=stdout_slave_fd,
            stderr_slave_fd=stderr_slave_fd,
        )
        run._set_library(self._wl)
        run._set_backend(backend)
        run._set_reporter(self._reporter)
        run._set_teardown_hooks(self._teardown_hooks)
        # TODO: pass mode to backend
        # run_synced = None

        backend._hack_set_run(run)
        backend.interface.publish_header()

        if s._offline:
            run_proto = backend.interface._make_run(run)
            backend.interface._publish_run(run_proto)
            run._set_run_obj_offline(run_proto)
        else:
            ret = backend.interface.communicate_check_version(
                current_version=wandb.__version__
            )
            if ret:
                if ret.upgrade_message:
                    run._set_upgraded_version_message(ret.upgrade_message)
                if ret.delete_message:
                    run._set_deleted_version_message(ret.delete_message)
                if ret.yank_message:
                    run._set_yanked_version_message(ret.yank_message)
            run._on_init()
            ret = backend.interface.communicate_run(run, timeout=30)
            error_message = None
            if not ret:
                error_message = "Error communicating with backend"
            if ret and ret.error:
                error_message = ret.error.message
            if error_message:
                # Shutdown the backend and get rid of the logger
                # we don't need to do console cleanup at this point
                backend.cleanup()
                self.teardown()
                raise UsageError(error_message)
            run._set_run_obj(ret.run)

        # initiate run (stats and metadata probing)
        _ = backend.interface.communicate_run_start()

        self._wl._global_run_stack.append(run)
        self.run = run
        self.backend = backend
        module.set_global(
            run=run,
            config=run.config,
            log=run.log,
            summary=run.summary,
            save=run.save,
            use_artifact=run.use_artifact,
            log_artifact=run.log_artifact,
            plot_table=run.plot_table,
            alert=run.alert,
        )
        self._reporter.set_context(run=run)
        run._on_start()

        run._freeze()
        return run
示例#5
0
    def init(self):  # noqa: C901
        assert logger
        logger.info("calling init triggers")
        trigger.call("on_init", **self.kwargs)
        s = self.settings
        sweep_config = self.sweep_config
        config = self.config
        logger.info(
            "wandb.init called with sweep_config: {}\nconfig: {}".format(
                sweep_config, config
            )
        )
        if s._noop:
            return self._make_run_disabled()
        if s.reinit or (s._jupyter and s.reinit is not False):
            if len(self._wl._global_run_stack) > 0:
                if len(self._wl._global_run_stack) > 1:
                    wandb.termwarn(
                        "If you want to track multiple runs concurrently in wandb you should use multi-processing not threads"  # noqa: E501
                    )

                last_id = self._wl._global_run_stack[-1]._run_id
                logger.info(
                    "re-initializing run, found existing run on stack: {}".format(
                        last_id
                    )
                )
                jupyter = (
                    s._jupyter
                    and not s._silent
                    and ipython._get_python_type() == "jupyter"
                )
                if jupyter:
                    ipython.display_html(
                        "Finishing last run (ID:{}) before initializing another...".format(
                            last_id
                        )
                    )

                self._wl._global_run_stack[-1].finish()

                if jupyter:
                    ipython.display_html(
                        "...Successfully finished last run (ID:{}). Initializing new run:<br/><br/>".format(
                            last_id
                        )
                    )
        elif isinstance(wandb.run, Run):
            logger.info("wandb.init() called when a run is still active")
            return wandb.run

        logger.info("starting backend")

        backend = Backend(settings=s)
        backend.ensure_launched()
        backend.server_connect()
        logger.info("backend started and connected")
        # Make sure we are logged in
        # wandb_login._login(_backend=backend, _settings=self.settings)

        # resuming needs access to the server, check server_status()?

        run = Run(config=config, settings=s, sweep_config=sweep_config)

        # probe the active start method
        active_start_method = None
        if s.start_method == "thread":
            active_start_method = s.start_method
        else:
            get_start_fn = getattr(backend._multiprocessing, "get_start_method", None)
            active_start_method = get_start_fn() if get_start_fn else None

        # Populate intial telemetry
        with telemetry.context(run=run) as tel:
            tel.cli_version = wandb.__version__
            tel.python_version = platform.python_version()
            hf_version = _huggingface_version()
            if hf_version:
                tel.huggingface_version = hf_version
            if s._jupyter:
                tel.env.jupyter = True
            if s._kaggle:
                tel.env.kaggle = True
            if s._windows:
                tel.env.windows = True
            run._telemetry_imports(tel.imports_init)

            if active_start_method == "spawn":
                tel.env.start_spawn = True
            elif active_start_method == "fork":
                tel.env.start_fork = True
            elif active_start_method == "forkserver":
                tel.env.start_forkserver = True
            elif active_start_method == "thread":
                tel.env.start_thread = True

        logger.info("updated telemetry")

        run._set_library(self._wl)
        run._set_backend(backend)
        run._set_reporter(self._reporter)
        run._set_teardown_hooks(self._teardown_hooks)
        # TODO: pass mode to backend
        # run_synced = None

        backend._hack_set_run(run)
        backend.interface.publish_header()

        if s._offline:
            with telemetry.context(run=run) as tel:
                tel.feature.offline = True
            run_proto = backend.interface._make_run(run)
            backend.interface._publish_run(run_proto)
            run._set_run_obj_offline(run_proto)
        else:
            logger.info("communicating current version")
            ret = backend.interface.communicate_check_version(
                current_version=wandb.__version__
            )
            if ret:
                logger.info("got version response {}".format(ret))
                if ret.upgrade_message:
                    run._set_upgraded_version_message(ret.upgrade_message)
                if ret.delete_message:
                    run._set_deleted_version_message(ret.delete_message)
                if ret.yank_message:
                    run._set_yanked_version_message(ret.yank_message)
            run._on_init()
            logger.info("communicating run to backend with 30 second timeout")
            ret = backend.interface.communicate_run(run, timeout=30)
            error_message = None
            if not ret:
                logger.error("backend process timed out")
                error_message = "Error communicating with wandb process"
                if active_start_method != "fork":
                    error_message += "\ntry: wandb.init(settings=wandb.Settings(start_method='fork'))"
                    error_message += "\nor:  wandb.init(settings=wandb.Settings(start_method='thread'))"
                    error_message += "\nFor more info see: https://docs.wandb.ai/library/init#init-start-error"
            if ret and ret.error:
                error_message = ret.error.message
            if error_message:
                logger.error("encountered error: {}".format(error_message))
                # Shutdown the backend and get rid of the logger
                # we don't need to do console cleanup at this point
                backend.cleanup()
                self.teardown()
                raise UsageError(error_message)
            if ret.run.resumed:
                logger.info("run resumed")
                with telemetry.context(run=run) as tel:
                    tel.feature.resumed = True
            run._set_run_obj(ret.run)

        logger.info("starting run threads in backend")
        # initiate run (stats and metadata probing)
        run_obj = run._run_obj or run._run_obj_offline
        _ = backend.interface.communicate_run_start(run_obj)

        self._wl._global_run_stack.append(run)
        self.run = run
        self.backend = backend
        module.set_global(
            run=run,
            config=run.config,
            log=run.log,
            summary=run.summary,
            save=run.save,
            use_artifact=run.use_artifact,
            log_artifact=run.log_artifact,
            define_metric=run._define_metric,
            plot_table=run.plot_table,
            alert=run.alert,
        )
        self._reporter.set_context(run=run)
        run._on_start()

        run._freeze()
        logger.info("run started, returning control to user process")
        return run
示例#6
0
def init(job_type=None, dir=None, config=None, project=None, entity=None, reinit=None, tags=None,
         group=None, allow_val_change=False, resume=False, force=False, tensorboard=False,
         sync_tensorboard=False, name=None, notes=None, id=None, magic=None):
    """Initialize W&B

    If called from within Jupyter, initializes a new run and waits for a call to
    `wandb.log` to begin pushing metrics.  Otherwise, spawns a new process
    to communicate with W&B.

    Args:
        job_type (str, optional): The type of job running, defaults to 'train'
        config (dict, argparse, or tf.FLAGS, optional): The hyper parameters to store with the run
        project (str, optional): The project to push metrics to
        entity (str, optional): The entity to push metrics to
        dir (str, optional): An absolute path to a directory where metadata will be stored
        group (str, optional): A unique string shared by all runs in a given group
        tags (list, optional): A list of tags to apply to the run
        id (str, optional): A globally unique (per project) identifier for the run
        name (str, optional): A display name which does not have to be unique
        notes (str, optional): A multiline string associated with the run
        reinit (bool, optional): Allow multiple calls to init in the same process
        resume (bool, str, optional): Automatically resume this run if run from the same machine,
            you can also pass a unique run_id
        sync_tensorboard (bool, optional): Synchronize wandb logs to tensorboard or tensorboardX
        force (bool, optional): Force authentication with wandb, defaults to False
        magic (bool, dict, or str, optional): magic configuration as bool, dict, json string,
            yaml filename

    Returns:
        A wandb.run object for metric and config logging.
    """
    trigger.call('on_init', **locals())
    global run
    global __stage_dir__

    # We allow re-initialization when we're in Jupyter or explicity opt-in to it.
    in_jupyter = _get_python_type() != "python"
    if reinit or (in_jupyter and reinit != False):
        reset_env(exclude=env.immutable_keys())
        run = None

    # TODO: deprecate tensorboard
    if tensorboard or sync_tensorboard and len(patched["tensorboard"]) == 0:
        util.get_module("wandb.tensorboard").patch()

    sagemaker_config = util.parse_sm_config()
    tf_config = util.parse_tfjob_config()
    if group == None:
        group = os.getenv(env.RUN_GROUP)
    if job_type == None:
        job_type = os.getenv(env.JOB_TYPE)
    if sagemaker_config:
        # Set run_id and potentially grouping if we're in SageMaker
        run_id = os.getenv('TRAINING_JOB_NAME')
        if run_id:
            os.environ[env.RUN_ID] = '-'.join([
                run_id,
                os.getenv('CURRENT_HOST', socket.gethostname())])
        conf = json.load(
            open("/opt/ml/input/config/resourceconfig.json"))
        if group == None and len(conf["hosts"]) > 1:
            group = os.getenv('TRAINING_JOB_NAME')
        # Set secret variables
        if os.path.exists("secrets.env"):
            for line in open("secrets.env", "r"):
                key, val = line.strip().split('=', 1)
                os.environ[key] = val
    elif tf_config:
        cluster = tf_config.get('cluster')
        job_name = tf_config.get('task', {}).get('type')
        task_index = tf_config.get('task', {}).get('index')
        if job_name is not None and task_index is not None:
            # TODO: set run_id for resuming?
            run_id = cluster[job_name][task_index].rsplit(":")[0]
            if job_type == None:
                job_type = job_name
            if group == None and len(cluster.get("worker", [])) > 0:
                group = cluster[job_name][0].rsplit("-"+job_name, 1)[0]
    image = util.image_id_from_k8s()
    if image:
        os.environ[env.DOCKER] = image
    if project:
        os.environ[env.PROJECT] = project
    if entity:
        os.environ[env.ENTITY] = entity
    if group:
        os.environ[env.RUN_GROUP] = group
    if job_type:
        os.environ[env.JOB_TYPE] = job_type
    if tags:
        os.environ[env.TAGS] = ",".join(tags)
    if id:
        os.environ[env.RUN_ID] = id
        if name is None:
            # We do this because of https://github.com/wandb/core/issues/2170
            # to ensure that the run's name is explicitly set to match its
            # id. If we don't do this and the id is eight characters long, the
            # backend will set the name to a generated human-friendly value.
            #
            # In any case, if the user is explicitly setting `id` but not
            # `name`, their id is probably a meaningful string that we can
            # use to label the run.
            name = os.environ.get(env.NAME, id)  # environment variable takes precedence over this.
    if name:
        os.environ[env.NAME] = name
    if notes:
        os.environ[env.NOTES] = notes
    if magic is not None and magic is not False:
        if isinstance(magic, dict):
            os.environ[env.MAGIC] = json.dumps(magic)
        elif isinstance(magic, str):
            os.environ[env.MAGIC] = magic
        elif isinstance(magic, bool):
            pass
        else:
            termwarn("wandb.init called with invalid magic parameter type", repeat=False)
        from wandb import magic_impl
        magic_impl.magic_install()
    if dir:
        os.environ[env.DIR] = dir
        util.mkdir_exists_ok(wandb_dir())
    resume_path = os.path.join(wandb_dir(), wandb_run.RESUME_FNAME)
    if resume == True:
        os.environ[env.RESUME] = "auto"
    elif resume:
        os.environ[env.RESUME] = os.environ.get(env.RESUME, "allow")
        # TODO: remove allowing resume as a string in the future
        os.environ[env.RUN_ID] = id or resume
    elif os.path.exists(resume_path):
        os.remove(resume_path)
    if os.environ.get(env.RESUME) == 'auto' and os.path.exists(resume_path):
        if not os.environ.get(env.RUN_ID):
            os.environ[env.RUN_ID] = json.load(open(resume_path))["run_id"]

    # the following line is useful to ensure that no W&B logging happens in the user
    # process that might interfere with what they do
    # logging.basicConfig(format='user process %(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # If a thread calls wandb.init() it will get the same Run object as
    # the parent. If a child process with distinct memory space calls
    # wandb.init(), it won't get an error, but it will get a result of
    # None.
    # This check ensures that a child process can safely call wandb.init()
    # after a parent has (only the parent will create the Run object).
    # This doesn't protect against the case where the parent doesn't call
    # wandb.init but two children do.
    if run or os.getenv(env.INITED):
        return run

    if __stage_dir__ is None:
        __stage_dir__ = "wandb"
        util.mkdir_exists_ok(wandb_dir())

    try:
        signal.signal(signal.SIGQUIT, _debugger)
    except AttributeError:
        pass

    try:
        run = wandb_run.Run.from_environment_or_defaults()
    except IOError as e:
        termerror('Failed to create run directory: {}'.format(e))
        raise LaunchError("Could not write to filesystem.")

    run.set_environment()

    def set_global_config(run):
        global config  # because we already have a local config
        config = run.config
    set_global_config(run)
    global summary
    summary = run.summary

    # set this immediately after setting the run and the config. if there is an
    # exception after this it'll probably break the user script anyway
    os.environ[env.INITED] = '1'

    # we do these checks after setting the run and the config because users scripts
    # may depend on those things
    if sys.platform == 'win32' and run.mode != 'clirun':
        termerror(
            'To use wandb on Windows, you need to run the command "wandb run python <your_train_script>.py"')
        return run

    if in_jupyter:
        _init_jupyter(run)
    elif run.mode == 'clirun':
        pass
    elif run.mode == 'run':
        api = InternalApi()
        # let init_jupyter handle this itself
        if not in_jupyter and not api.api_key:
            termlog(
                "W&B is a tool that helps track and visualize machine learning experiments")
            if force:
                termerror(
                    "No credentials found.  Run \"wandb login\" or \"wandb off\" to disable wandb")
            else:
                if run.check_anonymous():
                    _init_headless(run)
                else:
                    termlog(
                        "No credentials found.  Run \"wandb login\" to visualize your metrics")
                    run.mode = "dryrun"
                    _init_headless(run, False)
        else:
            _init_headless(run)
    elif run.mode == 'dryrun':
        termlog(
            'Dry run mode, not syncing to the cloud.')
        _init_headless(run, False)
    else:
        termerror(
            'Invalid run mode "%s". Please unset WANDB_MODE.' % run.mode)
        raise LaunchError("The WANDB_MODE environment variable is invalid.")

    # set the run directory in the config so it actually gets persisted
    run.config.set_run_dir(run.dir)

    if sagemaker_config:
        run.config.update(sagemaker_config)
        allow_val_change = True
    if config:
        run.config.update(config, allow_val_change=allow_val_change)

    # Access history to ensure resumed is set when resuming
    run.history
    # Load the summary to support resuming
    run.summary.load()

    atexit.register(run.close_files)

    return run
示例#7
0
def _fit_wrapper(self, fn, generator=None, *args, **kwargs):
    trigger.call('on_fit')
    keras = sys.modules.get("keras", None)
    tfkeras = sys.modules.get("tensorflow.python.keras", None)
    epochs = kwargs.pop("epochs", None)
    batch_size = kwargs.pop("batch_size", None)

    magic_epochs = _magic_get_config("keras.fit.epochs", None)
    if magic_epochs is not None:
        epochs = magic_epochs
    magic_batch_size = _magic_get_config("keras.fit.batch_size", None)
    if magic_batch_size is not None:
        batch_size = magic_batch_size
    callbacks = kwargs.pop("callbacks", [])

    tb_enabled = _magic_get_config("keras.fit.callbacks.tensorboard.enable",
                                   None)
    if tb_enabled:
        k = getattr(self, '_keras_or_tfkeras', None)
        if k:
            tb_duplicate = _magic_get_config(
                "keras.fit.callbacks.tensorboard.duplicate", None)
            tb_overwrite = _magic_get_config(
                "keras.fit.callbacks.tensorboard.overwrite", None)
            tb_present = any(
                [isinstance(cb, k.callbacks.TensorBoard) for cb in callbacks])
            if tb_present and tb_overwrite:
                callbacks = [
                    cb for cb in callbacks
                    if not isinstance(cb, k.callbacks.TensorBoard)
                ]
            if tb_overwrite or tb_duplicate or not tb_present:
                tb_callback_kwargs = {'log_dir': wandb.run.dir}
                cb_args = ('write_graph', 'histogram_freq', 'update_freq',
                           'write_grads', 'write_images', 'batch_size')
                for cb_arg in cb_args:
                    v = _magic_get_config(
                        "keras.fit.callbacks.tensorboard." + cb_arg, None)
                    if v is not None:
                        tb_callback_kwargs[cb_arg] = v
                tb_callback = k.callbacks.TensorBoard(**tb_callback_kwargs)
                callbacks.append(tb_callback)

    wandb_enabled = _magic_get_config("keras.fit.callbacks.wandb.enable", None)
    if wandb_enabled:
        wandb_duplicate = _magic_get_config(
            "keras.fit.callbacks.wandb.duplicate", None)
        wandb_overwrite = _magic_get_config(
            "keras.fit.callbacks.wandb.overwrite", None)
        wandb_present = any(
            [isinstance(cb, wandb.keras.WandbCallback) for cb in callbacks])
        if wandb_present and wandb_overwrite:
            callbacks = [
                cb for cb in callbacks
                if not isinstance(cb, wandb.keras.WandbCallback)
            ]
        if wandb_overwrite or wandb_duplicate or not wandb_present:
            wandb_callback_kwargs = {}
            log_gradients = _magic_get_config(
                "keras.fit.callbacks.wandb.log_gradients", None)
            if log_gradients and kwargs.get('x') and kwargs.get('y'):
                wandb_callback_kwargs['log_gradients'] = log_gradients
            cb_args = ("predictions", "log_weights", "data_type", "save_model",
                       "save_weights_only", "monitor", "mode", "verbose",
                       "input_type", "output_type", "log_evaluation", "labels")
            for cb_arg in cb_args:
                v = _magic_get_config("keras.fit.callbacks.wandb." + cb_arg,
                                      None)
                if v is not None:
                    wandb_callback_kwargs[cb_arg] = v
            wandb_callback = wandb.keras.WandbCallback(**wandb_callback_kwargs)
            callbacks.append(wandb_callback)

    kwargs["callbacks"] = callbacks
    if epochs is not None:
        kwargs["epochs"] = epochs
    if batch_size is not None:
        kwargs["batch_size"] = batch_size
    if generator:
        return fn(generator, *args, **kwargs)
    return fn(*args, **kwargs)
示例#8
0
    def init(self):
        trigger.call("on_init", **self.kwargs)
        s = self.settings
        config = self.config

        if s._noop:
            run = RunDummy()
            module.set_global(
                run=run,
                config=run.config,
                log=run.log,
                summary=run.summary,
                save=run.save,
                restore=run.restore,
                use_artifact=run.use_artifact,
                log_artifact=run.log_artifact,
                plot_table=run.plot_table,
            )
            return run

        if s.reinit or (s._jupyter and s.reinit is not False):
            if len(self._wl._global_run_stack) > 0:
                if len(self._wl._global_run_stack) > 1:
                    wandb.termwarn(
                        "If you want to track multiple runs concurrently in wandb you should use multi-processing not threads"  # noqa: E501
                    )
                self._wl._global_run_stack[-1].finish()
        elif wandb.run:
            logger.info("wandb.init() called when a run is still active")
            return wandb.run

        use_redirect = True
        stdout_master_fd, stderr_master_fd = None, None
        stdout_slave_fd, stderr_slave_fd = None, None

        backend = Backend()
        backend.ensure_launched(
            settings=s,
            stdout_fd=stdout_master_fd,
            stderr_fd=stderr_master_fd,
            use_redirect=use_redirect,
        )
        backend.server_connect()
        # Make sure we are logged in
        # wandb_login._login(_backend=backend, _settings=self.settings)

        # resuming needs access to the server, check server_status()?

        run = Run(config=config, settings=s)
        run._set_console(
            use_redirect=use_redirect,
            stdout_slave_fd=stdout_slave_fd,
            stderr_slave_fd=stderr_slave_fd,
        )
        run._set_library(self._wl)
        run._set_backend(backend)
        run._set_reporter(self._reporter)
        run._set_teardown_hooks(self._teardown_hooks)
        # TODO: pass mode to backend
        # run_synced = None

        backend._hack_set_run(run)
        backend.interface.publish_header()

        if s._offline:
            run_proto = backend.interface._make_run(run)
            backend.interface._publish_run(run_proto)
            run._set_run_obj_offline(run_proto)
        else:
            ret = backend.interface.communicate_check_version(
                current_version=wandb.__version__)
            if ret:
                if ret.upgrade_message:
                    run._set_upgraded_version_message(ret.upgrade_message)
                if ret.delete_message:
                    run._set_deleted_version_message(ret.delete_message)
                if ret.yank_message:
                    run._set_yanked_version_message(ret.yank_message)
            run._on_init()
            ret = backend.interface.communicate_run(run, timeout=30)
            error_message = None
            if not ret:
                error_message = "Error communicating with backend"
            if ret and ret.error:
                error_message = ret.error.message
            if error_message:
                # Shutdown the backend and get rid of the logger
                # we don't need to do console cleanup at this point
                backend.cleanup()
                self.teardown()
                raise UsageError(error_message)
            run._set_run_obj(ret.run)

        # initiate run (stats and metadata probing)
        _ = backend.interface.communicate_run_start()

        self._wl._global_run_stack.append(run)
        self.run = run
        self.backend = backend
        module.set_global(
            run=run,
            config=run.config,
            log=run.log,
            summary=run.summary,
            save=run.save,
            restore=run.restore,
            use_artifact=run.use_artifact,
            log_artifact=run.log_artifact,
            plot_table=run.plot_table,
        )
        self._reporter.set_context(run=run)
        run._on_start()

        run._freeze()
        return run
示例#9
0
    def init(self) -> Union[Run, RunDisabled, None]:  # noqa: C901
        if logger is None:
            raise RuntimeError("Logger not initialized")
        logger.info("calling init triggers")
        trigger.call("on_init", **self.kwargs)

        logger.info(
            f"wandb.init called with sweep_config: {self.sweep_config}\nconfig: {self.config}"
        )
        if self.settings._noop:
            return self._make_run_disabled()
        if self.settings.reinit or (self.settings._jupyter
                                    and self.settings.reinit is not False):
            if len(self._wl._global_run_stack) > 0:
                if len(self._wl._global_run_stack) > 1:
                    wandb.termwarn(
                        "If you want to track multiple runs concurrently in wandb, "
                        "you should use multi-processing not threads"  # noqa: E501
                    )

                last_id = self._wl._global_run_stack[-1]._run_id
                logger.info(
                    f"re-initializing run, found existing run on stack: {last_id}"
                )
                jupyter = (self.settings._jupyter and not self.settings.silent
                           and ipython.in_jupyter())
                if jupyter:
                    ipython.display_html(
                        f"Finishing last run (ID:{last_id}) before initializing another..."
                    )

                self._wl._global_run_stack[-1].finish()

                if jupyter:
                    ipython.display_html(
                        f"Successfully finished last run (ID:{last_id}). Initializing new run:<br/>"
                    )
        elif isinstance(wandb.run, Run):
            allow_return_run = True
            manager = self._wl._get_manager()
            if manager:
                current_pid = os.getpid()
                if current_pid != wandb.run._init_pid:
                    # We shouldn't return a stale global run if we are in a new pid
                    allow_return_run = False

            if allow_return_run:
                logger.info("wandb.init() called when a run is still active")
                with telemetry.context() as tel:
                    tel.feature.init_return_run = True
                return wandb.run

        logger.info("starting backend")

        manager = self._wl._get_manager()
        if manager:
            manager._inform_init(settings=self.settings,
                                 run_id=self.settings.run_id)

        backend = Backend(settings=self.settings, manager=manager)
        backend.ensure_launched()
        backend.server_connect()
        logger.info("backend started and connected")
        # Make sure we are logged in
        # wandb_login._login(_backend=backend, _settings=self.settings)

        # resuming needs access to the server, check server_status()?

        run = Run(config=self.config,
                  settings=self.settings,
                  sweep_config=self.sweep_config)

        # probe the active start method
        active_start_method: Optional[str] = None
        if self.settings.start_method == "thread":
            active_start_method = self.settings.start_method
        else:
            get_start_fn = getattr(backend._multiprocessing,
                                   "get_start_method", None)
            active_start_method = get_start_fn() if get_start_fn else None

        # Populate initial telemetry
        with telemetry.context(run=run) as tel:
            tel.cli_version = wandb.__version__
            tel.python_version = platform.python_version()
            hf_version = _huggingface_version()
            if hf_version:
                tel.huggingface_version = hf_version
            if self.settings._jupyter:
                tel.env.jupyter = True
            if self.settings._kaggle:
                tel.env.kaggle = True
            if self.settings._windows:
                tel.env.windows = True
            run._telemetry_imports(tel.imports_init)
            if self._use_sagemaker:
                tel.feature.sagemaker = True
            if self._set_init_config:
                tel.feature.set_init_config = True
            if self._set_init_name:
                tel.feature.set_init_name = True
            if self._set_init_id:
                tel.feature.set_init_id = True
            if self._set_init_tags:
                tel.feature.set_init_tags = True

            if self.settings.launch:
                tel.feature.launch = True

            if active_start_method == "spawn":
                tel.env.start_spawn = True
            elif active_start_method == "fork":
                tel.env.start_fork = True
            elif active_start_method == "forkserver":
                tel.env.start_forkserver = True
            elif active_start_method == "thread":
                tel.env.start_thread = True

            if manager:
                tel.feature.service = True

            tel.env.maybe_mp = _maybe_mp_process(backend)

            # fixme: detected issues with settings
            if self.settings.__dict__["_Settings__preprocessing_warnings"]:
                tel.issues.settings__preprocessing_warnings = True
            if self.settings.__dict__["_Settings__validation_warnings"]:
                tel.issues.settings__validation_warnings = True
            if self.settings.__dict__["_Settings__unexpected_args"]:
                tel.issues.settings__unexpected_args = True

        if not self.settings.label_disable:
            if self.notebook:
                run._label_probe_notebook(self.notebook)
            else:
                run._label_probe_main()

        logger.info("updated telemetry")

        run._set_library(self._wl)
        run._set_backend(backend)
        run._set_reporter(self._reporter)
        run._set_teardown_hooks(self._teardown_hooks)
        # TODO: pass mode to backend
        # run_synced = None

        backend._hack_set_run(run)
        assert backend.interface
        backend.interface.publish_header()

        # Using GitRepo() blocks & can be slow, depending on user's current git setup.
        # We don't want to block run initialization/start request, so populate run's git
        # info beforehand.
        if not self.settings.disable_git:
            run._populate_git_info()

        if self.settings._offline:
            with telemetry.context(run=run) as tel:
                tel.feature.offline = True
            run_proto = backend.interface._make_run(run)
            backend.interface._publish_run(run_proto)
            run._set_run_obj_offline(run_proto)
            if self.settings.resume:
                wandb.termwarn(
                    "`resume` will be ignored since W&B syncing is set to `offline`. "
                    f"Starting a new run with run id {run.id}.")
        else:
            logger.info("communicating run to backend with 30 second timeout")
            run_result = backend.interface.communicate_run(run, timeout=30)

            error_message: Optional[str] = None
            if not run_result:
                logger.error("backend process timed out")
                error_message = "Error communicating with wandb process"
                if active_start_method != "fork":
                    error_message += "\ntry: wandb.init(settings=wandb.Settings(start_method='fork'))"
                    error_message += "\nor:  wandb.init(settings=wandb.Settings(start_method='thread'))"
                    error_message += "\nFor more info see: https://docs.wandb.ai/library/init#init-start-error"
            elif run_result.error:
                error_message = run_result.error.message
            if error_message:
                logger.error(f"encountered error: {error_message}")

                # Shutdown the backend and get rid of the logger
                # we don't need to do console cleanup at this point
                backend.cleanup()
                self.teardown()
                raise UsageError(error_message)
            assert run_result and run_result.run
            if run_result.run.resumed:
                logger.info("run resumed")
                with telemetry.context(run=run) as tel:
                    tel.feature.resumed = True
            run._set_run_obj(run_result.run)
            run._on_init()

        logger.info("starting run threads in backend")
        # initiate run (stats and metadata probing)
        run_obj = run._run_obj or run._run_obj_offline

        self.settings._apply_run_start(message_to_dict(run_obj))
        run._update_settings(self.settings)
        if manager:
            manager._inform_start(settings=self.settings,
                                  run_id=self.settings.run_id)

        assert backend.interface
        assert run_obj
        _ = backend.interface.communicate_run_start(run_obj)

        self._wl._global_run_stack.append(run)
        self.run = run
        self.backend = backend
        module.set_global(
            run=run,
            config=run.config,
            log=run.log,
            summary=run.summary,
            save=run.save,
            use_artifact=run.use_artifact,
            log_artifact=run.log_artifact,
            define_metric=run.define_metric,
            plot_table=run.plot_table,
            alert=run.alert,
            mark_preempting=run.mark_preempting,
        )
        self._reporter.set_context(run=run)
        run._on_start()

        run._freeze()
        logger.info("run started, returning control to user process")
        return run