示例#1
0
def trial_metrics(trial_id: int) -> gql.trials:
    q = query()
    steps = q.op.trials_by_pk(id=trial_id).steps(
        order_by=[gql.steps_order_by(id=gql.order_by.asc)])
    steps.id()
    steps.metrics()
    steps.validation.metrics()
    r = q.send()
    return cast(gql.trials, r.trials_by_pk)
示例#2
0
def experiment_trials(experiment_id: int) -> List[gql.trials]:
    q = query()
    trials = q.op.experiments_by_pk(id=experiment_id).trials(
        order_by=[gql.trials_order_by(id=gql.order_by.asc)])
    trials.id()
    trials.state()
    trials.warm_start_checkpoint_id()

    steps = trials.steps(order_by=[gql.steps_order_by(id=gql.order_by.asc)])
    steps.id()
    steps.state()
    steps.checkpoint.id()
    steps.checkpoint.state()
    steps.checkpoint.step_id()
    steps.checkpoint.uuid()
    steps.validation.metrics()
    steps.validation.state()
    r = q.send()
    return cast(List[gql.trials], r.experiments_by_pk.trials)
示例#3
0
def follow_test_experiment_logs(master_url: str, exp_id: int) -> None:
    def print_progress(active_stage: int, ended: bool) -> None:
        # There are four sequential stages of verification. Track the
        # current stage with an index into this list.
        stages = [
            "Scheduling task",
            "Testing training",
            "Testing validation",
            "Testing checkpointing",
        ]

        for idx, stage in enumerate(stages):
            if active_stage > idx:
                color = "green"
                checkbox = "✔"
            elif active_stage == idx:
                color = "red" if ended else "yellow"
                checkbox = "✗" if ended else " "
            else:
                color = "white"
                checkbox = " "
            print(colored(stage + (25 - len(stage)) * ".", color), end="")
            print(colored(" [" + checkbox + "]", color), end="")

            if idx == len(stages) - 1:
                print("\n" if ended else "\r", end="")
            else:
                print(", ", end="")

    q = api.GraphQLQuery(master_url)
    exp = q.op.experiments_by_pk(id=exp_id)
    exp.state()
    steps = exp.trials.steps(
        order_by=[gql.steps_order_by(id=gql.order_by.asc)])
    steps.checkpoint().id()
    steps.validation().id()

    while True:
        exp = q.send().experiments_by_pk

        # Wait for experiment to start and initialize a trial and step.
        step = None
        if exp.trials and exp.trials[0].steps:
            step = exp.trials[0].steps[0]

        # Update the active stage by examining the status of the experiment. The way the GraphQL
        # library works is that the checkpoint and validation attributes of a step are always
        # present and non-None, but they don't have any attributes of their own when the
        # corresponding database object doesn't exist.
        if exp.state == constants.COMPLETED:
            active_stage = 4
        elif step and hasattr(step.checkpoint, "id"):
            active_stage = 3
        elif step and hasattr(step.validation, "id"):
            active_stage = 2
        elif step:
            active_stage = 1
        else:
            active_stage = 0

        # If the experiment is in a terminal state, output the appropriate
        # message and exit. Otherwise, sleep and repeat.
        if exp.state == "COMPLETED":
            print_progress(active_stage, ended=True)
            print(colored("Model definition test succeeded! 🎉", "green"))
            return
        elif exp.state == constants.CANCELED:
            print_progress(active_stage, ended=True)
            print(
                colored(
                    "Model definition test (ID: {}) canceled before "
                    "model test could complete. Please re-run the "
                    "command.".format(exp_id),
                    "yellow",
                ))
            sys.exit(1)
        elif exp.state == constants.ERROR:
            print_progress(active_stage, ended=True)
            trial_id = exp.trials[0].id
            logs_args = Namespace(trial_id=trial_id,
                                  master=master_url,
                                  tail=None,
                                  follow=False)
            logs(logs_args)
            sys.exit(1)
        else:
            print_progress(active_stage, ended=False)
            time.sleep(0.2)
示例#4
0
def describe_trial(args: Namespace) -> None:
    q = api.GraphQLQuery(args.master)
    trial = q.op.trials_by_pk(id=args.trial_id)
    trial.end_time()
    trial.experiment_id()
    trial.hparams()
    trial.start_time()
    trial.state()

    steps = trial.steps(order_by=[gql.steps_order_by(id=gql.order_by.asc)])
    steps.metrics()
    steps.id()
    steps.state()
    steps.start_time()
    steps.end_time()

    checkpoint_gql = steps.checkpoint()
    checkpoint_gql.state()
    checkpoint_gql.uuid()

    validation = steps.validation()
    validation.state()
    validation.metrics()

    resp = q.send()

    if args.json:
        print(json.dumps(resp.trials_by_pk.__to_json_value__(), indent=4))
        return

    trial = resp.trials_by_pk

    # Print information about the trial itself.
    headers = ["Experiment ID", "State", "H-Params", "Start Time", "End Time"]
    values = [
        [
            trial.experiment_id,
            trial.state,
            json.dumps(trial.hparams, indent=4),
            render.format_time(trial.start_time),
            render.format_time(trial.end_time),
        ]
    ]
    render.tabulate_or_csv(headers, values, args.csv)

    # Print information about individual steps.
    headers = [
        "Step #",
        "State",
        "Start Time",
        "End Time",
        "Checkpoint",
        "Checkpoint UUID",
        "Validation",
        "Validation Metrics",
    ]
    if args.metrics:
        headers.append("Step Metrics")

    values = [
        [
            s.id,
            s.state,
            render.format_time(s.start_time),
            render.format_time(s.end_time),
            *format_checkpoint(s.checkpoint),
            *format_validation(s.validation),
            *([json.dumps(s.metrics, indent=4)] if args.metrics else []),
        ]
        for s in trial.steps
    ]

    print()
    print("Steps:")
    render.tabulate_or_csv(headers, values, args.csv)
示例#5
0
def describe(args: Namespace) -> None:
    ids = [int(x) for x in args.experiment_ids.split(",")]

    q = api.GraphQLQuery(args.master)
    exps = q.op.experiments(where=gql.experiments_bool_exp(
        id=gql.Int_comparison_exp(_in=ids)))
    exps.archived()
    exps.config()
    exps.end_time()
    exps.id()
    exps.progress()
    exps.start_time()
    exps.state()

    trials = exps.trials(order_by=[gql.trials_order_by(id=gql.order_by.asc)])
    trials.end_time()
    trials.hparams()
    trials.id()
    trials.start_time()
    trials.state()

    steps = trials.steps(order_by=[gql.steps_order_by(id=gql.order_by.asc)])
    steps.end_time()
    steps.id()
    steps.start_time()
    steps.state()
    steps.trial_id()

    steps.checkpoint.end_time()
    steps.checkpoint.start_time()
    steps.checkpoint.state()

    steps.validation.end_time()
    steps.validation.start_time()
    steps.validation.state()

    if args.metrics:
        steps.metrics(path="avg_metrics")
        steps.validation.metrics()

    resp = q.send()

    # Re-sort the experiment objects to match the original order.
    exps_by_id = {e.id: e for e in resp.experiments}
    experiments = [exps_by_id[id] for id in ids]

    if args.json:
        print(json.dumps(resp.__to_json_value__()["experiments"], indent=4))
        return

    # Display overall experiment information.
    headers = [
        "Experiment ID",
        "State",
        "Progress",
        "Start Time",
        "End Time",
        "Description",
        "Archived",
        "Labels",
    ]
    values = [[
        e.id,
        e.state,
        render.format_percent(e.progress),
        render.format_time(e.start_time),
        render.format_time(e.end_time),
        e.config.get("description"),
        e.archived,
        ", ".join(sorted(e.config.get("labels", []))),
    ] for e in experiments]
    if not args.outdir:
        outfile = None
        print("Experiment:")
    else:
        outfile = args.outdir.joinpath("experiments.csv")
    render.tabulate_or_csv(headers, values, args.csv, outfile)

    # Display trial-related information.
    headers = [
        "Trial ID", "Experiment ID", "State", "Start Time", "End Time",
        "H-Params"
    ]
    values = [[
        t.id,
        e.id,
        t.state,
        render.format_time(t.start_time),
        render.format_time(t.end_time),
        json.dumps(t.hparams, indent=4),
    ] for e in experiments for t in e.trials]
    if not args.outdir:
        outfile = None
        print("\nTrials:")
    else:
        outfile = args.outdir.joinpath("trials.csv")
    render.tabulate_or_csv(headers, values, args.csv, outfile)

    # Display step-related information.
    if args.metrics:
        # Accumulate the scalar training and validation metric names from all provided experiments.
        t_metrics_names = sorted(
            {n
             for e in experiments for n in scalar_training_metrics_names(e)})
        t_metrics_headers = [
            "Training Metric: {}".format(name) for name in t_metrics_names
        ]

        v_metrics_names = sorted({
            n
            for e in experiments for n in scalar_validation_metrics_names(e)
        })
        v_metrics_headers = [
            "Validation Metric: {}".format(name) for name in v_metrics_names
        ]
    else:
        t_metrics_headers = []
        v_metrics_headers = []

    headers = (["Trial ID", "Step ID", "State", "Start Time", "End Time"] +
               t_metrics_headers + [
                   "Checkpoint State",
                   "Checkpoint Start Time",
                   "Checkpoint End Time",
                   "Validation State",
                   "Validation Start Time",
                   "Validation End Time",
               ] + v_metrics_headers)

    values = []
    for e in experiments:
        for t in e.trials:
            for step in t.steps:
                t_metrics_fields = []
                if hasattr(step, "metrics"):
                    avg_metrics = step.metrics
                    for name in t_metrics_names:
                        if name in avg_metrics:
                            t_metrics_fields.append(avg_metrics[name])
                        else:
                            t_metrics_fields.append(None)

                checkpoint = step.checkpoint
                if checkpoint:
                    checkpoint_state = checkpoint.state
                    checkpoint_start_time = checkpoint.start_time
                    checkpoint_end_time = checkpoint.end_time
                else:
                    checkpoint_state = None
                    checkpoint_start_time = None
                    checkpoint_end_time = None

                validation = step.validation
                if validation:
                    validation_state = validation.state
                    validation_start_time = validation.start_time
                    validation_end_time = validation.end_time

                else:
                    validation_state = None
                    validation_start_time = None
                    validation_end_time = None

                if args.metrics:
                    v_metrics_fields = [
                        api.metric.get_validation_metric(name, validation)
                        for name in v_metrics_names
                    ]
                else:
                    v_metrics_fields = []

                row = ([
                    step.trial_id,
                    step.id,
                    step.state,
                    render.format_time(step.start_time),
                    render.format_time(step.end_time),
                ] + t_metrics_fields + [
                    checkpoint_state,
                    render.format_time(checkpoint_start_time),
                    render.format_time(checkpoint_end_time),
                    validation_state,
                    render.format_time(validation_start_time),
                    render.format_time(validation_end_time),
                ] + v_metrics_fields)
                values.append(row)

    if not args.outdir:
        outfile = None
        print("\nSteps:")
    else:
        outfile = args.outdir.joinpath("steps.csv")
    render.tabulate_or_csv(headers, values, args.csv, outfile)