def config(args: Namespace) -> None: q = api.GraphQLQuery(args.master) q.op.experiments_by_pk(id=args.experiment_id).config() resp = q.send() yaml.safe_dump(resp.experiments_by_pk.config, stream=sys.stdout, default_flow_style=False)
def download(master: str, trial_id: int, step_id: int, output_dir: str) -> None: q = api.GraphQLQuery(master) step = q.op.steps_by_pk(trial_id=trial_id, id=step_id) step.checkpoint.labels() step.checkpoint.resources() step.checkpoint.uuid() step.trial.experiment.config(path="checkpoint_storage") step.trial.experiment_id() resp = q.send() step = resp.steps_by_pk if not step: raise ValueError("Trial {} step {} not found".format(trial_id, step_id)) if not step.checkpoint: raise ValueError("Trial {} step {} has no checkpoint".format(trial_id, step_id)) storage_config = step.trial.experiment.config manager = storage.build(storage_config) if not ( isinstance(manager, storage.S3StorageManager) or isinstance(manager, storage.GCSStorageManager) ): raise AssertionError( "Downloading from S3 or GCS requires the experiment to be configured with " "S3 or GCS checkpointing, {} found instead".format(storage_config["type"]) ) metadata = storage.StorageMetadata.from_json(step.checkpoint.__to_json_value__()) manager.download(metadata, output_dir)
def follow_experiment_logs(master_url: str, exp_id: int) -> None: # Get the ID of this experiment's first trial (i.e., the one with the lowest ID). q = api.GraphQLQuery(master_url) trials = q.op.trials( where=gql.trials_bool_exp(experiment_id=gql.Int_comparison_exp( _eq=exp_id)), order_by=[gql.trials_order_by(id=gql.order_by.asc)], limit=1, ) trials.id() print("Waiting for first trial to begin...") while True: resp = q.send() if resp.trials: break else: time.sleep(0.1) first_trial_id = resp.trials[0].id print("Following first trial with ID {}".format(first_trial_id)) # Call `logs --follow` on the new trial. logs_args = Namespace(trial_id=first_trial_id, follow=True, master=master_url, tail=None) logs(logs_args)
def list_trials(args: Namespace) -> None: q = api.GraphQLQuery(args.master) trials = q.op.trials( order_by=[gql.trials_order_by(id=gql.order_by.asc)], where=gql.trials_bool_exp(experiment_id=gql.Int_comparison_exp( _eq=args.experiment_id)), ) trials.id() trials.state() trials.hparams() trials.start_time() trials.end_time() trials.steps_aggregate().aggregate.count() resp = q.send() headers = [ "Trial ID", "State", "H-Params", "Start Time", "End Time", "# of Steps" ] values = [[ t.id, t.state, json.dumps(t.hparams, indent=4), render.format_time(t.start_time), render.format_time(t.end_time), t.steps_aggregate.aggregate.count, ] for t in resp.trials] render.tabulate_or_csv(headers, values, args.csv)
def experiment_id_completer(prefix: str, parsed_args: Namespace, **kwargs: Any) -> List[str]: auth.initialize_session(parsed_args.master, parsed_args.user, try_reauth=True) q = api.GraphQLQuery(parsed_args.master) q.op.experiments().id() resp = q.send() return [str(e["id"]) for e in resp.experiments]
def list(args: Namespace) -> None: q = api.GraphQLQuery(args.master) q.op.experiments_by_pk(id=args.experiment_id).config( path="checkpoint_storage") order_by = [ gql.checkpoints_order_by(validation=gql.validations_order_by( metric_values=gql.validation_metrics_order_by( signed=gql.order_by.asc))) ] limit = None if args.best is not None: if args.best < 0: raise AssertionError("--best must be a non-negative integer") limit = args.best checkpoints = q.op.checkpoints( where=gql.checkpoints_bool_exp(step=gql.steps_bool_exp( trial=gql.trials_bool_exp(experiment_id=gql.Int_comparison_exp( _eq=args.experiment_id)))), order_by=order_by, limit=limit, ) checkpoints.end_time() checkpoints.labels() checkpoints.resources() checkpoints.start_time() checkpoints.state() checkpoints.step_id() checkpoints.trial_id() checkpoints.uuid() checkpoints.step.validation.metric_values.raw() resp = q.send() headers = [ "Trial ID", "Step ID", "State", "Validation Metric", "UUID", "Resources", "Size" ] values = [[ c.trial_id, c.step_id, c.state, c.step.validation.metric_values.raw if c.step.validation and c.step.validation.metric_values else None, c.uuid, render.format_resources(c.resources), render.format_resource_sizes(c.resources), ] for c in resp.checkpoints] render.tabulate_or_csv(headers, values, args.csv)
def wait(args: Namespace) -> None: while True: q = api.GraphQLQuery(args.master) q.op.experiments_by_pk(id=args.experiment_id).state() resp = q.send() state = resp.experiments_by_pk.state if state in constants.TERMINAL_STATES: print("Experiment {} terminated with state {}".format(args.experiment_id, state)) if state == constants.COMPLETED: sys.exit(0) else: sys.exit(1) time.sleep(args.polling_interval)
def logs_query(tail: Optional[int] = None, greater_than_id: Optional[int] = None) -> Any: q = api.GraphQLQuery(args.master) limit = None order_by = [gql.trial_logs_order_by(id=gql.order_by.asc)] where = gql.trial_logs_bool_exp(trial_id=gql.Int_comparison_exp(_eq=args.trial_id)) if greater_than_id is not None: where.id = gql.Int_comparison_exp(_gt=greater_than_id) if tail is not None: order_by = [gql.trial_logs_order_by(id=gql.order_by.desc)] limit = tail logs = q.op.trial_logs(where=where, order_by=order_by, limit=limit) logs.id() logs.message() q.op.trials_by_pk(id=args.trial_id).state() return q
def list_experiments(args: Namespace) -> None: where = None if not args.all: user = api.Authentication.instance().get_session_user() where = gql.experiments_bool_exp( archived=gql.Boolean_comparison_exp(_eq=False), owner=gql.users_bool_exp(username=gql.String_comparison_exp( _eq=user)), ) q = api.GraphQLQuery(args.master) exps = q.op.experiments( order_by=[gql.experiments_order_by(id=gql.order_by.desc)], where=where) exps.archived() exps.config() exps.end_time() exps.id() exps.owner.username() exps.progress() exps.start_time() exps.state() resp = q.send() def format_experiment(e: Any) -> List[Any]: result = [ e.id, e.owner.username, e.config["description"], e.state, render.format_percent(e.progress), render.format_time(e.start_time), render.format_time(e.end_time), ] if args.all: result.append(e.archived) return result headers = [ "ID", "Owner", "Description", "State", "Progress", "Start Time", "End Time" ] if args.all: headers.append("Archived") values = [format_experiment(e) for e in resp.experiments] render.tabulate_or_csv(headers, values, args.csv)
def get_checkpoint(uuid: str, master: str) -> Checkpoint: q = api.GraphQLQuery(master) where = gql.checkpoints_bool_exp( state=gql.checkpoint_state_comparison_exp(_eq="COMPLETED"), uuid=gql.uuid_comparison_exp(_eq=uuid), ) checkpoint_gql = q.op.checkpoints(where=where) checkpoint_gql.state() checkpoint_gql.uuid() checkpoint_gql.resources() validation = checkpoint_gql.validation() validation.metrics() validation.state() step = checkpoint_gql.step() step.id() step.start_time() step.end_time() step.trial.experiment.config() resp = q.send() result = resp.checkpoints if not result: raise AssertionError("No checkpoint found for uuid {}".format(uuid)) ckpt_gql = result[0] batch_number = ckpt_gql.step.trial.experiment.config["batches_per_step"] * ckpt_gql.step.id return Checkpoint( ckpt_gql.uuid, ckpt_gql.step.trial.experiment.config["checkpoint_storage"], batch_number, ckpt_gql.step.start_time, ckpt_gql.step.end_time, ckpt_gql.resources, ckpt_gql.validation, )
def query() -> api.GraphQLQuery: auth.initialize_session(conf.make_master_url(), try_reauth=True) return api.GraphQLQuery(conf.make_master_url())
def top_n_checkpoints( self, limit: int, sort_by: Optional[str] = None, smaller_is_better: Optional[bool] = None ) -> List[checkpoint.Checkpoint]: """ Return the n :py:class:`det.experimental.Checkpoint` instances with the best validation metric values as defined by the `sort_by` and `smaller_is_better` arguments. Arguments: sort_by (string, optional): The name of the validation metric to order checkpoints by. If this parameter is unset the metric defined in the experiment configuration searcher field will be used. smaller_is_better (bool, optional): Specifies whether to sort the metric above in ascending or descending order. If sort_by is unset, this parameter is ignored. By default the smaller_is_better value in the experiment configuration is used. """ q = api.GraphQLQuery(self._master) exp = q.op.experiments_by_pk(id=self.id) checkpoints = exp.best_checkpoints_by_metric( args={ "lim": limit, "metric": sort_by, "smaller_is_better": smaller_is_better }) checkpoints.state() checkpoints.uuid() checkpoints.resources() validation = checkpoints.validation() validation.metrics() validation.state() step = checkpoints.step() step.id() step.start_time() step.end_time() step.trial.experiment.config() step.trial.id() resp = q.send() checkpoints_resp = resp.experiments_by_pk.best_checkpoints_by_metric if not checkpoints_resp: return [] experiment_conf = checkpoints_resp[0].step.trial.experiment.config sib = (smaller_is_better if smaller_is_better is not None else experiment_conf["searcher"]["smaller_is_better"]) sort_metric = sort_by if sort_by is not None else experiment_conf[ "searcher"]["metric"] ordered_checkpoints = sorted( checkpoints_resp, key=lambda c: c.validation.metrics["validation_metrics"][ sort_metric], reverse=not sib, ) return [ checkpoint.Checkpoint( ckpt.uuid, ckpt.step.trial.experiment.config["checkpoint_storage"], ckpt.step.trial.experiment.config["batches_per_step"] * ckpt.step.id, ckpt.step.start_time, ckpt.step.end_time, ckpt.resources, ckpt.validation, ) for ckpt in ordered_checkpoints ]
def describe_trial(args: Namespace) -> None: q = api.GraphQLQuery(args.master) trial = q.op.trials_by_pk(id=args.trial_id) trial.end_time() trial.experiment_id() trial.hparams() trial.start_time() trial.state() steps = trial.steps(order_by=[gql.steps_order_by(id=gql.order_by.asc)]) steps.metrics() steps.id() steps.state() steps.start_time() steps.end_time() checkpoint_gql = steps.checkpoint() checkpoint_gql.state() checkpoint_gql.uuid() validation = steps.validation() validation.state() validation.metrics() resp = q.send() if args.json: print(json.dumps(resp.trials_by_pk.__to_json_value__(), indent=4)) return trial = resp.trials_by_pk # Print information about the trial itself. headers = ["Experiment ID", "State", "H-Params", "Start Time", "End Time"] values = [ [ trial.experiment_id, trial.state, json.dumps(trial.hparams, indent=4), render.format_time(trial.start_time), render.format_time(trial.end_time), ] ] render.tabulate_or_csv(headers, values, args.csv) # Print information about individual steps. headers = [ "Step #", "State", "Start Time", "End Time", "Checkpoint", "Checkpoint UUID", "Validation", "Validation Metrics", ] if args.metrics: headers.append("Step Metrics") values = [ [ s.id, s.state, render.format_time(s.start_time), render.format_time(s.end_time), *format_checkpoint(s.checkpoint), *format_validation(s.validation), *([json.dumps(s.metrics, indent=4)] if args.metrics else []), ] for s in trial.steps ] print() print("Steps:") render.tabulate_or_csv(headers, values, args.csv)
def logs(args: Namespace) -> None: def logs_query(tail: Optional[int] = None, greater_than_id: Optional[int] = None) -> Any: q = api.GraphQLQuery(args.master) limit = None order_by = [gql.trial_logs_order_by(id=gql.order_by.asc)] where = gql.trial_logs_bool_exp(trial_id=gql.Int_comparison_exp(_eq=args.trial_id)) if greater_than_id is not None: where.id = gql.Int_comparison_exp(_gt=greater_than_id) if tail is not None: order_by = [gql.trial_logs_order_by(id=gql.order_by.desc)] limit = tail logs = q.op.trial_logs(where=where, order_by=order_by, limit=limit) logs.id() logs.message() q.op.trials_by_pk(id=args.trial_id).state() return q def process_response(logs: Any, latest_log_id: int) -> Tuple[int, bool]: changes = False for log in logs: check_gt(log.id, latest_log_id) latest_log_id = log.id msg = api.decode_bytes(log.message) print(msg, end="") changes = True return latest_log_id, changes resp = logs_query(args.tail).send() logs = resp.trial_logs # Due to limitations of the GraphQL API, which mimics SQL, requesting a tail means we have to # get the results in descending ID order and reverse them afterward. if args.tail is not None: logs = reversed(logs) latest_log_id, _ = process_response(logs, -1) # "Follow" mode is implemented as a loop in the CLI. We assume that # newer log messages have a numerically larger ID than older log # messages, so we keep track of the max ID seen so far. if args.follow: change_time = time.time() try: while True: # Poll for new logs at most every 100 ms. time.sleep(0.1) # The `tail` parameter only makes sense the first time we fetch logs. resp = logs_query(greater_than_id=latest_log_id).send() latest_log_id, changes = process_response(resp.trial_logs, latest_log_id) # Exit once the trial has, for 1 second, been in a terminal state and sent no logs. if changes or resp.trials_by_pk.state not in constants.TERMINAL_STATES: change_time = time.time() elif time.time() - change_time > 1: raise KeyboardInterrupt() except KeyboardInterrupt: state_query = api.GraphQLQuery(args.master) state_query.op.trials_by_pk(id=args.trial_id).state() resp = state_query.send() print( colored( "Trial is in the {} state. To reopen log stream, run: " "det trial logs -f {}".format(resp.trials_by_pk.state, args.trial_id), "green", ) )
def list(args: Namespace) -> None: q = api.GraphQLQuery(args.master) q.op.experiments_by_pk(id=args.experiment_id).config(path="checkpoint_storage") order_by = [ gql.checkpoints_order_by( validation=gql.validations_order_by( metric_values=gql.validation_metrics_order_by(signed=gql.order_by.asc) ) ) ] limit = None if args.best is not None: if args.best < 0: raise AssertionError("--best must be a non-negative integer") limit = args.best checkpoints = q.op.checkpoints( where=gql.checkpoints_bool_exp( step=gql.steps_bool_exp( trial=gql.trials_bool_exp( experiment_id=gql.Int_comparison_exp(_eq=args.experiment_id) ) ) ), order_by=order_by, limit=limit, ) checkpoints.end_time() checkpoints.labels() checkpoints.resources() checkpoints.start_time() checkpoints.state() checkpoints.step_id() checkpoints.trial_id() checkpoints.uuid() checkpoints.step.validation.metric_values.raw() resp = q.send() config = resp.experiments_by_pk.config headers = ["Trial ID", "Step ID", "State", "Validation Metric", "UUID", "Resources", "Size"] values = [ [ c.trial_id, c.step_id, c.state, c.step.validation.metric_values.raw if c.step.validation and c.step.validation.metric_values else None, c.uuid, render.format_resources(c.resources), render.format_resource_sizes(c.resources), ] for c in resp.checkpoints ] render.tabulate_or_csv(headers, values, args.csv) if args.download_dir is not None: manager = storage.build(config) if not ( isinstance(manager, storage.S3StorageManager) or isinstance(manager, storage.GCSStorageManager) ): print( "Downloading from S3 or GCS requires the experiment to be configured with " "S3 or GCS checkpointing, {} found instead".format(config["type"]) ) sys.exit(1) for checkpoint in resp.checkpoints: metadata = storage.StorageMetadata.from_json(checkpoint.__to_json_value__()) ckpt_dir = args.download_dir.joinpath( "exp-{}-trial-{}-step-{}".format( args.experiment_id, checkpoint.trial_id, checkpoint.step_id ) ) print("Downloading checkpoint {} to {}".format(checkpoint.uuid, ckpt_dir)) manager.download(metadata, ckpt_dir)
def follow_test_experiment_logs(master_url: str, exp_id: int) -> None: def print_progress(active_stage: int, ended: bool) -> None: # There are four sequential stages of verification. Track the # current stage with an index into this list. stages = [ "Scheduling task", "Testing training", "Testing validation", "Testing checkpointing", ] for idx, stage in enumerate(stages): if active_stage > idx: color = "green" checkbox = "✔" elif active_stage == idx: color = "red" if ended else "yellow" checkbox = "✗" if ended else " " else: color = "white" checkbox = " " print(colored(stage + (25 - len(stage)) * ".", color), end="") print(colored(" [" + checkbox + "]", color), end="") if idx == len(stages) - 1: print("\n" if ended else "\r", end="") else: print(", ", end="") q = api.GraphQLQuery(master_url) exp = q.op.experiments_by_pk(id=exp_id) exp.state() steps = exp.trials.steps( order_by=[gql.steps_order_by(id=gql.order_by.asc)]) steps.checkpoint().id() steps.validation().id() while True: exp = q.send().experiments_by_pk # Wait for experiment to start and initialize a trial and step. step = None if exp.trials and exp.trials[0].steps: step = exp.trials[0].steps[0] # Update the active stage by examining the status of the experiment. The way the GraphQL # library works is that the checkpoint and validation attributes of a step are always # present and non-None, but they don't have any attributes of their own when the # corresponding database object doesn't exist. if exp.state == constants.COMPLETED: active_stage = 4 elif step and hasattr(step.checkpoint, "id"): active_stage = 3 elif step and hasattr(step.validation, "id"): active_stage = 2 elif step: active_stage = 1 else: active_stage = 0 # If the experiment is in a terminal state, output the appropriate # message and exit. Otherwise, sleep and repeat. if exp.state == "COMPLETED": print_progress(active_stage, ended=True) print(colored("Model definition test succeeded! 🎉", "green")) return elif exp.state == constants.CANCELED: print_progress(active_stage, ended=True) print( colored( "Model definition test (ID: {}) canceled before " "model test could complete. Please re-run the " "command.".format(exp_id), "yellow", )) sys.exit(1) elif exp.state == constants.ERROR: print_progress(active_stage, ended=True) trial_id = exp.trials[0].id logs_args = Namespace(trial_id=trial_id, master=master_url, tail=None, follow=False) logs(logs_args) sys.exit(1) else: print_progress(active_stage, ended=False) time.sleep(0.2)
def select_checkpoint( self, latest: bool = False, best: bool = False, uuid: Optional[str] = None, sort_by: Optional[str] = None, smaller_is_better: Optional[bool] = None, ) -> checkpoint.Checkpoint: """ Return the :py:class:`det.experimental.Checkpoint` instance with the best validation metric as defined by the `sort_by` and `smaller_is_better` arguments. Exactly one of the best, latest, or uuid parameters must be set. Arguments: latest (bool, optional): return the most recent checkpoint. best (bool, optional): return the checkpoint with the best validation metric as defined by the `sort_by` and `smaller_is_better` arguments. If `sort_by` and `smaller_is_better` are not specified, the values from the associated experiment configuration will be used. uuid (string, optional): return the checkpoint for the specified uuid. sort_by (string, optional): the name of the validation metric to order checkpoints by. If this parameter is unset the metric defined in the related experiment configuration searcher field will be used. smaller_is_better (bool, optional): specifies whether to sort the metric above in ascending or descending order. If sort_by is unset, this parameter is ignored. By default the smaller_is_better value in the related experiment configuration is used. """ check.eq( sum([int(latest), int(best), int(uuid is not None)]), 1, "Exactly one of latest, best, or uuid must be set", ) check.eq( sort_by is None, smaller_is_better is None, "sort_by and smaller_is_better must be set together", ) if sort_by is not None and not best: raise AssertionError( "sort_by and smaller_is_better parameters can only be used with --best" ) q = api.GraphQLQuery(self._master) if sort_by is not None: checkpoint_gql = q.op.best_checkpoint_by_metric( args={"tid": self.id, "metric": sort_by, "smaller_is_better": smaller_is_better}, ) else: where = gql.checkpoints_bool_exp( state=gql.checkpoint_state_comparison_exp(_eq="COMPLETED"), trial_id=gql.Int_comparison_exp(_eq=self.id), ) order_by = [] # type: List[gql.checkpoints_order_by] if uuid is not None: where.uuid = gql.uuid_comparison_exp(_eq=uuid) elif latest: order_by = [gql.checkpoints_order_by(end_time=gql.order_by.desc)] elif best: where.validation = gql.validations_bool_exp( state=gql.validation_state_comparison_exp(_eq="COMPLETED") ) order_by = [ gql.checkpoints_order_by( validation=gql.validations_order_by( metric_values=gql.validation_metrics_order_by(signed=gql.order_by.asc) ) ) ] checkpoint_gql = q.op.checkpoints(where=where, order_by=order_by, limit=1) checkpoint_gql.state() checkpoint_gql.uuid() checkpoint_gql.resources() validation = checkpoint_gql.validation() validation.metrics() validation.state() step = checkpoint_gql.step() step.id() step.start_time() step.end_time() step.trial.experiment.config() resp = q.send() result = resp.best_checkpoint_by_metric if sort_by is not None else resp.checkpoints if not result: raise AssertionError("No checkpoint found for trial {}".format(self.id)) ckpt_gql = result[0] batch_number = ckpt_gql.step.trial.experiment.config["batches_per_step"] * ckpt_gql.step.id return checkpoint.Checkpoint( ckpt_gql.uuid, ckpt_gql.step.trial.experiment.config["checkpoint_storage"], batch_number, ckpt_gql.step.start_time, ckpt_gql.step.end_time, ckpt_gql.resources, ckpt_gql.validation, )
def describe(args: Namespace) -> None: ids = [int(x) for x in args.experiment_ids.split(",")] q = api.GraphQLQuery(args.master) exps = q.op.experiments(where=gql.experiments_bool_exp( id=gql.Int_comparison_exp(_in=ids))) exps.archived() exps.config() exps.end_time() exps.id() exps.progress() exps.start_time() exps.state() trials = exps.trials(order_by=[gql.trials_order_by(id=gql.order_by.asc)]) trials.end_time() trials.hparams() trials.id() trials.start_time() trials.state() steps = trials.steps(order_by=[gql.steps_order_by(id=gql.order_by.asc)]) steps.end_time() steps.id() steps.start_time() steps.state() steps.trial_id() steps.checkpoint.end_time() steps.checkpoint.start_time() steps.checkpoint.state() steps.validation.end_time() steps.validation.start_time() steps.validation.state() if args.metrics: steps.metrics(path="avg_metrics") steps.validation.metrics() resp = q.send() # Re-sort the experiment objects to match the original order. exps_by_id = {e.id: e for e in resp.experiments} experiments = [exps_by_id[id] for id in ids] if args.json: print(json.dumps(resp.__to_json_value__()["experiments"], indent=4)) return # Display overall experiment information. headers = [ "Experiment ID", "State", "Progress", "Start Time", "End Time", "Description", "Archived", "Labels", ] values = [[ e.id, e.state, render.format_percent(e.progress), render.format_time(e.start_time), render.format_time(e.end_time), e.config.get("description"), e.archived, ", ".join(sorted(e.config.get("labels", []))), ] for e in experiments] if not args.outdir: outfile = None print("Experiment:") else: outfile = args.outdir.joinpath("experiments.csv") render.tabulate_or_csv(headers, values, args.csv, outfile) # Display trial-related information. headers = [ "Trial ID", "Experiment ID", "State", "Start Time", "End Time", "H-Params" ] values = [[ t.id, e.id, t.state, render.format_time(t.start_time), render.format_time(t.end_time), json.dumps(t.hparams, indent=4), ] for e in experiments for t in e.trials] if not args.outdir: outfile = None print("\nTrials:") else: outfile = args.outdir.joinpath("trials.csv") render.tabulate_or_csv(headers, values, args.csv, outfile) # Display step-related information. if args.metrics: # Accumulate the scalar training and validation metric names from all provided experiments. t_metrics_names = sorted( {n for e in experiments for n in scalar_training_metrics_names(e)}) t_metrics_headers = [ "Training Metric: {}".format(name) for name in t_metrics_names ] v_metrics_names = sorted({ n for e in experiments for n in scalar_validation_metrics_names(e) }) v_metrics_headers = [ "Validation Metric: {}".format(name) for name in v_metrics_names ] else: t_metrics_headers = [] v_metrics_headers = [] headers = (["Trial ID", "Step ID", "State", "Start Time", "End Time"] + t_metrics_headers + [ "Checkpoint State", "Checkpoint Start Time", "Checkpoint End Time", "Validation State", "Validation Start Time", "Validation End Time", ] + v_metrics_headers) values = [] for e in experiments: for t in e.trials: for step in t.steps: t_metrics_fields = [] if hasattr(step, "metrics"): avg_metrics = step.metrics for name in t_metrics_names: if name in avg_metrics: t_metrics_fields.append(avg_metrics[name]) else: t_metrics_fields.append(None) checkpoint = step.checkpoint if checkpoint: checkpoint_state = checkpoint.state checkpoint_start_time = checkpoint.start_time checkpoint_end_time = checkpoint.end_time else: checkpoint_state = None checkpoint_start_time = None checkpoint_end_time = None validation = step.validation if validation: validation_state = validation.state validation_start_time = validation.start_time validation_end_time = validation.end_time else: validation_state = None validation_start_time = None validation_end_time = None if args.metrics: v_metrics_fields = [ api.metric.get_validation_metric(name, validation) for name in v_metrics_names ] else: v_metrics_fields = [] row = ([ step.trial_id, step.id, step.state, render.format_time(step.start_time), render.format_time(step.end_time), ] + t_metrics_fields + [ checkpoint_state, render.format_time(checkpoint_start_time), render.format_time(checkpoint_end_time), validation_state, render.format_time(validation_start_time), render.format_time(validation_end_time), ] + v_metrics_fields) values.append(row) if not args.outdir: outfile = None print("\nSteps:") else: outfile = args.outdir.joinpath("steps.csv") render.tabulate_or_csv(headers, values, args.csv, outfile)