def num_trials_by_state(experiment_id: int, state: str) -> int: q = query() q.op.experiments_by_pk(id=experiment_id).trials_aggregate( where=gql.trials_bool_exp(state=gql.trial_state_comparison_exp( _eq=state))).aggregate.count() r = q.send() return cast(int, r.experiments_by_pk.trials_aggregate.aggregate.count)
def follow_experiment_logs(master_url: str, exp_id: int) -> None: # Get the ID of this experiment's first trial (i.e., the one with the lowest ID). q = api.GraphQLQuery(master_url) trials = q.op.trials( where=gql.trials_bool_exp(experiment_id=gql.Int_comparison_exp( _eq=exp_id)), order_by=[gql.trials_order_by(id=gql.order_by.asc)], limit=1, ) trials.id() print("Waiting for first trial to begin...") while True: resp = q.send() if resp.trials: break else: time.sleep(0.1) first_trial_id = resp.trials[0].id print("Following first trial with ID {}".format(first_trial_id)) # Call `logs --follow` on the new trial. logs_args = Namespace(trial_id=first_trial_id, follow=True, master=master_url, tail=None) logs(logs_args)
def list_trials(args: Namespace) -> None: q = api.GraphQLQuery(args.master) trials = q.op.trials( order_by=[gql.trials_order_by(id=gql.order_by.asc)], where=gql.trials_bool_exp(experiment_id=gql.Int_comparison_exp( _eq=args.experiment_id)), ) trials.id() trials.state() trials.hparams() trials.start_time() trials.end_time() trials.steps_aggregate().aggregate.count() resp = q.send() headers = [ "Trial ID", "State", "H-Params", "Start Time", "End Time", "# of Steps" ] values = [[ t.id, t.state, json.dumps(t.hparams, indent=4), render.format_time(t.start_time), render.format_time(t.end_time), t.steps_aggregate.aggregate.count, ] for t in resp.trials] render.tabulate_or_csv(headers, values, args.csv)
def list(args: Namespace) -> None: q = api.GraphQLQuery(args.master) q.op.experiments_by_pk(id=args.experiment_id).config( path="checkpoint_storage") order_by = [ gql.checkpoints_order_by(validation=gql.validations_order_by( metric_values=gql.validation_metrics_order_by( signed=gql.order_by.asc))) ] limit = None if args.best is not None: if args.best < 0: raise AssertionError("--best must be a non-negative integer") limit = args.best checkpoints = q.op.checkpoints( where=gql.checkpoints_bool_exp(step=gql.steps_bool_exp( trial=gql.trials_bool_exp(experiment_id=gql.Int_comparison_exp( _eq=args.experiment_id)))), order_by=order_by, limit=limit, ) checkpoints.end_time() checkpoints.labels() checkpoints.resources() checkpoints.start_time() checkpoints.state() checkpoints.step_id() checkpoints.trial_id() checkpoints.uuid() checkpoints.step.validation.metric_values.raw() resp = q.send() headers = [ "Trial ID", "Step ID", "State", "Validation Metric", "UUID", "Resources", "Size" ] values = [[ c.trial_id, c.step_id, c.state, c.step.validation.metric_values.raw if c.step.validation and c.step.validation.metric_values else None, c.uuid, render.format_resources(c.resources), render.format_resource_sizes(c.resources), ] for c in resp.checkpoints] render.tabulate_or_csv(headers, values, args.csv)
def list(args: Namespace) -> None: q = api.GraphQLQuery(args.master) q.op.experiments_by_pk(id=args.experiment_id).config(path="checkpoint_storage") order_by = [ gql.checkpoints_order_by( validation=gql.validations_order_by( metric_values=gql.validation_metrics_order_by(signed=gql.order_by.asc) ) ) ] limit = None if args.best is not None: if args.best < 0: raise AssertionError("--best must be a non-negative integer") limit = args.best checkpoints = q.op.checkpoints( where=gql.checkpoints_bool_exp( step=gql.steps_bool_exp( trial=gql.trials_bool_exp( experiment_id=gql.Int_comparison_exp(_eq=args.experiment_id) ) ) ), order_by=order_by, limit=limit, ) checkpoints.end_time() checkpoints.labels() checkpoints.resources() checkpoints.start_time() checkpoints.state() checkpoints.step_id() checkpoints.trial_id() checkpoints.uuid() checkpoints.step.validation.metric_values.raw() resp = q.send() config = resp.experiments_by_pk.config headers = ["Trial ID", "Step ID", "State", "Validation Metric", "UUID", "Resources", "Size"] values = [ [ c.trial_id, c.step_id, c.state, c.step.validation.metric_values.raw if c.step.validation and c.step.validation.metric_values else None, c.uuid, render.format_resources(c.resources), render.format_resource_sizes(c.resources), ] for c in resp.checkpoints ] render.tabulate_or_csv(headers, values, args.csv) if args.download_dir is not None: manager = storage.build(config) if not ( isinstance(manager, storage.S3StorageManager) or isinstance(manager, storage.GCSStorageManager) ): print( "Downloading from S3 or GCS requires the experiment to be configured with " "S3 or GCS checkpointing, {} found instead".format(config["type"]) ) sys.exit(1) for checkpoint in resp.checkpoints: metadata = storage.StorageMetadata.from_json(checkpoint.__to_json_value__()) ckpt_dir = args.download_dir.joinpath( "exp-{}-trial-{}-step-{}".format( args.experiment_id, checkpoint.trial_id, checkpoint.step_id ) ) print("Downloading checkpoint {} to {}".format(checkpoint.uuid, ckpt_dir)) manager.download(metadata, ckpt_dir)