Beispiel #1
0
    def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
        """
        Log the current result of a Trial upon each iteration.
        """
        if trial not in self._trial_experiments:
            self.log_trial_start(trial)
        experiment = self._trial_experiments[trial]
        step = result["training_iteration"]

        config_update = result.pop("config", {}).copy()
        config_update.pop("callbacks", None)  # Remove callbacks
        for k, v in config_update.items():
            if isinstance(v, dict):
                experiment.log_parameters(flatten_dict({k: v}, "/"), step=step)

            else:
                experiment.log_parameter(k, v, step=step)

        other_logs = {}
        metric_logs = {}
        system_logs = {}
        episode_logs = {}

        flat_result = flatten_dict(result, delimiter="/")
        for k, v in flat_result.items():
            if any(self._check_key_name(k, item) for item in self._to_exclude):
                continue

            if any(self._check_key_name(k, item) for item in self._to_other):
                other_logs[k] = v

            elif any(
                    self._check_key_name(k, item) for item in self._to_system):
                system_logs[k] = v

            elif any(
                    self._check_key_name(k, item)
                    for item in self._to_episodes):
                episode_logs[k] = v

            else:
                metric_logs[k] = v

        experiment.log_others(other_logs)
        experiment.log_metrics(metric_logs, step=step)

        for k, v in system_logs.items():
            experiment.log_system_info(k, v)

        for k, v in episode_logs.items():
            experiment.log_curve(k, x=range(len(v)), y=v, step=step)
Beispiel #2
0
def _get_trial_info(trial, parameters, metrics):
    """Returns the following information about a trial:

    name | status | loc | params... | metrics...

    Args:
        trial (Trial): Trial to get information for.
        parameters (list[str]): Names of trial parameters to include.
        metrics (list[str]): Names of metrics to include.
    """
    result = flatten_dict(trial.last_result)
    config = flatten_dict(trial.config)
    trial_info = [str(trial), trial.status, str(trial.location)]
    trial_info += [config.get(param) for param in parameters]
    trial_info += [result.get(metric) for metric in metrics]
    return trial_info
Beispiel #3
0
    def convert_search_space(spec: Dict) -> Dict:
        spec = flatten_dict(spec, prevent_delimiter=True)
        resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)

        if grid_vars:
            raise ValueError(
                "Grid search parameters cannot be automatically converted "
                "to a BayesOpt search space.")

        def resolve_value(domain: Domain) -> Tuple[float, float]:
            sampler = domain.get_sampler()
            if isinstance(sampler, Quantized):
                logger.warning(
                    "BayesOpt search does not support quantization. "
                    "Dropped quantization.")
                sampler = sampler.get_sampler()

            if isinstance(domain, Float):
                if domain.sampler is not None:
                    logger.warning(
                        "BayesOpt does not support specific sampling methods. "
                        "The {} sampler will be dropped.".format(sampler))
                    return (domain.lower, domain.upper)

            raise ValueError("BayesOpt does not support parameters of type "
                             "`{}`".format(type(domain).__name__))

        # Parameter name is e.g. "a/b/c" for nested dicts
        bounds = {
            "/".join(path): resolve_value(domain)
            for path, domain in domain_vars
        }

        return bounds
Beispiel #4
0
    def on_result(self, result):
        step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]

        tmp = result.copy()
        for k in [
                "config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION
        ]:
            if k in tmp:
                del tmp[k]  # not useful to log these

        flat_result = flatten_dict(tmp, delimiter="/")
        path = ["ray", "tune"]
        valid_result = {}
        for attr, value in flat_result.items():
            full_attr = "/".join(path + [attr])
            if type(value) in VALID_SUMMARY_TYPES:
                valid_result[full_attr] = value
                self._file_writer.add_scalar(
                    full_attr, value, global_step=step)
            elif type(value) is list and len(value) > 0:
                valid_result[full_attr] = value
                self._file_writer.add_histogram(
                    full_attr, value, global_step=step)

        self.last_result = valid_result
        self._file_writer.flush()
Beispiel #5
0
 def update_last_result(self, result, terminate=False):
     result.update(trial_id=self.trial_id, done=terminate)
     if self.experiment_tag:
         result.update(experiment_tag=self.experiment_tag)
     if self.verbose and (terminate or time.time() - self.last_debug >
                          DEBUG_PRINT_INTERVAL):
         print("Result for {}:".format(self))
         print("  {}".format(pretty_print(result).replace("\n", "\n  ")))
         self.last_debug = time.time()
     self.set_location(Location(result.get("node_ip"), result.get("pid")))
     self.last_result = result
     self.last_update_time = time.time()
     self.result_logger.on_result(self.last_result)
     for metric, value in flatten_dict(result).items():
         if isinstance(value, Number):
             if metric not in self.metric_analysis:
                 self.metric_analysis[metric] = {
                     "max": value,
                     "min": value,
                     "avg": value,
                     "last": value
                 }
             else:
                 step = result["training_iteration"] or 1
                 self.metric_analysis[metric]["max"] = max(
                     value, self.metric_analysis[metric]["max"])
                 self.metric_analysis[metric]["min"] = min(
                     value, self.metric_analysis[metric]["min"])
                 self.metric_analysis[metric]["avg"] = 1 / step * (
                     value +
                     (step - 1) * self.metric_analysis[metric]["avg"])
                 self.metric_analysis[metric]["last"] = value
    def on_result(self, result):
        experiment_tag = result.get('experiment_tag', 'no_experiment_tag')
        experiment_id = result.get('experiment_id', 'no_experiment_id')

        if experiment_tag not in self.metrics_queue_dict:
            print("=" * 50)
            print("Setting up new w&b logger")
            print("Experiment tag:", experiment_tag)
            print("Experiment id:", experiment_id)
            config = result.get("config")
            queue = multiprocessing.Queue()
            p = multiprocessing.Process(target=wandb_process,
                                        args=(
                                            queue,
                                            config,
                                        ))
            p.start()
            self.metrics_queue_dict[experiment_tag] = queue
            print("=" * 50)

        queue = self.metrics_queue_dict[experiment_tag]

        tmp = result.copy()
        for k in ["done", "config", "pid", "timestamp"]:
            if k in tmp:
                del tmp[k]

        metrics = {}
        for key, value in flatten_dict(tmp, delimiter="/").items():
            if not isinstance(value, numbers.Number):
                continue
            metrics[key] = value

        queue.put(metrics)
Beispiel #7
0
def _parse_configs(cfg_path):
    try:
        with open(cfg_path) as f:
            cfg_dict = flatten_dict(json.load(f))
    except Exception:
        logger.exception("Config parsing failed.")
    return cfg_dict
Beispiel #8
0
 def _log_hparams(self):
     if hasattr(self, 'trial'):
         if self.trial and self.trial.evaluated_params:
             ep = flatten_dict(self.trial.evaluated_params, '/')
             ep = {format_keys(p): v for p, v in ep.items()}
             for key, value in ep.items():
                 self.client.log_param(self._run_id, key, value)
Beispiel #9
0
    def create_trial_if_possible(self, experiment_spec: Dict,
                                 output_path: str) -> Optional[Trial]:
        logger.debug("creating trial")
        trial_id = Trial.generate_id()
        suggested_config = self.searcher.suggest(trial_id)
        if suggested_config == Searcher.FINISHED:
            self._finished = True
            logger.debug("Searcher has finished.")
            return

        if suggested_config is None:
            return
        spec = copy.deepcopy(experiment_spec)
        spec["config"] = merge_dicts(spec["config"],
                                     copy.deepcopy(suggested_config))

        # Create a new trial_id if duplicate trial is created
        flattened_config = resolve_nested_dict(spec["config"])
        self._counter += 1
        tag = "{0}_{1}".format(str(self._counter),
                               format_vars(flattened_config))
        trial = create_trial_from_spec(
            spec,
            output_path,
            self._parser,
            evaluated_params=flatten_dict(suggested_config),
            experiment_tag=tag,
            trial_id=trial_id)
        return trial
Beispiel #10
0
    def _generate_trials(self, experiment_spec, output_path=""):
        """Generates trials with configurations from `_suggest`.

        Creates a trial_id that is passed into `_suggest`.

        Yields:
            Trial objects constructed according to `spec`
        """
        if "run" not in experiment_spec:
            raise TuneError("Must specify `run` in {}".format(experiment_spec))
        for _ in range(experiment_spec.get("num_samples", 1)):
            trial_id = Trial.generate_id()
            while True:
                suggested_config = self._suggest(trial_id)
                if suggested_config is None:
                    yield None
                else:
                    break
            spec = copy.deepcopy(experiment_spec)
            spec["config"] = merge_dicts(spec["config"],
                                         copy.deepcopy(suggested_config))
            flattened_config = resolve_nested_dict(spec["config"])
            self._counter += 1
            tag = "{0}_{1}".format(str(self._counter),
                                   format_vars(flattened_config))
            yield create_trial_from_spec(
                spec,
                output_path,
                self._parser,
                evaluated_params=flatten_dict(suggested_config),
                experiment_tag=tag,
                trial_id=trial_id)
Beispiel #11
0
        def on_result(self, result):
            step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
            tmp = result.copy()
            for k in [
                    "config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION
            ]:
                if k in tmp:
                    del tmp[k]  # not useful to log these

            # log system usage
            perf = result.get("perf", None)
            if FilteredTBXLogger.log_sys_usage and perf is not None:
                self.log_system_usage(step, perf)

            flat_result = flatten_dict(tmp, delimiter="/")
            path = ["scalars"]
            valid_result = {
                "/".join(path + [attr]): value
                for attr, value in flat_result.items()
                if type(value) in VALID_SUMMARY_TYPES 
                and attr in FilteredTBXLogger.keep_fields
            }

            # log scalars 
            for attr, value in valid_result.items():
                self._file_writer.add_scalar(attr, value, global_step=step)

            # log videos 
            videos = result.get("eval_frames", [])
            if FilteredTBXLogger.log_videos and len(videos) > 0:
                self.log_videos(step, videos, "rollout_frames")

            self.last_result = valid_result
            self._file_writer.flush()
Beispiel #12
0
    def get_all_configs(self, prefix: bool = False) -> Dict[str, Dict]:
        """Returns a list of all configurations.

        Args:
            prefix: If True, flattens the config dict
                and prepends `config/`.

        Returns:
            Dict[str, Dict]: Dict of all configurations of trials, indexed by
                their trial dir.
        """
        fail_count = 0
        for path in self._get_trial_paths():
            try:
                with open(os.path.join(path, EXPR_PARAM_FILE)) as f:
                    config = json.load(f)
                if prefix:
                    self._configs[path] = flatten_dict({CONFIG_PREFIX: config})
                else:
                    self._configs[path] = config
            except Exception:
                fail_count += 1

        if fail_count:
            logger.warning(
                "Couldn't read config from {} paths".format(fail_count))
        return self._configs
    def on_result(self, result):
        step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]

        # Log scalars
        logged_results = ['episode_reward_max', 'episode_reward_mean', 'episode_reward_min', 'episode_len_mean',
                          'custom_metrics', 'sampler_perf', 'info', 'perf']
        result_copy = result.copy()
        for key, val in result.items():
            if key not in logged_results:
                del result_copy[key]
        flat_result = flatten_dict(result_copy, delimiter="/")
        self.wandb_run.log(flat_result, step=step, sync=False)

        # Log histograms
        for key, val in result['hist_stats'].items():
            try:
                if key != '_robot_coordinates':
                    self.wandb_run.log({"Histograms/"+key: wandb.Histogram(val)}, step=step, sync=False)
            except ValueError:
                logger.warning("Unable to log histogram for {}".format(key))

        # Log trajectories
        traj_fig = plot_trajectories(result['hist_stats']['_robot_coordinates'])
        traj_fig.savefig("Trajectory.png")
        self.wandb_run.log({'Episode Trajectories': wandb.Image(traj_fig)}, step=step, sync=False)
        plt.close(traj_fig)
Beispiel #14
0
    def _try_log_hparams(self, trial: "Trial", result: Dict):
        # TBX currently errors if the hparams value is None.
        flat_params = flatten_dict(trial.evaluated_params)
        scrubbed_params = {
            k: v
            for k, v in flat_params.items()
            if isinstance(v, self.VALID_HPARAMS)
        }

        removed = {
            k: v
            for k, v in flat_params.items()
            if not isinstance(v, self.VALID_HPARAMS)
        }
        if removed:
            logger.info(
                "Removed the following hyperparameter values when "
                "logging to tensorboard: %s", str(removed))

        from tensorboardX.summary import hparams
        try:
            experiment_tag, session_start_tag, session_end_tag = hparams(
                hparam_dict=scrubbed_params, metric_dict=result)
            self._trial_writer[trial].file_writer.add_summary(experiment_tag)
            self._trial_writer[trial].file_writer.add_summary(
                session_start_tag)
            self._trial_writer[trial].file_writer.add_summary(session_end_tag)
        except Exception:
            logger.exception("TensorboardX failed to log hparams. "
                             "This may be due to an unsupported type "
                             "in the hyperparameter values.")
Beispiel #15
0
def _dict_hash(config, precision):
    flatconfig = flatten_dict(config)
    for param, value in flatconfig.items():
        if isinstance(value, float):
            flatconfig[param] = "{:.{digits}f}".format(value, digits=precision)

    hashed = json.dumps(flatconfig, sort_keys=True, default=str)
    return hashed
Beispiel #16
0
def to_tf_values(result, path):
    flat_result = flatten_dict(result, delimiter="/")
    values = [
        tf.Summary.Value(tag="/".join(path + [attr]), simple_value=value)
        for attr, value in flat_result.items()
        if type(value) in VALID_SUMMARY_TYPES
    ]
    return values
Beispiel #17
0
def get_result(root, friction_list, num_episodes=20, select_key=None):
    results = []

    root = os.path.abspath(root)
    a_count = 0
    for p in os.listdir(root):
        trial = osp.join(root, p)
        if not osp.isdir(trial):
            continue

        if select_key is not None:
            if select_key not in trial:
                print(
                    "We filter out the trial: {} since it does not contain keyword: {}."
                    .format(trial, select_key))
                continue

        assert p.startswith("PPO")
        a_count += 1

        exps = [pp for pp in os.listdir(trial) if pp.startswith("checkpoint")]
        exps.sort(key=lambda v: eval(v.split("_")[1]))
        if not exps:
            print("Empty!")
        assert exps

        # Largest checkpoint index
        ckpt = osp.join(trial, exps[-1], exps[-1].replace("_", "-"))

        config_file = osp.join(trial, "params.json")
        with open(config_file, "r") as f:
            config = json.load(f)
            config = flatten_dict(config)

        for friction in friction_list:
            trainer = get_trainer(friction, ckpt)

            print("\n===== Start Evaluating {} agent =====".format(a_count))
            print("friction: {}\nCheckpoint: {}\n".format(friction, ckpt))

            # evaluate
            ret = evaluate(trainer, num_episodes)

            results.append(
                dict(path=ckpt,
                     trial=trial,
                     friction=friction,
                     **ret,
                     **config))

            trainer.cleanup()
            del trainer

            print("Finish {} agents.".format(a_count))

    ret = pd.DataFrame(results)
    return ret
Beispiel #18
0
 def results_df(self) -> DataFrame:
     if not pd:
         raise ValueError("`best_result_df` requires pandas. Install with "
                          "`pip install pandas`.")
     return pd.DataFrame.from_records([
         flatten_dict(trial.last_result, delimiter=".")
         for trial in self.trials
     ],
                                      index="trial_id")
Beispiel #19
0
    def convert_search_space(spec: Dict, join: bool = False) -> Dict:
        resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)

        if grid_vars:
            raise ValueError(
                "Grid search parameters cannot be automatically converted "
                "to a SkOpt search space.")

        # Flatten and resolve again after checking for grid search.
        spec = flatten_dict(spec, prevent_delimiter=True)
        resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)

        def resolve_value(domain: Domain) -> Union[Tuple, List]:
            sampler = domain.get_sampler()
            if isinstance(sampler, Quantized):
                logger.warning("SkOpt search does not support quantization. "
                               "Dropped quantization.")
                sampler = sampler.get_sampler()

            if isinstance(domain, Float):
                if isinstance(domain.sampler, LogUniform):
                    return sko.space.Real(domain.lower,
                                          domain.upper,
                                          prior="log-uniform")
                return sko.space.Real(domain.lower,
                                      domain.upper,
                                      prior="uniform")

            elif isinstance(domain, Integer):
                if isinstance(domain.sampler, LogUniform):
                    return sko.space.Integer(domain.lower,
                                             domain.upper - 1,
                                             prior="log-uniform")
                return sko.space.Integer(domain.lower,
                                         domain.upper - 1,
                                         prior="uniform")

            elif isinstance(domain, Categorical):
                return sko.space.Categorical(domain.categories)

            raise ValueError("SkOpt does not support parameters of type "
                             "`{}` with samplers of type `{}`".format(
                                 type(domain).__name__,
                                 type(domain.sampler).__name__))

        # Parameter name is e.g. "a/b/c" for nested dicts
        space = {
            "/".join(path): resolve_value(domain)
            for path, domain in domain_vars
        }

        if join:
            spec.update(space)
            space = spec

        return space
Beispiel #20
0
    def on_result(self, result):
        """
        The following is copied from the parent class; however, non-serializable
        config values are saved as the repr's so that they are all yaml
        serializable. See for details:
            - https://github.com/wandb/client/issues/586
        """

        config = deepcopy(result.get("config"))
        if config and self._config is None:
            for k in config.keys():
                if wandb.config.get(k) is None:
                    s = repr(config[k])
                    try:
                        ast.literal_eval(s)
                        wandb.config[k] = config[k]
                    except (ValueError, SyntaxError):
                        # Non-serializable
                        wandb.config[k] = s
            self._config = config

        tmp = result.copy()
        for k in ["done", "config", "pid", "timestamp"]:
            if k in tmp:
                del tmp[k]

        if self.result_to_time_series_fn is not None:
            assert self._config is not None
            time_series_dict = self.result_to_time_series_fn(tmp, self._config)
            for t, d in sorted(time_series_dict.items(), key=lambda x: x[0]):
                metrics = {}
                for key, value in flatten_dict(d, delimiter="/").items():
                    if not isinstance(value, self.accepted_types):
                        continue
                    metrics[key] = value
                wandb.log(metrics, step=t)
        else:
            metrics = {}
            for key, value in flatten_dict(tmp, delimiter="/").items():
                if not isinstance(value, self.accepted_types):
                    continue
                metrics[key] = value
            wandb.log(metrics)
Beispiel #21
0
 def results_df(self) -> DataFrame:
     """Get all the last results as a pandas dataframe."""
     if not pd:
         raise ValueError("`results_df` requires pandas. Install with "
                          "`pip install pandas`.")
     return pd.DataFrame.from_records([
         flatten_dict(trial.last_result, delimiter=self._delimiter())
         for trial in self.trials
     ],
                                      index="trial_id")
Beispiel #22
0
def _parse_results(res_path):
    res_dict = {}
    try:
        with open(res_path) as f:
            # Get last line in file
            for line in f:
                pass
        res_dict = flatten_dict(json.loads(line.strip()))
    except Exception:
        logger.exception("Importing %s failed...Perhaps empty?" % res_path)
    return res_dict
Beispiel #23
0
 def close(self):
     if self._file_writer is not None:
         if self.trial and self.trial.evaluated_params and self.last_result:
             flat_result = flatten_dict(self.last_result, delimiter="/")
             scrubbed_result = {
                 k: value
                 for k, value in flat_result.items()
                 if isinstance(value, tuple(VALID_SUMMARY_TYPES))
             }
             self._try_log_hparams(scrubbed_result)
         self._file_writer.close()
Beispiel #24
0
    def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
        if trial not in self._trial_writer:
            self.log_trial_start(trial)

        step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]

        tmp = result.copy()
        for k in [
                "config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION
        ]:
            if k in tmp:
                del tmp[k]  # not useful to log these

        flat_result = flatten_dict(tmp, delimiter="/")
        path = ["ray", "tune"]
        valid_result = {}

        for attr, value in flat_result.items():
            full_attr = "/".join(path + [attr])
            if isinstance(value,
                          tuple(VALID_SUMMARY_TYPES)) and not np.isnan(value):
                valid_result[full_attr] = value
                self._trial_writer[trial].add_scalar(full_attr,
                                                     value,
                                                     global_step=step)
            elif (isinstance(value, list)
                  and len(value) > 0) or (isinstance(value, np.ndarray)
                                          and value.size > 0):
                valid_result[full_attr] = value

                # Must be video
                if isinstance(value, np.ndarray) and value.ndim == 5:
                    self._trial_writer[trial].add_video(full_attr,
                                                        value,
                                                        global_step=step,
                                                        fps=20)
                    continue

                try:
                    self._trial_writer[trial].add_histogram(full_attr,
                                                            value,
                                                            global_step=step)
                # In case TensorboardX still doesn't think it's a valid value
                # (e.g. `[[]]`), warn and move on.
                except (ValueError, TypeError):
                    if log_once("invalid_tbx_value"):
                        logger.warning(
                            "You are trying to log an invalid value ({}={}) "
                            "via {}!".format(full_attr, value,
                                             type(self).__name__))

        self._trial_result[trial] = valid_result
        self._trial_writer[trial].flush()
Beispiel #25
0
 def log_trial_end(self, trial: "Trial", failed: bool = False):
     if trial in self._trial_writer:
         if trial and trial.evaluated_params and self._trial_result[trial]:
             flat_result = flatten_dict(self._trial_result[trial], delimiter="/")
             scrubbed_result = {
                 k: value
                 for k, value in flat_result.items()
                 if isinstance(value, tuple(VALID_SUMMARY_TYPES))
             }
             self._try_log_hparams(trial, scrubbed_result)
         self._trial_writer[trial].close()
         del self._trial_writer[trial]
         del self._trial_result[trial]
Beispiel #26
0
 def on_result(self, result: Dict):
     tmp = result.copy()
     if "config" in tmp:
         del tmp["config"]
     result = flatten_dict(tmp, delimiter="/")
     if self._csv_out is None:
         self._csv_out = csv.DictWriter(self._file, result.keys())
         if not self._continuing:
             self._csv_out.writeheader()
     self._csv_out.writerow(
         {k: v
          for k, v in result.items() if k in self._csv_out.fieldnames})
     self._file.flush()
Beispiel #27
0
 def log_result(result: dict):
     # Avoid logging the config every iteration
     # Only log Jsonable objects
     tmp = result.copy()
     for k in ["done", "config", "pid", "timestamp"]:
         if k in tmp:
             del tmp[k]
     metrics = {}
     for key, value in flatten_dict(tmp, delimiter="/").items():
         if not isinstance(value, numbers.Number):
             continue
         metrics[key] = value
     wandb.log(metrics)
Beispiel #28
0
 def _try_log_hparams(self, result):
     # TBX currently errors if the hparams value is None.
     flat_params = flatten_dict(self.trial.evaluated_params)
     scrubbed_params = {
         k: v
         for k, v in flat_params.items() if v is not None
     }
     from tensorboardX.summary import hparams
     experiment_tag, session_start_tag, session_end_tag = hparams(
         hparam_dict=scrubbed_params, metric_dict=result)
     self._file_writer.file_writer.add_summary(experiment_tag)
     self._file_writer.file_writer.add_summary(session_start_tag)
     self._file_writer.file_writer.add_summary(session_end_tag)
    def __init__(self, config, logdir, trial):
        super(WeightsAndBiasesLogger, self).__init__(config, logdir, trial)
        # logger.warning("WeightsAndBiasesLogger.__init__() called! Trial.experiment_tag: {}".format(trial.experiment_tag))

        self.trial = trial
        self.experiment_tag = trial.experiment_tag
        self.wandb_run = wandb.init(project=weights_and_biases_project,
                                    name=config['env_config']['experiment_name'] + '_' + trial.experiment_tag,
                                    reinit=True)
        valid_config = config.copy()
        del valid_config['callbacks']
        valid_config = flatten_dict(valid_config, delimiter="/")
        self.wandb_run.config.update(valid_config, allow_val_change=True)
Beispiel #30
0
 def _get_report_dict(self, evals_log: Dict[str, Dict[str, list]]) -> dict:
     result_dict = flatten_dict(evals_log, delimiter="-")
     if not self._metrics:
         report_dict = result_dict
     else:
         report_dict = {}
         for key in self._metrics:
             if isinstance(self._metrics, dict):
                 metric = self._metrics[key]
             else:
                 metric = key
             report_dict[key] = result_dict[metric]
     return report_dict