コード例 #1
0
def list_trials(
    experiment_path,
    sort=None,
    output=None,
    filter_op=None,
    info_keys=None,
    limit=None,
    desc=False,
):
    """Lists trials in the directory subtree starting at the given path.

    Args:
        experiment_path (str): Directory where trials are located.
            Like Experiment.local_dir/Experiment.name/experiment*.json.
        sort (list): Keys to sort by.
        output (str): Name of file where output is saved.
        filter_op (str): Filter operation in the format
            "<column> <operator> <value>".
        info_keys (list): Keys that are displayed.
        limit (int): Number of rows to display.
        desc (bool): Sort ascending vs. descending.
    """
    _check_tabulate()

    try:
        checkpoints_df = ExperimentAnalysis(
            experiment_path).dataframe()  # last result
    except TuneError as e:
        raise click.ClickException("No trial data found!") from e

    def key_filter(k):
        return k in DEFAULT_CLI_KEYS or k.startswith(CONFIG_PREFIX)

    col_keys = [k for k in checkpoints_df.columns if key_filter(k)]

    if info_keys:
        for k in info_keys:
            if k not in checkpoints_df.columns:
                raise click.ClickException("Provided key invalid: {}. "
                                           "Available keys: {}.".format(
                                               k, checkpoints_df.columns))
        col_keys = [k for k in checkpoints_df.columns if k in info_keys]

    if not col_keys:
        raise click.ClickException("No columns to output.")

    checkpoints_df = checkpoints_df[col_keys]
    if "last_update_time" in checkpoints_df:
        with pd.option_context("mode.use_inf_as_null", True):
            datetime_series = checkpoints_df["last_update_time"].dropna()

        datetime_series = datetime_series.apply(
            lambda t: datetime.fromtimestamp(t).strftime(TIMESTAMP_FORMAT))
        checkpoints_df["last_update_time"] = datetime_series

    if "logdir" in checkpoints_df:
        # logdir often too long to view in table, so drop experiment_path
        checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace(
            experiment_path, "")

    if filter_op:
        col, op, val = filter_op.split(" ")
        col_type = checkpoints_df[col].dtype
        if is_numeric_dtype(col_type):
            val = float(val)
        elif is_string_dtype(col_type):
            val = str(val)
        # TODO(Andrew): add support for datetime and boolean
        else:
            raise click.ClickException("Unsupported dtype for {}: {}".format(
                val, col_type))
        op = OPERATORS[op]
        filtered_index = op(checkpoints_df[col], val)
        checkpoints_df = checkpoints_df[filtered_index]

    if sort:
        for key in sort:
            if key not in checkpoints_df:
                raise click.ClickException("{} not in: {}".format(
                    key, list(checkpoints_df)))
        ascending = not desc
        checkpoints_df = checkpoints_df.sort_values(by=sort,
                                                    ascending=ascending)

    if limit:
        checkpoints_df = checkpoints_df[:limit]

    print_format_output(checkpoints_df)

    if output:
        file_extension = os.path.splitext(output)[1].lower()
        if file_extension in (".p", ".pkl", ".pickle"):
            checkpoints_df.to_pickle(output)
        elif file_extension == ".csv":
            checkpoints_df.to_csv(output, index=False)
        else:
            raise click.ClickException(
                "Unsupported filetype: {}".format(output))
        click.secho("Output saved at {}".format(output), fg="green")
コード例 #2
0
ファイル: commands.py プロジェクト: quantumahesh/Project-Ray
def list_trials(experiment_path,
                sort=None,
                output=None,
                filter_op=None,
                info_keys=None,
                limit=None,
                desc=False):
    """Lists trials in the directory subtree starting at the given path.

    Args:
        experiment_path (str): Directory where trials are located.
            Corresponds to Experiment.local_dir/Experiment.name.
        sort (list): Keys to sort by.
        output (str): Name of file where output is saved.
        filter_op (str): Filter operation in the format
            "<column> <operator> <value>".
        info_keys (list): Keys that are displayed.
        limit (int): Number of rows to display.
        desc (bool): Sort ascending vs. descending.
    """
    _check_tabulate()

    try:
        checkpoints_df = ExperimentAnalysis(experiment_path).dataframe()
    except TuneError:
        print("No experiment state found!")
        sys.exit(0)

    if not info_keys:
        info_keys = DEFAULT_EXPERIMENT_INFO_KEYS
    col_keys = [k for k in list(info_keys) if k in checkpoints_df]
    checkpoints_df = checkpoints_df[col_keys]

    if "last_update_time" in checkpoints_df:
        with pd.option_context("mode.use_inf_as_null", True):
            datetime_series = checkpoints_df["last_update_time"].dropna()

        datetime_series = datetime_series.apply(
            lambda t: datetime.fromtimestamp(t).strftime(TIMESTAMP_FORMAT))
        checkpoints_df["last_update_time"] = datetime_series

    if "logdir" in checkpoints_df:
        # logdir often too verbose to view in table, so drop experiment_path
        checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace(
            experiment_path, "")

    if filter_op:
        col, op, val = filter_op.split(" ")
        col_type = checkpoints_df[col].dtype
        if is_numeric_dtype(col_type):
            val = float(val)
        elif is_string_dtype(col_type):
            val = str(val)
        # TODO(Andrew): add support for datetime and boolean
        else:
            raise ValueError("Unsupported dtype for {}: {}".format(
                val, col_type))
        op = OPERATORS[op]
        filtered_index = op(checkpoints_df[col], val)
        checkpoints_df = checkpoints_df[filtered_index]

    if sort:
        if sort not in checkpoints_df:
            raise KeyError("{} not in: {}".format(sort, list(checkpoints_df)))
        ascending = not desc
        checkpoints_df = checkpoints_df.sort_values(by=sort,
                                                    ascending=ascending)

    if limit:
        checkpoints_df = checkpoints_df[:limit]

    print_format_output(checkpoints_df)

    if output:
        file_extension = os.path.splitext(output)[1].lower()
        if file_extension in (".p", ".pkl", ".pickle"):
            checkpoints_df.to_pickle(output)
        elif file_extension == ".csv":
            checkpoints_df.to_csv(output, index=False)
        else:
            raise ValueError("Unsupported filetype: {}".format(output))
        print("Output saved at:", output)