예제 #1
0
def num_trials_by_state(experiment_id: int, state: str) -> int:
    q = query()
    q.op.experiments_by_pk(id=experiment_id).trials_aggregate(
        where=gql.trials_bool_exp(state=gql.trial_state_comparison_exp(
            _eq=state))).aggregate.count()
    r = q.send()
    return cast(int, r.experiments_by_pk.trials_aggregate.aggregate.count)
예제 #2
0
def follow_experiment_logs(master_url: str, exp_id: int) -> None:
    # Get the ID of this experiment's first trial (i.e., the one with the lowest ID).
    q = api.GraphQLQuery(master_url)
    trials = q.op.trials(
        where=gql.trials_bool_exp(experiment_id=gql.Int_comparison_exp(
            _eq=exp_id)),
        order_by=[gql.trials_order_by(id=gql.order_by.asc)],
        limit=1,
    )
    trials.id()

    print("Waiting for first trial to begin...")
    while True:
        resp = q.send()
        if resp.trials:
            break
        else:
            time.sleep(0.1)

    first_trial_id = resp.trials[0].id
    print("Following first trial with ID {}".format(first_trial_id))

    # Call `logs --follow` on the new trial.
    logs_args = Namespace(trial_id=first_trial_id,
                          follow=True,
                          master=master_url,
                          tail=None)
    logs(logs_args)
예제 #3
0
def list_trials(args: Namespace) -> None:
    q = api.GraphQLQuery(args.master)
    trials = q.op.trials(
        order_by=[gql.trials_order_by(id=gql.order_by.asc)],
        where=gql.trials_bool_exp(experiment_id=gql.Int_comparison_exp(
            _eq=args.experiment_id)),
    )
    trials.id()
    trials.state()
    trials.hparams()
    trials.start_time()
    trials.end_time()
    trials.steps_aggregate().aggregate.count()

    resp = q.send()

    headers = [
        "Trial ID", "State", "H-Params", "Start Time", "End Time", "# of Steps"
    ]
    values = [[
        t.id,
        t.state,
        json.dumps(t.hparams, indent=4),
        render.format_time(t.start_time),
        render.format_time(t.end_time),
        t.steps_aggregate.aggregate.count,
    ] for t in resp.trials]

    render.tabulate_or_csv(headers, values, args.csv)
예제 #4
0
def list(args: Namespace) -> None:
    q = api.GraphQLQuery(args.master)
    q.op.experiments_by_pk(id=args.experiment_id).config(
        path="checkpoint_storage")

    order_by = [
        gql.checkpoints_order_by(validation=gql.validations_order_by(
            metric_values=gql.validation_metrics_order_by(
                signed=gql.order_by.asc)))
    ]

    limit = None
    if args.best is not None:
        if args.best < 0:
            raise AssertionError("--best must be a non-negative integer")
        limit = args.best

    checkpoints = q.op.checkpoints(
        where=gql.checkpoints_bool_exp(step=gql.steps_bool_exp(
            trial=gql.trials_bool_exp(experiment_id=gql.Int_comparison_exp(
                _eq=args.experiment_id)))),
        order_by=order_by,
        limit=limit,
    )
    checkpoints.end_time()
    checkpoints.labels()
    checkpoints.resources()
    checkpoints.start_time()
    checkpoints.state()
    checkpoints.step_id()
    checkpoints.trial_id()
    checkpoints.uuid()

    checkpoints.step.validation.metric_values.raw()

    resp = q.send()

    headers = [
        "Trial ID", "Step ID", "State", "Validation Metric", "UUID",
        "Resources", "Size"
    ]
    values = [[
        c.trial_id,
        c.step_id,
        c.state,
        c.step.validation.metric_values.raw
        if c.step.validation and c.step.validation.metric_values else None,
        c.uuid,
        render.format_resources(c.resources),
        render.format_resource_sizes(c.resources),
    ] for c in resp.checkpoints]

    render.tabulate_or_csv(headers, values, args.csv)
예제 #5
0
def list(args: Namespace) -> None:
    q = api.GraphQLQuery(args.master)
    q.op.experiments_by_pk(id=args.experiment_id).config(path="checkpoint_storage")

    order_by = [
        gql.checkpoints_order_by(
            validation=gql.validations_order_by(
                metric_values=gql.validation_metrics_order_by(signed=gql.order_by.asc)
            )
        )
    ]

    limit = None
    if args.best is not None:
        if args.best < 0:
            raise AssertionError("--best must be a non-negative integer")
        limit = args.best

    checkpoints = q.op.checkpoints(
        where=gql.checkpoints_bool_exp(
            step=gql.steps_bool_exp(
                trial=gql.trials_bool_exp(
                    experiment_id=gql.Int_comparison_exp(_eq=args.experiment_id)
                )
            )
        ),
        order_by=order_by,
        limit=limit,
    )
    checkpoints.end_time()
    checkpoints.labels()
    checkpoints.resources()
    checkpoints.start_time()
    checkpoints.state()
    checkpoints.step_id()
    checkpoints.trial_id()
    checkpoints.uuid()

    checkpoints.step.validation.metric_values.raw()

    resp = q.send()
    config = resp.experiments_by_pk.config

    headers = ["Trial ID", "Step ID", "State", "Validation Metric", "UUID", "Resources", "Size"]
    values = [
        [
            c.trial_id,
            c.step_id,
            c.state,
            c.step.validation.metric_values.raw
            if c.step.validation and c.step.validation.metric_values
            else None,
            c.uuid,
            render.format_resources(c.resources),
            render.format_resource_sizes(c.resources),
        ]
        for c in resp.checkpoints
    ]

    render.tabulate_or_csv(headers, values, args.csv)

    if args.download_dir is not None:
        manager = storage.build(config)
        if not (
            isinstance(manager, storage.S3StorageManager)
            or isinstance(manager, storage.GCSStorageManager)
        ):
            print(
                "Downloading from S3 or GCS requires the experiment to be configured with "
                "S3 or GCS checkpointing, {} found instead".format(config["type"])
            )
            sys.exit(1)

        for checkpoint in resp.checkpoints:
            metadata = storage.StorageMetadata.from_json(checkpoint.__to_json_value__())
            ckpt_dir = args.download_dir.joinpath(
                "exp-{}-trial-{}-step-{}".format(
                    args.experiment_id, checkpoint.trial_id, checkpoint.step_id
                )
            )
            print("Downloading checkpoint {} to {}".format(checkpoint.uuid, ckpt_dir))
            manager.download(metadata, ckpt_dir)