Пример #1
0
def download_model_def(args: Namespace) -> None:
    resp = bindings.get_GetModelDef(setup_session(args),
                                    experimentId=args.experiment_id)
    with args.output_dir.joinpath(str(args.experiment_id)).open("wb") as f:
        f.write(base64.b64decode(resp.b64Tgz))
Пример #2
0
def archive(args: Namespace) -> None:
    bindings.post_ArchiveExperiment(setup_session(args), id=args.experiment_id)
    print("Archived experiment {}".format(args.experiment_id))
Пример #3
0
def describe(args: Namespace) -> None:
    session = setup_session(args)
    exps = []
    for experiment_id in args.experiment_ids.split(","):
        r = bindings.get_GetExperiment(session, experimentId=experiment_id)
        if args.json:
            exps.append(r.to_json())
        else:
            exps.append(r.experiment)

    if args.json:
        print(json.dumps(exps, indent=4))
        return

    # Display overall experiment information.
    headers = [
        "Experiment ID",
        "State",
        "Progress",
        "Start Time",
        "End Time",
        "Name",
        "Description",
        "Archived",
        "Resource Pool",
        "Labels",
    ]
    values = [[
        exp.id,
        exp.state.value.replace("STATE_", ""),
        render.format_percent(exp.progress),
        render.format_time(exp.startTime),
        render.format_time(exp.endTime),
        exp.name,
        exp.description,
        exp.archived,
        exp.resourcePool,
        ", ".join(sorted(exp.labels or [])),
    ] for exp in exps]
    if not args.outdir:
        outfile = None
        print("Experiment:")
    else:
        outfile = args.outdir.joinpath("experiments.csv")
    render.tabulate_or_csv(headers, values, args.csv, outfile)

    # Display trial-related information.
    trials_for_experiment: Dict[str, Sequence[bindings.trialv1Trial]] = {}
    for exp in exps:
        trials_for_experiment[exp.id] = bindings.get_GetExperimentTrials(
            session, experimentId=exp.id).trials

    headers = [
        "Trial ID", "Experiment ID", "State", "Start Time", "End Time",
        "H-Params"
    ]
    values = [[
        trial.id,
        exp.id,
        trial.state.value.replace("STATE_", ""),
        render.format_time(trial.startTime),
        render.format_time(trial.endTime),
        json.dumps(trial.hparams, indent=4),
    ] for exp in exps for trial in trials_for_experiment[exp.id]]
    if not args.outdir:
        outfile = None
        print("\nTrials:")
    else:
        outfile = args.outdir.joinpath("trials.csv")
    render.tabulate_or_csv(headers, values, args.csv, outfile)

    # Display step-related information.
    t_metrics_headers: List[str] = []
    t_metrics_names: List[str] = []
    v_metrics_headers: List[str] = []
    v_metrics_names: List[str] = []
    if args.metrics:
        # Accumulate the scalar training and validation metric names from all provided experiments.
        for exp in exps:
            sample_trial = trials_for_experiment[exp.id][0]
            sample_workloads = bindings.get_GetTrial(
                session, trialId=sample_trial.id).workloads
            t_metrics_names += scalar_training_metrics_names(sample_workloads)
            v_metrics_names += scalar_validation_metrics_names(
                sample_workloads)
        t_metrics_names = sorted(set(t_metrics_names))
        t_metrics_headers = [
            "Training Metric: {}".format(name) for name in t_metrics_names
        ]
        v_metrics_names = sorted(set(v_metrics_names))
        v_metrics_headers = [
            "Validation Metric: {}".format(name) for name in v_metrics_names
        ]

    headers = (["Trial ID", "# of Batches", "State", "Report Time"] +
               t_metrics_headers + [
                   "Checkpoint State",
                   "Checkpoint Report Time",
                   "Validation State",
                   "Validation Report Time",
               ] + v_metrics_headers)

    wl_output: Dict[int, List[Any]] = {}
    for exp in exps:
        for trial in trials_for_experiment[exp.id]:
            workloads = bindings.get_GetTrial(session,
                                              trialId=trial.id).workloads
            for workload in workloads:
                t_metrics_fields = []
                wl_detail: Optional[
                    Union[bindings.v1MetricsWorkload,
                          bindings.v1CheckpointWorkload]] = None
                if workload.training:
                    wl_detail = workload.training
                    for name in t_metrics_names:
                        if wl_detail.metrics and (name in wl_detail.metrics):
                            t_metrics_fields.append(wl_detail.metrics[name])
                        else:
                            t_metrics_fields.append(None)
                else:
                    t_metrics_fields = [None for name in t_metrics_names]

                if workload.checkpoint:
                    wl_detail = workload.checkpoint

                if workload.checkpoint and wl_detail:
                    checkpoint_state = wl_detail.state.value
                    checkpoint_end_time = wl_detail.endTime
                else:
                    checkpoint_state = ""
                    checkpoint_end_time = None

                v_metrics_fields = []
                if workload.validation:
                    wl_detail = workload.validation
                    validation_state = wl_detail.state.value
                    validation_end_time = wl_detail.endTime
                    for name in v_metrics_names:
                        if wl_detail.metrics and (name in wl_detail.metrics):
                            v_metrics_fields.append(wl_detail.metrics[name])
                        else:
                            v_metrics_fields.append(None)
                else:
                    validation_state = ""
                    validation_end_time = None
                    v_metrics_fields = [None for name in v_metrics_names]

                if wl_detail:
                    if wl_detail.totalBatches in wl_output:
                        # condense training, checkpoints, validation workloads into one step-like
                        # row for compatibility with previous versions of describe
                        merge_row = wl_output[wl_detail.totalBatches]
                        merge_row[3] = max(
                            merge_row[3],
                            render.format_time(wl_detail.endTime))
                        for idx, tfield in enumerate(t_metrics_fields):
                            if tfield and merge_row[4 + idx] is None:
                                merge_row[4 + idx] = tfield
                        start_checkpoint = 4 + len(t_metrics_fields)
                        if checkpoint_state:
                            merge_row[
                                start_checkpoint] = checkpoint_state.replace(
                                    "STATE_", "")
                            merge_row[start_checkpoint +
                                      1] = render.format_time(
                                          checkpoint_end_time)
                        if validation_end_time:
                            merge_row[start_checkpoint +
                                      3] = render.format_time(
                                          validation_end_time)
                        if validation_state:
                            merge_row[start_checkpoint +
                                      2] = validation_state.replace(
                                          "STATE_", "")
                        for idx, vfield in enumerate(v_metrics_fields):
                            if vfield and merge_row[start_checkpoint + idx +
                                                    4] is None:
                                merge_row[start_checkpoint + idx + 4] = vfield
                    else:
                        row = ([
                            trial.id,
                            wl_detail.totalBatches,
                            wl_detail.state.value.replace("STATE_", ""),
                            render.format_time(wl_detail.endTime),
                        ] + t_metrics_fields + [
                            checkpoint_state.replace("STATE_", ""),
                            render.format_time(checkpoint_end_time),
                            validation_state.replace("STATE_", ""),
                            render.format_time(validation_end_time),
                        ] + v_metrics_fields)
                        wl_output[wl_detail.totalBatches] = row

    if not args.outdir:
        outfile = None
        print("\nWorkloads:")
    else:
        outfile = args.outdir.joinpath("workloads.csv")
    values = sorted(wl_output.values(), key=lambda a: int(a[1]))
    render.tabulate_or_csv(headers, values, args.csv, outfile)
Пример #4
0
def config(args: Namespace) -> None:
    result = bindings.get_GetExperiment(setup_session(args),
                                        experimentId=args.experiment_id).config
    yaml.safe_dump(result, stream=sys.stdout, default_flow_style=False)
Пример #5
0
def pause(args: Namespace) -> None:
    bindings.post_PauseExperiment(setup_session(args), id=args.experiment_id)
    print("Paused experiment {}".format(args.experiment_id))
Пример #6
0
def cancel(args: Namespace) -> None:
    bindings.post_CancelExperiment(setup_session(args), id=args.experiment_id)
    print("Canceled experiment {}".format(args.experiment_id))
Пример #7
0
def kill_experiment(args: Namespace) -> None:
    bindings.post_KillExperiment(setup_session(args), id=args.experiment_id)
    print("Killed experiment {}".format(args.experiment_id))
Пример #8
0
def process_updates(args: Namespace) -> None:
    session = setup_session(args)
    for arg in args.operation:
        inputs = validate_operation_args(arg)
        _single_update(session=session, **inputs)
Пример #9
0
def unarchive_workspace(args: Namespace) -> None:
    sess = setup_session(args)
    current = workspace_by_name(sess, args.workspace_name)
    bindings.post_UnarchiveWorkspace(sess, id=current.id)
    print(f"Successfully un-archived workspace {args.workspace_name}.")