Example #1
0
def _get_trial_info(trial: Trial,
                    parameters: List[str],
                    metrics: List[str],
                    max_column_length: int = 20):
    """Returns the following information about a trial:

    name | status | loc | params... | metrics...

    Args:
        trial: Trial to get information for.
        parameters: Names of trial parameters to include.
        metrics: Names of metrics to include.
        max_column_length: Maximum column length (in characters).
    """
    result = trial.last_result
    config = trial.config
    location = _get_trial_location(trial, result)
    trial_info = [str(trial), trial.status, str(location)]
    trial_info += [
        _max_len(
            unflattened_lookup(param, config, default=None),
            max_len=max_column_length,
            add_addr=True,
        ) for param in parameters
    ]
    trial_info += [
        _max_len(
            unflattened_lookup(metric, result, default=None),
            max_len=max_column_length,
            add_addr=True,
        ) for metric in metrics
    ]
    return trial_info
Example #2
0
def _get_trial_info(trial: Trial, parameters: List[str], metrics: List[str]):
    """
    Returns the following information about a trial:
    params... | metrics...
    @param trial: Trial to get information for.
    @param parameters: List of names of trial parameters to include.
    @param metrics: List of names of metrics to include.
    @return: List of column values.
    """
    result = trial.last_result
    config = trial.config
    trial_info = []
    trial_info += [unflattened_lookup(param, config, default=None) for param in parameters]
    trial_info += [unflattened_lookup(metric, result, default=None) for metric in metrics]
    return trial_info
Example #3
0
def trial_progress_table(trials: List[Trial], metric: str, metric_columns: Union[List[str], Dict[str, str]],
                         parameter_columns: Union[None, List[str], Dict[str, str]] = None, fmt: str = "psql",
                         max_rows: Optional[int] = None):
    """
    Create table view for trials.
    @param trials: List of trials to get progress table string for.
    @param metric: Metric to use.
    @param metric_columns: Names of metrics to include. If this is a dict, the keys are metric names and the values are
    the names to use in the message. If this is a list, the metric name is used in the message directly.
    @param parameter_columns: Names of parameters to include. If this is a dict, the keys are parameter names and the
    values are the names to use in the message. If this is a list, the parameter name is used in the message directly.
    If this is empty, all parameters are used in the message.
    @param fmt: Output format (see tablefmt in tabulate API).
    @param max_rows: Maximum number of rows in the trial table. Defaults to unlimited.
    @return: List of messages/rows.
    """
    messages = []
    num_trials = len(trials)

    max_rows = max_rows or float("inf")
    if num_trials > max_rows:
        trials = _get_trials_by_order(trials, metric, max_rows)
        overflow = num_trials - max_rows
    else:
        overflow = False
        trials = _get_trials_by_order(trials, metric, max_rows)

    if isinstance(metric_columns, Mapping):
        metric_keys = list(metric_columns.keys())
    else:
        metric_keys = metric_columns

    metric_keys = [k for k in metric_keys if
                   any(unflattened_lookup(k, t.last_result, default=None) is not None for t in trials)]

    if not parameter_columns:
        parameter_keys = sorted(set().union(*[t.evaluated_params for t in trials]))
    elif isinstance(parameter_columns, Mapping):
        parameter_keys = list(parameter_columns.keys())
    else:
        parameter_keys = parameter_columns

    trial_table = [_get_trial_info(trial, parameter_keys, metric_keys) for trial in trials]

    if isinstance(metric_columns, Mapping):
        formatted_metric_columns = [metric_columns[k] for k in metric_keys]
    else:
        formatted_metric_columns = metric_keys
    if isinstance(parameter_columns, Mapping):
        formatted_parameter_columns = [
            parameter_columns[k] for k in parameter_keys
        ]
    else:
        formatted_parameter_columns = parameter_keys
    columns = (formatted_parameter_columns + formatted_metric_columns)

    messages.append(tabulate(trial_table, headers=columns, tablefmt=fmt, showindex=False))
    if overflow:
        messages.append("... {} more trials not shown".format(overflow))
    return messages
Example #4
0
def _get_trial_info(trial, parameters, metrics):
    """Returns the following information about a trial:

    name | status | loc | params... | metrics...

    Args:
        trial (Trial): Trial to get information for.
        parameters (list[str]): Names of trial parameters to include.
        metrics (list[str]): Names of metrics to include.
    """
    result = trial.last_result
    config = trial.config
    trial_info = [str(trial), trial.status, str(trial.location)]
    trial_info += [unflattened_lookup(param, config) for param in parameters]
    trial_info += [unflattened_lookup(metric, result) for metric in metrics]
    return trial_info
Example #5
0
def best_trial_str(
        trial: Trial,
        metric: str,
        parameter_columns: Union[None, List[str], Dict[str, str]] = None):
    """Returns a readable message stating the current best trial."""
    val = trial.last_result[metric]
    config = trial.last_result.get("config", {})
    parameter_columns = parameter_columns or list(config.keys())
    if isinstance(parameter_columns, Mapping):
        parameter_columns = parameter_columns.keys()
    params = {p: unflattened_lookup(p, config) for p in parameter_columns}
    return f"Current best trial: {trial.trial_id} with {metric}={val} and " \
           f"parameters={params}"
Example #6
0
def trial_progress_table(
        trials: List[Trial],
        metric_columns: Union[List[str], Dict[str, str]],
        parameter_columns: Union[None, List[str], Dict[str, str]] = None,
        fmt: str = "psql",
        max_rows: Optional[int] = None,
        metric: Optional[str] = None,
        mode: Optional[str] = None,
        sort_by_metric: bool = False):
    messages = []
    num_trials = len(trials)
    trials_by_state = _get_trials_by_state(trials)

    # Sort terminated trials by metric and mode, descending if mode is "max"
    if sort_by_metric:
        trials_by_state[Trial.TERMINATED] = sorted(
            trials_by_state[Trial.TERMINATED],
            reverse=(mode == "max"),
            key=lambda t: t.last_result[metric])

    state_tbl_order = [
        Trial.RUNNING, Trial.PAUSED, Trial.PENDING, Trial.TERMINATED,
        Trial.ERROR
    ]
    max_rows = max_rows or float("inf")
    if num_trials > max_rows:
        # TODO(ujvl): suggestion for users to view more rows.
        trials_by_state_trunc = _fair_filter_trials(trials_by_state, max_rows,
                                                    sort_by_metric)
        trials = []
        overflow_strs = []
        for state in state_tbl_order:
            if state not in trials_by_state:
                continue
            trials += trials_by_state_trunc[state]
            num = len(trials_by_state[state]) - len(
                trials_by_state_trunc[state])
            if num > 0:
                overflow_strs.append("{} {}".format(num, state))
        # Build overflow string.
        overflow = num_trials - max_rows
        overflow_str = ", ".join(overflow_strs)
    else:
        overflow = False
        overflow_str = ""
        trials = []
        for state in state_tbl_order:
            if state not in trials_by_state:
                continue
            trials += trials_by_state[state]

    # Pre-process trials to figure out what columns to show.
    if isinstance(metric_columns, Mapping):
        metric_keys = list(metric_columns.keys())
    else:
        metric_keys = metric_columns

    metric_keys = [
        k for k in metric_keys if any(
            unflattened_lookup(k, t.last_result, default=None) is not None
            for t in trials)
    ]

    if not parameter_columns:
        parameter_keys = sorted(
            set().union(*[t.evaluated_params for t in trials]))
    elif isinstance(parameter_columns, Mapping):
        parameter_keys = list(parameter_columns.keys())
    else:
        parameter_keys = parameter_columns

    # Build trial rows.
    trial_table = [
        _get_trial_info(trial, parameter_keys, metric_keys) for trial in trials
    ]
    # Format column headings
    if isinstance(metric_columns, Mapping):
        formatted_metric_columns = [metric_columns[k] for k in metric_keys]
    else:
        formatted_metric_columns = metric_keys
    if isinstance(parameter_columns, Mapping):
        formatted_parameter_columns = [
            parameter_columns[k] for k in parameter_keys
        ]
    else:
        formatted_parameter_columns = parameter_keys
    columns = (["Trial name", "status", "loc"] + formatted_parameter_columns +
               formatted_metric_columns)
    # Tabulate.
    messages.append(
        tabulate(trial_table, headers=columns, tablefmt=fmt, showindex=False))
    if overflow:
        messages.append("... {} more trials not shown ({})".format(
            overflow, overflow_str))
    return messages