def _get_trial_info(trial: Trial, parameters: List[str], metrics: List[str], max_column_length: int = 20): """Returns the following information about a trial: name | status | loc | params... | metrics... Args: trial: Trial to get information for. parameters: Names of trial parameters to include. metrics: Names of metrics to include. max_column_length: Maximum column length (in characters). """ result = trial.last_result config = trial.config location = _get_trial_location(trial, result) trial_info = [str(trial), trial.status, str(location)] trial_info += [ _max_len( unflattened_lookup(param, config, default=None), max_len=max_column_length, add_addr=True, ) for param in parameters ] trial_info += [ _max_len( unflattened_lookup(metric, result, default=None), max_len=max_column_length, add_addr=True, ) for metric in metrics ] return trial_info
def _get_trial_info(trial: Trial, parameters: List[str], metrics: List[str]): """ Returns the following information about a trial: params... | metrics... @param trial: Trial to get information for. @param parameters: List of names of trial parameters to include. @param metrics: List of names of metrics to include. @return: List of column values. """ result = trial.last_result config = trial.config trial_info = [] trial_info += [unflattened_lookup(param, config, default=None) for param in parameters] trial_info += [unflattened_lookup(metric, result, default=None) for metric in metrics] return trial_info
def trial_progress_table(trials: List[Trial], metric: str, metric_columns: Union[List[str], Dict[str, str]], parameter_columns: Union[None, List[str], Dict[str, str]] = None, fmt: str = "psql", max_rows: Optional[int] = None): """ Create table view for trials. @param trials: List of trials to get progress table string for. @param metric: Metric to use. @param metric_columns: Names of metrics to include. If this is a dict, the keys are metric names and the values are the names to use in the message. If this is a list, the metric name is used in the message directly. @param parameter_columns: Names of parameters to include. If this is a dict, the keys are parameter names and the values are the names to use in the message. If this is a list, the parameter name is used in the message directly. If this is empty, all parameters are used in the message. @param fmt: Output format (see tablefmt in tabulate API). @param max_rows: Maximum number of rows in the trial table. Defaults to unlimited. @return: List of messages/rows. """ messages = [] num_trials = len(trials) max_rows = max_rows or float("inf") if num_trials > max_rows: trials = _get_trials_by_order(trials, metric, max_rows) overflow = num_trials - max_rows else: overflow = False trials = _get_trials_by_order(trials, metric, max_rows) if isinstance(metric_columns, Mapping): metric_keys = list(metric_columns.keys()) else: metric_keys = metric_columns metric_keys = [k for k in metric_keys if any(unflattened_lookup(k, t.last_result, default=None) is not None for t in trials)] if not parameter_columns: parameter_keys = sorted(set().union(*[t.evaluated_params for t in trials])) elif isinstance(parameter_columns, Mapping): parameter_keys = list(parameter_columns.keys()) else: parameter_keys = parameter_columns trial_table = [_get_trial_info(trial, parameter_keys, metric_keys) for trial in trials] if isinstance(metric_columns, Mapping): formatted_metric_columns = [metric_columns[k] for k in metric_keys] else: formatted_metric_columns = metric_keys if isinstance(parameter_columns, Mapping): formatted_parameter_columns = [ parameter_columns[k] for k in parameter_keys ] else: formatted_parameter_columns = parameter_keys columns = (formatted_parameter_columns + formatted_metric_columns) messages.append(tabulate(trial_table, headers=columns, tablefmt=fmt, showindex=False)) if overflow: messages.append("... {} more trials not shown".format(overflow)) return messages
def _get_trial_info(trial, parameters, metrics): """Returns the following information about a trial: name | status | loc | params... | metrics... Args: trial (Trial): Trial to get information for. parameters (list[str]): Names of trial parameters to include. metrics (list[str]): Names of metrics to include. """ result = trial.last_result config = trial.config trial_info = [str(trial), trial.status, str(trial.location)] trial_info += [unflattened_lookup(param, config) for param in parameters] trial_info += [unflattened_lookup(metric, result) for metric in metrics] return trial_info
def best_trial_str( trial: Trial, metric: str, parameter_columns: Union[None, List[str], Dict[str, str]] = None): """Returns a readable message stating the current best trial.""" val = trial.last_result[metric] config = trial.last_result.get("config", {}) parameter_columns = parameter_columns or list(config.keys()) if isinstance(parameter_columns, Mapping): parameter_columns = parameter_columns.keys() params = {p: unflattened_lookup(p, config) for p in parameter_columns} return f"Current best trial: {trial.trial_id} with {metric}={val} and " \ f"parameters={params}"
def trial_progress_table( trials: List[Trial], metric_columns: Union[List[str], Dict[str, str]], parameter_columns: Union[None, List[str], Dict[str, str]] = None, fmt: str = "psql", max_rows: Optional[int] = None, metric: Optional[str] = None, mode: Optional[str] = None, sort_by_metric: bool = False): messages = [] num_trials = len(trials) trials_by_state = _get_trials_by_state(trials) # Sort terminated trials by metric and mode, descending if mode is "max" if sort_by_metric: trials_by_state[Trial.TERMINATED] = sorted( trials_by_state[Trial.TERMINATED], reverse=(mode == "max"), key=lambda t: t.last_result[metric]) state_tbl_order = [ Trial.RUNNING, Trial.PAUSED, Trial.PENDING, Trial.TERMINATED, Trial.ERROR ] max_rows = max_rows or float("inf") if num_trials > max_rows: # TODO(ujvl): suggestion for users to view more rows. trials_by_state_trunc = _fair_filter_trials(trials_by_state, max_rows, sort_by_metric) trials = [] overflow_strs = [] for state in state_tbl_order: if state not in trials_by_state: continue trials += trials_by_state_trunc[state] num = len(trials_by_state[state]) - len( trials_by_state_trunc[state]) if num > 0: overflow_strs.append("{} {}".format(num, state)) # Build overflow string. overflow = num_trials - max_rows overflow_str = ", ".join(overflow_strs) else: overflow = False overflow_str = "" trials = [] for state in state_tbl_order: if state not in trials_by_state: continue trials += trials_by_state[state] # Pre-process trials to figure out what columns to show. if isinstance(metric_columns, Mapping): metric_keys = list(metric_columns.keys()) else: metric_keys = metric_columns metric_keys = [ k for k in metric_keys if any( unflattened_lookup(k, t.last_result, default=None) is not None for t in trials) ] if not parameter_columns: parameter_keys = sorted( set().union(*[t.evaluated_params for t in trials])) elif isinstance(parameter_columns, Mapping): parameter_keys = list(parameter_columns.keys()) else: parameter_keys = parameter_columns # Build trial rows. trial_table = [ _get_trial_info(trial, parameter_keys, metric_keys) for trial in trials ] # Format column headings if isinstance(metric_columns, Mapping): formatted_metric_columns = [metric_columns[k] for k in metric_keys] else: formatted_metric_columns = metric_keys if isinstance(parameter_columns, Mapping): formatted_parameter_columns = [ parameter_columns[k] for k in parameter_keys ] else: formatted_parameter_columns = parameter_keys columns = (["Trial name", "status", "loc"] + formatted_parameter_columns + formatted_metric_columns) # Tabulate. messages.append( tabulate(trial_table, headers=columns, tablefmt=fmt, showindex=False)) if overflow: messages.append("... {} more trials not shown ({})".format( overflow, overflow_str)) return messages