예제 #1
0
    def __init__(self, data_or_path, **kwargs):

        if hasattr(data_or_path, 'name'):
            # if the file has a path, we just detect the type and copy it from there
            data_or_path = data_or_path.name

        if hasattr(data_or_path, 'read'):
            if hasattr(data_or_path, 'seek'):
                data_or_path.seek(0)
            object3D = data_or_path.read()

            extension = kwargs.pop("file_type", None)
            if extension == None:
                raise ValueError(
                    "Must pass file type keyword argument when using io objects.")
            if extension not in Object3D.SUPPORTED_TYPES:
                raise ValueError("Object 3D only supports numpy arrays or files of the type: " +
                                 ", ".join(Object3D.SUPPORTED_TYPES))

            tmp_path = os.path.join(MEDIA_TMP.name, util.generate_id() + '.' + extension)
            with open(tmp_path, "w") as f:
                f.write(object3D)

            super(Object3D, self).__init__(tmp_path, is_tmp=True)
        elif isinstance(data_or_path, six.string_types):
            path = data_or_path
            try:
                extension = os.path.splitext(data_or_path)[1][1:]
            except:
                raise ValueError(
                    "File type must have an extension")
            if extension not in Object3D.SUPPORTED_TYPES:
                raise ValueError("Object 3D only supports numpy arrays or files of the type: " +
                                 ", ".join(Object3D.SUPPORTED_TYPES))

            super(Object3D, self).__init__(data_or_path, is_tmp=False)
        elif is_numpy_array(data_or_path):
            data = data_or_path

            if len(data.shape) != 2 or data.shape[1] not in {3, 4, 6}:
                raise ValueError("""The shape of the numpy array must be one of either
                                    [[x y z],       ...] nx3
                                     [x y z c],     ...] nx4 where c is a category with supported range [1, 14]
                                     [x y z r g b], ...] nx4 where is rgb is color""")

            data = data.tolist()
            tmp_path = os.path.join(MEDIA_TMP.name, util.generate_id() + '.pts.json')
            json.dump(data, codecs.open(tmp_path, 'w', encoding='utf-8'),
                      separators=(',', ':'), sort_keys=True, indent=4)
            super(Object3D, self).__init__(tmp_path, is_tmp=True, extension='.pts.json')
        else:
            raise ValueError("data must be a numpy or a file object")
예제 #2
0
 def bind_to_run(self, *args, **kwargs):
     data = self._to_table_json()
     tmp_path = os.path.join(MEDIA_TMP.name, util.generate_id() + ".table.json")
     data = _numpy_arrays_to_lists(data)
     util.json_dump_safer(data, codecs.open(tmp_path, "w", encoding="utf-8"))
     self._set_file(tmp_path, is_tmp=True, extension=".table.json")
     super(Table, self).bind_to_run(*args, **kwargs)
예제 #3
0
    def __init__(self, data_or_path):
        super(Bokeh, self).__init__()
        bokeh = util.get_module("bokeh", required=True)
        if isinstance(data_or_path, str) and os.path.exists(data_or_path):
            with open(data_or_path, "r") as file:
                b_json = json.load(file)
            self.b_obj = bokeh.document.Document.from_json(b_json)
            self._set_file(data_or_path, is_tmp=False, extension=".bokeh.json")
        elif isinstance(data_or_path, bokeh.model.Model):
            _data = bokeh.document.Document()
            _data.add_root(data_or_path)
            # serialize/deserialize pairing followed by sorting attributes ensures
            # that the file's shas are equivalent in subsequent calls
            self.b_obj = bokeh.document.Document.from_json(_data.to_json())
            b_json = self.b_obj.to_json()
            if "references" in b_json["roots"]:
                b_json["roots"]["references"].sort(key=lambda x: x["id"])

            tmp_path = os.path.join(MEDIA_TMP.name, util.generate_id() + ".bokeh.json")
            util.json_dump_safer(b_json, codecs.open(tmp_path, "w", encoding="utf-8"))
            self._set_file(tmp_path, is_tmp=True, extension=".bokeh.json")
        elif not isinstance(data_or_path, bokeh.document.Document):
            raise TypeError(
                "Bokeh constructor accepts Bokeh document/model or path to Bokeh json file"
            )
예제 #4
0
    def __init__(self, data_or_path, caption=None, fps=4, format=None):
        self._fps = fps
        self._format = format or "gif"
        self._width = None
        self._height = None
        self._channels = None
        self._caption = caption
        if self._format not in Video.EXTS:
            raise ValueError("wandb.Video accepts %s formats" %
                             ", ".join(Video.EXTS))

        if isinstance(data_or_path, six.BytesIO):
            filename = os.path.join(MEDIA_TMP.name,
                                    util.generate_id() + '.' + self._format)
            with open(filename, "wb") as f:
                f.write(data_or_path.read())
            super(Video, self).__init__(filename, is_tmp=True)
        elif isinstance(data_or_path, six.string_types):
            _, ext = os.path.splitext(data_or_path)
            ext = ext[1:].lower()
            if ext not in Video.EXTS:
                raise ValueError("wandb.Video accepts %s formats" %
                                 ", ".join(Video.EXTS))
            super(Video, self).__init__(data_or_path, is_tmp=False)
            #ffprobe -v error -select_streams v:0 -show_entries stream=width,height -of csv=p=0 data_or_path
        else:
            if hasattr(data_or_path, "numpy"):  # TF data eager tensors
                self.data = data_or_path.numpy()
            elif is_numpy_array(data_or_path):
                self.data = data_or_path
            else:
                raise ValueError(
                    "wandb.Video accepts a file path or numpy like data as input"
                )
            self.encode()
예제 #5
0
    def __init__(self, data_or_path, sample_rate=None, caption=None):
        """Accepts a path to an audio file or a numpy array of audio data. 
        """
        self._duration = None
        self._sample_rate = sample_rate
        self._caption = caption

        if isinstance(data_or_path, six.string_types):
            super(Audio, self).__init__(data_or_path, is_tmp=False)
        else:
            if sample_rate == None:
                raise ValueError(
                    'Argument "sample_rate" is required when instantiating wandb.Audio with raw data.'
                )

            soundfile = util.get_module(
                "soundfile",
                required=
                'Raw audio requires the soundfile package. To get it, run "pip install soundfile"'
            )

            tmp_path = os.path.join(MEDIA_TMP.name,
                                    util.generate_id() + '.wav')
            soundfile.write(tmp_path, data_or_path, sample_rate)
            self._duration = len(data_or_path) / float(sample_rate)

            super(Audio, self).__init__(tmp_path, is_tmp=True)
예제 #6
0
def process_args(args, config):
    """
    Processes parsed arguments to modify config appropriately.
    :return: modified config or None to exit without running
    """
    if "wandb" in args:
        # Add ray-wandb logger to loggers.
        config.setdefault("loggers", [])
        config["loggers"].extend(
            list(DEFAULT_LOGGERS) + [ray_wandb.WandbLogger])

        # One may specify `wandb_args` or `env_config["wandb"]`
        name = config.get("name", "unknown_name")
        wandb_args = config.get("wandb_args", {})
        wandb_args.setdefault("name", name)
        config.setdefault("env_config", {})
        config["env_config"].setdefault("wandb", wandb_args)

        # Either restore from a run-id generate a new one.
        resume = wandb_args.get("resume", False)
        if ("wandb_resume" in args and args.wandb_resume) or resume:
            wandb_args.setdefault("resume", True)
            ray_wandb.enable_run_resume(wandb_args)
        else:
            wandb_id = util.generate_id()
            wandb_args["id"] = wandb_id

        # Enable logging on workers.
        insert_experiment_mixin(config,
                                ray_wandb.PrepPlotForWandb,
                                prepend_name=False)

    return config
예제 #7
0
    def __init__(self, data_or_path, sample_rate=None, caption=None):
        """Accepts a path to an audio file or a numpy array of audio data."""
        super(Audio, self).__init__()
        self._duration = None
        self._sample_rate = sample_rate
        self._caption = caption

        if isinstance(data_or_path, six.string_types):
            if Audio.path_is_reference(data_or_path):
                self._path = data_or_path
                self._sha256 = hashlib.sha256(
                    data_or_path.encode("utf-8")).hexdigest()
                self._is_tmp = False
            else:
                self._set_file(data_or_path, is_tmp=False)
        else:
            if sample_rate is None:
                raise ValueError(
                    'Argument "sample_rate" is required when instantiating wandb.Audio with raw data.'
                )

            soundfile = util.get_module(
                "soundfile",
                required=
                'Raw audio requires the soundfile package. To get it, run "pip install soundfile"',
            )

            tmp_path = os.path.join(MEDIA_TMP.name,
                                    util.generate_id() + ".wav")
            soundfile.write(tmp_path, data_or_path, sample_rate)
            self._duration = len(data_or_path) / float(sample_rate)

            self._set_file(tmp_path, is_tmp=True)
예제 #8
0
 def create(cls, api, run_id=None, project=None, username=None):
     """Create a run for the given project"""
     run_id = run_id or util.generate_id()
     project = project or api.settings.get("project")
     mutation = gql('''
     mutation upsertRun($project: String, $entity: String, $name: String!) {
         upsertBucket(input: {modelName: $project, entityName: $entity, name: $name}) {
             bucket {
                 project {
                     name
                     entity { name }
                 }
                 id
                 name
             }
             inserted
         }
     }
     ''')
     variables = {'entity': username, 'project': project, 'name': run_id}
     res = api.client.execute(mutation, variable_values=variables)
     res = res['upsertBucket']['bucket']
     return Run(
         api.client, res["project"]["entity"]["name"],
         res["project"]["name"], res["name"], {
             "id": res["id"],
             "config": "{}",
             "systemMetrics": "{}",
             "summaryMetrics": "{}",
             "tags": [],
             "description": None,
             "state": "running"
         })
예제 #9
0
    def encode(self):
        mpy = util.get_module(
            "moviepy.editor",
            required=
            'wandb.Video requires moviepy and imageio when passing raw data.  Install with "pip install moviepy imageio"'
        )
        tensor = self._prepare_video(self.data)
        _, self._height, self._width, self._channels = tensor.shape

        # encode sequence of images into gif string
        clip = mpy.ImageSequenceClip(list(tensor), fps=self._fps)

        filename = os.path.join(MEDIA_TMP.name,
                                util.generate_id() + '.' + self._format)
        try:  # older version of moviepy does not support progress_bar argument.
            if self._format == "gif":
                clip.write_gif(filename, verbose=False, progress_bar=False)
            else:
                clip.write_videofile(filename,
                                     verbose=False,
                                     progress_bar=False)
        except TypeError:
            if self._format == "gif":
                clip.write_gif(filename, verbose=False)
            else:
                clip.write_videofile(filename, verbose=False)
        super(Video, self).__init__(filename, is_tmp=True)
예제 #10
0
    def __init__(self, data_or_path, mode=None, caption=None, grouping=None):
        """
        Accepts numpy array of image data, or a PIL image. The class attempts to infer
        the data format and converts it.

        If grouping is set to a number the interface combines N images.
        """

        self._grouping = grouping
        self._caption = caption
        self._width = None
        self._height = None
        self._image = None

        if isinstance(data_or_path, six.string_types):
            super(Image, self).__init__(data_or_path, is_tmp=False)
        else:
            data = data_or_path

            PILImage = util.get_module(
                "PIL.Image",
                required=
                'wandb.Image needs the PIL package. To get it, run "pip install pillow".'
            )
            if util.is_matplotlib_typename(util.get_full_typename(data)):
                buf = six.BytesIO()
                util.ensure_matplotlib_figure(data).savefig(buf)
                self._image = PILImage.open(buf)
            elif isinstance(data, PILImage.Image):
                self._image = data
            elif util.is_pytorch_tensor_typename(util.get_full_typename(data)):
                vis_util = util.get_module(
                    "torchvision.utils",
                    "torchvision is required to render images")
                if hasattr(data, "requires_grad") and data.requires_grad:
                    data = data.detach()
                data = vis_util.make_grid(data, normalize=True)
                self._image = PILImage.fromarray(
                    data.mul(255).clamp(0,
                                        255).byte().permute(1, 2,
                                                            0).cpu().numpy())
            else:
                if hasattr(data, "numpy"):  # TF data eager tensors
                    data = data.numpy()
                if data.ndim > 2:
                    data = data.squeeze(
                    )  # get rid of trivial dimensions as a convenience
                self._image = PILImage.fromarray(self.to_uint8(data),
                                                 mode=mode
                                                 or self.guess_mode(data))

            self._width, self._height = self._image.size

            tmp_path = os.path.join(MEDIA_TMP.name,
                                    util.generate_id() + '.png')
            self._image.save(tmp_path, transparency=None)
            super(Image, self).__init__(tmp_path, is_tmp=True)
예제 #11
0
def test_launch_full_build_new_image(live_mock_server, test_settings,
                                     mocked_fetchable_git_repo):
    api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings,
                                              load_settings=False)
    random_id = util.generate_id()
    run = launch.run(
        "https://wandb.ai/mock_server_entity/test/runs/1",
        api,
        project=f"new-test-{random_id}",
    )
    assert str(run.get_status()) == "finished"
예제 #12
0
    def __init__(self, data, inject=True):

        if isinstance(data, str):
            self.html = data
        elif hasattr(data, 'read'):
            if hasattr(data, 'seek'):
                data.seek(0)
            self.html = data.read()
        else:
            raise ValueError("data must be a string or an io object")
        if inject:
            self.inject_head()

        tmp_path = os.path.join(MEDIA_TMP.name, util.generate_id() + '.html')
        with open(tmp_path, 'w') as out:
            print(self.html, file=out)

        super(Html, self).__init__(tmp_path, is_tmp=True)
예제 #13
0
    def __init__(self, data, inject=True):
        """Accepts a string or file object containing valid html

        By default we inject a style reset into the doc to make it
        look resonable, passing inject=False will disable it.
        """
        if isinstance(data, str):
            self.html = data
        elif hasattr(data, 'read'):
            if hasattr(data, 'seek'):
                data.seek(0)
            self.html = data.read()
        else:
            raise ValueError("data must be a string or an io object")
        if inject:
            self.inject_head()

        tmp_path = os.path.join(MEDIA_TMP.name, util.generate_id() + '.html')
        with open(tmp_path, 'w') as out:
            print(self.html, file=out)

        super(Html, self).__init__(tmp_path, is_tmp=True)
예제 #14
0
def data_frame_to_json(df, run, key, step):
    """Encode a Pandas DataFrame into the JSON/backend format.

    Writes the data to a file and returns a dictionary that we use to represent
    it in `Summary`'s.

    Arguments:
        df (pandas.DataFrame): The DataFrame. Must not have columns named
            "wandb_run_id" or "wandb_data_frame_id". They will be added to the
            DataFrame here.
        run (wandb_run.Run): The Run the DataFrame is associated with. We need
            this because the information we store on the DataFrame is derived
            from the Run it's in.
        key (str): Name of the DataFrame, ie. the summary key path in which it's
            stored. This is for convenience, so people exploring the
            directory tree can have some idea of what is in the Parquet files.
        step: History step or "summary".

    Returns:
        A dict representing the DataFrame that we can store in summaries or
        histories. This is the format:
        {
            '_type': 'data-frame',
                # Magic field that indicates that this object is a data frame as
                # opposed to a normal dictionary or anything else.
            'id': 'asdf',
                # ID for the data frame that is unique to this Run.
            'format': 'parquet',
                # The file format in which the data frame is stored. Currently can
                # only be Parquet.
            'project': 'wfeas',
                # (Current) name of the project that this Run is in. It'd be
                # better to store the project's ID because we know it'll never
                # change but we don't have that here. We store this just in
                # case because we use the project name in identifiers on the
                # back end.
            'path': 'media/data_frames/sdlk.parquet',
                # Path to the Parquet file in the Run directory.
        }
    """
    pandas = util.get_module("pandas")
    fastparquet = util.get_module("fastparquet")
    if not pandas or not fastparquet:
        raise wandb.Error(
            "Failed to save data frame: unable to import either pandas or fastparquet."
        )

    data_frame_id = util.generate_id()

    df = df.copy()  # we don't want to modify the user's DataFrame instance.

    for col_name, series in df.items():
        for i, val in enumerate(series):
            if isinstance(val, WBValue):
                series.iat[i] = six.text_type(
                    json.dumps(val_to_json(run, key, val, step)))

    # We have to call this wandb_run_id because that name is treated specially by
    # our filtering code
    df['wandb_run_id'] = pandas.Series([six.text_type(run.name)] *
                                       len(df.index),
                                       index=df.index)

    df['wandb_data_frame_id'] = pandas.Series([six.text_type(data_frame_id)] *
                                              len(df.index),
                                              index=df.index)
    frames_dir = os.path.join(run.dir, DATA_FRAMES_SUBDIR)
    util.mkdir_exists_ok(frames_dir)
    path = os.path.join(frames_dir, '{}-{}.parquet'.format(key, data_frame_id))
    fastparquet.write(path, df)

    return {
        'id': data_frame_id,
        '_type': 'data-frame',
        'format': 'parquet',
        'project': run.project_name(),  # we don't have the project ID here
        'entity': run.entity,
        'run': run.id,
        'path': path,
    }
예제 #15
0
파일: wandb_run.py 프로젝트: gampx/client
    def __init__(self,
                 run_id=None,
                 mode=None,
                 dir=None,
                 group=None,
                 job_type=None,
                 config=None,
                 sweep_id=None,
                 storage_id=None,
                 description=None,
                 resume=None,
                 program=None,
                 args=None,
                 wandb_dir=None,
                 tags=None,
                 name=None,
                 notes=None,
                 api=None):
        """Create a Run.

        Arguments:
            description (str): This is the old, deprecated style of description: the run's
                name followed by a newline, followed by multiline notes.
        """
        # self.storage_id is "id" in GQL.
        self.storage_id = storage_id
        # self.id is "name" in GQL.
        self.id = run_id if run_id else util.generate_id()
        # self._name is  "display_name" in GQL.
        self._name = None
        self.notes = None

        self.resume = resume if resume else 'never'
        self.mode = mode if mode else 'run'
        self.group = group
        self.job_type = job_type
        self.pid = os.getpid()
        self.resumed = False  # we set resume when history is first accessed
        if api:
            if api.current_run_id and api.current_run_id != self.id:
                raise RuntimeError(
                    'Api object passed to run {} is already being used by run {}'
                    .format(self.id, api.current_run_id))
            else:
                api.set_current_run_id(self.id)
        self._api = api

        if dir is None:
            self._dir = run_dir_path(self.id, dry=self.mode == 'dryrun')
        else:
            self._dir = os.path.abspath(dir)
        self._mkdir()

        # self.name and self.notes used to be combined into a single field.
        # Now if name and notes don't have their own values, we get them from
        # self._name_and_description, but we don't update description.md
        # if they're changed. This is to discourage relying on self.description
        # and self._name_and_description so that we can drop them later.
        #
        # This needs to be set before name and notes because name and notes may
        # influence it. They have higher precedence.
        self._name_and_description = None
        if description:
            wandb.termwarn(
                'Run.description is deprecated. Please use wandb.init(notes="long notes") instead.'
            )
            self._name_and_description = description
        elif os.path.exists(self.description_path):
            with open(self.description_path) as d_file:
                self._name_and_description = d_file.read()

        if name is not None:
            self.name = name
        if notes is not None:
            self.notes = notes

        self.program = program
        if not self.program:
            try:
                import __main__
                self.program = __main__.__file__
            except (ImportError, AttributeError):
                # probably `python -c`, an embedded interpreter or something
                self.program = '<python with no main file>'
        self.args = args
        if self.args is None:
            self.args = sys.argv[1:]
        self.wandb_dir = wandb_dir

        with configure_scope() as scope:
            self.project = self.api.settings("project")
            scope.set_tag("project", self.project)
            scope.set_tag("entity", self.entity)
            try:
                scope.set_tag("url", self.get_url(self.api, network=False)
                              )  # TODO: Move this somewhere outside of init
            except CommError:
                pass

        if self.resume == "auto":
            util.mkdir_exists_ok(wandb.wandb_dir())
            resume_path = os.path.join(wandb.wandb_dir(), RESUME_FNAME)
            with open(resume_path, "w") as f:
                f.write(json.dumps({"run_id": self.id}))

        if config is None:
            self.config = Config()
        else:
            self.config = config

        # socket server, currently only available in headless mode
        self.socket = None

        self.tags = tags if tags else []

        self.sweep_id = sweep_id

        self._history = None
        self._events = None
        self._summary = None
        self._meta = None
        self._run_manager = None
        self._jupyter_agent = None
예제 #16
0
파일: wandb_run.py 프로젝트: gampx/client
    def from_directory(cls,
                       directory,
                       project=None,
                       entity=None,
                       run_id=None,
                       api=None,
                       ignore_globs=None):
        api = api or InternalApi()
        run_id = run_id or util.generate_id()
        run = Run(run_id=run_id, dir=directory)

        run_name = None
        project_from_meta = None
        snap = DirectorySnapshot(directory)
        meta = next((p for p in snap.paths if METADATA_FNAME in p), None)
        if meta:
            meta = json.load(open(meta))
            run_name = meta.get("name")
            project_from_meta = meta.get("project")

        project = project or project_from_meta or api.settings(
            "project") or run.auto_project_name(api=api)
        if project is None:
            raise ValueError("You must specify project")
        api.set_current_run_id(run_id)
        api.set_setting("project", project)
        if entity:
            api.set_setting("entity", entity)
        res = api.upsert_run(name=run_id,
                             project=project,
                             entity=entity,
                             display_name=run_name)
        entity = res["project"]["entity"]["name"]
        wandb.termlog("Syncing {} to:".format(directory))
        try:
            wandb.termlog(res["displayName"] + " " + run.get_url(api))
        except CommError as e:
            wandb.termwarn(e.message)

        file_api = api.get_file_stream_api()
        file_api.start()
        paths = [
            os.path.relpath(abs_path, directory) for abs_path in snap.paths
            if os.path.isfile(abs_path)
        ]
        if ignore_globs:
            paths = set(paths)
            for g in ignore_globs:
                paths = paths - set(fnmatch.filter(paths, g))
            paths = list(paths)
        run_update = {"id": res["id"]}
        tfevents = sorted([p for p in snap.paths if ".tfevents." in p])
        history = next((p for p in snap.paths if HISTORY_FNAME in p), None)
        event = next((p for p in snap.paths if EVENTS_FNAME in p), None)
        config = next((p for p in snap.paths if CONFIG_FNAME in p), None)
        user_config = next((p for p in snap.paths if USER_CONFIG_FNAME in p),
                           None)
        summary = next((p for p in snap.paths if SUMMARY_FNAME in p), None)
        if history:
            wandb.termlog("Uploading history metrics")
            file_api.stream_file(history)
            snap.paths.remove(history)
        elif len(tfevents) > 0:
            from wandb import tensorflow as wbtf
            wandb.termlog("Found tfevents file, converting...")
            summary = {}
            for path in tfevents:
                filename = os.path.basename(path)
                namespace = path.replace(filename,
                                         "").replace(directory,
                                                     "").strip(os.sep)
                summary.update(
                    wbtf.stream_tfevents(path,
                                         file_api,
                                         run,
                                         namespace=namespace))
            for path in glob.glob(os.path.join(directory, "media/**/*"),
                                  recursive=True):
                if os.path.isfile(path):
                    paths.append(path)
        else:
            wandb.termerror(
                "No history or tfevents files found, only syncing files")
        if event:
            file_api.stream_file(event)
            snap.paths.remove(event)
        if config:
            run_update["config"] = util.load_yaml(open(config))
        elif user_config:
            # TODO: half backed support for config.json
            run_update["config"] = {
                k: {
                    "value": v
                }
                for k, v in six.iteritems(user_config)
            }
        if isinstance(summary, dict):
            #TODO: summary should already have data_types converted here...
            run_update["summary_metrics"] = util.json_dumps_safer(summary)
        elif summary:
            run_update["summary_metrics"] = open(summary).read()
        if meta:
            if meta.get("git"):
                run_update["commit"] = meta["git"].get("commit")
                run_update["repo"] = meta["git"].get("remote")
            if meta.get("host"):
                run_update["host"] = meta["host"]
            run_update["program_path"] = meta["program"]
            run_update["job_type"] = meta.get("jobType")
            run_update["notes"] = meta.get("notes")
        else:
            run_update["host"] = run.host

        wandb.termlog("Updating run and uploading files")
        api.upsert_run(**run_update)
        pusher = FilePusher(api)
        for k in paths:
            path = os.path.abspath(os.path.join(directory, k))
            pusher.update_file(k, path)
            pusher.file_changed(k, path)
        pusher.finish()
        pusher.print_status()
        file_api.finish(0)
        # Remove temporary media images generated from tfevents
        if history is None and os.path.exists(os.path.join(directory,
                                                           "media")):
            shutil.rmtree(os.path.join(directory, "media"))
        wandb.termlog("Finished!")
        return run
예제 #17
0
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--notes", help="Notes for the runs")
    return parser.parse_args()


################ Submit code ################

if __name__ == "__main__":

    args = parse_args()

    NB_JOBS = 3

    # Create job ids
    job_ids = [generate_id() for _ in range(NB_JOBS)]

    # Create random seeds
    random.seed(42)
    job_seeds = [str(random.randint(0, 1000)) for _ in range(NB_JOBS)]

    # Dumb the job ids in jobs.yaml for future parsing
    with open('jobs.yaml', 'w') as buf:
        buf.write(yaml.safe_dump({'jobs_ids': job_ids}, default_flow_style=False))

    # Create the process list and execute the command
    process_list = []

    # Progress bar
    t = tqdm(range(NB_JOBS), desc='Sending job', leave=True)
    for index_run in t:
예제 #18
0
    def __init__(self,
                 run_id=None,
                 mode=None,
                 dir=None,
                 group=None,
                 job_type=None,
                 config=None,
                 sweep_id=None,
                 storage_id=None,
                 description=None,
                 resume=None,
                 program=None,
                 args=None,
                 wandb_dir=None,
                 tags=None):
        # self.id is actually stored in the "name" attribute in GQL
        self.id = run_id if run_id else util.generate_id()
        self.display_name = self.id
        self.resume = resume if resume else 'never'
        self.mode = mode if mode else 'run'
        self.group = group
        self.job_type = job_type
        self.pid = os.getpid()
        self.resumed = False  # we set resume when history is first accessed

        self.program = program
        if not self.program:
            try:
                import __main__
                self.program = __main__.__file__
            except (ImportError, AttributeError):
                # probably `python -c`, an embedded interpreter or something
                self.program = '<python with no main file>'
        self.args = args
        if self.args is None:
            self.args = sys.argv[1:]
        self.wandb_dir = wandb_dir

        with configure_scope() as scope:
            api = InternalApi()
            self.project = api.settings("project")
            self.entity = api.settings("entity")
            scope.set_tag("project", self.project)
            scope.set_tag("entity", self.entity)
            scope.set_tag("url", self.get_url(api))

        if dir is None:
            self._dir = run_dir_path(self.id, dry=self.mode == 'dryrun')
        else:
            self._dir = os.path.abspath(dir)
        self._mkdir()

        if self.resume == "auto":
            util.mkdir_exists_ok(wandb.wandb_dir())
            resume_path = os.path.join(wandb.wandb_dir(), RESUME_FNAME)
            with open(resume_path, "w") as f:
                f.write(json.dumps({"run_id": self.id}))

        if config is None:
            self.config = Config()
        else:
            self.config = config

        # this is the GQL ID:
        self.storage_id = storage_id
        # socket server, currently only available in headless mode
        self.socket = None

        self.name_and_description = ""
        if description is not None:
            self.name_and_description = description
        elif os.path.exists(self.description_path):
            with open(self.description_path) as d_file:
                self.name_and_description = d_file.read()

        self.tags = tags if tags else []

        self.sweep_id = sweep_id

        self._history = None
        self._events = None
        self._summary = None
        self._meta = None
        self._run_manager = None
        self._jupyter_agent = None
예제 #19
0
def process_args(args, config):
    """
    Processes parsed arguments to modify config appropriately.

    This returns None when `create_sigopt` is included in the args
    signifying there is nothing to run.

    :return: modified config or None
    """

    if "profile" in args and args.profile:
        insert_experiment_mixin(config, mixins.Profile)

    if "profile_autograd" in args and args.profile_autograd:
        insert_experiment_mixin(config, mixins.ProfileAutograd)

    if "copy_checkpoint_dir" in args:
        config["copy_checkpoint_dir"] = args.copy_checkpoint_dir
        insert_experiment_mixin(config,
                                mixins.SaveFinalCheckpoint,
                                prepend_name=False)

    if "wandb" in args and args.wandb:

        from wandb import util
        from nupic.research.frameworks.wandb import ray_wandb
        from nupic.research.frameworks.vernon.ray_custom_loggers import (
            DEFAULT_LOGGERS)

        # Add ray-wandb logger to loggers.
        config.setdefault("loggers", [])
        config["loggers"].extend(
            list(DEFAULT_LOGGERS) + [ray_wandb.WandbLogger])

        # One may specify `wandb_args` or `env_config["wandb"]`
        name = config.get("name", "unknown_name")
        wandb_args = config.get("wandb_args", {})
        wandb_args.setdefault("name", name)
        config.setdefault("env_config", {})
        config["env_config"].setdefault("wandb", wandb_args)

        # Either restore from a run-id generate a new one.
        resume = wandb_args.get("resume", False)
        if ("wandb_resume" in args and args.wandb_resume) or resume:
            wandb_args.setdefault("resume", True)
            ray_wandb.enable_run_resume(wandb_args)
        else:
            wandb_id = util.generate_id()
            wandb_args["id"] = wandb_id

        # Enable logging on workers.
        insert_experiment_mixin(config,
                                ray_wandb.WorkerLogger,
                                prepend_name=False)

    if "create_sigopt" in args:

        from nupic.research.frameworks.sigopt import SigOptExperiment

        s = SigOptExperiment()
        s.create_experiment(config["sigopt_config"])
        print("Created experiment: https://app.sigopt.com/experiment/",
              s.experiment_id)
        return

    return config
예제 #20
0
    def from_directory(cls, directory, project=None, entity=None, run_id=None, api=None):
        api = api or InternalApi()
        run_id = run_id or util.generate_id()
        run = Run(run_id=run_id, dir=directory)
        project = project or api.settings(
            "project") or run.auto_project_name(api=api)
        if project is None:
            raise ValueError("You must specify project")
        api.set_current_run_id(run_id)
        api.set_setting("project", project)
        if entity:
            api.set_setting("entity", entity)
        res = api.upsert_run(name=run_id, project=project, entity=entity)
        entity = res["project"]["entity"]["name"]
        wandb.termlog("Syncing {} to:".format(directory))
        wandb.termlog(run.get_url(api))

        file_api = api.get_file_stream_api()
        snap = DirectorySnapshot(directory)
        paths = [os.path.relpath(abs_path, directory)
                 for abs_path in snap.paths if os.path.isfile(abs_path)]
        run_update = {"id": res["id"]}
        tfevents = sorted([p for p in snap.paths if ".tfevents." in p])
        history = next((p for p in snap.paths if HISTORY_FNAME in p), None)
        event = next((p for p in snap.paths if EVENTS_FNAME in p), None)
        config = next((p for p in snap.paths if CONFIG_FNAME in p), None)
        user_config = next(
            (p for p in snap.paths if USER_CONFIG_FNAME in p), None)
        summary = next((p for p in snap.paths if SUMMARY_FNAME in p), None)
        meta = next((p for p in snap.paths if METADATA_FNAME in p), None)
        if history:
            wandb.termlog("Uploading history metrics")
            file_api.stream_file(history)
            snap.paths.remove(history)
        elif len(tfevents) > 0:
            from wandb import tensorflow as wbtf
            wandb.termlog("Found tfevents file, converting.")
            for file in tfevents:
                summary = wbtf.stream_tfevents(file, file_api)
        else:
            wandb.termerror(
                "No history or tfevents files found, only syncing files")
        if event:
            file_api.stream_file(event)
            snap.paths.remove(event)
        if config:
            run_update["config"] = util.load_yaml(
                open(config))
        elif user_config:
            # TODO: half backed support for config.json
            run_update["config"] = {k: {"value": v}
                                    for k, v in six.iteritems(user_config)}
        if summary:
            run_update["summary_metrics"] = open(summary).read()
        if meta:
            meta = json.load(open(meta))
            if meta.get("git"):
                run_update["commit"] = meta["git"].get("commit")
                run_update["repo"] = meta["git"].get("remote")
            run_update["host"] = meta["host"]
            run_update["program_path"] = meta["program"]
            run_update["job_type"] = meta.get("jobType")
        else:
            run_update["host"] = socket.gethostname()

        wandb.termlog("Updating run and uploading files")
        api.upsert_run(**run_update)
        pusher = FilePusher(api)
        for k in paths:
            path = os.path.abspath(os.path.join(directory, k))
            pusher.update_file(k, path)
            pusher.file_changed(k, path)
        pusher.finish()
        pusher.print_status()
        wandb.termlog("Finished!")
        return run