Exemple #1
0
def open_shell(args: Namespace) -> None:
    shell = render.unmarshal(
        Command,
        api.get(args.master, "shells/{}".format(args.shell_id)).json())
    check_eq(shell.state, "RUNNING", "Shell must be in a running state")
    agent_user = get_agent_user(args.master)
    _open_shell(shell, agent_user, args.ssh_opts)
Exemple #2
0
def open_tensorboard(args: Namespace) -> None:
    resp = api.get(args.master,
                   "tensorboard/{}".format(args.tensorboard_id)).json()
    tensorboard = render.unmarshal(Command, resp)
    check_eq(tensorboard.state, "RUNNING",
             "TensorBoard must be in a running state")
    api.open(args.master, resp["service_address"])
Exemple #3
0
 def step(self, *args: typing.Any, **kwargs: typing.Any) -> None:
     """Call step() on the wrapped LRScheduler instance."""
     check.check_eq(
         self.step_mode,
         LRScheduler.StepMode.MANUAL_STEP,
         "Please use the MANUAL_STEP step mode to call step() on the scheduler.",
     )
     return self.scheduler.step(*args, **kwargs)
Exemple #4
0
 def step(self) -> None:
     check.check_eq(
         self._lr_scheduler.scheduler._step_count,  # type: ignore
         self._lr_scheduler_count,
         "You cannot call `scheduler.step()` if you have configured "
         "Determined to manage the learning rate scheduler.",
     )
     self._lr_scheduler.scheduler.step()  # type: ignore
     self._lr_scheduler_count += 1
Exemple #5
0
def _reduce_metrics(
    reducer: Reducer, metrics: np.array, num_batches: Optional[List[int]] = None
) -> np.float:
    if reducer == Reducer.AVG:
        if num_batches:
            check.check_eq(len(metrics), len(num_batches))
        return np.average(metrics, weights=num_batches)
    elif reducer == Reducer.SUM:
        return np.sum(metrics)
    elif reducer == Reducer.MAX:
        return np.max(metrics)
    elif reducer == Reducer.MIN:
        return np.min(metrics)
    else:
        raise NotImplementedError
    def check_sane_workload(self, new_workload: workload.Workload) -> None:
        # If this is the initial workload, we don't expect to start with
        # a checkpoint operation. All other workloads are reasonable.
        if self.workload is None:
            check_not_eq(new_workload.kind, workload.Workload.Kind.CHECKPOINT_MODEL)
            return

        # If this is not the initial workload, it should be compatible
        # with the previous workload that ran in this container.
        check_eq(self.workload.trial_id, new_workload.trial_id)

        if new_workload.kind == workload.Workload.Kind.RUN_STEP:
            check_eq(self.workload.step_id + 1, new_workload.step_id)
        else:
            check_eq(self.workload.step_id, new_workload.step_id)
Exemple #7
0
def open_notebook(args: Namespace) -> None:
    resp = api.get(args.master, "notebooks/{}".format(args.notebook_id)).json()
    notebook = render.unmarshal(Command, resp)
    check_eq(notebook.state, "RUNNING", "Notebook must be in a running state")
    api.open(args.master, resp["service_address"])
Exemple #8
0
def _scan_checkpoint_directory(checkpoint_dir: str) -> List[Checkpoint]:
    """
    Construct checkpoint metadata directly from a directory.

    State files are sometimes out of sync with directory contents. Insert
    additional orphaned checkpoint files and prune missing files. To be
    conservative, we prefer data in checkpoint states, if correct, to those
    gathered from reading the directory.
    """

    # Phase 1: Scan directory.

    # `checkpoint_state_files` is a list of (cname, full path) tuples for each
    # checkpoint state file in the directory.
    checkpoint_state_files = []
    scanned_basenames = set()
    checkpoint_paths = defaultdict(
        lambda: defaultdict(list))  # type: Dict[str, Dict[str, List[str]]]
    with os.scandir(checkpoint_dir) as it:
        for f in it:
            if not f.is_file():
                continue

            if f.name.startswith("checkpoint_"):
                cname = f.name[len("checkpoint_"):]
                checkpoint_state_files.append((cname, f.path))
                continue
            elif f.name == "checkpoint":
                cname = "model"
                checkpoint_state_files.append((cname, f.path))
                continue

            cname, basename = split_checkpoint_filename(f.name)
            if not cname:
                continue

            scanned_basenames.add(basename)
            checkpoint_paths[cname][basename].append(f.path)

    # Phase 2: Read data from state files.

    checkpoints = {}
    for cname, path in checkpoint_state_files:
        latest_filename = os.path.basename(path)
        state = tf.train.get_checkpoint_state(checkpoint_dir,
                                              latest_filename=latest_filename)
        checkpoints[cname] = Checkpoint(state_file=path,
                                        name=cname,
                                        state=state,
                                        paths=checkpoint_paths[cname])

    # Phase 3: Merge scanned data with state data, preferring state data.

    for cname, checkpoint in checkpoints.items():
        old_ts = checkpoint.state.all_model_checkpoint_timestamps
        old_paths = checkpoint.state.all_model_checkpoint_paths
        # Use 0.0 as the default timestamp if none exists previously.
        if not old_ts:
            old_ts = [0.0] * len(old_paths)

        items = [(os.path.join(checkpoint_dir, b), 0.0)
                 for b in checkpoint_paths[cname]]
        check.check_eq(len(old_paths), len(old_ts))
        items.extend(zip(old_paths, old_ts))

        seen = set()  # type: Set[str]
        new_items = []
        for path, ts in reversed(items):
            basename = os.path.basename(path)
            if basename not in scanned_basenames:
                continue
            elif basename in seen:
                continue
            seen.add(basename)
            new_items.append((path, ts))

        if not new_items:
            raise Exception(
                "No checkpoint files found for {} checkpoint in directory {}".
                format(cname, checkpoint_dir))

        new_paths, new_ts = zip(*reversed(new_items))

        all_model_checkpoint_timestamps = None
        last_preserved_timestamp = None
        if checkpoint.state.all_model_checkpoint_timestamps is not None:
            all_model_checkpoint_timestamps = new_ts
            last_preserved_timestamp = new_ts[-1]

        check.check_eq(
            new_paths[-1],
            checkpoint.state.model_checkpoint_path,
            "Most recent checkpoint path should not change",
        )
        checkpoint.state = tf.compat.v1.train.generate_checkpoint_state_proto(
            checkpoint_dir,
            new_paths[-1],
            all_model_checkpoint_paths=new_paths,
            all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
            last_preserved_timestamp=last_preserved_timestamp,
        )

    return list(checkpoints.values())
Exemple #9
0
 def load(self, storage_dir: str) -> None:
     with open(os.path.join(storage_dir, "VALIDATE.txt"), "r") as fp:
         check_eq(fp.read(), self.uuid,
                  "Unable to properly load from storage")