예제 #1
0
def load_bundle(path: Path, url: Optional[str] = None) -> Tuple[str, int]:
    if url:
        download_file(url, path)

    if not path.exists():
        raise FileNotFoundError(f'Bundle archive missing: {path}')

    with monit.section('Extract bundle'):
        with tarfile.open(str(path), 'r:gz') as tar:
            files = tar.getmembers()
            info_member = None
            for f in files:
                if f.name == 'info.json':
                    info_member = f

            if not info_member:
                raise RuntimeError(f"Corrupted bundle. Missing info.json")

            with tar.extractfile(info_member) as ef:
                info = json.load(ef)

            run_uuid, checkpoint = info['uuid'], info['checkpoint']
            run_path = get_run_by_uuid(run_uuid)

            if run_path is not None:
                logger.log(f"Run {run_uuid} exists", Text.meta)
                current_checkpoint = _get_run_checkpoint(run_path, checkpoint)
                if checkpoint == current_checkpoint:
                    logger.log(f"Checkpoint {checkpoint} exists", Text.meta)
                    return run_uuid, checkpoint

            run_path = lab.get_experiments_path() / 'bundled' / run_uuid

            checkpoint_path = run_path / "checkpoints" / str(checkpoint)
            if not checkpoint_path.exists():
                checkpoint_path.mkdir(parents=True)

            data_path = lab.get_data_path()
            if not data_path.exists():
                data_path.mkdir(parents=True)

            for f in files:
                if f.name == 'run.yaml':
                    _extract_tar_file(tar, f, run_path / 'run.yaml')
                elif f.name == 'configs.yaml':
                    _extract_tar_file(tar, f, run_path / 'configs.yaml')
                elif f.name.startswith('checkpoint/'):
                    p = f.name[len('checkpoint/'):]
                    p = checkpoint_path / p
                    if not p.parent.exists():
                        p.parent.mkdir(parents=True)
                    _extract_tar_file(tar, f, p)
                elif f.name.startswith('data/'):
                    p = f.name[len('data/'):]
                    p = data_path / p
                    if not p.parent.exists():
                        p.parent.mkdir(parents=True)
                    _extract_tar_file(tar, f, p)

            return run_uuid, checkpoint
예제 #2
0
 def _download():
     """
     Download the dataset
     """
     if not (lab.get_data_path() / 'cora').exists():
         download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
                                lab.get_data_path() / 'cora.tgz')
         download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())
예제 #3
0
    def __init__(self, seq_len: int):
        path = lab.get_data_path() / 'tiny_shakespeare.txt'
        download_file(
            'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt',
            path)
        with open(str(path), 'r') as f:
            text = f.read()

        chars = list(set(text))
        self.stoi = {c: i for i, c in enumerate(chars)}
        self.itos = {i: c for i, c in enumerate(chars)}
        self.seq_len = seq_len
        self.data = self.text_to_i(text)
예제 #4
0
파일: cycle_gan.py 프로젝트: Sandy4321/nn-1
 def download(dataset_name: str):
     """
     #### Download dataset and extract data
     """
     # URL
     url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'
     # Download folder
     root = lab.get_data_path() / 'cycle_gan'
     if not root.exists():
         root.mkdir(parents=True)
     # Download destination
     archive = root / f'{dataset_name}.zip'
     # Download file (generally ~100MB)
     download_file(url, archive)
     # Extract the archive
     with zipfile.ZipFile(archive, 'r') as f:
         f.extractall(root)
예제 #5
0
    def __init__(self, seq_len: int):
        # Location of the text file
        path = lab.get_data_path() / 'tiny_shakespeare.txt'
        # Download the file
        download_file(
            'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt',
            path)
        # Read the downloaded file
        with open(str(path), 'r') as f:
            text = f.read()

        # Extract the characters
        chars = list(set(text))
        # Character to id (integer) map
        self.stoi = {c: i for i, c in enumerate(chars)}
        # Id to character map
        self.itos = {i: c for i, c in enumerate(chars)}
        # Length of a training sample
        self.seq_len = seq_len
        # Data in the form of a tensor of ids
        self.data = self.text_to_i(text)
예제 #6
0
    def __init__(self,
                 path: PurePath,
                 tokenizer: Callable,
                 *,
                 url: Optional[str] = None,
                 filter_subset: Optional[int] = None):
        path = Path(path)
        if not path.exists():
            if not url:
                raise FileNotFoundError(str(path))
            else:
                download_file(url, path)

        with monit.section("Load data"):
            text = self.load(path)
            if filter_subset:
                text = text[:filter_subset]
            split = int(len(text) * .9)
            train = text[:split]
            valid = text[split:]

        super().__init__(path, tokenizer, train, valid, '')