Exemplo n.º 1
0
def config(args: Namespace) -> None:
    q = api.GraphQLQuery(args.master)
    q.op.experiments_by_pk(id=args.experiment_id).config()
    resp = q.send()
    yaml.safe_dump(resp.experiments_by_pk.config,
                   stream=sys.stdout,
                   default_flow_style=False)
Exemplo n.º 2
0
def download(master: str, trial_id: int, step_id: int, output_dir: str) -> None:
    q = api.GraphQLQuery(master)

    step = q.op.steps_by_pk(trial_id=trial_id, id=step_id)
    step.checkpoint.labels()
    step.checkpoint.resources()
    step.checkpoint.uuid()
    step.trial.experiment.config(path="checkpoint_storage")
    step.trial.experiment_id()

    resp = q.send()

    step = resp.steps_by_pk
    if not step:
        raise ValueError("Trial {} step {} not found".format(trial_id, step_id))

    if not step.checkpoint:
        raise ValueError("Trial {} step {} has no checkpoint".format(trial_id, step_id))

    storage_config = step.trial.experiment.config
    manager = storage.build(storage_config)
    if not (
        isinstance(manager, storage.S3StorageManager)
        or isinstance(manager, storage.GCSStorageManager)
    ):
        raise AssertionError(
            "Downloading from S3 or GCS requires the experiment to be configured with "
            "S3 or GCS checkpointing, {} found instead".format(storage_config["type"])
        )
    metadata = storage.StorageMetadata.from_json(step.checkpoint.__to_json_value__())
    manager.download(metadata, output_dir)
Exemplo n.º 3
0
def follow_experiment_logs(master_url: str, exp_id: int) -> None:
    # Get the ID of this experiment's first trial (i.e., the one with the lowest ID).
    q = api.GraphQLQuery(master_url)
    trials = q.op.trials(
        where=gql.trials_bool_exp(experiment_id=gql.Int_comparison_exp(
            _eq=exp_id)),
        order_by=[gql.trials_order_by(id=gql.order_by.asc)],
        limit=1,
    )
    trials.id()

    print("Waiting for first trial to begin...")
    while True:
        resp = q.send()
        if resp.trials:
            break
        else:
            time.sleep(0.1)

    first_trial_id = resp.trials[0].id
    print("Following first trial with ID {}".format(first_trial_id))

    # Call `logs --follow` on the new trial.
    logs_args = Namespace(trial_id=first_trial_id,
                          follow=True,
                          master=master_url,
                          tail=None)
    logs(logs_args)
Exemplo n.º 4
0
def list_trials(args: Namespace) -> None:
    q = api.GraphQLQuery(args.master)
    trials = q.op.trials(
        order_by=[gql.trials_order_by(id=gql.order_by.asc)],
        where=gql.trials_bool_exp(experiment_id=gql.Int_comparison_exp(
            _eq=args.experiment_id)),
    )
    trials.id()
    trials.state()
    trials.hparams()
    trials.start_time()
    trials.end_time()
    trials.steps_aggregate().aggregate.count()

    resp = q.send()

    headers = [
        "Trial ID", "State", "H-Params", "Start Time", "End Time", "# of Steps"
    ]
    values = [[
        t.id,
        t.state,
        json.dumps(t.hparams, indent=4),
        render.format_time(t.start_time),
        render.format_time(t.end_time),
        t.steps_aggregate.aggregate.count,
    ] for t in resp.trials]

    render.tabulate_or_csv(headers, values, args.csv)
Exemplo n.º 5
0
def experiment_id_completer(prefix: str, parsed_args: Namespace,
                            **kwargs: Any) -> List[str]:
    auth.initialize_session(parsed_args.master,
                            parsed_args.user,
                            try_reauth=True)
    q = api.GraphQLQuery(parsed_args.master)
    q.op.experiments().id()
    resp = q.send()
    return [str(e["id"]) for e in resp.experiments]
Exemplo n.º 6
0
def list(args: Namespace) -> None:
    q = api.GraphQLQuery(args.master)
    q.op.experiments_by_pk(id=args.experiment_id).config(
        path="checkpoint_storage")

    order_by = [
        gql.checkpoints_order_by(validation=gql.validations_order_by(
            metric_values=gql.validation_metrics_order_by(
                signed=gql.order_by.asc)))
    ]

    limit = None
    if args.best is not None:
        if args.best < 0:
            raise AssertionError("--best must be a non-negative integer")
        limit = args.best

    checkpoints = q.op.checkpoints(
        where=gql.checkpoints_bool_exp(step=gql.steps_bool_exp(
            trial=gql.trials_bool_exp(experiment_id=gql.Int_comparison_exp(
                _eq=args.experiment_id)))),
        order_by=order_by,
        limit=limit,
    )
    checkpoints.end_time()
    checkpoints.labels()
    checkpoints.resources()
    checkpoints.start_time()
    checkpoints.state()
    checkpoints.step_id()
    checkpoints.trial_id()
    checkpoints.uuid()

    checkpoints.step.validation.metric_values.raw()

    resp = q.send()

    headers = [
        "Trial ID", "Step ID", "State", "Validation Metric", "UUID",
        "Resources", "Size"
    ]
    values = [[
        c.trial_id,
        c.step_id,
        c.state,
        c.step.validation.metric_values.raw
        if c.step.validation and c.step.validation.metric_values else None,
        c.uuid,
        render.format_resources(c.resources),
        render.format_resource_sizes(c.resources),
    ] for c in resp.checkpoints]

    render.tabulate_or_csv(headers, values, args.csv)
Exemplo n.º 7
0
def wait(args: Namespace) -> None:
    while True:
        q = api.GraphQLQuery(args.master)
        q.op.experiments_by_pk(id=args.experiment_id).state()
        resp = q.send()
        state = resp.experiments_by_pk.state

        if state in constants.TERMINAL_STATES:
            print("Experiment {} terminated with state {}".format(args.experiment_id, state))
            if state == constants.COMPLETED:
                sys.exit(0)
            else:
                sys.exit(1)

        time.sleep(args.polling_interval)
Exemplo n.º 8
0
    def logs_query(tail: Optional[int] = None, greater_than_id: Optional[int] = None) -> Any:
        q = api.GraphQLQuery(args.master)
        limit = None
        order_by = [gql.trial_logs_order_by(id=gql.order_by.asc)]
        where = gql.trial_logs_bool_exp(trial_id=gql.Int_comparison_exp(_eq=args.trial_id))
        if greater_than_id is not None:
            where.id = gql.Int_comparison_exp(_gt=greater_than_id)
        if tail is not None:
            order_by = [gql.trial_logs_order_by(id=gql.order_by.desc)]
            limit = tail
        logs = q.op.trial_logs(where=where, order_by=order_by, limit=limit)
        logs.id()
        logs.message()

        q.op.trials_by_pk(id=args.trial_id).state()
        return q
Exemplo n.º 9
0
def list_experiments(args: Namespace) -> None:
    where = None
    if not args.all:
        user = api.Authentication.instance().get_session_user()
        where = gql.experiments_bool_exp(
            archived=gql.Boolean_comparison_exp(_eq=False),
            owner=gql.users_bool_exp(username=gql.String_comparison_exp(
                _eq=user)),
        )

    q = api.GraphQLQuery(args.master)
    exps = q.op.experiments(
        order_by=[gql.experiments_order_by(id=gql.order_by.desc)], where=where)
    exps.archived()
    exps.config()
    exps.end_time()
    exps.id()
    exps.owner.username()
    exps.progress()
    exps.start_time()
    exps.state()

    resp = q.send()

    def format_experiment(e: Any) -> List[Any]:
        result = [
            e.id,
            e.owner.username,
            e.config["description"],
            e.state,
            render.format_percent(e.progress),
            render.format_time(e.start_time),
            render.format_time(e.end_time),
        ]
        if args.all:
            result.append(e.archived)
        return result

    headers = [
        "ID", "Owner", "Description", "State", "Progress", "Start Time",
        "End Time"
    ]
    if args.all:
        headers.append("Archived")

    values = [format_experiment(e) for e in resp.experiments]
    render.tabulate_or_csv(headers, values, args.csv)
Exemplo n.º 10
0
def get_checkpoint(uuid: str, master: str) -> Checkpoint:
    q = api.GraphQLQuery(master)

    where = gql.checkpoints_bool_exp(
        state=gql.checkpoint_state_comparison_exp(_eq="COMPLETED"),
        uuid=gql.uuid_comparison_exp(_eq=uuid),
    )

    checkpoint_gql = q.op.checkpoints(where=where)

    checkpoint_gql.state()
    checkpoint_gql.uuid()
    checkpoint_gql.resources()

    validation = checkpoint_gql.validation()
    validation.metrics()
    validation.state()

    step = checkpoint_gql.step()
    step.id()
    step.start_time()
    step.end_time()
    step.trial.experiment.config()

    resp = q.send()
    result = resp.checkpoints

    if not result:
        raise AssertionError("No checkpoint found for uuid {}".format(uuid))

    ckpt_gql = result[0]
    batch_number = ckpt_gql.step.trial.experiment.config["batches_per_step"] * ckpt_gql.step.id
    return Checkpoint(
        ckpt_gql.uuid,
        ckpt_gql.step.trial.experiment.config["checkpoint_storage"],
        batch_number,
        ckpt_gql.step.start_time,
        ckpt_gql.step.end_time,
        ckpt_gql.resources,
        ckpt_gql.validation,
    )
Exemplo n.º 11
0
def query() -> api.GraphQLQuery:
    auth.initialize_session(conf.make_master_url(), try_reauth=True)
    return api.GraphQLQuery(conf.make_master_url())
Exemplo n.º 12
0
    def top_n_checkpoints(
        self,
        limit: int,
        sort_by: Optional[str] = None,
        smaller_is_better: Optional[bool] = None
    ) -> List[checkpoint.Checkpoint]:
        """
        Return the n :py:class:`det.experimental.Checkpoint` instances with the best
        validation metric values as defined by the `sort_by` and `smaller_is_better`
        arguments.

        Arguments:
            sort_by (string, optional): The name of the validation metric to
                order checkpoints by. If this parameter is unset the metric defined
                in the experiment configuration searcher field will be
                used.

            smaller_is_better (bool, optional): Specifies whether to sort the
                metric above in ascending or descending order. If sort_by is unset,
                this parameter is ignored. By default the smaller_is_better value
                in the experiment configuration is used.
        """
        q = api.GraphQLQuery(self._master)
        exp = q.op.experiments_by_pk(id=self.id)
        checkpoints = exp.best_checkpoints_by_metric(
            args={
                "lim": limit,
                "metric": sort_by,
                "smaller_is_better": smaller_is_better
            })

        checkpoints.state()
        checkpoints.uuid()
        checkpoints.resources()

        validation = checkpoints.validation()
        validation.metrics()
        validation.state()

        step = checkpoints.step()
        step.id()
        step.start_time()
        step.end_time()
        step.trial.experiment.config()
        step.trial.id()

        resp = q.send()

        checkpoints_resp = resp.experiments_by_pk.best_checkpoints_by_metric

        if not checkpoints_resp:
            return []

        experiment_conf = checkpoints_resp[0].step.trial.experiment.config
        sib = (smaller_is_better if smaller_is_better is not None else
               experiment_conf["searcher"]["smaller_is_better"])

        sort_metric = sort_by if sort_by is not None else experiment_conf[
            "searcher"]["metric"]
        ordered_checkpoints = sorted(
            checkpoints_resp,
            key=lambda c: c.validation.metrics["validation_metrics"][
                sort_metric],
            reverse=not sib,
        )

        return [
            checkpoint.Checkpoint(
                ckpt.uuid,
                ckpt.step.trial.experiment.config["checkpoint_storage"],
                ckpt.step.trial.experiment.config["batches_per_step"] *
                ckpt.step.id,
                ckpt.step.start_time,
                ckpt.step.end_time,
                ckpt.resources,
                ckpt.validation,
            ) for ckpt in ordered_checkpoints
        ]
Exemplo n.º 13
0
def describe_trial(args: Namespace) -> None:
    q = api.GraphQLQuery(args.master)
    trial = q.op.trials_by_pk(id=args.trial_id)
    trial.end_time()
    trial.experiment_id()
    trial.hparams()
    trial.start_time()
    trial.state()

    steps = trial.steps(order_by=[gql.steps_order_by(id=gql.order_by.asc)])
    steps.metrics()
    steps.id()
    steps.state()
    steps.start_time()
    steps.end_time()

    checkpoint_gql = steps.checkpoint()
    checkpoint_gql.state()
    checkpoint_gql.uuid()

    validation = steps.validation()
    validation.state()
    validation.metrics()

    resp = q.send()

    if args.json:
        print(json.dumps(resp.trials_by_pk.__to_json_value__(), indent=4))
        return

    trial = resp.trials_by_pk

    # Print information about the trial itself.
    headers = ["Experiment ID", "State", "H-Params", "Start Time", "End Time"]
    values = [
        [
            trial.experiment_id,
            trial.state,
            json.dumps(trial.hparams, indent=4),
            render.format_time(trial.start_time),
            render.format_time(trial.end_time),
        ]
    ]
    render.tabulate_or_csv(headers, values, args.csv)

    # Print information about individual steps.
    headers = [
        "Step #",
        "State",
        "Start Time",
        "End Time",
        "Checkpoint",
        "Checkpoint UUID",
        "Validation",
        "Validation Metrics",
    ]
    if args.metrics:
        headers.append("Step Metrics")

    values = [
        [
            s.id,
            s.state,
            render.format_time(s.start_time),
            render.format_time(s.end_time),
            *format_checkpoint(s.checkpoint),
            *format_validation(s.validation),
            *([json.dumps(s.metrics, indent=4)] if args.metrics else []),
        ]
        for s in trial.steps
    ]

    print()
    print("Steps:")
    render.tabulate_or_csv(headers, values, args.csv)
Exemplo n.º 14
0
def logs(args: Namespace) -> None:
    def logs_query(tail: Optional[int] = None, greater_than_id: Optional[int] = None) -> Any:
        q = api.GraphQLQuery(args.master)
        limit = None
        order_by = [gql.trial_logs_order_by(id=gql.order_by.asc)]
        where = gql.trial_logs_bool_exp(trial_id=gql.Int_comparison_exp(_eq=args.trial_id))
        if greater_than_id is not None:
            where.id = gql.Int_comparison_exp(_gt=greater_than_id)
        if tail is not None:
            order_by = [gql.trial_logs_order_by(id=gql.order_by.desc)]
            limit = tail
        logs = q.op.trial_logs(where=where, order_by=order_by, limit=limit)
        logs.id()
        logs.message()

        q.op.trials_by_pk(id=args.trial_id).state()
        return q

    def process_response(logs: Any, latest_log_id: int) -> Tuple[int, bool]:
        changes = False
        for log in logs:
            check_gt(log.id, latest_log_id)
            latest_log_id = log.id
            msg = api.decode_bytes(log.message)
            print(msg, end="")
            changes = True

        return latest_log_id, changes

    resp = logs_query(args.tail).send()
    logs = resp.trial_logs
    # Due to limitations of the GraphQL API, which mimics SQL, requesting a tail means we have to
    # get the results in descending ID order and reverse them afterward.
    if args.tail is not None:
        logs = reversed(logs)
    latest_log_id, _ = process_response(logs, -1)

    # "Follow" mode is implemented as a loop in the CLI. We assume that
    # newer log messages have a numerically larger ID than older log
    # messages, so we keep track of the max ID seen so far.
    if args.follow:
        change_time = time.time()
        try:
            while True:
                # Poll for new logs at most every 100 ms.
                time.sleep(0.1)

                # The `tail` parameter only makes sense the first time we fetch logs.
                resp = logs_query(greater_than_id=latest_log_id).send()
                latest_log_id, changes = process_response(resp.trial_logs, latest_log_id)

                # Exit once the trial has, for 1 second, been in a terminal state and sent no logs.
                if changes or resp.trials_by_pk.state not in constants.TERMINAL_STATES:
                    change_time = time.time()
                elif time.time() - change_time > 1:
                    raise KeyboardInterrupt()
        except KeyboardInterrupt:
            state_query = api.GraphQLQuery(args.master)
            state_query.op.trials_by_pk(id=args.trial_id).state()
            resp = state_query.send()

            print(
                colored(
                    "Trial is in the {} state. To reopen log stream, run: "
                    "det trial logs -f {}".format(resp.trials_by_pk.state, args.trial_id),
                    "green",
                )
            )
Exemplo n.º 15
0
def list(args: Namespace) -> None:
    q = api.GraphQLQuery(args.master)
    q.op.experiments_by_pk(id=args.experiment_id).config(path="checkpoint_storage")

    order_by = [
        gql.checkpoints_order_by(
            validation=gql.validations_order_by(
                metric_values=gql.validation_metrics_order_by(signed=gql.order_by.asc)
            )
        )
    ]

    limit = None
    if args.best is not None:
        if args.best < 0:
            raise AssertionError("--best must be a non-negative integer")
        limit = args.best

    checkpoints = q.op.checkpoints(
        where=gql.checkpoints_bool_exp(
            step=gql.steps_bool_exp(
                trial=gql.trials_bool_exp(
                    experiment_id=gql.Int_comparison_exp(_eq=args.experiment_id)
                )
            )
        ),
        order_by=order_by,
        limit=limit,
    )
    checkpoints.end_time()
    checkpoints.labels()
    checkpoints.resources()
    checkpoints.start_time()
    checkpoints.state()
    checkpoints.step_id()
    checkpoints.trial_id()
    checkpoints.uuid()

    checkpoints.step.validation.metric_values.raw()

    resp = q.send()
    config = resp.experiments_by_pk.config

    headers = ["Trial ID", "Step ID", "State", "Validation Metric", "UUID", "Resources", "Size"]
    values = [
        [
            c.trial_id,
            c.step_id,
            c.state,
            c.step.validation.metric_values.raw
            if c.step.validation and c.step.validation.metric_values
            else None,
            c.uuid,
            render.format_resources(c.resources),
            render.format_resource_sizes(c.resources),
        ]
        for c in resp.checkpoints
    ]

    render.tabulate_or_csv(headers, values, args.csv)

    if args.download_dir is not None:
        manager = storage.build(config)
        if not (
            isinstance(manager, storage.S3StorageManager)
            or isinstance(manager, storage.GCSStorageManager)
        ):
            print(
                "Downloading from S3 or GCS requires the experiment to be configured with "
                "S3 or GCS checkpointing, {} found instead".format(config["type"])
            )
            sys.exit(1)

        for checkpoint in resp.checkpoints:
            metadata = storage.StorageMetadata.from_json(checkpoint.__to_json_value__())
            ckpt_dir = args.download_dir.joinpath(
                "exp-{}-trial-{}-step-{}".format(
                    args.experiment_id, checkpoint.trial_id, checkpoint.step_id
                )
            )
            print("Downloading checkpoint {} to {}".format(checkpoint.uuid, ckpt_dir))
            manager.download(metadata, ckpt_dir)
Exemplo n.º 16
0
def follow_test_experiment_logs(master_url: str, exp_id: int) -> None:
    def print_progress(active_stage: int, ended: bool) -> None:
        # There are four sequential stages of verification. Track the
        # current stage with an index into this list.
        stages = [
            "Scheduling task",
            "Testing training",
            "Testing validation",
            "Testing checkpointing",
        ]

        for idx, stage in enumerate(stages):
            if active_stage > idx:
                color = "green"
                checkbox = "✔"
            elif active_stage == idx:
                color = "red" if ended else "yellow"
                checkbox = "✗" if ended else " "
            else:
                color = "white"
                checkbox = " "
            print(colored(stage + (25 - len(stage)) * ".", color), end="")
            print(colored(" [" + checkbox + "]", color), end="")

            if idx == len(stages) - 1:
                print("\n" if ended else "\r", end="")
            else:
                print(", ", end="")

    q = api.GraphQLQuery(master_url)
    exp = q.op.experiments_by_pk(id=exp_id)
    exp.state()
    steps = exp.trials.steps(
        order_by=[gql.steps_order_by(id=gql.order_by.asc)])
    steps.checkpoint().id()
    steps.validation().id()

    while True:
        exp = q.send().experiments_by_pk

        # Wait for experiment to start and initialize a trial and step.
        step = None
        if exp.trials and exp.trials[0].steps:
            step = exp.trials[0].steps[0]

        # Update the active stage by examining the status of the experiment. The way the GraphQL
        # library works is that the checkpoint and validation attributes of a step are always
        # present and non-None, but they don't have any attributes of their own when the
        # corresponding database object doesn't exist.
        if exp.state == constants.COMPLETED:
            active_stage = 4
        elif step and hasattr(step.checkpoint, "id"):
            active_stage = 3
        elif step and hasattr(step.validation, "id"):
            active_stage = 2
        elif step:
            active_stage = 1
        else:
            active_stage = 0

        # If the experiment is in a terminal state, output the appropriate
        # message and exit. Otherwise, sleep and repeat.
        if exp.state == "COMPLETED":
            print_progress(active_stage, ended=True)
            print(colored("Model definition test succeeded! 🎉", "green"))
            return
        elif exp.state == constants.CANCELED:
            print_progress(active_stage, ended=True)
            print(
                colored(
                    "Model definition test (ID: {}) canceled before "
                    "model test could complete. Please re-run the "
                    "command.".format(exp_id),
                    "yellow",
                ))
            sys.exit(1)
        elif exp.state == constants.ERROR:
            print_progress(active_stage, ended=True)
            trial_id = exp.trials[0].id
            logs_args = Namespace(trial_id=trial_id,
                                  master=master_url,
                                  tail=None,
                                  follow=False)
            logs(logs_args)
            sys.exit(1)
        else:
            print_progress(active_stage, ended=False)
            time.sleep(0.2)
Exemplo n.º 17
0
    def select_checkpoint(
        self,
        latest: bool = False,
        best: bool = False,
        uuid: Optional[str] = None,
        sort_by: Optional[str] = None,
        smaller_is_better: Optional[bool] = None,
    ) -> checkpoint.Checkpoint:
        """
        Return the :py:class:`det.experimental.Checkpoint` instance with the best
        validation metric as defined by the `sort_by` and `smaller_is_better`
        arguments.

        Exactly one of the best, latest, or uuid parameters must be set.

        Arguments:
            latest (bool, optional): return the most recent checkpoint.

            best (bool, optional): return the checkpoint with the best validation
                metric as defined by the `sort_by` and `smaller_is_better`
                arguments. If `sort_by` and `smaller_is_better` are not
                specified, the values from the associated experiment
                configuration will be used.

            uuid (string, optional): return the checkpoint for the specified uuid.

            sort_by (string, optional): the name of the validation metric to
                order checkpoints by. If this parameter is unset the metric defined
                in the related experiment configuration searcher field will be
                used.

            smaller_is_better (bool, optional): specifies whether to sort the
                metric above in ascending or descending order. If sort_by is unset,
                this parameter is ignored. By default the smaller_is_better value
                in the related experiment configuration is used.
        """
        check.eq(
            sum([int(latest), int(best), int(uuid is not None)]),
            1,
            "Exactly one of latest, best, or uuid must be set",
        )

        check.eq(
            sort_by is None,
            smaller_is_better is None,
            "sort_by and smaller_is_better must be set together",
        )

        if sort_by is not None and not best:
            raise AssertionError(
                "sort_by and smaller_is_better parameters can only be used with --best"
            )

        q = api.GraphQLQuery(self._master)

        if sort_by is not None:
            checkpoint_gql = q.op.best_checkpoint_by_metric(
                args={"tid": self.id, "metric": sort_by, "smaller_is_better": smaller_is_better},
            )
        else:
            where = gql.checkpoints_bool_exp(
                state=gql.checkpoint_state_comparison_exp(_eq="COMPLETED"),
                trial_id=gql.Int_comparison_exp(_eq=self.id),
            )

            order_by = []  # type: List[gql.checkpoints_order_by]
            if uuid is not None:
                where.uuid = gql.uuid_comparison_exp(_eq=uuid)
            elif latest:
                order_by = [gql.checkpoints_order_by(end_time=gql.order_by.desc)]
            elif best:
                where.validation = gql.validations_bool_exp(
                    state=gql.validation_state_comparison_exp(_eq="COMPLETED")
                )
                order_by = [
                    gql.checkpoints_order_by(
                        validation=gql.validations_order_by(
                            metric_values=gql.validation_metrics_order_by(signed=gql.order_by.asc)
                        )
                    )
                ]

            checkpoint_gql = q.op.checkpoints(where=where, order_by=order_by, limit=1)

        checkpoint_gql.state()
        checkpoint_gql.uuid()
        checkpoint_gql.resources()

        validation = checkpoint_gql.validation()
        validation.metrics()
        validation.state()

        step = checkpoint_gql.step()
        step.id()
        step.start_time()
        step.end_time()
        step.trial.experiment.config()

        resp = q.send()

        result = resp.best_checkpoint_by_metric if sort_by is not None else resp.checkpoints

        if not result:
            raise AssertionError("No checkpoint found for trial {}".format(self.id))

        ckpt_gql = result[0]
        batch_number = ckpt_gql.step.trial.experiment.config["batches_per_step"] * ckpt_gql.step.id
        return checkpoint.Checkpoint(
            ckpt_gql.uuid,
            ckpt_gql.step.trial.experiment.config["checkpoint_storage"],
            batch_number,
            ckpt_gql.step.start_time,
            ckpt_gql.step.end_time,
            ckpt_gql.resources,
            ckpt_gql.validation,
        )
Exemplo n.º 18
0
def describe(args: Namespace) -> None:
    ids = [int(x) for x in args.experiment_ids.split(",")]

    q = api.GraphQLQuery(args.master)
    exps = q.op.experiments(where=gql.experiments_bool_exp(
        id=gql.Int_comparison_exp(_in=ids)))
    exps.archived()
    exps.config()
    exps.end_time()
    exps.id()
    exps.progress()
    exps.start_time()
    exps.state()

    trials = exps.trials(order_by=[gql.trials_order_by(id=gql.order_by.asc)])
    trials.end_time()
    trials.hparams()
    trials.id()
    trials.start_time()
    trials.state()

    steps = trials.steps(order_by=[gql.steps_order_by(id=gql.order_by.asc)])
    steps.end_time()
    steps.id()
    steps.start_time()
    steps.state()
    steps.trial_id()

    steps.checkpoint.end_time()
    steps.checkpoint.start_time()
    steps.checkpoint.state()

    steps.validation.end_time()
    steps.validation.start_time()
    steps.validation.state()

    if args.metrics:
        steps.metrics(path="avg_metrics")
        steps.validation.metrics()

    resp = q.send()

    # Re-sort the experiment objects to match the original order.
    exps_by_id = {e.id: e for e in resp.experiments}
    experiments = [exps_by_id[id] for id in ids]

    if args.json:
        print(json.dumps(resp.__to_json_value__()["experiments"], indent=4))
        return

    # Display overall experiment information.
    headers = [
        "Experiment ID",
        "State",
        "Progress",
        "Start Time",
        "End Time",
        "Description",
        "Archived",
        "Labels",
    ]
    values = [[
        e.id,
        e.state,
        render.format_percent(e.progress),
        render.format_time(e.start_time),
        render.format_time(e.end_time),
        e.config.get("description"),
        e.archived,
        ", ".join(sorted(e.config.get("labels", []))),
    ] for e in experiments]
    if not args.outdir:
        outfile = None
        print("Experiment:")
    else:
        outfile = args.outdir.joinpath("experiments.csv")
    render.tabulate_or_csv(headers, values, args.csv, outfile)

    # Display trial-related information.
    headers = [
        "Trial ID", "Experiment ID", "State", "Start Time", "End Time",
        "H-Params"
    ]
    values = [[
        t.id,
        e.id,
        t.state,
        render.format_time(t.start_time),
        render.format_time(t.end_time),
        json.dumps(t.hparams, indent=4),
    ] for e in experiments for t in e.trials]
    if not args.outdir:
        outfile = None
        print("\nTrials:")
    else:
        outfile = args.outdir.joinpath("trials.csv")
    render.tabulate_or_csv(headers, values, args.csv, outfile)

    # Display step-related information.
    if args.metrics:
        # Accumulate the scalar training and validation metric names from all provided experiments.
        t_metrics_names = sorted(
            {n
             for e in experiments for n in scalar_training_metrics_names(e)})
        t_metrics_headers = [
            "Training Metric: {}".format(name) for name in t_metrics_names
        ]

        v_metrics_names = sorted({
            n
            for e in experiments for n in scalar_validation_metrics_names(e)
        })
        v_metrics_headers = [
            "Validation Metric: {}".format(name) for name in v_metrics_names
        ]
    else:
        t_metrics_headers = []
        v_metrics_headers = []

    headers = (["Trial ID", "Step ID", "State", "Start Time", "End Time"] +
               t_metrics_headers + [
                   "Checkpoint State",
                   "Checkpoint Start Time",
                   "Checkpoint End Time",
                   "Validation State",
                   "Validation Start Time",
                   "Validation End Time",
               ] + v_metrics_headers)

    values = []
    for e in experiments:
        for t in e.trials:
            for step in t.steps:
                t_metrics_fields = []
                if hasattr(step, "metrics"):
                    avg_metrics = step.metrics
                    for name in t_metrics_names:
                        if name in avg_metrics:
                            t_metrics_fields.append(avg_metrics[name])
                        else:
                            t_metrics_fields.append(None)

                checkpoint = step.checkpoint
                if checkpoint:
                    checkpoint_state = checkpoint.state
                    checkpoint_start_time = checkpoint.start_time
                    checkpoint_end_time = checkpoint.end_time
                else:
                    checkpoint_state = None
                    checkpoint_start_time = None
                    checkpoint_end_time = None

                validation = step.validation
                if validation:
                    validation_state = validation.state
                    validation_start_time = validation.start_time
                    validation_end_time = validation.end_time

                else:
                    validation_state = None
                    validation_start_time = None
                    validation_end_time = None

                if args.metrics:
                    v_metrics_fields = [
                        api.metric.get_validation_metric(name, validation)
                        for name in v_metrics_names
                    ]
                else:
                    v_metrics_fields = []

                row = ([
                    step.trial_id,
                    step.id,
                    step.state,
                    render.format_time(step.start_time),
                    render.format_time(step.end_time),
                ] + t_metrics_fields + [
                    checkpoint_state,
                    render.format_time(checkpoint_start_time),
                    render.format_time(checkpoint_end_time),
                    validation_state,
                    render.format_time(validation_start_time),
                    render.format_time(validation_end_time),
                ] + v_metrics_fields)
                values.append(row)

    if not args.outdir:
        outfile = None
        print("\nSteps:")
    else:
        outfile = args.outdir.joinpath("steps.csv")
    render.tabulate_or_csv(headers, values, args.csv, outfile)