def trial_metrics(trial_id: int) -> gql.trials: q = query() steps = q.op.trials_by_pk(id=trial_id).steps( order_by=[gql.steps_order_by(id=gql.order_by.asc)]) steps.id() steps.metrics() steps.validation.metrics() r = q.send() return cast(gql.trials, r.trials_by_pk)
def experiment_trials(experiment_id: int) -> List[gql.trials]: q = query() trials = q.op.experiments_by_pk(id=experiment_id).trials( order_by=[gql.trials_order_by(id=gql.order_by.asc)]) trials.id() trials.state() trials.warm_start_checkpoint_id() steps = trials.steps(order_by=[gql.steps_order_by(id=gql.order_by.asc)]) steps.id() steps.state() steps.checkpoint.id() steps.checkpoint.state() steps.checkpoint.step_id() steps.checkpoint.uuid() steps.validation.metrics() steps.validation.state() r = q.send() return cast(List[gql.trials], r.experiments_by_pk.trials)
def follow_test_experiment_logs(master_url: str, exp_id: int) -> None: def print_progress(active_stage: int, ended: bool) -> None: # There are four sequential stages of verification. Track the # current stage with an index into this list. stages = [ "Scheduling task", "Testing training", "Testing validation", "Testing checkpointing", ] for idx, stage in enumerate(stages): if active_stage > idx: color = "green" checkbox = "✔" elif active_stage == idx: color = "red" if ended else "yellow" checkbox = "✗" if ended else " " else: color = "white" checkbox = " " print(colored(stage + (25 - len(stage)) * ".", color), end="") print(colored(" [" + checkbox + "]", color), end="") if idx == len(stages) - 1: print("\n" if ended else "\r", end="") else: print(", ", end="") q = api.GraphQLQuery(master_url) exp = q.op.experiments_by_pk(id=exp_id) exp.state() steps = exp.trials.steps( order_by=[gql.steps_order_by(id=gql.order_by.asc)]) steps.checkpoint().id() steps.validation().id() while True: exp = q.send().experiments_by_pk # Wait for experiment to start and initialize a trial and step. step = None if exp.trials and exp.trials[0].steps: step = exp.trials[0].steps[0] # Update the active stage by examining the status of the experiment. The way the GraphQL # library works is that the checkpoint and validation attributes of a step are always # present and non-None, but they don't have any attributes of their own when the # corresponding database object doesn't exist. if exp.state == constants.COMPLETED: active_stage = 4 elif step and hasattr(step.checkpoint, "id"): active_stage = 3 elif step and hasattr(step.validation, "id"): active_stage = 2 elif step: active_stage = 1 else: active_stage = 0 # If the experiment is in a terminal state, output the appropriate # message and exit. Otherwise, sleep and repeat. if exp.state == "COMPLETED": print_progress(active_stage, ended=True) print(colored("Model definition test succeeded! 🎉", "green")) return elif exp.state == constants.CANCELED: print_progress(active_stage, ended=True) print( colored( "Model definition test (ID: {}) canceled before " "model test could complete. Please re-run the " "command.".format(exp_id), "yellow", )) sys.exit(1) elif exp.state == constants.ERROR: print_progress(active_stage, ended=True) trial_id = exp.trials[0].id logs_args = Namespace(trial_id=trial_id, master=master_url, tail=None, follow=False) logs(logs_args) sys.exit(1) else: print_progress(active_stage, ended=False) time.sleep(0.2)
def describe_trial(args: Namespace) -> None: q = api.GraphQLQuery(args.master) trial = q.op.trials_by_pk(id=args.trial_id) trial.end_time() trial.experiment_id() trial.hparams() trial.start_time() trial.state() steps = trial.steps(order_by=[gql.steps_order_by(id=gql.order_by.asc)]) steps.metrics() steps.id() steps.state() steps.start_time() steps.end_time() checkpoint_gql = steps.checkpoint() checkpoint_gql.state() checkpoint_gql.uuid() validation = steps.validation() validation.state() validation.metrics() resp = q.send() if args.json: print(json.dumps(resp.trials_by_pk.__to_json_value__(), indent=4)) return trial = resp.trials_by_pk # Print information about the trial itself. headers = ["Experiment ID", "State", "H-Params", "Start Time", "End Time"] values = [ [ trial.experiment_id, trial.state, json.dumps(trial.hparams, indent=4), render.format_time(trial.start_time), render.format_time(trial.end_time), ] ] render.tabulate_or_csv(headers, values, args.csv) # Print information about individual steps. headers = [ "Step #", "State", "Start Time", "End Time", "Checkpoint", "Checkpoint UUID", "Validation", "Validation Metrics", ] if args.metrics: headers.append("Step Metrics") values = [ [ s.id, s.state, render.format_time(s.start_time), render.format_time(s.end_time), *format_checkpoint(s.checkpoint), *format_validation(s.validation), *([json.dumps(s.metrics, indent=4)] if args.metrics else []), ] for s in trial.steps ] print() print("Steps:") render.tabulate_or_csv(headers, values, args.csv)
def describe(args: Namespace) -> None: ids = [int(x) for x in args.experiment_ids.split(",")] q = api.GraphQLQuery(args.master) exps = q.op.experiments(where=gql.experiments_bool_exp( id=gql.Int_comparison_exp(_in=ids))) exps.archived() exps.config() exps.end_time() exps.id() exps.progress() exps.start_time() exps.state() trials = exps.trials(order_by=[gql.trials_order_by(id=gql.order_by.asc)]) trials.end_time() trials.hparams() trials.id() trials.start_time() trials.state() steps = trials.steps(order_by=[gql.steps_order_by(id=gql.order_by.asc)]) steps.end_time() steps.id() steps.start_time() steps.state() steps.trial_id() steps.checkpoint.end_time() steps.checkpoint.start_time() steps.checkpoint.state() steps.validation.end_time() steps.validation.start_time() steps.validation.state() if args.metrics: steps.metrics(path="avg_metrics") steps.validation.metrics() resp = q.send() # Re-sort the experiment objects to match the original order. exps_by_id = {e.id: e for e in resp.experiments} experiments = [exps_by_id[id] for id in ids] if args.json: print(json.dumps(resp.__to_json_value__()["experiments"], indent=4)) return # Display overall experiment information. headers = [ "Experiment ID", "State", "Progress", "Start Time", "End Time", "Description", "Archived", "Labels", ] values = [[ e.id, e.state, render.format_percent(e.progress), render.format_time(e.start_time), render.format_time(e.end_time), e.config.get("description"), e.archived, ", ".join(sorted(e.config.get("labels", []))), ] for e in experiments] if not args.outdir: outfile = None print("Experiment:") else: outfile = args.outdir.joinpath("experiments.csv") render.tabulate_or_csv(headers, values, args.csv, outfile) # Display trial-related information. headers = [ "Trial ID", "Experiment ID", "State", "Start Time", "End Time", "H-Params" ] values = [[ t.id, e.id, t.state, render.format_time(t.start_time), render.format_time(t.end_time), json.dumps(t.hparams, indent=4), ] for e in experiments for t in e.trials] if not args.outdir: outfile = None print("\nTrials:") else: outfile = args.outdir.joinpath("trials.csv") render.tabulate_or_csv(headers, values, args.csv, outfile) # Display step-related information. if args.metrics: # Accumulate the scalar training and validation metric names from all provided experiments. t_metrics_names = sorted( {n for e in experiments for n in scalar_training_metrics_names(e)}) t_metrics_headers = [ "Training Metric: {}".format(name) for name in t_metrics_names ] v_metrics_names = sorted({ n for e in experiments for n in scalar_validation_metrics_names(e) }) v_metrics_headers = [ "Validation Metric: {}".format(name) for name in v_metrics_names ] else: t_metrics_headers = [] v_metrics_headers = [] headers = (["Trial ID", "Step ID", "State", "Start Time", "End Time"] + t_metrics_headers + [ "Checkpoint State", "Checkpoint Start Time", "Checkpoint End Time", "Validation State", "Validation Start Time", "Validation End Time", ] + v_metrics_headers) values = [] for e in experiments: for t in e.trials: for step in t.steps: t_metrics_fields = [] if hasattr(step, "metrics"): avg_metrics = step.metrics for name in t_metrics_names: if name in avg_metrics: t_metrics_fields.append(avg_metrics[name]) else: t_metrics_fields.append(None) checkpoint = step.checkpoint if checkpoint: checkpoint_state = checkpoint.state checkpoint_start_time = checkpoint.start_time checkpoint_end_time = checkpoint.end_time else: checkpoint_state = None checkpoint_start_time = None checkpoint_end_time = None validation = step.validation if validation: validation_state = validation.state validation_start_time = validation.start_time validation_end_time = validation.end_time else: validation_state = None validation_start_time = None validation_end_time = None if args.metrics: v_metrics_fields = [ api.metric.get_validation_metric(name, validation) for name in v_metrics_names ] else: v_metrics_fields = [] row = ([ step.trial_id, step.id, step.state, render.format_time(step.start_time), render.format_time(step.end_time), ] + t_metrics_fields + [ checkpoint_state, render.format_time(checkpoint_start_time), render.format_time(checkpoint_end_time), validation_state, render.format_time(validation_start_time), render.format_time(validation_end_time), ] + v_metrics_fields) values.append(row) if not args.outdir: outfile = None print("\nSteps:") else: outfile = args.outdir.joinpath("steps.csv") render.tabulate_or_csv(headers, values, args.csv, outfile)