Example #1
0
    def add(self, row):
        """Add a row to the table.

        Args:
            row: A dict whose keys match the keys added in set_columns, and whose
                values can be cast to the types added in set_columns.
        """
        if not self._types:
            raise wandb.Error(
                'TypedTable.set_columns must be called before add.')
        mapped_row = {}
        for key, val in row.items():
            try:
                typed_val = self._types[key](val)
                if hasattr(typed_val, 'encode'):
                    typed_val = typed_val.encode()
                mapped_row[key] = typed_val
            except KeyError:
                raise wandb.Error(
                    'TypedTable.add received key ("%s") which wasn\'t provided to set_columns'
                    % key)
            except:
                raise wandb.Error(
                    'TypedTable.add couldn\'t convert and encode ("{}") provided for key ("{}") to type ({})'
                    .format(val, key, self._types[key]))
        self._output.add(mapped_row)
        self._count += 1
Example #2
0
    def set_columns(self, types):
        """Set the column types

        args:
            types: iterable of (column_name, type) pairs.
        """
        if self._types:
            raise wandb.Error('TypedTable.set_columns called more than once.')
        try:
            for key, type_ in types:
                if type_ not in TYPE_TO_TYPESTRING:
                    raise wandb.Error(
                        'TypedTable.set_columns received invalid type ({}) for key "{}".\n  Valid types: {}'
                        .format(type_, key,
                                '[%s]' % ', '.join(VALID_TYPE_NAMES)))
        except TypeError:
            raise wandb.Error(
                'TypedTable.set_columns requires iterable of (column_name, type) pairs.'
            )
        self._types = dict(types)
        self._output.add({
            'typemap': {k: TYPE_TO_TYPESTRING[type_]
                        for k, type_ in types},
            'columns': [t[0] for t in types]
        })
Example #3
0
    def add(self, row={}, step=None, timestamp=None):
        """Adds or updates a history step.

        If row isn't specified, will write the current state of row.

        If step is specified, the row will be written only when add() is called with
        a different step value.

        run.history.row["duration"] = 1.0
        run.history.add({"loss": 1})
        => {"duration": 1.0, "loss": 1}

        """
        if not isinstance(row, collections.Mapping):
            raise wandb.Error('history.add expects dict-like object')
        if timestamp and self._current_timestamp and timestamp < self._current_timestamp:
            wandb.termwarn(
                "When passing timestamp, it must be increasing.  Current timestamp is {} but was passed {}"
                .format(self._current_timestamp, timestamp))
        self._current_timestamp = timestamp or time.time()
        # Importing data, reset start time to the first timestamp passed in
        if self._start_time > self._current_timestamp:
            self._start_time = timestamp

        if step is None:
            self.update(row)
            if not self.batched:
                self._write()
        else:
            if not isinstance(step, numbers.Number):
                raise wandb.Error("Step must be a number, not {}".format(step))
            else:
                if step != round(step):
                    # tensorflow just applies `int()`. seems a little crazy.
                    wandb.termwarn(
                        'Non-integer history step: {}; rounding.'.format(step))

                # the backend actually handles floats right now. seems a bit weird to let those through though.
                step = int(round(step))

                if step < self._steps:
                    wandb.termwarn(
                        "Adding to old History rows isn't currently supported.  Step {} < {}; dropping {}."
                        .format(step, self._steps, row))
                    return
                elif step == self._steps:
                    pass
                elif self.batched:
                    raise wandb.Error(
                        "Can't log to a particular History step ({}) while in batched mode."
                        .format(step))
                else:  # step > self._steps
                    self._write()
                    self._steps = step

            self.update(row)
Example #4
0
def log_summary(model: Booster,
                feature_importance: bool = True,
                save_model_checkpoint: bool = False) -> None:
    """Logs useful metrics about lightgbm model after training is done.

    Arguments:
        model: (Booster) is an instance of lightgbm.basic.Booster.
        feature_importance: (boolean) if True (default), logs the feature importance plot.
        save_model_checkpoint: (boolean) if True saves the best model and upload as W&B artifacts.

    Using this along with `wandb_callback` will:

    - log `best_iteration` and `best_score` as `wandb.summary`.
    - log feature importance plot.
    - save and upload your best trained model to Weights & Biases Artifacts (when `save_model_checkpoint = True`)

    Example:
        ```python
        params = {
            'boosting_type': 'gbdt',
            'objective': 'regression',
            .
        }
        gbm = lgb.train(params,
                        lgb_train,
                        num_boost_round=10,
                        valid_sets=lgb_eval,
                        valid_names=('validation'),
                        callbacks=[wandb_callback()])

        log_summary(gbm)
        ```
    """
    if wandb.run is None:
        raise wandb.Error("You must call wandb.init() before WandbCallback()")

    if not isinstance(model, Booster):
        raise wandb.Error(
            "Model should be an instance of lightgbm.basic.Booster")

    wandb.run.summary["best_iteration"] = model.best_iteration
    wandb.run.summary["best_score"] = model.best_score

    # Log feature importance
    if feature_importance:
        _log_feature_importance(model)

    if save_model_checkpoint:
        _checkpoint_artifact(model, model.best_iteration, aliases=["best"])

    with wb_telemetry.context() as tel:
        tel.feature.lightgbm_log_summary = True
Example #5
0
 def __setattr__(self, key, value):
     if not key.startswith("_"):
         raise wandb.Error(
             "You must call wandb.init() before {}.{}".format(self._name, key)
         )
     else:
         return object.__setattr__(self, key, value)
Example #6
0
 def __getattr__(self, key):
     if not key.startswith("_"):
         raise wandb.Error(
             "You must call wandb.init() before {}.{}".format(self._name, key)
         )
     else:
         raise AttributeError()
Example #7
0
 def update(self, key_vals):
     if not isinstance(key_vals, dict):
         raise wandb.Error('summary.update expects dict')
     # TODO: This removes media from the summary, but will silently remove a user provided dict with _type
     self._summary.update(
         {k: v for k, v in six.iteritems(key_vals) if not (isinstance(v, dict) and v.get("_type"))})
     self._write()
Example #8
0
    def __init__(self, metric_period: int = 1):
        if wandb.run is None:
            raise wandb.Error(
                "You must call `wandb.init()` before `WandbCallback()`")

        with wb_telemetry.context() as tel:
            tel.feature.catboost_wandb_callback = True

        self.metric_period: int = metric_period
Example #9
0
 def update(self, key_vals):
     if not isinstance(key_vals, dict):
         raise wandb.Error('summary.update expects dict')
     summary = {}
     for k, v in six.iteritems(key_vals):
         if isinstance(v, dict) and v.get("_type") == "image":
             continue
         summary[k] = self._transform(v)
     self._summary.update(summary)
     self._write()
Example #10
0
    def _communicate_exit(self, exit_data, timeout=None):
        req = self._make_record(exit=exit_data)

        result = self._communicate(req, timeout=timeout)
        if result is None:
            # TODO: friendlier error message here
            raise wandb.Error(
                "Couldn't communicate with backend after %s seconds" % timeout)
        assert result.exit_result
        return result.exit_result
Example #11
0
 def communicate_login(self, api_key=None, anonymous=None, timeout=15):
     login = self._make_login(api_key, anonymous)
     rec = self._make_request(login=login)
     result = self._communicate(rec, timeout=timeout)
     if result is None:
         # TODO: friendlier error message here
         raise wandb.Error(
             "Couldn't communicate with backend after %s seconds" % timeout)
     login_response = result.response.login_response
     assert login_response
     return login_response
Example #12
0
    def add(self, row={}, step=None):
        """Adds or updates a history step.

        If row isn't specified, will write the current state of row.

        If step is specified, the row will be written only when add() is called with
        a different step value.

        run.history.row["duration"] = 1.0
        run.history.add({"loss": 1})
        => {"duration": 1.0, "loss": 1}

        """
        if not isinstance(row, collections.Mapping):
            raise wandb.Error('history.add expects dict-like object')

        if step is None:
            self.update(row)
            if not self.batched:
                self._write()
        else:
            if not isinstance(step, numbers.Integral):
                raise wandb.Error(
                    "Step must be an integer, not {}".format(step))
            elif step < self._steps:
                warnings.warn(
                    "Adding to old History rows isn't currently supported. Dropping.",
                    wandb.WandbWarning)
                return
            elif step == self._steps:
                pass
            elif self.batched:
                raise wandb.Error(
                    "Can't log to a particular History step ({}) while in batched mode."
                    .format(step))
            else:  # step > self._steps
                self._write()
                self._steps = step

            self.update(row)
Example #13
0
 def communicate_login(self,
                       api_key: str = None,
                       timeout: Optional[int] = 15) -> pb.LoginResponse:
     login = self._make_login(api_key)
     rec = self._make_request(login=login)
     result = self._communicate(rec, timeout=timeout)
     if result is None:
         # TODO: friendlier error message here
         raise wandb.Error(
             "Couldn't communicate with backend after %s seconds" % timeout)
     login_response = result.response.login_response
     assert login_response
     return login_response
Example #14
0
    def add(self, row={}):
        """Adds keys to history and writes the row.  If row isn't specified, will write
        the current state of row.

        run.history.row["duration"] = 1.0
        run.history.add({"loss": 1})
        => {"duration": 1.0, "loss": 1}

        """
        if not isinstance(row, collections.Mapping):
            raise wandb.Error('history.add expects dict-like object')
        self.row.update(row)
        if not self.batched:
            self._write()
Example #15
0
    def log_stats(self, variable_or_module, name=None, prefix='_', values=True, gradients=True):
        """Log distribution statistics for a torch Variable or Module and its next
        gradient in History.

        For a Variable, logs statistics on its current data and gradient whenever
        its backward() method is next called. For a module, logs the same on all
        its Parameters (including those of submodules).

        Here's how you might use this function to instrument the hidden units of a
        network:

            def forward(self, x):
                x = F.relu(F.max_pool2d(self.conv1(x), 2))
                run.history.torch.log_stats(x, 'conv1.out')
                x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
                run.history.torch.log_stats(x, 'conv2.out')
                x = x.view(-1, args.n2 * 16)
                x = F.relu(self.fc1(x))
                run.history.torch.log_stats(x, 'fc1.out')
                x = F.dropout(x, training=self.training)
                x = self.fc2(x)
                x = F.log_softmax(x, dim=0)
                run.history.torch.log_stats(x, 'fc2.out')
                return x
        """
        history = self._history()
        if history is None or not history.compute:
            return

        if isinstance(variable_or_module, torch.autograd.Variable):
            if name is None:
                raise wandb.Error('Need a name to log stats for a Variable.')
            var = variable_or_module
            if values:
                self.log_tensor_stats(var.data, prefix + name)
            if gradients:
                self._hook_variable_gradient_stats(
                    var, prefix + name + ':grad')
        elif isinstance(variable_or_module, torch.nn.Module):
            module = variable_or_module
            if name is not None:
                prefix = prefix + name
            for name, parameter in module.named_parameters():
                self.log_stats(parameter, name=name, prefix=prefix,
                               values=values, gradients=gradients)
        else:
            cls = type(var)
            raise TypeError('Expected torch.autograd.Variable or torch.nn.Module, not {}.{}'.format(
                cls.__module__, cls.__name__))
Example #16
0
def apply_patch(patch_string: str, dst_dir: str) -> None:
    """Applies a patch file to a directory."""
    _logger.info("Applying diff.patch")
    with open(os.path.join(dst_dir, "diff.patch"), "w") as fp:
        fp.write(patch_string)
    try:
        subprocess.check_call([
            "patch",
            "-s",
            "--directory={}".format(dst_dir),
            "-p1",
            "-i",
            "diff.patch",
        ])
    except subprocess.CalledProcessError:
        raise wandb.Error("Failed to apply diff.patch associated with run.")
Example #17
0
    def __init__(
        self,
        log_model: bool = False,
        log_feature_importance: bool = True,
        importance_type: str = "gain",
        define_metric: bool = True,
    ):

        if wandb.run is None:
            raise wandb.Error(
                "You must call wandb.init() before WandbCallback()")

        self.log_model = log_model
        self.log_feature_importance = log_feature_importance
        self.importance_type = importance_type
        self.define_metric = define_metric
Example #18
0
def get_module(name, required=None):
    """
    Return module or None. Absolute import is required.
    :param (str) name: Dot-separated module path. E.g., 'scipy.stats'.
    :param (str) required: A string to raise a ValueError if missing
    :return: (module|None) If import succeeds, the module will be returned.
    """
    if name not in _not_importable:
        try:
            return import_module(name)
        except Exception as e:
            _not_importable.add(name)
            msg = "Error importing optional module {}".format(name)
            if required:
                logger.exception(msg)
    if required and name in _not_importable:
        raise wandb.Error(required)
Example #19
0
 def track(self, event, properties, timestamp=None, _wandb=False):
     if not isinstance(properties, collections.Mapping):
         raise wandb.Error('event.track expects dict-like object')
     self.lock.acquire()
     try:
         row = {}
         row[event] = properties
         self.flatten(row)
         if _wandb:
             row["_wandb"] = _wandb
         row["_timestamp"] = int(timestamp or time.time())
         row['_runtime'] = int(time.time() - self._start_time)
         self._file.write(util.json_dumps_safer(row))
         self._file.write('\n')
         self._file.flush()
     finally:
         self.lock.release()
Example #20
0
    def step(self, compute=True):
        """Context manager to gradually build a history row, then commit it at the end.

        To reduce the number of conditionals needed, code can check run.history.compute:

        with run.history.step(batch_idx % log_interval == 0):
            run.history.add({"nice": "ok"})
            if run.history.compute:
                # Something expensive here
        """
        if self.batched:  # we're already in a context manager
            raise wandb.Error("Nested History step contexts aren't supported")
        self.batched = True
        self.compute = compute
        yield self
        if compute:
            self._write()
        compute = True
Example #21
0
def _checkpoint_artifact(model: Union[CatBoostClassifier, CatBoostRegressor],
                         aliases: List[str]) -> None:
    """
    Upload model checkpoint as W&B artifact
    """
    if wandb.run is None:
        raise wandb.Error(
            "You must call `wandb.init()` before `_checkpoint_artifact()`")

    model_name = f"model_{wandb.run.id}"
    # save the model in the default `cbm` format
    model_path = Path(wandb.run.dir) / "model"

    model.save_model(model_path)

    model_artifact = wandb.Artifact(name=model_name, type="model")
    model_artifact.add_file(str(model_path))
    wandb.log_artifact(model_artifact, aliases=aliases)
Example #22
0
 def __init__(
     self,
     verbose: int = 0,
     model_save_path: str = None,
     model_save_freq: int = 0,
     gradient_save_freq: int = 0,
 ):
     super(WandbCallback, self).__init__(verbose)
     if wandb.run is None:
         raise wandb.Error(
             "You must call wandb.init() before WandbCallback()")
     with wb_telemetry.context() as tel:
         tel.feature.sb3 = True
     self.model_save_freq = model_save_freq
     self.model_save_path = model_save_path
     self.gradient_save_freq = gradient_save_freq
     # Create folder if needed
     if self.model_save_path is not None:
         os.makedirs(self.model_save_path, exist_ok=True)
         self.path = os.path.join(self.model_save_path, "model.zip")
     else:
         assert (
             self.model_save_freq == 0
         ), "to use the `model_save_freq` you have to set the `model_save_path` parameter"
Example #23
0
def _log_feature_importance(
        model: Union[CatBoostClassifier, CatBoostRegressor]) -> None:
    """
    Log feature importance with default settings.
    """
    if wandb.run is None:
        raise wandb.Error(
            "You must call `wandb.init()` before `_checkpoint_artifact()`")

    feat_df = model.get_feature_importance(prettified=True)

    fi_data = [[
        feat, feat_imp
    ] for feat, feat_imp in zip(feat_df["Feature Id"], feat_df["Importances"])]
    table = wandb.Table(data=fi_data, columns=["Feature", "Importance"])
    # todo: replace with wandb.run._log once available
    wandb.log(
        {
            "Feature Importance":
            wandb.plot.bar(
                table, "Feature", "Importance", title="Feature Importance")
        },
        commit=False,
    )
Example #24
0
def data_frame_to_json(df, run, key, step):
    """Encode a Pandas DataFrame into the JSON/backend format.

    Writes the data to a file and returns a dictionary that we use to represent
    it in `Summary`'s.

    Arguments:
        df (pandas.DataFrame): The DataFrame. Must not have columns named
            "wandb_run_id" or "wandb_data_frame_id". They will be added to the
            DataFrame here.
        run (wandb_run.Run): The Run the DataFrame is associated with. We need
            this because the information we store on the DataFrame is derived
            from the Run it's in.
        key (str): Name of the DataFrame, ie. the summary key path in which it's
            stored. This is for convenience, so people exploring the
            directory tree can have some idea of what is in the Parquet files.
        step: History step or "summary".

    Returns:
        A dict representing the DataFrame that we can store in summaries or
        histories. This is the format:
        {
            '_type': 'data-frame',
                # Magic field that indicates that this object is a data frame as
                # opposed to a normal dictionary or anything else.
            'id': 'asdf',
                # ID for the data frame that is unique to this Run.
            'format': 'parquet',
                # The file format in which the data frame is stored. Currently can
                # only be Parquet.
            'project': 'wfeas',
                # (Current) name of the project that this Run is in. It'd be
                # better to store the project's ID because we know it'll never
                # change but we don't have that here. We store this just in
                # case because we use the project name in identifiers on the
                # back end.
            'path': 'media/data_frames/sdlk.parquet',
                # Path to the Parquet file in the Run directory.
        }
    """
    pandas = util.get_module("pandas")
    fastparquet = util.get_module("fastparquet")
    if not pandas or not fastparquet:
        raise wandb.Error(
            "Failed to save data frame: unable to import either pandas or fastparquet."
        )

    data_frame_id = util.generate_id()

    df = df.copy()  # we don't want to modify the user's DataFrame instance.

    for col_name, series in df.items():
        for i, val in enumerate(series):
            if isinstance(val, WBValue):
                series.iat[i] = six.text_type(
                    json.dumps(val_to_json(run, key, val, step)))

    # We have to call this wandb_run_id because that name is treated specially by
    # our filtering code
    df['wandb_run_id'] = pandas.Series([six.text_type(run.name)] *
                                       len(df.index),
                                       index=df.index)

    df['wandb_data_frame_id'] = pandas.Series([six.text_type(data_frame_id)] *
                                              len(df.index),
                                              index=df.index)
    frames_dir = os.path.join(run.dir, DATA_FRAMES_SUBDIR)
    util.mkdir_exists_ok(frames_dir)
    path = os.path.join(frames_dir, '{}-{}.parquet'.format(key, data_frame_id))
    fastparquet.write(path, df)

    return {
        'id': data_frame_id,
        '_type': 'data-frame',
        'format': 'parquet',
        'project': run.project_name(),  # we don't have the project ID here
        'entity': run.entity,
        'run': run.id,
        'path': path,
    }
    def __init__(self, sweep_id_or_config=None, entity=None, project=None):
        global wandb_sweeps
        try:
            from wandb.sweeps import sweeps as wandb_sweeps
        except ImportError as e:
            raise wandb.Error("Module load error: " + str(e))

        # sweep id configured in constuctor
        self._sweep_id = None

        # configured parameters
        # Configuration to be created
        self._create = {}
        # Custom search
        self._custom_search = None
        # Custom stopping
        self._custom_stopping = None
        # Program function (used for future jupyter support)
        self._program_function = None

        # The following are updated every sweep step
        # raw sweep object (dict of strings)
        self._sweep_obj = None
        # parsed sweep config (dict)
        self._sweep_config = None
        # sweep metric used to optimize (str or None)
        self._sweep_metric = None
        # list of _Run objects
        self._sweep_runs = None
        # dictionary mapping name of run to run object
        self._sweep_runs_map = None
        # scheduler dict (read only from controller) - used as feedback from the server
        self._scheduler = None
        # controller dict (write only from controller) - used to send commands to server
        self._controller = None
        # keep track of controller dict from previous step
        self._controller_prev_step = None

        # Internal
        # Keep track of whether the sweep has been started
        self._started = False
        # indicate whether there is more to schedule
        self._done_scheduling = False
        # indicate whether the sweep needs to be created
        self._defer_sweep_creation = False
        # count of logged lines since last status
        self._logged = 0
        # last status line printed
        self._laststatus = ""
        # keep track of logged actions for print_actions()
        self._log_actions = []
        # keep track of logged debug for print_debug()
        self._log_debug = []

        # all backend commands use internal api
        environ = os.environ
        if entity:
            env.set_entity(entity, env=environ)
        if project:
            env.set_project(project, env=environ)
        self._api = InternalApi(environ=environ)

        if isinstance(sweep_id_or_config, str):
            self._sweep_id = sweep_id_or_config
        elif isinstance(sweep_id_or_config, dict):
            self.configure(sweep_id_or_config)
            self._sweep_id = self.create()
        elif sweep_id_or_config is None:
            self._defer_sweep_creation = True
            return
        else:
            raise ControllerError("Unhandled sweep controller type")
        sweep_obj = self._sweep_object_read_from_backend()
        if sweep_obj is None:
            raise ControllerError("Can not find sweep")
        self._sweep_obj = sweep_obj
Example #26
0
    def __init__(self,
                 monitor='val_loss',
                 verbose=0,
                 mode='auto',
                 save_weights_only=False,
                 log_weights=False,
                 log_gradients=False,
                 save_model=True,
                 training_data=None,
                 validation_data=None,
                 labels=[],
                 data_type=None,
                 predictions=36,
                 generator=None,
                 input_type=None,
                 output_type=None,
                 log_evaluation=False,
                 validation_steps=None,
                 class_colors=None,
                 log_batch_frequency=None,
                 log_best_prefix="best_",
                 save_graph=True):
        if wandb.run is None:
            raise wandb.Error(
                'You must call wandb.init() before WandbCallback()')

        self.validation_data = None
        # This is kept around for legacy reasons
        if validation_data is not None:
            if is_generator_like(validation_data):
                generator = validation_data
            else:
                self.validation_data = validation_data

        self.labels = labels
        self.predictions = min(predictions, 100)

        self.monitor = monitor
        self.verbose = verbose
        self.save_weights_only = save_weights_only
        self.save_graph = save_graph

        wandb.save('model-best.h5')
        self.filepath = os.path.join(wandb.run.dir, 'model-best.h5')
        self.save_model = save_model
        self.log_weights = log_weights
        self.log_gradients = log_gradients
        self.training_data = training_data
        self.generator = generator
        self._graph_rendered = False

        self.input_type = input_type or data_type
        self.output_type = output_type
        self.log_evaluation = log_evaluation
        self.validation_steps = validation_steps
        self.class_colors = np.array(
            class_colors) if class_colors is not None else None
        self.log_batch_frequency = log_batch_frequency
        self.log_best_prefix = log_best_prefix

        if self.training_data:
            if len(self.training_data) != 2:
                raise ValueError("training data must be a tuple of length two")

        # From Keras
        if mode not in ['auto', 'min', 'max']:
            print('WandbCallback mode %s is unknown, '
                  'fallback to auto mode.' % (mode))
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = operator.lt
            self.best = float('inf')
        elif mode == 'max':
            self.monitor_op = operator.gt
            self.best = float('-inf')
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = operator.gt
                self.best = float('-inf')
            else:
                self.monitor_op = operator.lt
                self.best = float('inf')
Example #27
0
 def __call__(self, *args, **kwargs):
     raise wandb.Error(
         'You must call wandb.init() before {}()'.format(self._name))
Example #28
0
def log_summary(
    model: Union[CatBoostClassifier, CatBoostRegressor],
    log_all_params: bool = True,
    save_model_checkpoint: bool = False,
    log_feature_importance: bool = True,
) -> None:
    """`log_summary` logs useful metrics about catboost model after training is done

    Arguments:
        model: it can be CatBoostClassifier or CatBoostRegressor.
        log_all_params: (boolean) if True (default) log the model hyperparameters as W&B config.
        save_model_checkpoint: (boolean) if True saves the model upload as W&B artifacts.
        log_feature_importance: (boolean) if True (default) logs feature importance as W&B bar chart using the default setting of `get_feature_importance`.

    Using this along with `wandb_callback` will:

    - save the hyperparameters as W&B config,
    - log `best_iteration` and `best_score` as `wandb.summary`,
    - save and upload your trained model to Weights & Biases Artifacts (when `save_model_checkpoint = True`)
    - log feature importance plot.

    Example:
        ```python
        train_pool = Pool(train[features], label=train['label'], cat_features=cat_features)
        test_pool = Pool(test[features], label=test['label'], cat_features=cat_features)

        model = CatBoostRegressor(
            iterations=100,
            loss_function='Cox',
            eval_metric='Cox',
        )

        model.fit(
            train_pool,
            eval_set=test_pool,
            callbacks=[WandbCallback()],
        )

        log_summary(model)
        ```
    """
    if wandb.run is None:
        raise wandb.Error(
            "You must call `wandb.init()` before `log_summary()`")

    if not (isinstance(model, (CatBoostClassifier, CatBoostRegressor))):
        raise wandb.Error(
            "Model should be an instance of CatBoostClassifier or CatBoostRegressor"
        )

    with wb_telemetry.context() as tel:
        tel.feature.catboost_log_summary = True

    # log configs
    params = model.get_all_params()
    if log_all_params:
        wandb.config.update(params)

    # log best score and iteration
    wandb.run.summary["best_iteration"] = model.get_best_iteration()
    wandb.run.summary["best_score"] = model.get_best_score()

    # log model
    if save_model_checkpoint:
        aliases = ["best"] if params["use_best_model"] else ["last"]
        _checkpoint_artifact(model, aliases=aliases)

    # Feature importance
    if log_feature_importance:
        _log_feature_importance(model)
Example #29
0
 def __setitem__(self, key, value):
     raise wandb.Error('You must call wandb.init() before {}["{}"]'.format(
         self._name, key))
Example #30
0
 def preinit_wrapper(*args, **kwargs):
     raise wandb.Error(
         "You must call wandb.init() before {}()".format(name))