Beispiel #1
0
 def update(self, value: float, label: str) -> None:
     if self._disabled:
         return
     try:
         self._progress.value = value
         self._label.value = label
         if not self._displayed:
             self._displayed = True
             display_widget(self._widget)
     except Exception as e:
         self._disabled = True
         logger.exception(e)
         wandb.termwarn(
             "Unable to render progress bar, see the user log for details"
         )
Beispiel #2
0
def _magic_init(**kwargs):
    magic_arg = kwargs.get("magic", None)
    if magic_arg is not None and magic_arg is not False:
        global _magic_init_seen
        if _magic_init_seen and magic_arg is not True:
            wandb.termwarn(
                "wandb.init() magic argument ignored because wandb magic has already been initialized",
                repeat=False,
            )
        _magic_init_seen = True
    else:
        wandb.termwarn(
            "wandb.init() arguments ignored because wandb magic has already been initialized",
            repeat=False,
        )
Beispiel #3
0
def check_environ():
    """Warn about WANDB_ environment variables the user has set

    Sometimes it's useful to set things like WANDB_DEBUG intentionally, or
    set other things for hacky debugging, but we want to make sure the user
    knows about it.
    """
    # we ignore WANDB_DESCRIPTION because we set it intentionally in
    # pytest_runtest_setup()
    wandb_keys = [key for key in os.environ.keys() if key.startswith(
        'WANDB_') and key not in ['WANDB_TEST', 'WANDB_DESCRIPTION']]
    if wandb_keys:
        wandb.termwarn('You have WANDB_ environment variable(s) set. These may interfere with tests:')
        for key in wandb_keys:
            wandb.termwarn('    {} = {}'.format(key, repr(os.environ[key])))
Beispiel #4
0
 def _parse_settings(settings):
     """settings could be json or comma seperated assignments."""
     ret = {}
     # TODO(jhr): merge with magic_impl:_parse_magic
     if settings.find('=') > 0:
         for item in settings.split(","):
             kv = item.split("=")
             if len(kv) != 2:
                 wandb.termwarn(
                     "Unable to parse sweep settings key value pair",
                     repeat=False)
             ret.update(dict([kv]))
         return ret
     wandb.termwarn("Unable to parse settings parameter", repeat=False)
     return ret
Beispiel #5
0
 def learning_curve_table(train, test, trainsize):
     data=[]
     for i in range(len(train)):
         if i >= chart_limit/2:
             wandb.termwarn("wandb uses only the first %d datapoints to create the plots."% wandb.Table.MAX_ROWS)
             break
         train_set = ["train", round(train[i],2), trainsize[i]]
         test_set = ["test", round(test[i],2), trainsize[i]]
         data.append(train_set)
         data.append(test_set)
     return wandb.visualize(
         'wandb/learning_curve/v1', wandb.Table(
         columns=['dataset', 'score', 'train_size'],
         data=data
     ))
Beispiel #6
0
def get_wandb_dir(root_dir):
    # We use the hidden version if it already exists, otherwise non-hidden.
    if os.path.exists(os.path.join(root_dir, ".wandb")):
        __stage_dir__ = ".wandb" + os.sep
    else:
        __stage_dir__ = "wandb" + os.sep

    path = os.path.join(root_dir, __stage_dir__)
    if not os.access(root_dir or ".", os.W_OK):
        wandb.termwarn("Path %s wasn't writable, using system temp directory" %
                       path)
        path = os.path.join(tempfile.gettempdir(), __stage_dir__
                            or ("wandb" + os.sep))

    return path
Beispiel #7
0
def _login(
    anonymous=None,
    key=None,
    relogin=None,
    host=None,
    force=None,
    _backend=None,
    _silent=None,
):
    kwargs = dict(locals())

    if wandb.run is not None:
        wandb.termwarn("Calling wandb.login() after wandb.init() has no effect.")
        return True

    wlogin = _WandbLogin()

    _backend = kwargs.pop("_backend", None)
    if _backend:
        wlogin.set_backend(_backend)

    _silent = kwargs.pop("_silent", None)
    if _silent:
        wlogin.set_silent(_silent)

    # configure login object
    wlogin.setup(kwargs)

    if wlogin._settings._offline:
        return False

    # perform a login
    logged_in = wlogin.login()

    key = kwargs.get("key")
    if key:
        wlogin.configure_api_key(key)

    if logged_in:
        return logged_in

    if not key:
        wlogin.prompt_api_key()

    # make sure login credentials get to the backend
    wlogin.propogate_login()

    return wlogin._key or False
Beispiel #8
0
    def from_environment_or_defaults(cls, environment=None):
        """Create a Run object taking values from the local environment where possible.

        The run ID comes from WANDB_RUN_ID or is randomly generated.
        The run mode ("dryrun", or "run") comes from WANDB_MODE or defaults to "dryrun".
        The run directory comes from WANDB_RUN_DIR or is generated from the run ID.

        The Run will have a .config attribute but its run directory won't be set by
        default.
        """
        if environment is None:
            environment = os.environ
        run_id = environment.get(env.RUN_ID)
        resume = environment.get(env.RESUME)
        storage_id = environment.get(env.RUN_STORAGE_ID)
        mode = environment.get(env.MODE)
        api = InternalApi(environ=environment)
        disabled = api.disabled()
        if not mode and disabled:
            mode = "dryrun"
        elif disabled and mode != "dryrun":
            wandb.termwarn(
                "WANDB_MODE is set to run, but W&B was disabled.  Run `wandb on` to remove this message")
        elif disabled:
            wandb.termlog(
                'W&B is disabled in this directory.  Run `wandb on` to enable cloud syncing.')

        group = environment.get(env.RUN_GROUP)
        job_type = environment.get(env.JOB_TYPE)
        run_dir = environment.get(env.RUN_DIR)
        sweep_id = environment.get(env.SWEEP_ID)
        program = environment.get(env.PROGRAM)
        description = environment.get(env.DESCRIPTION)
        name = environment.get(env.NAME)
        notes = environment.get(env.NOTES)
        args = env.get_args(env=environment)
        wandb_dir = env.get_dir(env=environment)
        tags = env.get_tags(env=environment)
        # TODO(adrian): should pass environment into here as well.
        config = Config.from_environment_or_defaults()
        run = cls(run_id, mode, run_dir,
                  group, job_type, config,
                  sweep_id, storage_id, program=program, description=description,
                  args=args, wandb_dir=wandb_dir, tags=tags,
                  name=name, notes=notes,
                  resume=resume, api=api)

        return run
Beispiel #9
0
def calibration_curve(clf=None, X=None, y=None, clf_name="Classifier"):
    """Logs a plot depicting how well-calibrated the predicted probabilities of a classifier are.

    Also suggests how to calibrate an uncalibrated classifier. Compares estimated predicted
    probabilities by a baseline logistic regression model, the model passed as
    an argument, and by both its isotonic calibration and sigmoid calibrations.
    The closer the calibration curves are to a diagonal the better.
    A sine wave like curve represents an overfitted classifier, while a cosine
    wave like curve represents an underfitted classifier.
    By training isotonic and sigmoid calibrations of the model and comparing
    their curves we can figure out whether the model is over or underfitting and
    if so which calibration (sigmoid or isotonic) might help fix this.
    For more details, see https://scikit-learn.org/stable/auto_examples/calibration/plot_calibration_curve.html.

    Should only be called with a fitted classifer (otherwise an error is thrown).

    Please note this function fits variations of the model on the training set when called.

    Arguments:
        clf: (clf) Takes in a fitted classifier.
        X: (arr) Training set features.
        y: (arr) Training set labels.
        model_name: (str) Model name. Defaults to 'Classifier'

    Returns:
        None: To see plots, go to your W&B run page then expand the 'media' tab
              under 'auto visualizations'.

    Example:
    ```python
        wandb.sklearn.plot_calibration_curve(clf, X, y, 'RandomForestClassifier')
    ```
    """
    not_missing = utils.test_missing(clf=clf, X=X, y=y)
    correct_types = utils.test_types(clf=clf, X=X, y=y)
    is_fitted = utils.test_fitted(clf)
    if not_missing and correct_types and is_fitted:
        y = np.asarray(y)
        if y.dtype.char == "U" or not ((y == 0) | (y == 1)).all():
            wandb.termwarn(
                "This function only supports binary classification at the moment and therefore expects labels to be binary. Skipping calibration curve."
            )
            return

        calibration_curve_chart = calculate.calibration_curves(
            clf, X, y, clf_name)

        wandb.log({"calibration_curve": calibration_curve_chart})
Beispiel #10
0
    def run(self, launch_project: LaunchProject) -> Optional[AbstractRun]:
        _logger.info("Validating docker installation")
        validate_docker_installation()
        synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
        docker_args: Dict[str, Any] = self.backend_config[PROJECT_DOCKER_ARGS]
        entry_point = launch_project.get_single_entry_point()

        if launch_project.docker_image:
            # user has provided their own docker image
            _logger.info("Pulling user provided docker image")
            pull_docker_image(launch_project.docker_image)

            wandb.termwarn(
                "Using supplied docker image: {}. Artifact swapping and launch metadata disabled"
                .format(launch_project.docker_image))
            image_uri = launch_project.docker_image

        else:
            # build our own image
            image_uri = construct_local_image_uri(launch_project)
            generate_docker_image(
                self._api,
                launch_project,
                image_uri,
                entry_point,
                docker_args,
                runner_type="local",
            )

        if self.backend_config.get("runQueueItemId"):
            try:
                _logger.info("Acking run queue item...")
                self._api.ack_run_queue_item(
                    self.backend_config["runQueueItemId"],
                    launch_project.run_id)
            except CommError:
                wandb.termerror(
                    "Error acking run queue item. Item lease may have ended or another process may have acked it."
                )
                return None

        command_str = " ".join(get_docker_command(image_uri, docker_args))
        wandb.termlog("Launching run in docker with command: {}".format(
            sanitize_wandb_api_key(command_str)))
        run = _run_entry_point(command_str, launch_project.project_dir)
        if synchronous:
            run.wait()
        return run
Beispiel #11
0
    def add(self, row={}, step=None):
        """Adds or updates a history step.

        If row isn't specified, will write the current state of row.

        If step is specified, the row will be written only when add() is called with
        a different step value.

        run.history.row["duration"] = 1.0
        run.history.add({"loss": 1})
        => {"duration": 1.0, "loss": 1}

        """
        if not isinstance(row, collections.Mapping):
            raise wandb.Error('history.add expects dict-like object')

        if step is None:
            self.update(row)
            if not self.batched:
                self._write()
        else:
            if not isinstance(step, numbers.Number):
                raise wandb.Error("Step must be a number, not {}".format(step))
            else:
                if step != round(step):
                    # tensorflow just applies `int()`. seems a little crazy.
                    wandb.termwarn(
                        'Non-integer history step: {}; rounding.'.format(step))

                # the backend actually handles floats right now. seems a bit weird to let those through though.
                step = int(round(step))

                if step < self._steps:
                    wandb.termwarn(
                        "Adding to old History rows isn't currently supported.  Step {} < {}; dropping {}."
                        .format(step, self._steps, row))
                    return
                elif step == self._steps:
                    pass
                elif self.batched:
                    raise wandb.Error(
                        "Can't log to a particular History step ({}) while in batched mode."
                        .format(step))
                else:  # step > self._steps
                    self._write()
                    self._steps = step

            self.update(row)
Beispiel #12
0
    def _log_gradients(self):
        if (not self.training_data):
            raise ValueError(
                "Need to pass in training data if logging gradients")

        X_train = self.training_data[0]
        y_train = self.training_data[1]
        metrics = {}
        weights = self.model.trainable_weights  # weight tensors
        # filter down weights tensors to only ones which are trainable
        weights = [
            weight for weight in weights
            if self.model.get_layer(weight.name.split('/')[0]).trainable
        ]

        gradients = self.model.optimizer.get_gradients(
            self.model.total_loss, weights)  # gradient tensors
        if hasattr(self.model, "targets"):
            # TF < 1.14
            target = self.model.targets[0]
            sample_weight = self.model.sample_weights[0]
        elif hasattr(self.model, "_training_endpoints") and len(
                self.model._training_endpoints) > 0:
            # TF > 1.14 TODO: not sure if we're handling sample_weight properly here...
            target = self.model._training_endpoints[0].training_target.target
            sample_weight = self.model._training_endpoints[
                0].sample_weight or K.variable(1)
        else:
            wandb.termwarn(
                "Couldn't extract gradients from your model, this could be an unsupported version of keras.  File an issue here: https://github.com/wandb/client",
                repeat=False)
            return metrics
        input_tensors = [
            self.model.inputs[0],  # input data
            # how much to weight each sample by
            sample_weight,
            target,  # labels
            K.learning_phase(),  # train or test mode
        ]

        get_gradients = K.function(inputs=input_tensors, outputs=gradients)
        grads = get_gradients([X_train, np.ones(len(y_train)), y_train])

        for (weight, grad) in zip(weights, grads):
            metrics["gradients/" + weight.name.split(':')[0] +
                    ".gradient"] = wandb.Histogram(grad)

        return metrics
Beispiel #13
0
 def _attempt_evaluation_log(self, commit=True):
     if self.log_evaluation and self._validation_data_logger:
         try:
             if not self.model:
                 wandb.termwarn(
                     "WandbCallback unable to read model from trainer")
             else:
                 self._validation_data_logger.log_predictions(
                     predictions=self._validation_data_logger.
                     make_predictions(self.model.predict),
                     commit=commit,
                 )
                 self._model_trained_since_last_eval = False
         except Exception as e:
             wandb.termwarn("Error durring prediction logging for epoch: " +
                            str(e))
Beispiel #14
0
 def _validate_and_fix_spec_project_entity(
     self, launch_spec: Dict[str, Any]
 ) -> None:
     """Checks if launch spec target project/entity differs from agent. Forces these values to agent's if they are set."""
     if (
         launch_spec.get("project") is not None
         and launch_spec.get("project") != self._project
     ) or (
         launch_spec.get("entity") is not None
         and launch_spec.get("entity") != self._entity
     ):
         wandb.termwarn(
             f"Launch agents only support sending runs to their own project and entity. This run will be sent to {self._entity}/{self._project}"
         )
         launch_spec["entity"] = self._entity
         launch_spec["project"] = self._project
Beispiel #15
0
def json_friendly(obj):
    """Convert an object into something that's more becoming of JSON"""
    converted = True
    typename = get_full_typename(obj)

    if is_tf_eager_tensor_typename(typename):
        obj = obj.numpy()
    elif is_tf_tensor_typename(typename):
        try:
            obj = obj.eval()
        except RuntimeError:
            obj = obj.numpy()
    elif is_pytorch_tensor_typename(typename):
        try:
            if obj.requires_grad:
                obj = obj.detach()
        except AttributeError:
            pass  # before 0.4 is only present on variables

        try:
            obj = obj.data
        except RuntimeError:
            pass  # happens for Tensors before 0.4

        if obj.size():
            obj = obj.numpy()
        else:
            return obj.item(), True

    if is_numpy_array(obj):
        if obj.size == 1:
            obj = obj.flatten()[0]
        elif obj.size <= 32:
            obj = obj.tolist()
    elif np and isinstance(obj, np.generic):
        obj = obj.item()
    elif isinstance(obj, bytes):
        obj = obj.decode('utf-8')
    elif isinstance(obj, (datetime, date)):
        obj = obj.isoformat()
    else:
        converted = False
    if getsizeof(obj) > VALUE_BYTES_LIMIT:
        wandb.termwarn("Serializing object of type {} that is {} bytes".format(
            type(obj).__name__, getsizeof(obj)))

    return obj, converted
Beispiel #16
0
    def on_epoch_end(self, epoch, logs={}):
        if self.log_weights:
            wandb.log(self._log_weights(), commit=False)

        if self.log_gradients:
            wandb.log(self._log_gradients(), commit=False)

        if self.input_type in (
                "image",
                "images",
                "segmentation_mask",
        ) or self.output_type in ("image", "images", "segmentation_mask"):
            if self.generator:
                self.validation_data = next(self.generator)
            if self.validation_data is None:
                wandb.termwarn(
                    "No validation_data set, pass a generator to the callback."
                )
            elif self.validation_data and len(self.validation_data) > 0:
                wandb.log(
                    {
                        "examples":
                        self._log_images(num_images=self.predictions)
                    },
                    commit=False,
                )

        if (self._log_evaluation_frequency > 0
                and epoch % self._log_evaluation_frequency == 0):
            self._attempt_evaluation_log(commit=False)

        wandb.log({"epoch": epoch}, commit=False)
        wandb.log(logs, commit=True)

        self.current = logs.get(self.monitor)
        if self.current and self.monitor_op(self.current, self.best):
            if self.log_best_prefix:
                wandb.run.summary["%s%s" % (self.log_best_prefix,
                                            self.monitor)] = self.current
                wandb.run.summary["%s%s" %
                                  (self.log_best_prefix, "epoch")] = epoch
                if self.verbose and not self.save_model:
                    print("Epoch %05d: %s improved from %0.5f to %0.5f" %
                          (epoch, self.monitor, self.best, self.current))
            if self.save_model:
                self._save_model(epoch)
            self.best = self.current
Beispiel #17
0
    def _log_dataframe(self):
        x, y_true, y_pred = None, None, None

        if self.validation_data:
            x, y_true = self.validation_data[0], self.validation_data[1]
            y_pred = self.model.predict(x)
        elif self.generator:
            if not self.validation_steps:
                wandb.termwarn(
                    "when using a generator for validation data with dataframes, you must pass validation_steps. skipping"
                )
                return None

            for i in range(self.validation_steps):
                bx, by_true = next(self.generator)
                by_pred = self.model.predict(bx)
                if x is None:
                    x, y_true, y_pred = bx, by_true, by_pred
                else:
                    x, y_true, y_pred = (
                        np.append(x, bx, axis=0),
                        np.append(y_true, by_true, axis=0),
                        np.append(y_pred, by_pred, axis=0),
                    )

        if self.input_type in ("image", "images") and self.output_type == "label":
            return wandb.image_categorizer_dataframe(
                x=x, y_true=y_true, y_pred=y_pred, labels=self.labels
            )
        elif (
            self.input_type in ("image", "images")
            and self.output_type == "segmentation_mask"
        ):
            return wandb.image_segmentation_dataframe(
                x=x,
                y_true=y_true,
                y_pred=y_pred,
                labels=self.labels,
                class_colors=self.class_colors,
            )
        else:
            wandb.termwarn(
                "unknown dataframe type for input_type=%s and output_type=%s"
                % (self.input_type, self.output_type)
            )
            return None
Beispiel #18
0
    def apply(self) -> None:
        """Call require_* method for supported features."""
        last_message: str = ""
        for feature_item in self._features:
            full_feature = feature_item.split("@", 2)[0]
            feature = full_feature.split(":", 2)[0]
            func_str = "require_{}".format(feature.replace("-", "_"))
            func = getattr(self, func_str, None)
            if not func:
                last_message = "require() unsupported requirement: {}".format(
                    feature)
                wandb.termwarn(last_message)
                continue
            func()

        if last_message:
            raise RequireError(last_message)
Beispiel #19
0
 def use_artifact(self, artifact, type=None, aliases=None):
     if self.mode == "dryrun":
         wandb.termwarn(
             "Using artifacts in dryrun mode is currently unsupported.")
         return artifact
     self.history.ensure_jupyter_started()
     if isinstance(artifact, str):
         if type is None:
             raise ValueError('type required')
         public_api = PublicApi()
         artifact = public_api.artifact(type=type, name=artifact)
         self.api.use_artifact(artifact.id)
         return artifact
     else:
         if type is not None:
             raise ValueError(
                 'cannot specify type when passing Artifact object')
         if isinstance(aliases, str):
             aliases = [aliases]
         if isinstance(artifact, wandb.Artifact):
             artifact.finalize()
             self.send_message({
                 'use_artifact': {
                     'type':
                     artifact.type,
                     'name':
                     artifact.name,
                     'server_manifest_entries':
                     artifact.server_manifest.entries,
                     'manifest':
                     artifact.manifest.to_manifest_json(include_local=True),
                     'digest':
                     artifact.digest,
                     'metadata':
                     artifact.metadata,
                     'aliases':
                     aliases
                 }
             })
         elif isinstance(artifact, ApiArtifact):
             self.api.use_artifact(artifact.id)
             return artifact
         else:
             raise ValueError(
                 'You must pass an artifact name (e.g. "pedestrian-dataset:v1"), an instance of wandb.Artifact, or wandb.Api().artifact() to use_artifact'
             )
Beispiel #20
0
def is_wandb_installed_and_logged_in() -> bool:
    """Checks if wandb is installed and if a login is detected.

    Returns
    -------
    bool
        Is true, if wandb is installed and a login is detected, otherwise false.
    """
    if not _HAS_WANDB:
        return False
    if wandb.api.api_key is None:
        wandb.termwarn(
            "W&B installed but not logged in. "
            "Run `wandb login` or `import wandb; wandb.login()` or set the WANDB_API_KEY env variable."
        )
        return False
    return True
def unwatch(models=None):
    """Remove pytorch gradient and parameter hooks.

    Args:
        models: (list) Optional list of pytorch models that have had watch called on them
    """
    if models:
        if not isinstance(models, (tuple, list)):
            models = (models, )
        for model in models:
            if not hasattr(model, "_wandb_hook_names"):
                wandb.termwarn("%s model has not been watched" % model)
            else:
                for name in model._wandb_hook_names:
                    wandb.run.history.torch.unhook(name)
    else:
        wandb.run.history.torch.unhook_all()
Beispiel #22
0
 def _setup_tensorboard(self, tb_root, tb_logdirs, tb_event_files,
                        sync_item):
     """Returns true if this sync item can be synced as tensorboard"""
     if tb_root is not None:
         if tb_event_files > 0 and sync_item.endswith(WANDB_SUFFIX):
             wandb.termwarn(
                 "Found .wandb file, not streaming tensorboard metrics.")
         else:
             print("Found {} tfevent files in {}".format(
                 tb_event_files, tb_root))
             if len(tb_logdirs) > 3:
                 wandb.termwarn(
                     "Found {} directories containing tfevent files. "
                     "If these represent multiple experiments, sync them "
                     "individually or pass a list of paths.")
             return True
     return False
Beispiel #23
0
 def confusion_matrix_table(cm, label):
     data = []
     count = 0
     pred_classes, true_classes = [label, 'Rest'], [label, 'Rest']
     for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
         pred_dict = pred_classes[i]
         true_dict = true_classes[j]
         data.append([pred_dict, true_dict, cm[i, j]])
         count += 1
         if count >= wandb.Table.MAX_ROWS:
             wandb.termwarn(
                 f"wandb uses only the first {wandb.Table.MAX_ROWS} datapoints to create plots."
             )
             break
     return wandb.visualize(
         'wandb/confusion_matrix/v1',
         wandb.Table(columns=['Predicted', 'Actual', 'Count'], data=data))
Beispiel #24
0
def test_missing(**kwargs):
    np = util.get_module("numpy", required="Logging plots requires numpy")
    pd = util.get_module("pandas",
                         required="Logging dataframes requires pandas")
    scipy = util.get_module("scipy",
                            required="Logging scipy matrices requires scipy")
    scikit = util.get_module(
        "sklearn", required="Logging plots matrices requires scikit-learn")
    test_passed = True
    for k, v in kwargs.items():
        # Missing/empty params/datapoint arrays
        if v is None:
            wandb.termerror("%s is None. Please try again." % (k))
            test_passed = False
        if ((k == 'X') or (k == 'X_test')):
            if isinstance(v, scipy.sparse.csr.csr_matrix):
                v = v.toarray()
            elif isinstance(v, (pd.DataFrame, pd.Series)):
                v = v.to_numpy()
            elif isinstance(v, list):
                v = np.asarray(v)

            # Warn the user about missing values
            missing = 0
            missing = np.count_nonzero(pd.isnull(v))
            if missing > 0:
                wandb.termwarn("%s contains %d missing values. " %
                               (k, missing))
                test_passed = False
            # Ensure the dataset contains only integers
            non_nums = 0
            if v.ndim == 1:
                non_nums = sum(1 for val in v
                               if (not isinstance(val, (int, float, complex))
                                   and not isinstance(val, np.number)))
            else:
                non_nums = sum(1 for sl in v for val in sl
                               if (not isinstance(val, (int, float, complex))
                                   and not isinstance(val, np.number)))
            if non_nums > 0:
                wandb.termerror(
                    "%s contains values that are not numbers. Please vectorize, label encode or one hot encode %s and call the plotting function again."
                    % (k, k))
                test_passed = False
    return test_passed
Beispiel #25
0
 def set_wandb_attrs(cbk, val_data):
     if isinstance(cbk, WandbCallback):
         if is_generator_like(val_data):
             cbk.generator = val_data
         elif is_dataset(val_data):
             if context.executing_eagerly():
                 cbk.generator = iter(val_data)
             else:
                 wandb.termwarn(
                     "Found a validation dataset in graph mode, can't patch Keras.")
         elif isinstance(val_data, tuple) and isinstance(val_data[0], tf.Tensor):
             # Graph mode dataset generator
             def gen():
                 while True:
                     yield K.get_session().run(val_data)
             cbk.generator = gen()
         else:
             cbk.validation_data = val_data
Beispiel #26
0
def _parse_magic(val):
    # attempt to treat string as a json
    not_set = {}
    if val is None:
        return _magic_defaults, not_set
    if val.startswith("{"):
        try:
            val = json.loads(val)
        except ValueError:
            wandb.termwarn("Unable to parse magic json", repeat=False)
            return _magic_defaults, not_set
        conf = _merge_dicts(_magic_defaults, {})
        return _merge_dicts(val, conf), val
    if os.path.isfile(val):
        try:
            with open(val, 'r') as stream:
                val = yaml.safe_load(stream)
        except IOError as e:
            wandb.termwarn("Unable to read magic config file", repeat=False)
            return _magic_defaults, not_set
        except yaml.YAMLError as e:
            wandb.termwarn("Unable to parse magic yaml file", repeat=False)
            return _magic_defaults, not_set
        conf = _merge_dicts(_magic_defaults, {})
        return _merge_dicts(val, conf), val
    # parse as a list of key value pairs
    if val.find('=') > 0:
        # split on commas but ignore commas inside quotes
        # Using this re allows env variable parsing like:
        # WANDB_MAGIC=key1='"["cat","dog","pizza"]"',key2=true
        items = re.findall(r'(?:[^\s,"]|"(?:\\.|[^"])*")+', val)
        conf_set = {}
        for kv in items:
            kv = kv.split('=')
            if len(kv) != 2:
                wandb.termwarn("Unable to parse magic key value pair",
                               repeat=False)
                continue
            d = _dict_from_keyval(*kv)
            _merge_dicts(d, conf_set)
        conf = _merge_dicts(_magic_defaults, {})
        return _merge_dicts(conf_set, conf), conf_set
    wandb.termwarn("Unable to parse magic parameter", repeat=False)
    return _magic_defaults, not_set
 def apply_init(self, args):
     # strip out items where value is None
     param_map = dict(
         name="run_name",
         id="run_id",
         tags="run_tags",
         group="run_group",
         job_type="run_job_type",
         notes="run_notes",
         dir="root_dir",
     )
     args = {
         param_map.get(k, k): v
         for k, v in six.iteritems(args) if v is not None
     }
     # fun logic to convert the resume init arg
     if args.get("resume") is not None:
         if isinstance(args["resume"], six.string_types):
             if args["resume"] not in ("allow", "must", "never", "auto"):
                 if args.get("run_id") is None:
                     #  TODO: deprecate or don't support
                     args["run_id"] = args["resume"]
                 args["resume"] = "allow"
         elif args["resume"] is True:
             args["resume"] = "auto"
     self.update(args)
     # handle auto resume logic
     if self.resume == "auto":
         if os.path.exists(self.resume_fname):
             with open(self.resume_fname) as f:
                 resume_run_id = json.load(f)["run_id"]
             if self.run_id is None:
                 self.run_id = resume_run_id
             else:
                 wandb.termwarn(
                     "Tried to auto resume run with id %s but id %s is set."
                     % (resume_run_id, self.run_id))
     self.run_id = self.run_id or generate_id()
     # persist our run id incase of failure
     if self.resume == "auto":
         wandb.util.mkdir_exists_ok(self.wandb_dir)
         with open(self.resume_fname, "w") as f:
             f.write(json.dumps({"run_id": self.run_id}))
Beispiel #28
0
def feature_importances(model, feature_names):
    attributes_to_check = [
        "feature_importances_", "feature_log_prob_", "coef_"
    ]
    found_attribute = check_for_attribute_on(model, attributes_to_check)
    if found_attribute is None:
        wandb.termwarn(
            f"could not find any of attributes {', '.join(attributes_to_check)} on classifier. Cannot plot feature importances."
        )
        return
    elif found_attribute == "feature_importances_":
        importances = model.feature_importances_
    elif found_attribute == "coef_":  # ElasticNet-like models
        importances = model.coef_
    elif found_attribute == "feature_log_prob_":
        # coef_ was deprecated in sklearn 0.24, replaced with
        # feature_log_prob_
        importances = model.feature_log_prob_

    if len(importances.shape) > 1:
        n_significant_dims = sum([i > 1 for i in importances.shape])
        if n_significant_dims > 1:
            nd = len(importances.shape)
            wandb.termwarn(
                f"{nd}-dimensional feature importances array passed to plot_feature_importances. "
                f"{nd}-dimensional and higher feature importances arrays are not currently supported. "
                f"These importances will not be plotted.")
            return
        else:
            importances = np.squeeze(importances)

    indices = np.argsort(importances)[::-1]
    importances = importances[indices]

    if feature_names is None:
        feature_names = indices
    else:
        feature_names = np.array(feature_names)[indices]

    table = make_table(feature_names, importances)
    chart = wandb.visualize("wandb/feature_importances/v1", table)

    return chart
Beispiel #29
0
 def _init_validation_gen(self):
     """
     Helper method for initializing Validation data table
     """
     if self.log_evaluation:
         try:
             validation_data = None
             if self.validation_data:
                 validation_data = self.validation_data
                 self.validation_data_logger = ValidationDataLogger(
                     inputs=validation_data[0],
                     targets=validation_data[1],
                     indexes=None,
                     validation_row_processor=None,
                     prediction_row_processor=lambda ndx, row: {"output": np.argmax(row["output"])},
                     class_labels=self.labels,
                     infer_missing_processors=self.infer_missing_processors)
         except Exception as e:
             wandb.termwarn(
                 "Error initializing ValidationDataLogger in WandbCallback. Skipping logging validation data. Error: " + str(
                     e))
Beispiel #30
0
        def __init__(self, logdir, *args, **kwargs):
            logdir_hist.append(logdir)
            root_logdir_arg = root_logdir
            if len(set(logdir_hist)) > 1:
                wandb.termwarn(
                    'When using several event log directories, please call wandb.tensorboard.patch(root_logdir="...") before wandb.init'
                )

            if root_logdir is not None and not os.path.abspath(logdir).startswith(
                os.path.abspath(root_logdir)
            ):
                wandb.termwarn(
                    "Found logdirectory outside of given root_logdir, dropping given root_logdir for eventfile in {}".format(
                        logdir
                    )
                )
                root_logdir_arg = None

            _notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg)

            super(TBXEventFileWriter, self).__init__(logdir, *args, **kwargs)