示例#1
0
def load_bundle(path: Path, url: Optional[str] = None) -> Tuple[str, int]:
    if url:
        download_file(url, path)

    if not path.exists():
        raise FileNotFoundError(f'Bundle archive missing: {path}')

    with monit.section('Extract bundle'):
        with tarfile.open(str(path), 'r:gz') as tar:
            files = tar.getmembers()
            info_member = None
            for f in files:
                if f.name == 'info.json':
                    info_member = f

            if not info_member:
                raise RuntimeError(f"Corrupted bundle. Missing info.json")

            with tar.extractfile(info_member) as ef:
                info = json.load(ef)

            run_uuid, checkpoint = info['uuid'], info['checkpoint']
            run_path = get_run_by_uuid(lab.get_experiments_path(), run_uuid)

            if run_path is not None:
                logger.log(f"Run {run_uuid} exists", Text.meta)
                current_checkpoint = _get_run_checkpoint(run_path, checkpoint)
                if checkpoint == current_checkpoint:
                    logger.log(f"Checkpoint {checkpoint} exists", Text.meta)
                    return run_uuid, checkpoint

            run_path = lab.get_experiments_path() / 'bundled' / run_uuid

            checkpoint_path = run_path / "checkpoints" / str(checkpoint)
            if not checkpoint_path.exists():
                checkpoint_path.mkdir(parents=True)

            data_path = lab.get_data_path()
            if not data_path.exists():
                data_path.mkdir(parents=True)

            for f in files:
                if f.name == 'run.yaml':
                    _extract_tar_file(tar, f, run_path / 'run.yaml')
                elif f.name == 'configs.yaml':
                    _extract_tar_file(tar, f, run_path / 'configs.yaml')
                elif f.name.startswith('checkpoint/'):
                    p = f.name[len('checkpoint/'):]
                    p = checkpoint_path / p
                    if not p.parent.exists():
                        p.parent.mkdir(parents=True)
                    _extract_tar_file(tar, f, p)
                elif f.name.startswith('data/'):
                    p = f.name[len('data/'):]
                    p = data_path / p
                    if not p.parent.exists():
                        p.parent.mkdir(parents=True)
                    _extract_tar_file(tar, f, p)

            return run_uuid, checkpoint
示例#2
0
文件: runs.py 项目: skiedra/labml
def get_experiments() -> Generator[Path, None, None]:
    if not lab.get_experiments_path().exists():
        return

    for exp_path in lab.get_experiments_path().iterdir():
        if exp_path.name.startswith('_') or exp_path.name.startswith('.'):
            continue

        yield exp_path
示例#3
0
def save_bundle(path: Path, run_uuid: str, checkpoint: int = -1, *,
                data_files: List[str]):
    run_path = get_run_by_uuid(lab.get_experiments_path(), run_uuid)
    if run_path is None:
        raise RuntimeError(f"Couldn't find run {run_uuid}")

    checkpoint = _get_run_checkpoint(run_path, checkpoint)

    if checkpoint is None:
        raise RuntimeError(f"Couldn't find checkpoint {run_uuid}:{checkpoint}")

    info_path = path.parent / f'{path.stem}.info.json'
    info = {'uuid': run_uuid, 'checkpoint': checkpoint}
    with open(str(info_path), 'w') as f:
        f.write(json.dumps(info))

    checkpoint_path = run_path / "checkpoints" / str(checkpoint)

    with monit.section('Create bundle'):
        with tarfile.open(str(path), 'w:gz') as tar:
            tar.add(str(checkpoint_path), 'checkpoint')
            tar.add(str(run_path / 'run.yaml'), 'run.yaml')
            tar.add(str(run_path / 'configs.yaml'), 'configs.yaml')
            tar.add(str(info_path), 'info.json')
            for f in data_files:
                tar.add(str(lab.get_data_path() / f), f'data/{f}')

    info_path.unlink()
示例#4
0
def get_run_checkpoint(run_uuid: str, checkpoint: int = -1):
    run_path = get_run_by_uuid(lab.get_experiments_path(), run_uuid)
    if run_path is None:
        logger.log("Couldn't find a previous run")
        return None, None

    checkpoint = _get_run_checkpoint(run_path, checkpoint)

    if checkpoint is None:
        logger.log("Couldn't find checkpoints")
        return None, None

    logger.log(["Selected ",
                ("experiment", Text.key),
                " = ",
                (run_path.parent.name, Text.value),
                " ",
                ("run", Text.key),
                " = ",
                (run_uuid, Text.value),
                " ",
                ("checkpoint", Text.key),
                " = ",
                (str(checkpoint), Text.value)])

    checkpoint_path = run_path / "checkpoints"
    return checkpoint_path / str(checkpoint), checkpoint
示例#5
0
    def __init__(self):
        experiment_path = Path(lab.get_experiments_path())
        runs = {}
        for exp_path in experiment_path.iterdir():
            for run_path in exp_path.iterdir():
                runs[run_path.name] = (run_path, experiment_path.name)

        self._runs = runs
示例#6
0
def get_configs(run_uuid: str):
    run_path = get_run_by_uuid(lab.get_experiments_path(), run_uuid)
    if run_path is None:
        labml_notice(["Couldn't find a previous run to load configurations: ",
                      (run_uuid, Text.value)], is_danger=True)
        return None

    configs_path = run_path / "configs.yaml"
    configs = load_configs(configs_path)

    return configs
示例#7
0
 def save_dir(self) -> Optional[str]:
     return str(lab.get_experiments_path())
示例#8
0
def _test():
    from labml import lab
    inspect(list(get_runs(lab.get_experiments_path())))
示例#9
0
文件: runs.py 项目: skiedra/labml
def get_runs() -> Generator[Path, None, None]:
    for exp_path in lab.get_experiments_path().iterdir():
        if exp_path.name.startswith('_'):
            continue
        for run_path in exp_path.iterdir():
            yield run_path