예제 #1
0
def list_trials(args: Namespace) -> None:
    q = api.GraphQLQuery(args.master)
    trials = q.op.trials(
        order_by=[gql.trials_order_by(id=gql.order_by.asc)],
        where=gql.trials_bool_exp(experiment_id=gql.Int_comparison_exp(
            _eq=args.experiment_id)),
    )
    trials.id()
    trials.state()
    trials.hparams()
    trials.start_time()
    trials.end_time()
    trials.steps_aggregate().aggregate.count()

    resp = q.send()

    headers = [
        "Trial ID", "State", "H-Params", "Start Time", "End Time", "# of Steps"
    ]
    values = [[
        t.id,
        t.state,
        json.dumps(t.hparams, indent=4),
        render.format_time(t.start_time),
        render.format_time(t.end_time),
        t.steps_aggregate.aggregate.count,
    ] for t in resp.trials]

    render.tabulate_or_csv(headers, values, args.csv)
예제 #2
0
    def logs_query(tail: Optional[int] = None, greater_than_id: Optional[int] = None) -> Any:
        q = api.GraphQLQuery(args.master)
        limit = None
        order_by = [gql.trial_logs_order_by(id=gql.order_by.asc)]
        where = gql.trial_logs_bool_exp(trial_id=gql.Int_comparison_exp(_eq=args.trial_id))
        if greater_than_id is not None:
            where.id = gql.Int_comparison_exp(_gt=greater_than_id)
        if tail is not None:
            order_by = [gql.trial_logs_order_by(id=gql.order_by.desc)]
            limit = tail
        logs = q.op.trial_logs(where=where, order_by=order_by, limit=limit)
        logs.id()
        logs.message()

        q.op.trials_by_pk(id=args.trial_id).state()
        return q
예제 #3
0
def follow_experiment_logs(master_url: str, exp_id: int) -> None:
    # Get the ID of this experiment's first trial (i.e., the one with the lowest ID).
    q = api.GraphQLQuery(master_url)
    trials = q.op.trials(
        where=gql.trials_bool_exp(experiment_id=gql.Int_comparison_exp(
            _eq=exp_id)),
        order_by=[gql.trials_order_by(id=gql.order_by.asc)],
        limit=1,
    )
    trials.id()

    print("Waiting for first trial to begin...")
    while True:
        resp = q.send()
        if resp.trials:
            break
        else:
            time.sleep(0.1)

    first_trial_id = resp.trials[0].id
    print("Following first trial with ID {}".format(first_trial_id))

    # Call `logs --follow` on the new trial.
    logs_args = Namespace(trial_id=first_trial_id,
                          follow=True,
                          master=master_url,
                          tail=None)
    logs(logs_args)
예제 #4
0
def trial_logs(trial_id: int) -> List[str]:
    q = query()
    q.op.trial_logs(
        where=gql.trial_logs_bool_exp(trial_id=gql.Int_comparison_exp(
            _eq=trial_id)),
        order_by=[gql.trial_logs_order_by(id=gql.order_by.asc)],
    ).message()
    r = q.send()
    return [api.decode_bytes(t.message) for t in r.trial_logs]
예제 #5
0
def list(args: Namespace) -> None:
    q = api.GraphQLQuery(args.master)
    q.op.experiments_by_pk(id=args.experiment_id).config(
        path="checkpoint_storage")

    order_by = [
        gql.checkpoints_order_by(validation=gql.validations_order_by(
            metric_values=gql.validation_metrics_order_by(
                signed=gql.order_by.asc)))
    ]

    limit = None
    if args.best is not None:
        if args.best < 0:
            raise AssertionError("--best must be a non-negative integer")
        limit = args.best

    checkpoints = q.op.checkpoints(
        where=gql.checkpoints_bool_exp(step=gql.steps_bool_exp(
            trial=gql.trials_bool_exp(experiment_id=gql.Int_comparison_exp(
                _eq=args.experiment_id)))),
        order_by=order_by,
        limit=limit,
    )
    checkpoints.end_time()
    checkpoints.labels()
    checkpoints.resources()
    checkpoints.start_time()
    checkpoints.state()
    checkpoints.step_id()
    checkpoints.trial_id()
    checkpoints.uuid()

    checkpoints.step.validation.metric_values.raw()

    resp = q.send()

    headers = [
        "Trial ID", "Step ID", "State", "Validation Metric", "UUID",
        "Resources", "Size"
    ]
    values = [[
        c.trial_id,
        c.step_id,
        c.state,
        c.step.validation.metric_values.raw
        if c.step.validation and c.step.validation.metric_values else None,
        c.uuid,
        render.format_resources(c.resources),
        render.format_resource_sizes(c.resources),
    ] for c in resp.checkpoints]

    render.tabulate_or_csv(headers, values, args.csv)
예제 #6
0
def describe(args: Namespace) -> None:
    ids = [int(x) for x in args.experiment_ids.split(",")]

    q = api.GraphQLQuery(args.master)
    exps = q.op.experiments(where=gql.experiments_bool_exp(
        id=gql.Int_comparison_exp(_in=ids)))
    exps.archived()
    exps.config()
    exps.end_time()
    exps.id()
    exps.progress()
    exps.start_time()
    exps.state()

    trials = exps.trials(order_by=[gql.trials_order_by(id=gql.order_by.asc)])
    trials.end_time()
    trials.hparams()
    trials.id()
    trials.start_time()
    trials.state()

    steps = trials.steps(order_by=[gql.steps_order_by(id=gql.order_by.asc)])
    steps.end_time()
    steps.id()
    steps.start_time()
    steps.state()
    steps.trial_id()

    steps.checkpoint.end_time()
    steps.checkpoint.start_time()
    steps.checkpoint.state()

    steps.validation.end_time()
    steps.validation.start_time()
    steps.validation.state()

    if args.metrics:
        steps.metrics(path="avg_metrics")
        steps.validation.metrics()

    resp = q.send()

    # Re-sort the experiment objects to match the original order.
    exps_by_id = {e.id: e for e in resp.experiments}
    experiments = [exps_by_id[id] for id in ids]

    if args.json:
        print(json.dumps(resp.__to_json_value__()["experiments"], indent=4))
        return

    # Display overall experiment information.
    headers = [
        "Experiment ID",
        "State",
        "Progress",
        "Start Time",
        "End Time",
        "Description",
        "Archived",
        "Labels",
    ]
    values = [[
        e.id,
        e.state,
        render.format_percent(e.progress),
        render.format_time(e.start_time),
        render.format_time(e.end_time),
        e.config.get("description"),
        e.archived,
        ", ".join(sorted(e.config.get("labels", []))),
    ] for e in experiments]
    if not args.outdir:
        outfile = None
        print("Experiment:")
    else:
        outfile = args.outdir.joinpath("experiments.csv")
    render.tabulate_or_csv(headers, values, args.csv, outfile)

    # Display trial-related information.
    headers = [
        "Trial ID", "Experiment ID", "State", "Start Time", "End Time",
        "H-Params"
    ]
    values = [[
        t.id,
        e.id,
        t.state,
        render.format_time(t.start_time),
        render.format_time(t.end_time),
        json.dumps(t.hparams, indent=4),
    ] for e in experiments for t in e.trials]
    if not args.outdir:
        outfile = None
        print("\nTrials:")
    else:
        outfile = args.outdir.joinpath("trials.csv")
    render.tabulate_or_csv(headers, values, args.csv, outfile)

    # Display step-related information.
    if args.metrics:
        # Accumulate the scalar training and validation metric names from all provided experiments.
        t_metrics_names = sorted(
            {n
             for e in experiments for n in scalar_training_metrics_names(e)})
        t_metrics_headers = [
            "Training Metric: {}".format(name) for name in t_metrics_names
        ]

        v_metrics_names = sorted({
            n
            for e in experiments for n in scalar_validation_metrics_names(e)
        })
        v_metrics_headers = [
            "Validation Metric: {}".format(name) for name in v_metrics_names
        ]
    else:
        t_metrics_headers = []
        v_metrics_headers = []

    headers = (["Trial ID", "Step ID", "State", "Start Time", "End Time"] +
               t_metrics_headers + [
                   "Checkpoint State",
                   "Checkpoint Start Time",
                   "Checkpoint End Time",
                   "Validation State",
                   "Validation Start Time",
                   "Validation End Time",
               ] + v_metrics_headers)

    values = []
    for e in experiments:
        for t in e.trials:
            for step in t.steps:
                t_metrics_fields = []
                if hasattr(step, "metrics"):
                    avg_metrics = step.metrics
                    for name in t_metrics_names:
                        if name in avg_metrics:
                            t_metrics_fields.append(avg_metrics[name])
                        else:
                            t_metrics_fields.append(None)

                checkpoint = step.checkpoint
                if checkpoint:
                    checkpoint_state = checkpoint.state
                    checkpoint_start_time = checkpoint.start_time
                    checkpoint_end_time = checkpoint.end_time
                else:
                    checkpoint_state = None
                    checkpoint_start_time = None
                    checkpoint_end_time = None

                validation = step.validation
                if validation:
                    validation_state = validation.state
                    validation_start_time = validation.start_time
                    validation_end_time = validation.end_time

                else:
                    validation_state = None
                    validation_start_time = None
                    validation_end_time = None

                if args.metrics:
                    v_metrics_fields = [
                        api.metric.get_validation_metric(name, validation)
                        for name in v_metrics_names
                    ]
                else:
                    v_metrics_fields = []

                row = ([
                    step.trial_id,
                    step.id,
                    step.state,
                    render.format_time(step.start_time),
                    render.format_time(step.end_time),
                ] + t_metrics_fields + [
                    checkpoint_state,
                    render.format_time(checkpoint_start_time),
                    render.format_time(checkpoint_end_time),
                    validation_state,
                    render.format_time(validation_start_time),
                    render.format_time(validation_end_time),
                ] + v_metrics_fields)
                values.append(row)

    if not args.outdir:
        outfile = None
        print("\nSteps:")
    else:
        outfile = args.outdir.joinpath("steps.csv")
    render.tabulate_or_csv(headers, values, args.csv, outfile)
예제 #7
0
def list(args: Namespace) -> None:
    q = api.GraphQLQuery(args.master)
    q.op.experiments_by_pk(id=args.experiment_id).config(path="checkpoint_storage")

    order_by = [
        gql.checkpoints_order_by(
            validation=gql.validations_order_by(
                metric_values=gql.validation_metrics_order_by(signed=gql.order_by.asc)
            )
        )
    ]

    limit = None
    if args.best is not None:
        if args.best < 0:
            raise AssertionError("--best must be a non-negative integer")
        limit = args.best

    checkpoints = q.op.checkpoints(
        where=gql.checkpoints_bool_exp(
            step=gql.steps_bool_exp(
                trial=gql.trials_bool_exp(
                    experiment_id=gql.Int_comparison_exp(_eq=args.experiment_id)
                )
            )
        ),
        order_by=order_by,
        limit=limit,
    )
    checkpoints.end_time()
    checkpoints.labels()
    checkpoints.resources()
    checkpoints.start_time()
    checkpoints.state()
    checkpoints.step_id()
    checkpoints.trial_id()
    checkpoints.uuid()

    checkpoints.step.validation.metric_values.raw()

    resp = q.send()
    config = resp.experiments_by_pk.config

    headers = ["Trial ID", "Step ID", "State", "Validation Metric", "UUID", "Resources", "Size"]
    values = [
        [
            c.trial_id,
            c.step_id,
            c.state,
            c.step.validation.metric_values.raw
            if c.step.validation and c.step.validation.metric_values
            else None,
            c.uuid,
            render.format_resources(c.resources),
            render.format_resource_sizes(c.resources),
        ]
        for c in resp.checkpoints
    ]

    render.tabulate_or_csv(headers, values, args.csv)

    if args.download_dir is not None:
        manager = storage.build(config)
        if not (
            isinstance(manager, storage.S3StorageManager)
            or isinstance(manager, storage.GCSStorageManager)
        ):
            print(
                "Downloading from S3 or GCS requires the experiment to be configured with "
                "S3 or GCS checkpointing, {} found instead".format(config["type"])
            )
            sys.exit(1)

        for checkpoint in resp.checkpoints:
            metadata = storage.StorageMetadata.from_json(checkpoint.__to_json_value__())
            ckpt_dir = args.download_dir.joinpath(
                "exp-{}-trial-{}-step-{}".format(
                    args.experiment_id, checkpoint.trial_id, checkpoint.step_id
                )
            )
            print("Downloading checkpoint {} to {}".format(checkpoint.uuid, ckpt_dir))
            manager.download(metadata, ckpt_dir)
예제 #8
0
    def select_checkpoint(
        self,
        latest: bool = False,
        best: bool = False,
        uuid: Optional[str] = None,
        sort_by: Optional[str] = None,
        smaller_is_better: Optional[bool] = None,
    ) -> checkpoint.Checkpoint:
        """
        Return the :py:class:`det.experimental.Checkpoint` instance with the best
        validation metric as defined by the `sort_by` and `smaller_is_better`
        arguments.

        Exactly one of the best, latest, or uuid parameters must be set.

        Arguments:
            latest (bool, optional): return the most recent checkpoint.

            best (bool, optional): return the checkpoint with the best validation
                metric as defined by the `sort_by` and `smaller_is_better`
                arguments. If `sort_by` and `smaller_is_better` are not
                specified, the values from the associated experiment
                configuration will be used.

            uuid (string, optional): return the checkpoint for the specified uuid.

            sort_by (string, optional): the name of the validation metric to
                order checkpoints by. If this parameter is unset the metric defined
                in the related experiment configuration searcher field will be
                used.

            smaller_is_better (bool, optional): specifies whether to sort the
                metric above in ascending or descending order. If sort_by is unset,
                this parameter is ignored. By default the smaller_is_better value
                in the related experiment configuration is used.
        """
        check.eq(
            sum([int(latest), int(best), int(uuid is not None)]),
            1,
            "Exactly one of latest, best, or uuid must be set",
        )

        check.eq(
            sort_by is None,
            smaller_is_better is None,
            "sort_by and smaller_is_better must be set together",
        )

        if sort_by is not None and not best:
            raise AssertionError(
                "sort_by and smaller_is_better parameters can only be used with --best"
            )

        q = api.GraphQLQuery(self._master)

        if sort_by is not None:
            checkpoint_gql = q.op.best_checkpoint_by_metric(
                args={"tid": self.id, "metric": sort_by, "smaller_is_better": smaller_is_better},
            )
        else:
            where = gql.checkpoints_bool_exp(
                state=gql.checkpoint_state_comparison_exp(_eq="COMPLETED"),
                trial_id=gql.Int_comparison_exp(_eq=self.id),
            )

            order_by = []  # type: List[gql.checkpoints_order_by]
            if uuid is not None:
                where.uuid = gql.uuid_comparison_exp(_eq=uuid)
            elif latest:
                order_by = [gql.checkpoints_order_by(end_time=gql.order_by.desc)]
            elif best:
                where.validation = gql.validations_bool_exp(
                    state=gql.validation_state_comparison_exp(_eq="COMPLETED")
                )
                order_by = [
                    gql.checkpoints_order_by(
                        validation=gql.validations_order_by(
                            metric_values=gql.validation_metrics_order_by(signed=gql.order_by.asc)
                        )
                    )
                ]

            checkpoint_gql = q.op.checkpoints(where=where, order_by=order_by, limit=1)

        checkpoint_gql.state()
        checkpoint_gql.uuid()
        checkpoint_gql.resources()

        validation = checkpoint_gql.validation()
        validation.metrics()
        validation.state()

        step = checkpoint_gql.step()
        step.id()
        step.start_time()
        step.end_time()
        step.trial.experiment.config()

        resp = q.send()

        result = resp.best_checkpoint_by_metric if sort_by is not None else resp.checkpoints

        if not result:
            raise AssertionError("No checkpoint found for trial {}".format(self.id))

        ckpt_gql = result[0]
        batch_number = ckpt_gql.step.trial.experiment.config["batches_per_step"] * ckpt_gql.step.id
        return checkpoint.Checkpoint(
            ckpt_gql.uuid,
            ckpt_gql.step.trial.experiment.config["checkpoint_storage"],
            batch_number,
            ckpt_gql.step.start_time,
            ckpt_gql.step.end_time,
            ckpt_gql.resources,
            ckpt_gql.validation,
        )