Ejemplo n.º 1
0
    def __init__(self, path, dataset, vocab, *args, **kwargs):
        self.dataset = dataset
        if vocab == 'word':
            self.vocab = Vocab(*args, **kwargs)
        elif vocab == 'bpe':
            self.vocab = OpenAIVocab()
        else:
            raise RuntimeError('Unsupported vocab')

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))
        elif self.dataset == 'wt103':
            self.vocab.count_file(os.path.join(path, 'train.txt'))
        elif self.dataset == 'lm1b':
            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
            # the vocab will load from file when build_vocab() is called

        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True)
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = train_paths
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=False,
                                               add_double_eos=True)
Ejemplo n.º 2
0
class Corpus(object):
    def __init__(self, path, dataset, vocab, *args, **kwargs):
        self.dataset = dataset
        if vocab == 'word':
            self.vocab = Vocab(*args, **kwargs)
        elif vocab == 'bpe':
            self.vocab = OpenAIVocab()
        else:
            raise RuntimeError('Unsupported vocab')

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))
        elif self.dataset == 'wt103':
            self.vocab.count_file(os.path.join(path, 'train.txt'))
        elif self.dataset == 'lm1b':
            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
            # the vocab will load from file when build_vocab() is called

        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True)
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = train_paths
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=False,
                                               add_double_eos=True)

    def get_iterator(self, split, *args, **kwargs):
        if split == 'train':
            if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
                data_iter = LMOrderedIterator(self.train, *args, **kwargs)
            elif self.dataset == 'lm1b':
                kwargs['shuffle'] = True
                data_iter = LMMultiFileIterator(self.train, self.vocab, *args,
                                                **kwargs)
        elif split in ['valid', 'test']:
            data = self.valid if split == 'valid' else self.test
            if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
                data_iter = LMOrderedIterator(data, *args, **kwargs)
            elif self.dataset == 'lm1b':
                data_iter = LMShuffledIterator(data, *args, **kwargs)

        return data_iter
Ejemplo n.º 3
0
    def __init__(self, path, dataset, use_bpe, *args, **kwargs):
        self.dataset = dataset
        if use_bpe:
            self.vocab = OpenAIVocab(kwargs['max_size'],
                                     kwargs.get('vocab_file'))
        else:
            self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))
        elif self.dataset == 'wt103' or self.dataset == 'wt2':
            self.vocab.count_file(os.path.join(path, 'train.txt'))
        elif self.dataset == 'wt103-normal':
            self.vocab.count_file(os.path.join(path, 'wiki.train.tokens'))
        elif self.dataset == 'lm1b':
            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
        elif self.dataset == 'wiki':
            file_path_pattern = os.path.join(path, '*/wiki_*.txt')
            file_paths = glob.glob(file_path_pattern)
            assert file_paths, f'Nothing found at {file_path_pattern}'

        # the vocab will load from file when build_vocab() is called
        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True)
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = train_paths
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=False,
                                               add_double_eos=True)
        elif self.dataset == 'wiki':
            # Take the first and second file of each alphabetical directory for train and test.
            self.valid = [x for x in file_paths if x.endswith('00.txt')]
            self.test = [x for x in file_paths if x.endswith('01.txt')]
            self.train = list(
                set(file_paths) - set(self.valid) - set(self.test))
        elif self.dataset in ['wt103-normal']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'wiki.train.tokens'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'wiki.valid.tokens'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(
                path, 'wiki.test.tokens'),
                                               ordered=True,
                                               add_eos=False)
Ejemplo n.º 4
0
    def __init__(self, path, dataset, use_bpe, valid_custom=None, *args, **kwargs):
        self.dataset = dataset
        train_paths = None
        file_paths = None
        self.valid_custom = None

        if use_bpe:
            self.vocab = OpenAIVocab(kwargs['max_size'], kwargs.get('vocab_file'))
        else:
            self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))
        elif self.dataset == 'wt103' or self.dataset == 'wt2':
            self.vocab.count_file(os.path.join(path, 'train.txt'))
        elif self.dataset == 'wt103-normal':
            self.vocab.count_file(os.path.join(path, 'wiki.train.tokens'))
        elif self.dataset == 'lm1b':
            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
        elif self.dataset == 'wiki':
            file_path_pattern = os.path.join(path, '*/wiki_*.txt')
            file_paths = glob.glob(file_path_pattern)
            assert file_paths, f'Nothing found at {file_path_pattern}'
        elif self.dataset == 'git':
            file_path_pattern = os.path.join(path, 'git_*.txt')
            file_paths = glob.glob(file_path_pattern)
            valid_path = os.path.join(path, 'valid.txt')
            test_path = os.path.join(path, 'test.txt')
            assert file_paths, f'Nothing found at {file_path_pattern}'

        file_paths = natsort.natsorted(file_paths)

        # the vocab will load from file when build_vocab() is called
        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            self.train = self.vocab.encode_file(
                os.path.join(path, 'train.txt'), ordered=True)
            self.valid = self.vocab.encode_file(
                os.path.join(path, 'valid.txt'), ordered=True)
            self.test = self.vocab.encode_file(
                os.path.join(path, 'test.txt'), ordered=True)
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(
                os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
            self.valid = self.vocab.encode_file(
                os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
            self.test = self.vocab.encode_file(
                os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = natsort.natsorted(train_paths)
            self.valid = self.vocab.encode_file(
                os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
            self.test = self.vocab.encode_file(
                os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
        elif self.dataset == 'wiki':
            if g.args.test:  # in testing mode we use smaller dataset
                valid_path = sorted(file_paths)[-1]
                test_path = sorted(file_paths)[-1]
            else:
                valid_path = sorted(file_paths)[42]
                test_path = sorted(file_paths)[1337]
            self.valid = self.vocab.encode_file(valid_path, ordered=True)
            self.test = self.vocab.encode_file(test_path, ordered=True)
            self.train = None
            self.train_files = list(set(file_paths) - {valid_path, test_path})
        elif self.dataset == 'git':
            if g.args.test:  # in testing mode we use smaller dataset
                valid_path = sorted(file_paths)[-1]
                test_path = sorted(file_paths)[-1]
            if valid_custom:
                g.logger.info(f"Using file {valid_custom} as additional validation file")
                self.valid_custom = self.vocab.encode_file(valid_custom, ordered=True)
            self.valid = self.vocab.encode_file(valid_path, ordered=True)
            self.test = self.vocab.encode_file(test_path, ordered=True)
            self.train = None
            self.train_files = file_paths
        elif self.dataset in ['wt103-normal']:
            self.train = self.vocab.encode_file(
                os.path.join(path, 'wiki.train.tokens'), ordered=True, add_eos=False)
            self.valid = self.vocab.encode_file(
                os.path.join(path, 'wiki.valid.tokens'), ordered=True, add_eos=False)
            self.test = self.vocab.encode_file(
                os.path.join(path, 'wiki.test.tokens'), ordered=True, add_eos=False)

        self.train_files = natsort.natsorted(self.train_files)
Ejemplo n.º 5
0
    def __init__(self, path, dataset, use_bpe, *args, **kwargs):
        self.dataset = dataset
        if use_bpe:
            self.vocab = OpenAIVocab(kwargs['max_size'],
                                     kwargs.get('vocab_file'))
        else:
            self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8', 'ger-wiki']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))
        elif self.dataset == 'wt103' or self.dataset == 'wt2':
            self.vocab.count_file(os.path.join(path, 'train.txt'))
        elif self.dataset == 'wt103-normal':
            self.vocab.count_file(os.path.join(path, 'wiki.train.tokens'))
        elif self.dataset == 'lm1b':
            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
        elif self.dataset == 'wiki':
            file_path_pattern = os.path.join(path, '*/wiki_*.txt')
            file_paths = glob.glob(file_path_pattern)
            assert file_paths, f'Nothing found at {file_path_pattern}'

        # the vocab will load from file when build_vocab() is called
        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103', 'ger-wiki']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True)
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = train_paths
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=False,
                                               add_double_eos=True)
        elif self.dataset == 'wiki':
            valid_path = sorted(file_paths)[42]
            test_path = sorted(file_paths)[1337]
            self.valid = self.vocab.encode_file(valid_path, ordered=True)
            self.test = self.vocab.encode_file(test_path, ordered=True)
            self.train = list(set(file_paths) - set((valid_path, test_path)))
        elif self.dataset in ['wt103-normal']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'wiki.train.tokens'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'wiki.valid.tokens'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(
                path, 'wiki.test.tokens'),
                                               ordered=True,
                                               add_eos=False)