コード例 #1
0
ファイル: wandb_run.py プロジェクト: rguerrettaz/client
    def __init__(self,
                 run_id=None,
                 mode=None,
                 dir=None,
                 config=None,
                 sweep_id=None,
                 storage_id=None,
                 description=None,
                 resume=None,
                 program=None,
                 wandb_dir=None):
        # self.id is actually stored in the "name" attribute in GQL
        self.id = run_id if run_id else generate_id()
        self.resume = resume if resume else 'never'
        self.mode = mode if mode else 'run'

        self.program = program
        if not self.program:
            try:
                import __main__
                self.program = __main__.__file__
            except (ImportError, AttributeError):
                # probably `python -c`, an embedded interpreter or something
                self.program = '<python with no main file>'
        self.wandb_dir = wandb_dir

        if dir is None:
            self._dir = run_dir_path(self.id, dry=self.mode == 'dryrun')
        else:
            self._dir = os.path.abspath(dir)
        self._mkdir()

        if config is None:
            self.config = Config()
        else:
            self.config = config

        # this is the GQL ID:
        self.storage_id = storage_id
        # socket server, currently only available in headless mode
        self.socket = None

        if description is not None:
            self.description = description
        # An empty description.md may have been created by RunManager() so it's
        # important that we overwrite empty strings here.
        if not self.description:
            self.description = self.id

        self.sweep_id = sweep_id

        self._history = None
        self._events = None
        self._summary = None
        self._meta = None
        self._user_accessed_summary = False
        self._examples = None
コード例 #2
0
ファイル: wandb_run.py プロジェクト: arroadie/client
    def from_environment_or_defaults(cls, environment=None):
        """Create a Run object taking values from the local environment where possible.

        The run ID comes from WANDB_RUN_ID or is randomly generated.
        The run mode ("dryrun", or "run") comes from WANDB_MODE or defaults to "dryrun".
        The run directory comes from WANDB_RUN_DIR or is generated from the run ID.

        The Run will have a .config attribute but its run directory won't be set by
        default.
        """
        if environment is None:
            environment = os.environ
        run_id = environment.get('WANDB_RUN_ID')
        resume = environment.get('WANDB_RESUME')
        storage_id = environment.get('WANDB_RUN_STORAGE_ID')
        mode = environment.get('WANDB_MODE')
        run_dir = environment.get('WANDB_RUN_DIR')
        sweep_id = environment.get('WANDB_SWEEP_ID')
        program = environment.get('WANDB_PROGRAM')
        wandb_dir = environment.get('WANDB_DIR')
        config = Config.from_environment_or_defaults()
        run = cls(run_id,
                  mode,
                  run_dir,
                  config,
                  sweep_id,
                  storage_id,
                  program=program,
                  wandb_dir=wandb_dir,
                  resume=resume)
        return run
コード例 #3
0
def run(ctx, program, args, id, resume, dir, configs, message, name, notes,
        show, tags, run_group, job_type):
    wandb.ensure_configured()
    if configs:
        config_paths = configs.split(',')
    else:
        config_paths = []
    config = Config(config_paths=config_paths,
                    wandb_dir=dir or wandb.wandb_dir())
    tags = [tag for tag in tags.split(",") if tag] if tags else None

    # populate run parameters from env if not specified
    id = id or os.environ.get(env.RUN_ID)
    message = message or os.environ.get(env.DESCRIPTION)
    tags = tags or env.get_tags()
    run_group = run_group or os.environ.get(env.RUN_GROUP)
    job_type = job_type or os.environ.get(env.JOB_TYPE)
    name = name or os.environ.get(env.NAME)
    notes = notes or os.environ.get(env.NOTES)
    resume = resume or os.environ.get(env.RESUME)

    run = wandb_run.Run(run_id=id,
                        mode='clirun',
                        config=config,
                        description=message,
                        program=program,
                        tags=tags,
                        group=run_group,
                        job_type=job_type,
                        name=name,
                        notes=notes,
                        resume=resume)
    run.enable_logging()

    environ = dict(os.environ)
    if configs:
        environ[env.CONFIG_PATHS] = configs
    if show:
        environ[env.SHOW_RUN] = 'True'

    if not run.api.api_key:
        util.prompt_api_key(run.api, input_callback=click.prompt)

    try:
        rm = run_manager.RunManager(run)
        rm.init_run(environ)
    except run_manager.Error:
        exc_type, exc_value, exc_traceback = sys.exc_info()
        wandb.termerror(
            'An Exception was raised during setup, see %s for full traceback.'
            % util.get_log_file_path())
        wandb.termerror(str(exc_value))
        if 'permission' in str(exc_value):
            wandb.termerror(
                'Are you sure you provided the correct API key to "wandb login"?'
            )
        lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
        logger.error('\n'.join(lines))
        sys.exit(1)
    rm.run_user_process(program, args, environ)
コード例 #4
0
ファイル: wandb_run.py プロジェクト: youtang1993/client
    def from_environment_or_defaults(cls, environment=None):
        """Create a Run object taking values from the local environment where possible.

        The run ID comes from WANDB_RUN_ID or is randomly generated.
        The run mode ("dryrun", or "run") comes from WANDB_MODE or defaults to "dryrun".
        The run directory comes from WANDB_RUN_DIR or is generated from the run ID.

        The Run will have a .config attribute but its run directory won't be set by
        default.
        """
        if environment is None:
            environment = os.environ
        run_id = environment.get(env.RUN_ID)
        resume = environment.get(env.RESUME)
        storage_id = environment.get(env.RUN_STORAGE_ID)
        mode = environment.get(env.MODE)
        disabled = InternalApi().disabled()
        if not mode and disabled:
            mode = "dryrun"
        elif disabled and mode != "dryrun":
            wandb.termwarn(
                "WANDB_MODE is set to run, but W&B was disabled.  Run `wandb on` to remove this message"
            )
        elif disabled:
            wandb.termlog(
                'W&B is disabled in this directory.  Run `wandb on` to enable cloud syncing.'
            )

        group = environment.get(env.RUN_GROUP)
        job_type = environment.get(env.JOB_TYPE)
        run_dir = environment.get(env.RUN_DIR)
        sweep_id = environment.get(env.SWEEP_ID)
        program = environment.get(env.PROGRAM)
        description = environment.get(env.DESCRIPTION)
        args = env.get_args()
        wandb_dir = env.get_dir()
        tags = env.get_tags()
        config = Config.from_environment_or_defaults()
        run = cls(run_id,
                  mode,
                  run_dir,
                  group,
                  job_type,
                  config,
                  sweep_id,
                  storage_id,
                  program=program,
                  description=description,
                  args=args,
                  wandb_dir=wandb_dir,
                  tags=tags,
                  resume=resume)
        return run
コード例 #5
0
ファイル: cli.py プロジェクト: mikewlange/client-1
def run(ctx, program, args, id, resume, dir, configs, message, name, notes,
        show, tags, run_group, job_type):
    wandb.ensure_configured()
    if configs:
        config_paths = configs.split(',')
    else:
        config_paths = []
    config = Config(config_paths=config_paths,
                    wandb_dir=dir or wandb.wandb_dir())
    tags = [tag for tag in tags.split(",") if tag] if tags else None
    run = wandb_run.Run(run_id=id,
                        mode='clirun',
                        config=config,
                        description=message,
                        program=program,
                        tags=tags,
                        group=run_group,
                        job_type=job_type,
                        name=name,
                        notes=notes,
                        resume=resume)
    run.enable_logging()

    environ = dict(os.environ)
    if configs:
        environ[env.CONFIG_PATHS] = configs
    if show:
        environ[env.SHOW_RUN] = 'True'
    run.check_anonymous()

    try:
        rm = run_manager.RunManager(run)
        rm.init_run(environ)
    except run_manager.Error:
        exc_type, exc_value, exc_traceback = sys.exc_info()
        wandb.termerror(
            'An Exception was raised during setup, see %s for full traceback.'
            % util.get_log_file_path())
        wandb.termerror(str(exc_value))
        if 'permission' in str(exc_value):
            wandb.termerror(
                'Are you sure you provided the correct API key to "wandb login"?'
            )
        lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
        logger.error('\n'.join(lines))
        sys.exit(1)
    rm.run_user_process(program, args, environ)
コード例 #6
0
ファイル: wandb_run.py プロジェクト: rguerrettaz/client
    def from_environment_or_defaults(cls, environment=None):
        """Create a Run object taking values from the local environment where possible.

        The run ID comes from WANDB_RUN_ID or is randomly generated.
        The run mode ("dryrun", or "run") comes from WANDB_MODE or defaults to "dryrun".
        The run directory comes from WANDB_RUN_DIR or is generated from the run ID.

        The Run will have a .config attribute but its run directory won't be set by
        default.
        """
        if environment is None:
            environment = os.environ
        run_id = environment.get('WANDB_RUN_ID')
        resume = environment.get('WANDB_RESUME')
        storage_id = environment.get('WANDB_RUN_STORAGE_ID')
        mode = environment.get('WANDB_MODE')
        disabled = Api().disabled()
        if not mode and disabled:
            mode = "dryrun"
        elif disabled and mode != "dryrun":
            wandb.termlog(
                "WARNING: WANDB_MODE is set to run, but W&B was disabled.  Run `wandb on` to remove this message"
            )
        elif disabled:
            wandb.termlog(
                'W&B is disabled in this directory.  Run `wandb on` to enable cloud syncing.'
            )

        run_dir = environment.get('WANDB_RUN_DIR')
        sweep_id = environment.get('WANDB_SWEEP_ID')
        program = environment.get('WANDB_PROGRAM')
        wandb_dir = environment.get('WANDB_DIR')
        config = Config.from_environment_or_defaults()
        run = cls(run_id,
                  mode,
                  run_dir,
                  config,
                  sweep_id,
                  storage_id,
                  program=program,
                  wandb_dir=wandb_dir,
                  resume=resume)
        return run
コード例 #7
0
ファイル: vertical_nn.py プロジェクト: Huizerd/slayerPytorch

if __name__ == "__main__":
    # Parse for configuration file
    parser = argparse.ArgumentParser(description=None)
    parser.add_argument(
        "-c",
        "--config",
        type=str,
        default="vertical_nn.yaml",
        help="Select configuration file",
    )
    args = vars(parser.parse_args())

    # Config
    config = Config(config_paths=[args["config"]])
    wandb.init(config=config, project="vertical")

    # Environment
    env = gym.make(
        config["environment"]["name"],
        obs_noise=config["environment"]["obsNoise"],
        init_rand=config["environment"]["initRand"],
        init_state=config["environment"]["initState"],
        delay=config["environment"]["delay"],
        reward_mods=config["environment"]["rewardMods"],
        state_bounds=[config["environment"]["altBounds"], None],
        total_steps=config["environment"]["steps"],
        goal_obs=config["environment"]["goalObs"],
        state_obs=config["environment"]["stateObs"],
        action_bounds=config["environment"]["actionBounds"],
コード例 #8
0
ファイル: wandb_run.py プロジェクト: gampx/client
    def __init__(self,
                 run_id=None,
                 mode=None,
                 dir=None,
                 group=None,
                 job_type=None,
                 config=None,
                 sweep_id=None,
                 storage_id=None,
                 description=None,
                 resume=None,
                 program=None,
                 args=None,
                 wandb_dir=None,
                 tags=None,
                 name=None,
                 notes=None,
                 api=None):
        """Create a Run.

        Arguments:
            description (str): This is the old, deprecated style of description: the run's
                name followed by a newline, followed by multiline notes.
        """
        # self.storage_id is "id" in GQL.
        self.storage_id = storage_id
        # self.id is "name" in GQL.
        self.id = run_id if run_id else util.generate_id()
        # self._name is  "display_name" in GQL.
        self._name = None
        self.notes = None

        self.resume = resume if resume else 'never'
        self.mode = mode if mode else 'run'
        self.group = group
        self.job_type = job_type
        self.pid = os.getpid()
        self.resumed = False  # we set resume when history is first accessed
        if api:
            if api.current_run_id and api.current_run_id != self.id:
                raise RuntimeError(
                    'Api object passed to run {} is already being used by run {}'
                    .format(self.id, api.current_run_id))
            else:
                api.set_current_run_id(self.id)
        self._api = api

        if dir is None:
            self._dir = run_dir_path(self.id, dry=self.mode == 'dryrun')
        else:
            self._dir = os.path.abspath(dir)
        self._mkdir()

        # self.name and self.notes used to be combined into a single field.
        # Now if name and notes don't have their own values, we get them from
        # self._name_and_description, but we don't update description.md
        # if they're changed. This is to discourage relying on self.description
        # and self._name_and_description so that we can drop them later.
        #
        # This needs to be set before name and notes because name and notes may
        # influence it. They have higher precedence.
        self._name_and_description = None
        if description:
            wandb.termwarn(
                'Run.description is deprecated. Please use wandb.init(notes="long notes") instead.'
            )
            self._name_and_description = description
        elif os.path.exists(self.description_path):
            with open(self.description_path) as d_file:
                self._name_and_description = d_file.read()

        if name is not None:
            self.name = name
        if notes is not None:
            self.notes = notes

        self.program = program
        if not self.program:
            try:
                import __main__
                self.program = __main__.__file__
            except (ImportError, AttributeError):
                # probably `python -c`, an embedded interpreter or something
                self.program = '<python with no main file>'
        self.args = args
        if self.args is None:
            self.args = sys.argv[1:]
        self.wandb_dir = wandb_dir

        with configure_scope() as scope:
            self.project = self.api.settings("project")
            scope.set_tag("project", self.project)
            scope.set_tag("entity", self.entity)
            try:
                scope.set_tag("url", self.get_url(self.api, network=False)
                              )  # TODO: Move this somewhere outside of init
            except CommError:
                pass

        if self.resume == "auto":
            util.mkdir_exists_ok(wandb.wandb_dir())
            resume_path = os.path.join(wandb.wandb_dir(), RESUME_FNAME)
            with open(resume_path, "w") as f:
                f.write(json.dumps({"run_id": self.id}))

        if config is None:
            self.config = Config()
        else:
            self.config = config

        # socket server, currently only available in headless mode
        self.socket = None

        self.tags = tags if tags else []

        self.sweep_id = sweep_id

        self._history = None
        self._events = None
        self._summary = None
        self._meta = None
        self._run_manager = None
        self._jupyter_agent = None
コード例 #9
0
ファイル: wandb_run.py プロジェクト: gampx/client
class Run(object):
    def __init__(self,
                 run_id=None,
                 mode=None,
                 dir=None,
                 group=None,
                 job_type=None,
                 config=None,
                 sweep_id=None,
                 storage_id=None,
                 description=None,
                 resume=None,
                 program=None,
                 args=None,
                 wandb_dir=None,
                 tags=None,
                 name=None,
                 notes=None,
                 api=None):
        """Create a Run.

        Arguments:
            description (str): This is the old, deprecated style of description: the run's
                name followed by a newline, followed by multiline notes.
        """
        # self.storage_id is "id" in GQL.
        self.storage_id = storage_id
        # self.id is "name" in GQL.
        self.id = run_id if run_id else util.generate_id()
        # self._name is  "display_name" in GQL.
        self._name = None
        self.notes = None

        self.resume = resume if resume else 'never'
        self.mode = mode if mode else 'run'
        self.group = group
        self.job_type = job_type
        self.pid = os.getpid()
        self.resumed = False  # we set resume when history is first accessed
        if api:
            if api.current_run_id and api.current_run_id != self.id:
                raise RuntimeError(
                    'Api object passed to run {} is already being used by run {}'
                    .format(self.id, api.current_run_id))
            else:
                api.set_current_run_id(self.id)
        self._api = api

        if dir is None:
            self._dir = run_dir_path(self.id, dry=self.mode == 'dryrun')
        else:
            self._dir = os.path.abspath(dir)
        self._mkdir()

        # self.name and self.notes used to be combined into a single field.
        # Now if name and notes don't have their own values, we get them from
        # self._name_and_description, but we don't update description.md
        # if they're changed. This is to discourage relying on self.description
        # and self._name_and_description so that we can drop them later.
        #
        # This needs to be set before name and notes because name and notes may
        # influence it. They have higher precedence.
        self._name_and_description = None
        if description:
            wandb.termwarn(
                'Run.description is deprecated. Please use wandb.init(notes="long notes") instead.'
            )
            self._name_and_description = description
        elif os.path.exists(self.description_path):
            with open(self.description_path) as d_file:
                self._name_and_description = d_file.read()

        if name is not None:
            self.name = name
        if notes is not None:
            self.notes = notes

        self.program = program
        if not self.program:
            try:
                import __main__
                self.program = __main__.__file__
            except (ImportError, AttributeError):
                # probably `python -c`, an embedded interpreter or something
                self.program = '<python with no main file>'
        self.args = args
        if self.args is None:
            self.args = sys.argv[1:]
        self.wandb_dir = wandb_dir

        with configure_scope() as scope:
            self.project = self.api.settings("project")
            scope.set_tag("project", self.project)
            scope.set_tag("entity", self.entity)
            try:
                scope.set_tag("url", self.get_url(self.api, network=False)
                              )  # TODO: Move this somewhere outside of init
            except CommError:
                pass

        if self.resume == "auto":
            util.mkdir_exists_ok(wandb.wandb_dir())
            resume_path = os.path.join(wandb.wandb_dir(), RESUME_FNAME)
            with open(resume_path, "w") as f:
                f.write(json.dumps({"run_id": self.id}))

        if config is None:
            self.config = Config()
        else:
            self.config = config

        # socket server, currently only available in headless mode
        self.socket = None

        self.tags = tags if tags else []

        self.sweep_id = sweep_id

        self._history = None
        self._events = None
        self._summary = None
        self._meta = None
        self._run_manager = None
        self._jupyter_agent = None

    @property
    def config_static(self):
        return ConfigStatic(self.config)

    @property
    def api(self):
        if self._api is None:
            self._api = InternalApi()
            self._api.set_current_run_id(self.id)
        return self._api

    @property
    def entity(self):
        return self.api.settings('entity')

    @entity.setter
    def entity(self, entity):
        self.api.set_setting("entity", entity)

    @property
    def path(self):
        # TODO: theres an edge case where self.entity is None
        return "/".join([str(self.entity), self.project_name(), self.id])

    def _init_jupyter_agent(self):
        from wandb.jupyter import JupyterAgent
        self._jupyter_agent = JupyterAgent()

    def _stop_jupyter_agent(self):
        self._jupyter_agent.stop()

    def send_message(self, options):
        """ Sends a message to the wandb process changing the policy
        of saved files.  This is primarily used internally by wandb.save
        """
        if not options.get("save_policy") and not options.get("tensorboard"):
            raise ValueError(
                "Only configuring save_policy and tensorboard is supported")
        if self.socket:
            # In the user process
            self.socket.send(options)
        elif self._jupyter_agent:
            # Running in jupyter
            self._jupyter_agent.start()
            if options.get("save_policy"):
                self._jupyter_agent.rm.update_user_file_policy(
                    options["save_policy"])
            elif options.get("tensorboard"):
                self._jupyter_agent.rm.start_tensorboard_watcher(
                    options["tensorboard"]["logdir"],
                    options["tensorboard"]["save"])
        elif self._run_manager:
            # Running in the wandb process, used for tfevents saving
            if options.get("save_policy"):
                self._run_manager.update_user_file_policy(
                    options["save_policy"])
        else:
            wandb.termerror(
                "wandb.init hasn't been called, can't configure run")

    @classmethod
    def from_environment_or_defaults(cls, environment=None):
        """Create a Run object taking values from the local environment where possible.

        The run ID comes from WANDB_RUN_ID or is randomly generated.
        The run mode ("dryrun", or "run") comes from WANDB_MODE or defaults to "dryrun".
        The run directory comes from WANDB_RUN_DIR or is generated from the run ID.

        The Run will have a .config attribute but its run directory won't be set by
        default.
        """
        if environment is None:
            environment = os.environ
        run_id = environment.get(env.RUN_ID)
        resume = environment.get(env.RESUME)
        storage_id = environment.get(env.RUN_STORAGE_ID)
        mode = environment.get(env.MODE)
        api = InternalApi(environ=environment)
        disabled = api.disabled()
        if not mode and disabled:
            mode = "dryrun"
        elif disabled and mode != "dryrun":
            wandb.termwarn(
                "WANDB_MODE is set to run, but W&B was disabled.  Run `wandb on` to remove this message"
            )
        elif disabled:
            wandb.termlog(
                'W&B is disabled in this directory.  Run `wandb on` to enable cloud syncing.'
            )

        group = environment.get(env.RUN_GROUP)
        job_type = environment.get(env.JOB_TYPE)
        run_dir = environment.get(env.RUN_DIR)
        sweep_id = environment.get(env.SWEEP_ID)
        program = environment.get(env.PROGRAM)
        description = environment.get(env.DESCRIPTION)
        name = environment.get(env.NAME)
        notes = environment.get(env.NOTES)
        args = env.get_args(env=environment)
        wandb_dir = env.get_dir(env=environment)
        tags = env.get_tags(env=environment)
        # TODO(adrian): should pass environment into here as well.
        config = Config.from_environment_or_defaults()
        run = cls(run_id,
                  mode,
                  run_dir,
                  group,
                  job_type,
                  config,
                  sweep_id,
                  storage_id,
                  program=program,
                  description=description,
                  args=args,
                  wandb_dir=wandb_dir,
                  tags=tags,
                  name=name,
                  notes=notes,
                  resume=resume,
                  api=api)

        return run

    @classmethod
    def from_directory(cls,
                       directory,
                       project=None,
                       entity=None,
                       run_id=None,
                       api=None,
                       ignore_globs=None):
        api = api or InternalApi()
        run_id = run_id or util.generate_id()
        run = Run(run_id=run_id, dir=directory)

        run_name = None
        project_from_meta = None
        snap = DirectorySnapshot(directory)
        meta = next((p for p in snap.paths if METADATA_FNAME in p), None)
        if meta:
            meta = json.load(open(meta))
            run_name = meta.get("name")
            project_from_meta = meta.get("project")

        project = project or project_from_meta or api.settings(
            "project") or run.auto_project_name(api=api)
        if project is None:
            raise ValueError("You must specify project")
        api.set_current_run_id(run_id)
        api.set_setting("project", project)
        if entity:
            api.set_setting("entity", entity)
        res = api.upsert_run(name=run_id,
                             project=project,
                             entity=entity,
                             display_name=run_name)
        entity = res["project"]["entity"]["name"]
        wandb.termlog("Syncing {} to:".format(directory))
        try:
            wandb.termlog(res["displayName"] + " " + run.get_url(api))
        except CommError as e:
            wandb.termwarn(e.message)

        file_api = api.get_file_stream_api()
        file_api.start()
        paths = [
            os.path.relpath(abs_path, directory) for abs_path in snap.paths
            if os.path.isfile(abs_path)
        ]
        if ignore_globs:
            paths = set(paths)
            for g in ignore_globs:
                paths = paths - set(fnmatch.filter(paths, g))
            paths = list(paths)
        run_update = {"id": res["id"]}
        tfevents = sorted([p for p in snap.paths if ".tfevents." in p])
        history = next((p for p in snap.paths if HISTORY_FNAME in p), None)
        event = next((p for p in snap.paths if EVENTS_FNAME in p), None)
        config = next((p for p in snap.paths if CONFIG_FNAME in p), None)
        user_config = next((p for p in snap.paths if USER_CONFIG_FNAME in p),
                           None)
        summary = next((p for p in snap.paths if SUMMARY_FNAME in p), None)
        if history:
            wandb.termlog("Uploading history metrics")
            file_api.stream_file(history)
            snap.paths.remove(history)
        elif len(tfevents) > 0:
            from wandb import tensorflow as wbtf
            wandb.termlog("Found tfevents file, converting...")
            summary = {}
            for path in tfevents:
                filename = os.path.basename(path)
                namespace = path.replace(filename,
                                         "").replace(directory,
                                                     "").strip(os.sep)
                summary.update(
                    wbtf.stream_tfevents(path,
                                         file_api,
                                         run,
                                         namespace=namespace))
            for path in glob.glob(os.path.join(directory, "media/**/*"),
                                  recursive=True):
                if os.path.isfile(path):
                    paths.append(path)
        else:
            wandb.termerror(
                "No history or tfevents files found, only syncing files")
        if event:
            file_api.stream_file(event)
            snap.paths.remove(event)
        if config:
            run_update["config"] = util.load_yaml(open(config))
        elif user_config:
            # TODO: half backed support for config.json
            run_update["config"] = {
                k: {
                    "value": v
                }
                for k, v in six.iteritems(user_config)
            }
        if isinstance(summary, dict):
            #TODO: summary should already have data_types converted here...
            run_update["summary_metrics"] = util.json_dumps_safer(summary)
        elif summary:
            run_update["summary_metrics"] = open(summary).read()
        if meta:
            if meta.get("git"):
                run_update["commit"] = meta["git"].get("commit")
                run_update["repo"] = meta["git"].get("remote")
            if meta.get("host"):
                run_update["host"] = meta["host"]
            run_update["program_path"] = meta["program"]
            run_update["job_type"] = meta.get("jobType")
            run_update["notes"] = meta.get("notes")
        else:
            run_update["host"] = run.host

        wandb.termlog("Updating run and uploading files")
        api.upsert_run(**run_update)
        pusher = FilePusher(api)
        for k in paths:
            path = os.path.abspath(os.path.join(directory, k))
            pusher.update_file(k, path)
            pusher.file_changed(k, path)
        pusher.finish()
        pusher.print_status()
        file_api.finish(0)
        # Remove temporary media images generated from tfevents
        if history is None and os.path.exists(os.path.join(directory,
                                                           "media")):
            shutil.rmtree(os.path.join(directory, "media"))
        wandb.termlog("Finished!")
        return run

    def auto_project_name(self, api):
        return util.auto_project_name(self.program, api)

    def save(self,
             id=None,
             program=None,
             summary_metrics=None,
             num_retries=None,
             api=None):
        api = api or self.api
        project = api.settings('project')
        if project is None:
            project = self.auto_project_name(api)
        upsert_result = api.upsert_run(
            id=id or self.storage_id,
            name=self.id,
            commit=api.git.last_commit,
            project=project,
            entity=self.entity,
            group=self.group,
            tags=self.tags if len(self.tags) > 0 else None,
            config=self.config.as_dict(),
            description=self._name_and_description,
            host=self.host,
            program_path=program or self.program,
            repo=api.git.remote_url,
            sweep_name=self.sweep_id,
            display_name=self._name,
            notes=self.notes,
            summary_metrics=summary_metrics,
            job_type=self.job_type,
            num_retries=num_retries)
        self.storage_id = upsert_result['id']
        self.name = upsert_result.get('displayName')
        return upsert_result

    def set_environment(self, environment=None):
        """Set environment variables needed to reconstruct this object inside
        a user scripts (eg. in `wandb.init()`).
        """
        if environment is None:
            environment = os.environ
        environment[env.RUN_ID] = self.id
        environment[env.RESUME] = self.resume
        if self.storage_id:
            environment[env.RUN_STORAGE_ID] = self.storage_id
        environment[env.MODE] = self.mode
        environment[env.RUN_DIR] = self.dir

        if self.group:
            environment[env.RUN_GROUP] = self.group
        if self.job_type:
            environment[env.JOB_TYPE] = self.job_type
        if self.wandb_dir:
            environment[env.DIR] = self.wandb_dir
        if self.sweep_id is not None:
            environment[env.SWEEP_ID] = self.sweep_id
        if self.program is not None:
            environment[env.PROGRAM] = self.program
        if self.args is not None:
            environment[env.ARGS] = json.dumps(self.args)
        if self._name_and_description is not None:
            environment[env.DESCRIPTION] = self._name_and_description
        if self._name is not None:
            environment[env.NAME] = self._name
        if self.notes is not None:
            environment[env.NOTES] = self.notes
        if len(self.tags) > 0:
            environment[env.TAGS] = ",".join(self.tags)
        return environment

    def _mkdir(self):
        util.mkdir_exists_ok(self._dir)

    def project_name(self, api=None):
        api = api or self.api
        return api.settings('project') or self.auto_project_name(
            api) or "uncategorized"

    def _generate_query_string(self, api, params=None):
        """URL encodes dictionary of params"""

        params = params or {}

        if str(api.settings().get('anonymous', 'false')) == 'true':
            params['apiKey'] = api.api_key

        if not params:
            return ""
        return '?' + urllib.parse.urlencode(params)

    def _load_entity(self, api, network):
        if not api.api_key:
            raise CommError(
                "Can't find API key, run wandb login or set WANDB_API_KEY")

        entity = api.settings('entity')
        if network:
            if api.settings('entity') is None:
                viewer = api.viewer()
                if viewer.get('entity'):
                    api.set_setting('entity', viewer['entity'])

            entity = api.settings('entity')

        if not entity:
            # This can happen on network failure
            raise CommError(
                "Can't connect to network to query entity from API key")

        return entity

    def get_project_url(self, api=None, network=True, params=None):
        """Generate a url for a project.

        If network is false and entity isn't specified in the environment raises wandb.apis.CommError
        """
        params = params or {}
        api = api or self.api
        self._load_entity(api, network)

        return "{base}/{entity}/{project}{query_string}".format(
            base=api.app_url,
            entity=urllib.parse.quote_plus(api.settings('entity')),
            project=urllib.parse.quote_plus(self.project_name(api)),
            query_string=self._generate_query_string(api, params))

    def get_sweep_url(self, api=None, network=True, params=None):
        """Generate a url for a sweep.

        If network is false and entity isn't specified in the environment raises wandb.apis.CommError

        Returns:
            string - url if the run is part of a sweep
            None - if the run is not part of the sweep
        """
        params = params or {}
        api = api or self.api
        self._load_entity(api, network)

        sweep_id = self.sweep_id
        if sweep_id is None:
            return

        return "{base}/{entity}/{project}/sweeps/{sweepid}{query_string}".format(
            base=api.app_url,
            entity=urllib.parse.quote_plus(api.settings('entity')),
            project=urllib.parse.quote_plus(self.project_name(api)),
            sweepid=urllib.parse.quote_plus(sweep_id),
            query_string=self._generate_query_string(api, params))

    def get_url(self, api=None, network=True, params=None):
        """Generate a url for a run.

        If network is false and entity isn't specified in the environment raises wandb.apis.CommError
        """
        params = params or {}
        api = api or self.api
        self._load_entity(api, network)

        return "{base}/{entity}/{project}/runs/{run}{query_string}".format(
            base=api.app_url,
            entity=urllib.parse.quote_plus(api.settings('entity')),
            project=urllib.parse.quote_plus(self.project_name(api)),
            run=urllib.parse.quote_plus(self.id),
            query_string=self._generate_query_string(api, params))

    def upload_debug(self):
        """Uploads the debug log to cloud storage"""
        if os.path.exists(self.log_fname):
            pusher = FilePusher(self.api)
            pusher.update_file("wandb-debug.log", self.log_fname)
            pusher.file_changed("wandb-debug.log", self.log_fname)
            pusher.finish()

    def __repr__(self):
        try:
            return "W&B Run: %s" % self.get_url()
        except CommError as e:
            return "W&B Error: %s" % e.message

    @property
    def name(self):
        if self._name is not None:
            return self._name
        elif self._name_and_description is not None:
            return self._name_and_description.split("\n")[0]
        else:
            return None

    @name.setter
    def name(self, name):
        self._name = name
        if self._name_and_description is not None:
            parts = self._name_and_description.split("\n", 1)
            parts[0] = name
            self._name_and_description = "\n".join(parts)

    @property
    def description(self):
        wandb.termwarn(
            'Run.description is deprecated. Please use run.notes instead.')
        if self._name_and_description is None:
            self._name_and_description = ''
        parts = self._name_and_description.split("\n", 1)
        if len(parts) > 1:
            return parts[1]
        else:
            return ""

    @description.setter
    def description(self, desc):
        wandb.termwarn(
            'Run.description is deprecated. Please use wandb.init(notes="long notes") instead.'
        )
        if self._name_and_description is None:
            self._name_and_description = self._name or ""
        parts = self._name_and_description.split("\n", 1)
        if len(parts) == 1:
            parts.append("")
        parts[1] = desc
        self._name_and_description = "\n".join(parts)
        with open(self.description_path, 'w') as d_file:
            d_file.write(self._name_and_description)

    @property
    def host(self):
        return os.environ.get(env.HOST, socket.gethostname())

    @property
    def dir(self):
        return self._dir

    @property
    def log_fname(self):
        # TODO: we started work to log to a file in the run dir, but it had issues.
        # For now all logs goto the same place.
        return util.get_log_file_path()

    def enable_logging(self):
        """Enable logging to the global debug log.  This adds a run_id to the log,
        in case of muliple processes on the same machine.

        Currently no way to disable logging after it's enabled.
        """
        handler = logging.FileHandler(self.log_fname)
        handler.setLevel(logging.INFO)
        run_id = self.id

        class WBFilter(logging.Filter):
            def filter(self, record):
                record.run_id = run_id
                return True

        formatter = logging.Formatter(
            '%(asctime)s %(levelname)-7s %(threadName)-10s:%(process)d [%(run_id)s:%(filename)s:%(funcName)s():%(lineno)s] %(message)s'
        )
        handler.setFormatter(formatter)
        handler.addFilter(WBFilter())

        root = logging.getLogger()
        root.addHandler(handler)

    @property
    def summary(self):
        if self._summary is None:
            self._summary = summary.FileSummary(self)
        return self._summary

    @property
    def has_summary(self):
        return self._summary or os.path.exists(
            os.path.join(self._dir, summary.SUMMARY_FNAME))

    def _history_added(self, row):
        self.summary.update(row, overwrite=False)

    def log(self,
            row=None,
            commit=None,
            step=None,
            sync=True,
            *args,
            **kwargs):
        if sync == False:
            wandb._ensure_async_log_thread_started()
            return wandb._async_log_queue.put({
                "row": row,
                "commit": commit,
                "step": step
            })

        if row is None:
            row = {}

        for k in row:
            if isinstance(row[k], Visualize):
                self._add_viz(k, row[k].viz_id)
                row[k] = row[k].value

        if not isinstance(row, collections.Mapping):
            raise ValueError("wandb.log must be passed a dictionary")

        if any(not isinstance(key, six.string_types) for key in row.keys()):
            raise ValueError(
                "Key values passed to `wandb.log` must be strings.")

        if commit is not False or step is not None:
            self.history.add(row, *args, step=step, commit=commit, **kwargs)
        else:
            self.history.update(row, *args, **kwargs)

    def _add_viz(self, key, viz_id):
        if not 'viz' in self.config['_wandb']:
            self.config._set_wandb('viz', {})
        self.config['_wandb']['viz'][key] = {
            'id': viz_id,
            'historyFieldSettings': {
                'key': key,
                'x-axis': '_step'
            }
        }
        self.config.persist()

    @property
    def history(self):
        if self._history is None:
            jupyter_callback = self._jupyter_agent.start if self._jupyter_agent else None
            self._history = history.History(self,
                                            add_callback=self._history_added,
                                            jupyter_callback=jupyter_callback)
            if self._history._steps > 0:
                self.resumed = True
        return self._history

    @property
    def step(self):
        return self.history._steps

    @property
    def has_history(self):
        return self._history or os.path.exists(
            os.path.join(self._dir, HISTORY_FNAME))

    @property
    def events(self):
        if self._events is None:
            self._events = jsonlfile.JsonlEventsFile(EVENTS_FNAME, self._dir)
        return self._events

    @property
    def has_events(self):
        return self._events or os.path.exists(
            os.path.join(self._dir, EVENTS_FNAME))

    @property
    def description_path(self):
        return os.path.join(self.dir, DESCRIPTION_FNAME)

    def close_files(self):
        """Close open files to avoid Python warnings on termination:

        Exception ignored in: <_io.FileIO name='wandb/dryrun-20180130_144602-9vmqjhgy/wandb-history.jsonl' mode='wb' closefd=True>
        ResourceWarning: unclosed file <_io.TextIOWrapper name='wandb/dryrun-20180130_144602-9vmqjhgy/wandb-history.jsonl' mode='w' encoding='UTF-8'>
        """
        if self._events is not None:
            self._events.close()
            self._events = None
        if self._history is not None:
            self._history.close()
            self._history = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        exit_code = 0 if exc_type is None else 1
        wandb.join(exit_code)
        return exc_type is None
コード例 #10
0
ファイル: wandb_run.py プロジェクト: youtang1993/client
    def __init__(self,
                 run_id=None,
                 mode=None,
                 dir=None,
                 group=None,
                 job_type=None,
                 config=None,
                 sweep_id=None,
                 storage_id=None,
                 description=None,
                 resume=None,
                 program=None,
                 args=None,
                 wandb_dir=None,
                 tags=None):
        # self.id is actually stored in the "name" attribute in GQL
        self.id = run_id if run_id else util.generate_id()
        self.display_name = self.id
        self.resume = resume if resume else 'never'
        self.mode = mode if mode else 'run'
        self.group = group
        self.job_type = job_type
        self.pid = os.getpid()
        self.resumed = False  # we set resume when history is first accessed

        self.program = program
        if not self.program:
            try:
                import __main__
                self.program = __main__.__file__
            except (ImportError, AttributeError):
                # probably `python -c`, an embedded interpreter or something
                self.program = '<python with no main file>'
        self.args = args
        if self.args is None:
            self.args = sys.argv[1:]
        self.wandb_dir = wandb_dir

        with configure_scope() as scope:
            api = InternalApi()
            self.project = api.settings("project")
            self.entity = api.settings("entity")
            scope.set_tag("project", self.project)
            scope.set_tag("entity", self.entity)
            scope.set_tag("url", self.get_url(api))

        if dir is None:
            self._dir = run_dir_path(self.id, dry=self.mode == 'dryrun')
        else:
            self._dir = os.path.abspath(dir)
        self._mkdir()

        if self.resume == "auto":
            util.mkdir_exists_ok(wandb.wandb_dir())
            resume_path = os.path.join(wandb.wandb_dir(), RESUME_FNAME)
            with open(resume_path, "w") as f:
                f.write(json.dumps({"run_id": self.id}))

        if config is None:
            self.config = Config()
        else:
            self.config = config

        # this is the GQL ID:
        self.storage_id = storage_id
        # socket server, currently only available in headless mode
        self.socket = None

        self.name_and_description = ""
        if description is not None:
            self.name_and_description = description
        elif os.path.exists(self.description_path):
            with open(self.description_path) as d_file:
                self.name_and_description = d_file.read()

        self.tags = tags if tags else []

        self.sweep_id = sweep_id

        self._history = None
        self._events = None
        self._summary = None
        self._meta = None
        self._run_manager = None
        self._jupyter_agent = None
コード例 #11
0
ファイル: wandb_run.py プロジェクト: youtang1993/client
class Run(object):
    def __init__(self,
                 run_id=None,
                 mode=None,
                 dir=None,
                 group=None,
                 job_type=None,
                 config=None,
                 sweep_id=None,
                 storage_id=None,
                 description=None,
                 resume=None,
                 program=None,
                 args=None,
                 wandb_dir=None,
                 tags=None):
        # self.id is actually stored in the "name" attribute in GQL
        self.id = run_id if run_id else util.generate_id()
        self.display_name = self.id
        self.resume = resume if resume else 'never'
        self.mode = mode if mode else 'run'
        self.group = group
        self.job_type = job_type
        self.pid = os.getpid()
        self.resumed = False  # we set resume when history is first accessed

        self.program = program
        if not self.program:
            try:
                import __main__
                self.program = __main__.__file__
            except (ImportError, AttributeError):
                # probably `python -c`, an embedded interpreter or something
                self.program = '<python with no main file>'
        self.args = args
        if self.args is None:
            self.args = sys.argv[1:]
        self.wandb_dir = wandb_dir

        with configure_scope() as scope:
            api = InternalApi()
            self.project = api.settings("project")
            self.entity = api.settings("entity")
            scope.set_tag("project", self.project)
            scope.set_tag("entity", self.entity)
            scope.set_tag("url", self.get_url(api))

        if dir is None:
            self._dir = run_dir_path(self.id, dry=self.mode == 'dryrun')
        else:
            self._dir = os.path.abspath(dir)
        self._mkdir()

        if self.resume == "auto":
            util.mkdir_exists_ok(wandb.wandb_dir())
            resume_path = os.path.join(wandb.wandb_dir(), RESUME_FNAME)
            with open(resume_path, "w") as f:
                f.write(json.dumps({"run_id": self.id}))

        if config is None:
            self.config = Config()
        else:
            self.config = config

        # this is the GQL ID:
        self.storage_id = storage_id
        # socket server, currently only available in headless mode
        self.socket = None

        self.name_and_description = ""
        if description is not None:
            self.name_and_description = description
        elif os.path.exists(self.description_path):
            with open(self.description_path) as d_file:
                self.name_and_description = d_file.read()

        self.tags = tags if tags else []

        self.sweep_id = sweep_id

        self._history = None
        self._events = None
        self._summary = None
        self._meta = None
        self._run_manager = None
        self._jupyter_agent = None

    @property
    def path(self):
        # TODO: theres an edge case where self.entity is None
        return "/".join([str(self.entity), str(self.project), str(self.id)])

    def _init_jupyter_agent(self):
        from wandb.jupyter import JupyterAgent
        self._jupyter_agent = JupyterAgent()

    def _stop_jupyter_agent(self):
        self._jupyter_agent.stop()

    def send_message(self, options):
        """ Sends a message to the wandb process changing the policy
        of saved files.  This is primarily used internally by wandb.save
        """
        if not options.get("save_policy") and not options.get("tensorboard"):
            raise ValueError(
                "Only configuring save_policy and tensorboard is supported")
        if self.socket:
            # In the user process
            self.socket.send(options)
        elif self._jupyter_agent:
            # Running in jupyter
            self._jupyter_agent.start()
            if options.get("save_policy"):
                self._jupyter_agent.rm.update_user_file_policy(
                    options["save_policy"])
            elif options.get("tensorboard"):
                self._jupyter_agent.rm.start_tensorboard_watcher(
                    options["tensorboard"]["logdir"],
                    options["tensorboard"]["save"])
        elif self._run_manager:
            # Running in the wandb process, used for tfevents saving
            if options.get("save_policy"):
                self._run_manager.update_user_file_policy(
                    options["save_policy"])
        else:
            wandb.termerror(
                "wandb.init hasn't been called, can't configure run")

    @classmethod
    def from_environment_or_defaults(cls, environment=None):
        """Create a Run object taking values from the local environment where possible.

        The run ID comes from WANDB_RUN_ID or is randomly generated.
        The run mode ("dryrun", or "run") comes from WANDB_MODE or defaults to "dryrun".
        The run directory comes from WANDB_RUN_DIR or is generated from the run ID.

        The Run will have a .config attribute but its run directory won't be set by
        default.
        """
        if environment is None:
            environment = os.environ
        run_id = environment.get(env.RUN_ID)
        resume = environment.get(env.RESUME)
        storage_id = environment.get(env.RUN_STORAGE_ID)
        mode = environment.get(env.MODE)
        disabled = InternalApi().disabled()
        if not mode and disabled:
            mode = "dryrun"
        elif disabled and mode != "dryrun":
            wandb.termwarn(
                "WANDB_MODE is set to run, but W&B was disabled.  Run `wandb on` to remove this message"
            )
        elif disabled:
            wandb.termlog(
                'W&B is disabled in this directory.  Run `wandb on` to enable cloud syncing.'
            )

        group = environment.get(env.RUN_GROUP)
        job_type = environment.get(env.JOB_TYPE)
        run_dir = environment.get(env.RUN_DIR)
        sweep_id = environment.get(env.SWEEP_ID)
        program = environment.get(env.PROGRAM)
        description = environment.get(env.DESCRIPTION)
        args = env.get_args()
        wandb_dir = env.get_dir()
        tags = env.get_tags()
        config = Config.from_environment_or_defaults()
        run = cls(run_id,
                  mode,
                  run_dir,
                  group,
                  job_type,
                  config,
                  sweep_id,
                  storage_id,
                  program=program,
                  description=description,
                  args=args,
                  wandb_dir=wandb_dir,
                  tags=tags,
                  resume=resume)
        return run

    @classmethod
    def from_directory(cls,
                       directory,
                       project=None,
                       entity=None,
                       run_id=None,
                       api=None,
                       ignore_globs=None):
        api = api or InternalApi()
        run_id = run_id or util.generate_id()
        run = Run(run_id=run_id, dir=directory)
        project = project or api.settings("project") or run.auto_project_name(
            api=api)
        if project is None:
            raise ValueError("You must specify project")
        api.set_current_run_id(run_id)
        api.set_setting("project", project)
        if entity:
            api.set_setting("entity", entity)
        res = api.upsert_run(name=run_id, project=project, entity=entity)
        entity = res["project"]["entity"]["name"]
        wandb.termlog("Syncing {} to:".format(directory))
        wandb.termlog(res["displayName"] + " " + run.get_url(api))

        file_api = api.get_file_stream_api()
        file_api.start()
        snap = DirectorySnapshot(directory)
        paths = [
            os.path.relpath(abs_path, directory) for abs_path in snap.paths
            if os.path.isfile(abs_path)
        ]
        if ignore_globs:
            paths = set(paths)
            for g in ignore_globs:
                paths = paths - set(fnmatch.filter(paths, g))
            paths = list(paths)
        run_update = {"id": res["id"]}
        tfevents = sorted([p for p in snap.paths if ".tfevents." in p])
        history = next((p for p in snap.paths if HISTORY_FNAME in p), None)
        event = next((p for p in snap.paths if EVENTS_FNAME in p), None)
        config = next((p for p in snap.paths if CONFIG_FNAME in p), None)
        user_config = next((p for p in snap.paths if USER_CONFIG_FNAME in p),
                           None)
        summary = next((p for p in snap.paths if SUMMARY_FNAME in p), None)
        meta = next((p for p in snap.paths if METADATA_FNAME in p), None)
        if history:
            wandb.termlog("Uploading history metrics")
            file_api.stream_file(history)
            snap.paths.remove(history)
        elif len(tfevents) > 0:
            from wandb import tensorflow as wbtf
            wandb.termlog("Found tfevents file, converting...")
            summary = {}
            for path in tfevents:
                filename = os.path.basename(path)
                namespace = path.replace(filename,
                                         "").replace(directory,
                                                     "").strip(os.sep)
                summary.update(
                    wbtf.stream_tfevents(path, file_api, namespace=namespace))
            for path in glob.glob(os.path.join(tempfile.gettempdir(),
                                               "media/**/*"),
                                  recursive=True):
                if os.path.isfile(path):
                    paths.append(path)
        else:
            wandb.termerror(
                "No history or tfevents files found, only syncing files")
        if event:
            file_api.stream_file(event)
            snap.paths.remove(event)
        if config:
            run_update["config"] = util.load_yaml(open(config))
        elif user_config:
            # TODO: half backed support for config.json
            run_update["config"] = {
                k: {
                    "value": v
                }
                for k, v in six.iteritems(user_config)
            }
        if isinstance(summary, dict):
            #TODO: summary should already have data_types converted here...
            run_update["summary_metrics"] = util.json_dumps_safer(
                data_types.to_json(summary, mode="summary"))
        elif summary:
            run_update["summary_metrics"] = open(summary).read()
        if meta:
            meta = json.load(open(meta))
            if meta.get("git"):
                run_update["commit"] = meta["git"].get("commit")
                run_update["repo"] = meta["git"].get("remote")
            run_update["host"] = meta["host"]
            run_update["program_path"] = meta["program"]
            run_update["job_type"] = meta.get("jobType")
        else:
            run_update["host"] = socket.gethostname()

        wandb.termlog("Updating run and uploading files")
        api.upsert_run(**run_update)
        pusher = FilePusher(api)
        tmpdir = tempfile.gettempdir()
        for k in paths:
            if tmpdir in k:
                path = k
                k = k.replace(tmpdir + os.sep, "")
            else:
                path = os.path.abspath(os.path.join(directory, k))
            pusher.update_file(k, path)
            pusher.file_changed(k, path)
        pusher.finish()
        pusher.print_status()
        file_api.finish(0)
        if os.path.exists(os.path.join(tmpdir, "media")):
            shutil.rmtree(os.path.join(tmpdir, "media"))
        wandb.termlog("Finished!")
        return run

    def auto_project_name(self, api):
        # if we're in git, set project name to git repo name + relative path within repo
        root_dir = api.git.root_dir
        if root_dir is None:
            return None
        repo_name = os.path.basename(root_dir)
        program = self.program
        if program is None:
            return repo_name
        if not os.path.isabs(program):
            program = os.path.join(os.curdir, program)
        prog_dir = os.path.dirname(os.path.abspath(program))
        if not prog_dir.startswith(root_dir):
            return repo_name
        project = repo_name
        sub_path = os.path.relpath(prog_dir, root_dir)
        if sub_path != '.':
            project += '-' + sub_path
        return project.replace(os.sep, '_')

    def save(self,
             id=None,
             program=None,
             summary_metrics=None,
             num_retries=None,
             api=None):
        api = api or InternalApi()
        project = api.settings('project')
        if project is None:
            project = self.auto_project_name(api)
        upsert_result = api.upsert_run(
            id=id or self.storage_id,
            name=self.id,
            commit=api.git.last_commit,
            project=project,
            entity=api.settings("entity"),
            group=self.group,
            tags=self.tags if len(self.tags) > 0 else None,
            config=self.config.as_dict(),
            description=self.name_and_description,
            host=socket.gethostname(),
            program_path=program or self.program,
            repo=api.git.remote_url,
            sweep_name=self.sweep_id,
            summary_metrics=summary_metrics,
            job_type=self.job_type,
            num_retries=num_retries)
        self.storage_id = upsert_result['id']
        self.display_name = upsert_result.get('displayName') or self.id
        return upsert_result

    def set_environment(self, environment=None):
        """Set environment variables needed to reconstruct this object inside
        a user scripts (eg. in `wandb.init()`).
        """
        if environment is None:
            environment = os.environ
        environment[env.RUN_ID] = self.id
        environment[env.RESUME] = self.resume
        if self.storage_id:
            environment[env.RUN_STORAGE_ID] = self.storage_id
        environment[env.MODE] = self.mode
        environment[env.RUN_DIR] = self.dir

        if self.group:
            environment[env.RUN_GROUP] = self.group
        if self.job_type:
            environment[env.JOB_TYPE] = self.job_type
        if self.wandb_dir:
            environment[env.DIR] = self.wandb_dir
        if self.sweep_id is not None:
            environment[env.SWEEP_ID] = self.sweep_id
        if self.program is not None:
            environment[env.PROGRAM] = self.program
        if self.args is not None:
            environment[env.ARGS] = json.dumps(self.args)
        if self.name_and_description is not None:
            environment[env.DESCRIPTION] = self.name_and_description
        if len(self.tags) > 0:
            environment[env.TAGS] = ",".join(self.tags)

    def _mkdir(self):
        util.mkdir_exists_ok(self._dir)

    def project_name(self, api=None):
        if api is None:
            api = InternalApi()
        return api.settings('project') or self.auto_project_name(
            api) or "uncategorized"

    def get_url(self, api=None):
        api = api or InternalApi()
        if api.api_key:
            if api.settings('entity') is None:
                viewer = api.viewer()
                if viewer.get('entity'):
                    api.set_setting('entity', viewer['entity'])
            if api.settings('entity'):
                return "{base}/{entity}/{project}/runs/{run}".format(
                    base=api.app_url,
                    entity=urllib.parse.quote_plus(api.settings('entity')),
                    project=urllib.parse.quote_plus(self.project_name(api)),
                    run=urllib.parse.quote_plus(self.id))
            else:
                # TODO: I think this could only happen if the api key is invalid
                return "run pending creation, url not known"
        else:
            return "not logged in, run wandb login or set WANDB_API_KEY"

    def upload_debug(self):
        """Uploads the debug log to cloud storage"""
        if os.path.exists(self.log_fname):
            api = InternalApi()
            api.set_current_run_id(self.id)
            pusher = FilePusher(api)
            pusher.update_file("wandb-debug.log", self.log_fname)
            pusher.file_changed("wandb-debug.log", self.log_fname)
            pusher.finish()

    def __repr__(self):
        return "W&B Run: %s" % self.get_url()

    @property
    def name(self):
        """We assume the first line of the description is the name users
        want to use, and we automatically set it to id if the user didn't specify
        """
        return self.name_and_description.split("\n")[0]

    @name.setter
    def name(self, name):
        parts = self.name_and_description.split("\n", 1)
        parts[0] = name
        self.name_and_description = "\n".join(parts)

    @property
    def description(self):
        parts = self.name_and_description.split("\n", 1)
        if len(parts) > 1:
            return parts[1]
        else:
            return ""

    @description.setter
    def description(self, desc):
        parts = self.name_and_description.split("\n", 1)
        if len(parts) == 1:
            parts.append("")
        parts[1] = desc
        self.name_and_description = "\n".join(parts)
        with open(self.description_path, 'w') as d_file:
            d_file.write(self.name_and_description)

    @property
    def host(self):
        return socket.gethostname()

    @property
    def dir(self):
        return self._dir

    @property
    def log_fname(self):
        # TODO: we started work to log to a file in the run dir, but it had issues.
        # For now all logs goto the same place.
        return util.get_log_file_path()

    def enable_logging(self):
        """Enable logging to the global debug log.  This adds a run_id to the log,
        in case of muliple processes on the same machine.

        Currently no way to disable logging after it's enabled.
        """
        handler = logging.FileHandler(self.log_fname)
        handler.setLevel(logging.INFO)
        run_id = self.id

        class WBFilter(logging.Filter):
            def filter(self, record):
                record.run_id = run_id
                return True

        formatter = logging.Formatter(
            '%(asctime)s %(levelname)-7s %(threadName)-10s:%(process)d [%(run_id)s:%(filename)s:%(funcName)s():%(lineno)s] %(message)s'
        )
        handler.setFormatter(formatter)
        handler.addFilter(WBFilter())

        root = logging.getLogger()
        root.addHandler(handler)

    @property
    def summary(self):
        if self._summary is None:
            self._summary = summary.FileSummary(self)
        return self._summary

    @property
    def has_summary(self):
        return self._summary or os.path.exists(
            os.path.join(self._dir, summary.SUMMARY_FNAME))

    def _history_added(self, row):
        if self._summary is None:
            self._summary = summary.FileSummary(self)
        self._summary.update(row, overwrite=False)

    @property
    def history(self):
        if self._history is None:
            jupyter_callback = self._jupyter_agent.start if self._jupyter_agent else None
            self._history = history.History(HISTORY_FNAME,
                                            self._dir,
                                            add_callback=self._history_added,
                                            jupyter_callback=jupyter_callback)
            if self._history._steps > 0:
                self.resumed = True
        return self._history

    @property
    def step(self):
        return self.history._steps

    @property
    def has_history(self):
        return self._history or os.path.exists(
            os.path.join(self._dir, HISTORY_FNAME))

    @property
    def events(self):
        if self._events is None:
            self._events = jsonlfile.JsonlEventsFile(EVENTS_FNAME, self._dir)
        return self._events

    @property
    def has_events(self):
        return self._events or os.path.exists(
            os.path.join(self._dir, EVENTS_FNAME))

    @property
    def description_path(self):
        return os.path.join(self.dir, DESCRIPTION_FNAME)

    def close_files(self):
        """Close open files to avoid Python warnings on termination:

        Exception ignored in: <_io.FileIO name='wandb/dryrun-20180130_144602-9vmqjhgy/wandb-history.jsonl' mode='wb' closefd=True>
        ResourceWarning: unclosed file <_io.TextIOWrapper name='wandb/dryrun-20180130_144602-9vmqjhgy/wandb-history.jsonl' mode='w' encoding='UTF-8'>
        """
        if self._events is not None:
            self._events.close()
            self._events = None
        if self._history is not None:
            self._history.close()
            self._history = None
コード例 #12
0
ファイル: cli.py プロジェクト: mikewlange/client-1
def restore(ctx, run, no_git, branch, project, entity):
    if ":" in run:
        if "/" in run:
            entity, rest = run.split("/", 1)
        else:
            rest = run
        project, run = rest.split(":", 1)
    elif run.count("/") > 1:
        entity, run = run.split("/", 1)

    project, run = api.parse_slug(run, project=project)
    commit, json_config, patch_content, metadata = api.run_config(
        project, run=run, entity=entity)
    repo = metadata.get("git", {}).get("repo")
    image = metadata.get("docker")
    RESTORE_MESSAGE = """`wandb restore` needs to be run from the same git repository as the original run.
Run `git clone %s` and restore from there or pass the --no-git flag.""" % repo
    if no_git:
        commit = None
    elif not api.git.enabled:
        if repo:
            raise ClickException(RESTORE_MESSAGE)
        elif image:
            wandb.termlog(
                "Original run has no git history.  Just restoring config and docker"
            )

    if commit and api.git.enabled:
        subprocess.check_call(['git', 'fetch', '--all'])
        try:
            api.git.repo.commit(commit)
        except ValueError:
            wandb.termlog("Couldn't find original commit: {}".format(commit))
            commit = None
            files = api.download_urls(project, run=run, entity=entity)
            for filename in files:
                if filename.startswith('upstream_diff_') and filename.endswith(
                        '.patch'):
                    commit = filename[len('upstream_diff_'):-len('.patch')]
                    try:
                        api.git.repo.commit(commit)
                    except ValueError:
                        commit = None
                    else:
                        break

            if commit:
                wandb.termlog(
                    "Falling back to upstream commit: {}".format(commit))
                patch_path, _ = api.download_write_file(files[filename])
            else:
                raise ClickException(RESTORE_MESSAGE)
        else:
            if patch_content:
                patch_path = os.path.join(wandb.wandb_dir(), 'diff.patch')
                with open(patch_path, "w") as f:
                    f.write(patch_content)
            else:
                patch_path = None

        branch_name = "wandb/%s" % run
        if branch and branch_name not in api.git.repo.branches:
            api.git.repo.git.checkout(commit, b=branch_name)
            wandb.termlog("Created branch %s" %
                          click.style(branch_name, bold=True))
        elif branch:
            wandb.termlog(
                "Using existing branch, run `git branch -D %s` from master for a clean checkout"
                % branch_name)
            api.git.repo.git.checkout(branch_name)
        else:
            wandb.termlog("Checking out %s in detached mode" % commit)
            api.git.repo.git.checkout(commit)

        if patch_path:
            # we apply the patch from the repository root so git doesn't exclude
            # things outside the current directory
            root = api.git.root
            patch_rel_path = os.path.relpath(patch_path, start=root)
            # --reject is necessary or else this fails any time a binary file
            # occurs in the diff
            # we use .call() instead of .check_call() for the same reason
            # TODO(adrian): this means there is no error checking here
            subprocess.call(['git', 'apply', '--reject', patch_rel_path],
                            cwd=root)
            wandb.termlog("Applied patch")

    # TODO: we should likely respect WANDB_DIR here.
    util.mkdir_exists_ok("wandb")
    config = Config(run_dir="wandb")
    config.load_json(json_config)
    config.persist()
    wandb.termlog("Restored config variables to %s" % config._config_path())
    if image:
        if not metadata["program"].startswith("<") and metadata.get(
                "args") is not None:
            # TODO: we may not want to default to python here.
            runner = util.find_runner(metadata["program"]) or ["python"]
            command = runner + [metadata["program"]] + metadata["args"]
            cmd = " ".join(command)
        else:
            wandb.termlog(
                "Couldn't find original command, just restoring environment")
            cmd = None
        wandb.termlog("Docker image found, attempting to start")
        ctx.invoke(docker, docker_run_args=[image], cmd=cmd)

    return commit, json_config, patch_content, repo, metadata
コード例 #13
0
ファイル: wandb_run.py プロジェクト: rguerrettaz/client
class Run(object):
    def __init__(self,
                 run_id=None,
                 mode=None,
                 dir=None,
                 config=None,
                 sweep_id=None,
                 storage_id=None,
                 description=None,
                 resume=None,
                 program=None,
                 wandb_dir=None):
        # self.id is actually stored in the "name" attribute in GQL
        self.id = run_id if run_id else generate_id()
        self.resume = resume if resume else 'never'
        self.mode = mode if mode else 'run'

        self.program = program
        if not self.program:
            try:
                import __main__
                self.program = __main__.__file__
            except (ImportError, AttributeError):
                # probably `python -c`, an embedded interpreter or something
                self.program = '<python with no main file>'
        self.wandb_dir = wandb_dir

        if dir is None:
            self._dir = run_dir_path(self.id, dry=self.mode == 'dryrun')
        else:
            self._dir = os.path.abspath(dir)
        self._mkdir()

        if config is None:
            self.config = Config()
        else:
            self.config = config

        # this is the GQL ID:
        self.storage_id = storage_id
        # socket server, currently only available in headless mode
        self.socket = None

        if description is not None:
            self.description = description
        # An empty description.md may have been created by RunManager() so it's
        # important that we overwrite empty strings here.
        if not self.description:
            self.description = self.id

        self.sweep_id = sweep_id

        self._history = None
        self._events = None
        self._summary = None
        self._meta = None
        self._user_accessed_summary = False
        self._examples = None

    @classmethod
    def from_environment_or_defaults(cls, environment=None):
        """Create a Run object taking values from the local environment where possible.

        The run ID comes from WANDB_RUN_ID or is randomly generated.
        The run mode ("dryrun", or "run") comes from WANDB_MODE or defaults to "dryrun".
        The run directory comes from WANDB_RUN_DIR or is generated from the run ID.

        The Run will have a .config attribute but its run directory won't be set by
        default.
        """
        if environment is None:
            environment = os.environ
        run_id = environment.get('WANDB_RUN_ID')
        resume = environment.get('WANDB_RESUME')
        storage_id = environment.get('WANDB_RUN_STORAGE_ID')
        mode = environment.get('WANDB_MODE')
        disabled = Api().disabled()
        if not mode and disabled:
            mode = "dryrun"
        elif disabled and mode != "dryrun":
            wandb.termlog(
                "WARNING: WANDB_MODE is set to run, but W&B was disabled.  Run `wandb on` to remove this message"
            )
        elif disabled:
            wandb.termlog(
                'W&B is disabled in this directory.  Run `wandb on` to enable cloud syncing.'
            )

        run_dir = environment.get('WANDB_RUN_DIR')
        sweep_id = environment.get('WANDB_SWEEP_ID')
        program = environment.get('WANDB_PROGRAM')
        wandb_dir = environment.get('WANDB_DIR')
        config = Config.from_environment_or_defaults()
        run = cls(run_id,
                  mode,
                  run_dir,
                  config,
                  sweep_id,
                  storage_id,
                  program=program,
                  wandb_dir=wandb_dir,
                  resume=resume)
        return run

    def save(self,
             id=None,
             program=None,
             summary_metrics=None,
             num_retries=None,
             api=None,
             job_type="train"):
        api = api or Api()
        if api.settings("project") is None:
            raise ValueError("Project must be configured.")
        upsert_result = api.upsert_run(id=id or self.storage_id,
                                       name=self.id,
                                       commit=api.git.last_commit,
                                       project=api.settings("project"),
                                       entity=api.settings("entity"),
                                       config=self.config.as_dict(),
                                       description=self.description,
                                       host=socket.gethostname(),
                                       program_path=program or self.program,
                                       repo=api.git.remote_url,
                                       sweep_name=self.sweep_id,
                                       summary_metrics=summary_metrics,
                                       job_type=job_type,
                                       num_retries=num_retries)
        self.storage_id = upsert_result['id']
        return upsert_result

    def set_environment(self, environment=None):
        """Set environment variables needed to reconstruct this object inside
        a user scripts (eg. in `wandb.init()`).
        """
        if environment is None:
            environment = os.environ
        environment['WANDB_RUN_ID'] = self.id
        environment['WANDB_RESUME'] = self.resume
        if self.storage_id:
            environment['WANDB_RUN_STORAGE_ID'] = self.storage_id
        environment['WANDB_MODE'] = self.mode
        environment['WANDB_RUN_DIR'] = self.dir

        if self.wandb_dir:
            environment['WANDB_DIR'] = self.wandb_dir
        if self.sweep_id is not None:
            environment['WANDB_SWEEP_ID'] = self.sweep_id
        if self.program is not None:
            environment['WANDB_PROGRAM'] = self.program

    def _mkdir(self):
        util.mkdir_exists_ok(self._dir)

    def get_url(self, api=None):
        api = api or Api()
        return "{base}/{entity}/{project}/runs/{run}".format(
            base=api.app_url,
            entity=api.settings('entity'),
            project=api.settings('project'),
            run=self.id)

    def __repr__(self):
        return "W&B Run %s" % self.get_url()

    def _repr_html_(self):
        if self.storage_id:
            url = self.get_url() + "/edit?jupyter=true"
            return '''<iframe src="%s" style="border:none;width:100%%;height:300px">
            </iframe>''' % url
        else:
            return '''Not logged in or configured, see https://docs.wandb.com'''

    @property
    def host(self):
        return socket.gethostname()

    @property
    def dir(self):
        return self._dir

    @property
    def summary(self):
        # If we added summary from history then manually reset before setting again
        if self._user_accessed_summary == False and self._summary is not None:
            self._summary._summary = {}
        # We use this to track whether user has accessed summary
        self._user_accessed_summary = True
        if self._summary is None:
            self._summary = summary.Summary(self._dir)
        return self._summary

    @property
    def has_summary(self):
        return self._summary or os.path.exists(
            os.path.join(self._dir, summary.SUMMARY_FNAME))

    def _history_added(self, row):
        if self._summary is None:
            self._summary = summary.Summary(self._dir)
        if not self._user_accessed_summary:
            self._summary.update(row)

    @property
    def history(self):
        if self._history is None:
            self._history = history.History(HISTORY_FNAME,
                                            self._dir,
                                            add_callback=self._history_added)
        return self._history

    @property
    def has_history(self):
        return self._history or os.path.exists(
            os.path.join(self._dir, HISTORY_FNAME))

    @property
    def events(self):
        if self._events is None:
            self._events = jsonlfile.JsonlEventsFile(EVENTS_FNAME, self._dir)
        return self._events

    @property
    def has_events(self):
        return self._events or os.path.exists(
            os.path.join(self._dir, EVENTS_FNAME))

    @property
    def examples(self):
        if self._examples is None:
            self._examples = typedtable.TypedTable(
                jsonlfile.JsonlFile(EXAMPLES_FNAME, self._dir))
        return self._examples

    @property
    def has_examples(self):
        return self._examples or os.path.exists(
            os.path.join(self._dir, EXAMPLES_FNAME))

    @property
    def description_path(self):
        return os.path.join(self.dir, DESCRIPTION_FNAME)

    @property
    def description(self):
        try:
            with open(self.description_path) as d_file:
                return d_file.read()
        except IOError:
            # TODO(adrian): should probably check specifically for a nonexistant file error
            return None

    @description.setter
    def description(self, description):
        with open(self.description_path, 'w') as d_file:
            d_file.write(description)
        return description

    def close_files(self):
        """Close open files to avoid Python warnings on termination:

        Exception ignored in: <_io.FileIO name='wandb/dryrun-20180130_144602-9vmqjhgy/wandb-history.jsonl' mode='wb' closefd=True>
        ResourceWarning: unclosed file <_io.TextIOWrapper name='wandb/dryrun-20180130_144602-9vmqjhgy/wandb-history.jsonl' mode='w' encoding='UTF-8'>
        """
        if self._events is not None:
            self._events.close()
            self._events = None
        if self._history is not None:
            self._history.close()
            self._history = None
コード例 #14
0
def restore(run, branch, project, entity):
    project, run = api.parse_slug(run, project=project)
    commit, json_config, patch_content = api.run_config(project,
                                                        run=run,
                                                        entity=entity)
    subprocess.check_call(['git', 'fetch', '--all'])

    if commit:
        try:
            api.git.repo.commit(commit)
        except ValueError:
            click.echo("Couldn't find original commit: {}".format(commit))
            commit = None
            files = api.download_urls(project, run=run, entity=entity)
            for filename in files:
                if filename.startswith('upstream_diff_') and filename.endswith(
                        '.patch'):
                    commit = filename[len('upstream_diff_'):-len('.patch')]
                    try:
                        api.git.repo.commit(commit)
                    except ValueError:
                        commit = None
                    else:
                        break

            if commit:
                click.echo(
                    "Falling back to upstream commit: {}".format(commit))
                patch_path, _ = api.download_write_file(files[filename])
            else:
                raise ClickException(
                    "Can't find commit from which to restore code")
        else:
            if patch_content:
                patch_path = os.path.join(wandb.wandb_dir(), 'diff.patch')
                with open(patch_path, "w") as f:
                    f.write(patch_content)
            else:
                patch_path = None

        branch_name = "wandb/%s" % run
        if branch and branch_name not in api.git.repo.branches:
            api.git.repo.git.checkout(commit, b=branch_name)
            click.echo("Created branch %s" %
                       click.style(branch_name, bold=True))
        elif branch:
            click.secho(
                "Using existing branch, run `git branch -D %s` from master for a clean checkout"
                % branch_name,
                fg="red")
            api.git.repo.git.checkout(branch_name)
        else:
            click.secho("Checking out %s in detached mode" % commit)
            api.git.repo.git.checkout(commit)

        if patch_path:
            # we apply the patch from the repository root so git doesn't exclude
            # things outside the current directory
            root = api.git.root
            patch_rel_path = os.path.relpath(patch_path, start=root)
            # --reject is necessary or else this fails any time a binary file
            # occurs in the diff
            # we use .call() instead of .check_call() for the same reason
            # TODO(adrian): this means there is no error checking here
            subprocess.call(['git', 'apply', '--reject', patch_rel_path],
                            cwd=root)
            click.echo("Applied patch")

    config = Config()
    config.load_json(json_config)
    config.persist()
    click.echo("Restored config variables")