コード例 #1
0
ファイル: data_utils.py プロジェクト: Vasyka/DeepGQuad
    def __init__(self,
                 data_path,
                 labels=2,
                 merge_size=1,
                 shift=20,
                 norm_dist=False,
                 fracs=[0.2, 0.1],
                 at_idx=None):
        self.data_path = data_path
        self.vocab = Vocab()
        self.labels = labels

        token_size = 4**merge_size
        self.Gen = Genome(data_path, shift=shift, merge_size=merge_size)
        if type(at_idx) is list and len(at_idx) > 1:
            parts = self.Gen.slice_genome(fractions=None, at_idx=at_idx)
        else:
            parts = self.Gen.slice_genome(fracs, at_idx)

        self.Gen_train = parts[0]
        print(f'Train set size: {len(self.Gen_train.labels)}')
        if len(parts) > 1:
            self.Gen_test = parts[1]
            print(f'Test set size: {len(self.Gen_test.labels)}')
            if len(parts) == 3:
                self.Gen_valid = parts[2]
            else:
                self.Gen_valid = copy(self.Gen_test)

        self.T_type = torch.LongTensor

        self.vocab.create_tokens(token_size, self.Gen_train.DNA.ravel())
        self.vocab.build_vocab()

        self.train_lab, self.valid_lab, self.test_lab = None, None, None
コード例 #2
0
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)
        # self.order = kwargs.get('order', True)

        # if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8', 'bilingual_ted']:
        #     self.vocab.count_file(os.path.join(path, 'train.txt'))
        #     self.vocab.count_file(os.path.join(path, 'valid.txt'))
        #     self.vocab.count_file(os.path.join(path, 'test.txt'))
        # elif self.dataset == 'wt103':
        #     self.vocab.count_file(os.path.join(path, 'train.txt'))
        # elif self.dataset == 'lm1b':
        #     train_path_pattern = os.path.join(
        #         path, '1-billion-word-language-modeling-benchmark-r13output',
        #         'training-monolingual.tokenized.shuffled', 'news.en-*')
        #     train_paths = glob.glob(train_path_pattern)
        #     # the vocab will load from file when build_vocab() is called

        self.vocab.count_file(os.path.join(path, 'train.txt'))
        self.vocab.build_vocab()

        self.train = self.vocab.encode_file(
            os.path.join(path, 'train.txt'))
        self.valid = self.vocab.encode_file(
            os.path.join(path, 'valid.txt'))
        self.test = self.vocab.encode_file(
            os.path.join(path, 'test.txt'))
コード例 #3
0
ファイル: data_utils.py プロジェクト: RutujaTaware1998/txl_tf
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8', 'smiles']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))

        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True)
        elif self.dataset in ['enwik8', 'text8', 'smiles']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True,
                                               add_eos=False)
コード例 #4
0
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8', 'wddev']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))
        elif self.dataset == 'wt103':
            self.vocab.count_file(os.path.join(path, 'train.txt'))
        elif self.dataset == 'wdtrain' or self.dataset == 'wdtrain-morph':
            train_path_pattern = os.path.join(path, 'train.txt')
            train_paths = glob.glob(train_path_pattern)
            # the vocab will load from file when build_vocab() is called

        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103', 'wddev']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True)
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = train_paths
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=False,
                                               add_double_eos=True)
        elif self.dataset == 'wdtrain' or self.dataset == 'wdtrain-morph':
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=False,
                                               add_double_eos=True)
コード例 #5
0
    def __init__(self, path, dataset, trainfname, validfname, testfname, *args,
                 **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
            self.vocab.count_file(os.path.join(path, trainfname))
            self.vocab.count_file(os.path.join(path, validfname))
            self.vocab.count_file(os.path.join(path, testfname))
        elif self.dataset == 'wt103':
            self.vocab.count_file(os.path.join(path, trainfname))
        elif self.dataset == 'lm1b':
            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
            # the vocab will load from file when build_vocab() is called
        else:
            self.vocab.count_file(os.path.join(path, trainfname))
            self.vocab.count_file(os.path.join(path, validfname))
            self.vocab.count_file(os.path.join(path, testfname))

        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            self.train = self.vocab.encode_file(os.path.join(path, trainfname),
                                                ordered=True,
                                                add_eos=True)
            self.valid = self.vocab.encode_file(os.path.join(path, validfname),
                                                ordered=True,
                                                add_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, testfname),
                                               ordered=False,
                                               add_eos=True)
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(os.path.join(path, trainfname),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(path, validfname),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, testfname),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = train_paths
            self.valid = self.vocab.encode_file(os.path.join(path, validfname),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, testfname),
                                               ordered=False,
                                               add_double_eos=True)
        else:
            self.train = self.vocab.encode_file(os.path.join(path, trainfname),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(path, validfname),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, testfname),
                                               ordered=True)
コード例 #6
0
ファイル: data_utils.py プロジェクト: quanpn90/NMTBMajor
    def __init__(self, vocab=None, *args, **kwargs):

        if vocab is None:
            self.vocab = Vocab(*args, **kwargs)
        else:
            self.vocab = vocab

        self.train, self.valid = [], []
コード例 #7
0
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset  #the string storing name of dataset
        self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
            # Add words to the counter object of vocab
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))

        elif self.dataset == 'wt103':
            # Add words to the counter object of vocab
            self.vocab.count_file(os.path.join(path, 'train.txt'))

        elif self.dataset == 'lm1b':
            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
            # the vocab will load from file when build_vocab() is called

        # Add words to idx2sym and sym2idx of vocab
        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            # Add LongTensors of the full corpus consisting only of words.
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True)

        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True,
                                               add_eos=False)

        elif self.dataset == 'lm1b':
            self.train = train_paths
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=False,
                                               add_double_eos=True)
コード例 #8
0
class TimeSeries():
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        self.vocab.count_file(os.path.join(path, "train.txt"))
        self.vocab.count_file(os.path.join(path, "valid.txt"))
        self.vocab.count_file(os.path.join(path, "test.txt"))

        self.vocab.build_vocab()

        self.train = self.vocab.encode_file(os.path.join(path, "train.txt"),
                                            ordered=True,
                                            add_eos=False)
        self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"),
                                            ordered=True,
                                            add_eos=False)
        self.test = self.vocab.encode_file(os.path.join(path, "test.txt"),
                                           ordered=True,
                                           add_eos=False)

    def get_iterator(self, split, *args, **kwargs):
        if split == "train":
            data_iter = OrderedIterator(self.train, *args, **kwargs)

        elif split in ["valid", "test"]:
            data = self.valid if split == "valid" else self.test
            data_iter = OrderedIterator(data, *args, **kwargs)

        return data_iter
コード例 #9
0
class Corpus(object):
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        self.vocab.count_file(os.path.join(path, 'train.txt'))
        self.vocab.count_file(os.path.join(path, 'valid.txt'))
        self.vocab.count_file(os.path.join(path, 'test.txt'))

        self.vocab.build_vocab()
        self.train = self.vocab.encode_file(os.path.join(path, 'train.txt'),
                                            ordered=True,
                                            add_eos=False)
        self.valid = self.vocab.encode_file(os.path.join(path, 'valid.txt'),
                                            ordered=True,
                                            add_eos=False)
        self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                           ordered=True,
                                           add_eos=False)

    def get_iterator(self, split, *args, **kwargs):
        if split == 'train':
            data_iter = LMOrderedIterator(self.train, *args, **kwargs)
        elif split in ['valid', 'test']:
            data = self.valid if split == 'valid' else self.test
            data_iter = LMOrderedIterator(data, *args, **kwargs)
        return data_iter
コード例 #10
0
    def __init__(self, path, dataset, *args, **kw):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kw)

        if self.dataset in ["ptb", "wt2", "enwik8", "text8"]:
            self.vocab.count_file(os.path.join(path, "train.txt"))
            self.vocab.count_file(os.path.join(path, "valid.txt"))
            self.vocab.count_file(os.path.join(path, "test.txt"))
        elif self.dataset == "wt103":
            self.vocab.count_file(os.path.join(path, "train.txt"))
        elif self.dataset == "lm1b":
            train_path_pattern = os.path.join(
                path,
                "1-billion-word-language-modeling-benchmark-r13output",
                "training-monolingual.tokenized.shuffled",
                "news.en-*",
            )
            train_paths = glob.glob(train_path_pattern)
            # the vocab will load from file when build_vocab() is called

        self.vocab.build_vocab()

        if self.dataset in ["ptb", "wt2", "wt103"]:
            self.train = self.vocab.encode_file(os.path.join(
                path, "train.txt"),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, "valid.txt"),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, "test.txt"),
                                               ordered=True)
        elif self.dataset in ["enwik8", "text8"]:
            self.train = self.vocab.encode_file(os.path.join(
                path, "train.txt"),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, "valid.txt"),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, "test.txt"),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == "lm1b":
            self.train = train_paths
            self.valid = self.vocab.encode_file(os.path.join(
                path, "valid.txt"),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, "test.txt"),
                                               ordered=False,
                                               add_double_eos=True)
コード例 #11
0
    def __init__(self, path, *args, **kwargs):
        self.vocab = Vocab(*args, **kwargs)

        self.vocab.count_file(os.path.join(path, 'train.txt'),verbose=True)
        self.vocab.count_file(os.path.join(path, 'valid.txt'),verbose=True)

        self.vocab.build_vocab()


        self.train = self.vocab.encode_file(
            os.path.join(path, 'train.txt'), ordered=True, verbose=True)
        self.valid = self.vocab.encode_file(
            os.path.join(path, 'valid.txt'), ordered=True, verbose=True)
コード例 #12
0
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))
        elif self.dataset == 'wt103':
            self.vocab.count_file(os.path.join(path, 'train.txt'))
        elif self.dataset == 'lm1b':

            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
            # the vocab will load from file when build_vocab() is called

        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True)
            # print(self.train.size())
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True)
            # print(self.test.size())
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = train_paths
            vaild_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'heldout-monolingual.tokenized.shuffled', 'news.en.heldout*')
            self.valid = glob.glob(vaild_path_pattern)
            # print(vaild_path_pattern)
            self.test = self.valid
コード例 #13
0
  def __init__(self, model, path_2_vocab, score_fn=score_fun_linear):
    self._model = model
    self._model.eval()
    self._model.crit.keep_order=True
    self._vocab = Vocab(vocab_file=path_2_vocab)
    self._vocab.build_vocab()
    self._score_fn = score_fn

    print('---->>> Testing Model.')
    self.test_model(candidates=['they had one night in which to prepare for deach',
                                'they had one night in which to prepare for death',
                                'i hate school', 'i love school',
                                'the fox jumps on a grass',
                                'the crox jump a la glass'])
    print('---->>> Done testing model')
コード例 #14
0
    def __init__(self,
                 train_src,
                 train_tgt,
                 valid_src,
                 valid_tgt,
                 order=True,
                 *args,
                 **kwargs):
        self.dataset = dataset

        # if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8', 'bilingual_ted']:
        #     self.vocab.count_file(os.path.join(path, 'train.txt'))
        #     self.vocab.count_file(os.path.join(path, 'valid.txt'))
        #     self.vocab.count_file(os.path.join(path, 'test.txt'))
        # elif self.dataset == 'wt103':
        #     self.vocab.count_file(os.path.join(path, 'train.txt'))
        # elif self.dataset == 'lm1b':
        #     train_path_pattern = os.path.join(
        #         path, '1-billion-word-language-modeling-benchmark-r13output',
        #         'training-monolingual.tokenized.shuffled', 'news.en-*')
        #     train_paths = glob.glob(train_path_pattern)
        #     # the vocab will load from file when build_vocab() is called

        if kwargs.get('share_vocab'):
            self.src_vocab = Vocab(*args, **kwargs)
            self.src_vocab.count_file(train_src)
            self.src_vocab.count_file(train_tgt)
            self.src_vocab.build_vocab()
            self.tgt_vocab = self.src_vocab
        else:
            print("Two vocabularies are not supported at the moment")
            raise NotImplementedError

        self.train = dict()

        self.train['src'] = self.src_vocab.encode_file(train_src)

        self.train['tgt'] = self.tgt_vocab.encode_file(train_tgt,
                                                       bos=True,
                                                       eos=True)

        self.valid['src'] = self.src_vocab.encode_file(valid_src)

        self.valid['tgt'] = self.tgt_vocab.encode_file(valid_tgt,
                                                       bos=True,
                                                       eos=True)
コード例 #15
0
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        self.vocab.count_file(os.path.join(path, 'train.txt'))
        self.vocab.count_file(os.path.join(path, 'valid.txt'))
        self.vocab.count_file(os.path.join(path, 'test.txt'))

        self.vocab.build_vocab()
        self.train = self.vocab.encode_file(os.path.join(path, 'train.txt'),
                                            ordered=True,
                                            add_eos=False)
        self.valid = self.vocab.encode_file(os.path.join(path, 'valid.txt'),
                                            ordered=True,
                                            add_eos=False)
        self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                           ordered=True,
                                           add_eos=False)
コード例 #16
0
    def __init__(self, path, *args, **kwargs):
        self.vocab = Vocab(*args, **kwargs)

        # 从单词表里面加载单词
        self.vocab.build_vocab()

        # 训练集
        self.train = self.vocab.encode_file(os.path.join(path, 'train.txt'),
                                            verbose=True)
        self.train_label = self.vocab.encode_file_only_for_lables(os.path.join(
            path, 'train.label'),
                                                                  verbose=True)

        # 验证集
        self.valid = self.vocab.encode_file(os.path.join(path, 'valid.txt'),
                                            verbose=True)
        self.valid_label = self.vocab.encode_file_only_for_lables(os.path.join(
            path, 'valid.label'),
                                                                  verbose=True)
コード例 #17
0
    def __init__(self, path, dataset, vocab, *args, **kwargs):
        self.dataset = dataset
        if vocab == 'word':
            self.vocab = Vocab(*args, **kwargs)
        elif vocab == 'bpe':
            self.vocab = OpenAIVocab()
        else:
            raise RuntimeError('Unsupported vocab')

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))
        elif self.dataset == 'wt103':
            self.vocab.count_file(os.path.join(path, 'train.txt'))
        elif self.dataset == 'lm1b':
            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
            # the vocab will load from file when build_vocab() is called

        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True)
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = train_paths
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=False,
                                               add_double_eos=True)
コード例 #18
0
class Corpus(object):
    def __init__(self, path, *args, **kwargs):
        self.vocab = Vocab(*args, **kwargs)

        self.vocab.count_file(os.path.join(path, 'train.txt'),verbose=True)
        self.vocab.count_file(os.path.join(path, 'valid.txt'),verbose=True)

        self.vocab.build_vocab()


        self.train = self.vocab.encode_file(
            os.path.join(path, 'train.txt'), ordered=True, verbose=True)
        self.valid = self.vocab.encode_file(
            os.path.join(path, 'valid.txt'), ordered=True, verbose=True)
        # self.test = self.vocab.encode_file(
        #     os.path.join(path, 'test.txt'), ordered=True)
    # 许海明
    def get_iterator(self, split, *args, **kwargs):
        '''

        :param split:
        :param args:
        :param kwargs:
        :return:
        '''
        if split == 'train':
            data_iter = LMOrderedIterator(self.train, *args, **kwargs)

        elif split in ['valid', 'test']:
            data = self.valid if split == 'valid' else self.test
            data_iter = LMOrderedIterator(data, *args, **kwargs)

        return data_iter
コード例 #19
0
 def load(cls,
          model_path: Path,
          spm_path: Path,
          device: str = None) -> 'ModelWrapper':
     if device is None:
         device = 'cuda' if torch.cuda.is_available() else 'cpu'
     with open(model_path, 'rb') as f:
         state = torch.load(f, map_location='cpu')
     model = MemTransformerLM(**state['model_params'])
     model.load_state_dict(state['state_dict'])
     vocab_params = state['vocab_params']
     vocab = Vocab.from_symbols(state['vocab'], )
     sp_processor = spm.SentencePieceProcessor()
     sp_processor.Load(str(spm_path))
     return cls(model, vocab, sp_processor, device)
コード例 #20
0
class Corpus(object):
    def __init__(self, path, *args, **kwargs):
        self.vocab = Vocab(*args, **kwargs)

        # 从单词表里面加载单词
        self.vocab.build_vocab()

        # 训练集
        self.train = self.vocab.encode_file(os.path.join(path, 'train.txt'),
                                            verbose=True)
        self.train_label = self.vocab.encode_file_only_for_lables(os.path.join(
            path, 'train.label'),
                                                                  verbose=True)

        # 验证集
        self.valid = self.vocab.encode_file(os.path.join(path, 'valid.txt'),
                                            verbose=True)
        self.valid_label = self.vocab.encode_file_only_for_lables(os.path.join(
            path, 'valid.label'),
                                                                  verbose=True)

        # self.test = self.vocab.encode_file(
        #     os.path.join(path, 'test.txt'), ordered=True)

    # 许海明
    def get_batch_iterator(self, split, *args, **kwargs):
        '''

        :param split:
        :param args:
        :param kwargs:
        :return:
        '''
        if split == 'train':
            # data_iter = LMOrderedIterator(self.train, *args, **kwargs)
            batch_iter = BatchIteratorHelper(self.train, self.train_label,
                                             *args, **kwargs)

        elif split == 'valid':
            batch_iter = BatchIteratorHelper(self.valid, self.valid_label,
                                             *args, **kwargs)

        return batch_iter
コード例 #21
0
class Corpus:
    def __init__(self, path, dataset, use_bpe, *args, **kwargs):
        self.dataset = dataset
        if use_bpe:
            self.vocab = OpenAIVocab(kwargs['max_size'],
                                     kwargs.get('vocab_file'))
        else:
            self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))
        elif self.dataset == 'wt103' or self.dataset == 'wt2':
            self.vocab.count_file(os.path.join(path, 'train.txt'))
        elif self.dataset == 'wt103-normal':
            self.vocab.count_file(os.path.join(path, 'wiki.train.tokens'))
        elif self.dataset == 'lm1b':
            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
        elif self.dataset == 'wiki':
            file_path_pattern = os.path.join(path, '*/wiki_*.txt')
            file_paths = glob.glob(file_path_pattern)
            assert file_paths, f'Nothing found at {file_path_pattern}'

        # the vocab will load from file when build_vocab() is called
        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True)
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = train_paths
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=False,
                                               add_double_eos=True)
        elif self.dataset == 'wiki':
            # Take the first and second file of each alphabetical directory for train and test.
            self.valid = [x for x in file_paths if x.endswith('00.txt')]
            self.test = [x for x in file_paths if x.endswith('01.txt')]
            self.train = list(
                set(file_paths) - set(self.valid) - set(self.test))
        elif self.dataset in ['wt103-normal']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'wiki.train.tokens'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'wiki.valid.tokens'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(
                path, 'wiki.test.tokens'),
                                               ordered=True,
                                               add_eos=False)

    def get_dist_iterator(self, split, rank, max_rank, *args, **kwargs):
        """Get an iterator that only operates on rank//max_rank independent subset of the data."""
        data = self.__getattribute__(split)
        subset = list(chunk(data, max_rank))[rank]
        if self.dataset in ['lm1b', 'wiki']:
            return LMMultiFileIterator(subset, self.vocab, *args, **kwargs)

        return LMOrderedIterator(subset, *args, **kwargs)

    def get_iterator(self, split, *args, **kwargs):
        """Get an iterator over the corpus.

        Each next() returns (data, target, seq_length).
        data and target have shape (bptt, bsz) and seq_length is a scalar.
        """
        data = self.__getattribute__(split)
        if self.dataset in [
                'ptb', 'wt2', 'wt103', 'enwik8', 'text8', 'wt103-normal'
        ]:
            return LMOrderedIterator(data, *args, **kwargs)
        elif self.dataset == 'lm1b':
            if split in ['valid', 'test']:
                return LMShuffledIterator(data, *args, **kwargs)
            else:
                kwargs['shuffle'] = True
                return LMMultiFileIterator(data, self.vocab, *args, **kwargs)
        elif self.dataset == 'wiki':
            return LMMultiFileIterator(data, self.vocab, *args, **kwargs)
コード例 #22
0
class Corpus(object):
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))
        elif self.dataset == 'wt103':
            self.vocab.count_file(os.path.join(path, 'train.txt'),
                                  verbose=True)
        elif self.dataset == 'lm1b':
            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
            # the vocab will load from file when build_vocab() is called

        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True)
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = train_paths
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=False,
                                               add_double_eos=True)

    def get_iterator(self, split, *args, **kwargs):
        if split == 'train':
            if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
                data_iter = LMOrderedIterator(self.train, *args, **kwargs)
            elif self.dataset == 'lm1b':
                kwargs['shuffle'] = True
                data_iter = LMMultiFileIterator(self.train, self.vocab, *args,
                                                **kwargs)
        elif split in ['valid', 'test']:
            data = self.valid if split == 'valid' else self.test
            if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
                data_iter = LMOrderedIterator(data, *args, **kwargs)
            elif self.dataset == 'lm1b':
                data_iter = LMShuffledIterator(data, *args, **kwargs)

        return data_iter
コード例 #23
0
    def __init__(self, path, dataset, use_bpe, valid_custom=None, *args, **kwargs):
        self.dataset = dataset
        train_paths = None
        file_paths = None
        self.valid_custom = None

        if use_bpe:
            self.vocab = OpenAIVocab(kwargs['max_size'], kwargs.get('vocab_file'))
        else:
            self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))
        elif self.dataset == 'wt103' or self.dataset == 'wt2':
            self.vocab.count_file(os.path.join(path, 'train.txt'))
        elif self.dataset == 'wt103-normal':
            self.vocab.count_file(os.path.join(path, 'wiki.train.tokens'))
        elif self.dataset == 'lm1b':
            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
        elif self.dataset == 'wiki':
            file_path_pattern = os.path.join(path, '*/wiki_*.txt')
            file_paths = glob.glob(file_path_pattern)
            assert file_paths, f'Nothing found at {file_path_pattern}'
        elif self.dataset == 'git':
            file_path_pattern = os.path.join(path, 'git_*.txt')
            file_paths = glob.glob(file_path_pattern)
            valid_path = os.path.join(path, 'valid.txt')
            test_path = os.path.join(path, 'test.txt')
            assert file_paths, f'Nothing found at {file_path_pattern}'

        file_paths = natsort.natsorted(file_paths)

        # the vocab will load from file when build_vocab() is called
        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            self.train = self.vocab.encode_file(
                os.path.join(path, 'train.txt'), ordered=True)
            self.valid = self.vocab.encode_file(
                os.path.join(path, 'valid.txt'), ordered=True)
            self.test = self.vocab.encode_file(
                os.path.join(path, 'test.txt'), ordered=True)
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(
                os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
            self.valid = self.vocab.encode_file(
                os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
            self.test = self.vocab.encode_file(
                os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = natsort.natsorted(train_paths)
            self.valid = self.vocab.encode_file(
                os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
            self.test = self.vocab.encode_file(
                os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
        elif self.dataset == 'wiki':
            if g.args.test:  # in testing mode we use smaller dataset
                valid_path = sorted(file_paths)[-1]
                test_path = sorted(file_paths)[-1]
            else:
                valid_path = sorted(file_paths)[42]
                test_path = sorted(file_paths)[1337]
            self.valid = self.vocab.encode_file(valid_path, ordered=True)
            self.test = self.vocab.encode_file(test_path, ordered=True)
            self.train = None
            self.train_files = list(set(file_paths) - {valid_path, test_path})
        elif self.dataset == 'git':
            if g.args.test:  # in testing mode we use smaller dataset
                valid_path = sorted(file_paths)[-1]
                test_path = sorted(file_paths)[-1]
            if valid_custom:
                g.logger.info(f"Using file {valid_custom} as additional validation file")
                self.valid_custom = self.vocab.encode_file(valid_custom, ordered=True)
            self.valid = self.vocab.encode_file(valid_path, ordered=True)
            self.test = self.vocab.encode_file(test_path, ordered=True)
            self.train = None
            self.train_files = file_paths
        elif self.dataset in ['wt103-normal']:
            self.train = self.vocab.encode_file(
                os.path.join(path, 'wiki.train.tokens'), ordered=True, add_eos=False)
            self.valid = self.vocab.encode_file(
                os.path.join(path, 'wiki.valid.tokens'), ordered=True, add_eos=False)
            self.test = self.vocab.encode_file(
                os.path.join(path, 'wiki.test.tokens'), ordered=True, add_eos=False)

        self.train_files = natsort.natsorted(self.train_files)
コード例 #24
0
class Corpus:
    train_files: List[str]
    vocab: Vocab
    train: Optional[torch.LongTensor]
    valid: Optional[torch.LongTensor]
    test: Optional[torch.LongTensor]

    def __init__(self, path, dataset, use_bpe, valid_custom=None, *args, **kwargs):
        self.dataset = dataset
        train_paths = None
        file_paths = None
        self.valid_custom = None

        if use_bpe:
            self.vocab = OpenAIVocab(kwargs['max_size'], kwargs.get('vocab_file'))
        else:
            self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))
        elif self.dataset == 'wt103' or self.dataset == 'wt2':
            self.vocab.count_file(os.path.join(path, 'train.txt'))
        elif self.dataset == 'wt103-normal':
            self.vocab.count_file(os.path.join(path, 'wiki.train.tokens'))
        elif self.dataset == 'lm1b':
            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
        elif self.dataset == 'wiki':
            file_path_pattern = os.path.join(path, '*/wiki_*.txt')
            file_paths = glob.glob(file_path_pattern)
            assert file_paths, f'Nothing found at {file_path_pattern}'
        elif self.dataset == 'git':
            file_path_pattern = os.path.join(path, 'git_*.txt')
            file_paths = glob.glob(file_path_pattern)
            valid_path = os.path.join(path, 'valid.txt')
            test_path = os.path.join(path, 'test.txt')
            assert file_paths, f'Nothing found at {file_path_pattern}'

        file_paths = natsort.natsorted(file_paths)

        # the vocab will load from file when build_vocab() is called
        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            self.train = self.vocab.encode_file(
                os.path.join(path, 'train.txt'), ordered=True)
            self.valid = self.vocab.encode_file(
                os.path.join(path, 'valid.txt'), ordered=True)
            self.test = self.vocab.encode_file(
                os.path.join(path, 'test.txt'), ordered=True)
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(
                os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
            self.valid = self.vocab.encode_file(
                os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
            self.test = self.vocab.encode_file(
                os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = natsort.natsorted(train_paths)
            self.valid = self.vocab.encode_file(
                os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
            self.test = self.vocab.encode_file(
                os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
        elif self.dataset == 'wiki':
            if g.args.test:  # in testing mode we use smaller dataset
                valid_path = sorted(file_paths)[-1]
                test_path = sorted(file_paths)[-1]
            else:
                valid_path = sorted(file_paths)[42]
                test_path = sorted(file_paths)[1337]
            self.valid = self.vocab.encode_file(valid_path, ordered=True)
            self.test = self.vocab.encode_file(test_path, ordered=True)
            self.train = None
            self.train_files = list(set(file_paths) - {valid_path, test_path})
        elif self.dataset == 'git':
            if g.args.test:  # in testing mode we use smaller dataset
                valid_path = sorted(file_paths)[-1]
                test_path = sorted(file_paths)[-1]
            if valid_custom:
                g.logger.info(f"Using file {valid_custom} as additional validation file")
                self.valid_custom = self.vocab.encode_file(valid_custom, ordered=True)
            self.valid = self.vocab.encode_file(valid_path, ordered=True)
            self.test = self.vocab.encode_file(test_path, ordered=True)
            self.train = None
            self.train_files = file_paths
        elif self.dataset in ['wt103-normal']:
            self.train = self.vocab.encode_file(
                os.path.join(path, 'wiki.train.tokens'), ordered=True, add_eos=False)
            self.valid = self.vocab.encode_file(
                os.path.join(path, 'wiki.valid.tokens'), ordered=True, add_eos=False)
            self.test = self.vocab.encode_file(
                os.path.join(path, 'wiki.test.tokens'), ordered=True, add_eos=False)

        self.train_files = natsort.natsorted(self.train_files)

    def get_dist_iterator(self, split: str, *args, rank: int = 0, max_rank: int = 1, skip_files: float = .0, **kwargs):
        """Get an iterator that only operates on rank'th independent subset of the data."""
        if split == 'train':
            data = self.train
        elif split == 'valid':
            data = self.valid
        elif split == 'valid_custom':
            assert self.valid_custom is not None, "Custom validation file was not specified while the Corpus initialization"
            data = self.valid_custom
        else:
            assert split == 'test'
            data = self.test

        # special handling for large datasets, don't load training set in memory
        if self.dataset in ['lm1b', 'wiki', 'git'] and split == 'train':
            file_subset = list(chunk(self.train_files, max_rank))[rank]
            return LMMultiFileIterator(file_subset, self.vocab, skip_files=skip_files, *args, **kwargs)

        # noinspection PyTypeChecker
        assert len(data), f"data attribute '{split}' empty for iterator.dataset={self.dataset}"
        # noinspection PyTypeChecker
        subset = list(chunk(data, max_rank))[rank]
        return LMOrderedIterator(subset, *args, **kwargs)

    def get_iterator(self, split: str, *args, **kwargs):
        """Get an iterator over the corpus.

        Each next() returns (data, target, seq_length).
        data and target have shape (bptt, bsz) and seq_length is a scalar.
        """
        data = self.__getattribute__(split)
        if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8', 'wt103-normal']:
            return LMOrderedIterator(data, *args, **kwargs)
        if self.dataset == 'lm1b':
            if split in ['valid', 'test']:
                return LMShuffledIterator(data, *args, **kwargs)

            kwargs['shuffle'] = True
            return LMMultiFileIterator(data, self.vocab, *args, **kwargs)
        if self.dataset in ['wiki', 'git']:
            if split == 'train':
                return LMMultiFileIterator(data, self.vocab, *args, **kwargs)
            return LMOrderedIterator(data, *args, **kwargs)
コード例 #25
0
class Corpus(object):
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)
        # self.order = kwargs.get('order', True)

        # if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8', 'bilingual_ted']:
        #     self.vocab.count_file(os.path.join(path, 'train.txt'))
        #     self.vocab.count_file(os.path.join(path, 'valid.txt'))
        #     self.vocab.count_file(os.path.join(path, 'test.txt'))
        # elif self.dataset == 'wt103':
        #     self.vocab.count_file(os.path.join(path, 'train.txt'))
        # elif self.dataset == 'lm1b':
        #     train_path_pattern = os.path.join(
        #         path, '1-billion-word-language-modeling-benchmark-r13output',
        #         'training-monolingual.tokenized.shuffled', 'news.en-*')
        #     train_paths = glob.glob(train_path_pattern)
        #     # the vocab will load from file when build_vocab() is called

        self.vocab.count_file(os.path.join(path, 'train.txt'))
        self.vocab.build_vocab()

        self.train = self.vocab.encode_file(
            os.path.join(path, 'train.txt'))
        self.valid = self.vocab.encode_file(
            os.path.join(path, 'valid.txt'))
        self.test = self.vocab.encode_file(
            os.path.join(path, 'test.txt'))

        # if self.dataset in ['ptb', 'wt2', 'wt103']:
        #
        # elif self.dataset in ['enwik8', 'text8', 'bilingual_ted']:
        #     print("Creating %s dataset" % self.dataset)
        #     self.train = self.vocab.encode_file(
        #         os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
        #     self.valid = self.vocab.encode_file(
        #         os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
        #     self.test  = self.vocab.encode_file(
        #         os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
        # elif self.dataset == 'lm1b':
        #     self.train = train_paths
        #     self.valid = self.vocab.encode_file(
        #         os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
        #     self.test  = self.vocab.encode_file(
        #         os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)

    def get_iterator(self, split, *args, **kwargs):
        # if split == 'train':
        #     if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8', 'bilingual_ted']:
        #         data_iter = LMOrderedIterator(self.train, *args, **kwargs)
        #     elif self.dataset == 'lm1b':
        #         kwargs['shuffle'] = True
        #         data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
        # elif split in ['valid', 'test']:
        #     data = self.valid if split == 'valid' else self.test
        #     if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8', 'bilingual_ted']:
        #         data_iter = LMOrderedIterator(data, *args, **kwargs)
        #     elif self.dataset == 'lm1b':
        #         data_iter = LMShuffledIterator(data, *args, **kwargs)

        # if not hasattr(self, 'order'):
        #     self.order = True
        order = kwargs.get('order', True)

        if order:
            if split == 'train':
                data_iter = LMOrderedIterator(self.train, *args, **kwargs)
            elif split in ['valid', 'test']:
                data_iter = LMOrderedIterator(self.valid, *args, **kwargs)
        else:
            if split == 'train':
                data_iter = LMShuffledIterator(self.train, *args, **kwargs)
            elif split in ['valid', 'test']:
                data_iter = LMShuffledIterator(self.valid, *args, **kwargs)

        return data_iter
コード例 #26
0
ファイル: data_utils.py プロジェクト: quanpn90/NMTBMajor
class Corpus(object):
    def __init__(self, vocab=None, *args, **kwargs):

        if vocab is None:
            self.vocab = Vocab(*args, **kwargs)
        else:
            self.vocab = vocab

        self.train, self.valid = [], []

    def generate_data(self, path, update_vocab=True):
        # self.order = kwargs.get('order', True)

        # if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8', 'bilingual_ted']:
        #     self.vocab.count_file(os.path.join(path, 'train.txt'))
        #     self.vocab.count_file(os.path.join(path, 'valid.txt'))
        #     self.vocab.count_file(os.path.join(path, 'test.txt'))
        # elif self.dataset == 'wt103':
        #     self.vocab.count_file(os.path.join(path, 'train.txt'))
        # elif self.dataset == 'lm1b':
        #     train_path_pattern = os.path.join(
        #         path, '1-billion-word-language-modeling-benchmark-r13output',
        #         'training-monolingual.tokenized.shuffled', 'news.en-*')
        #     train_paths = glob.glob(train_path_pattern)
        #     # the vocab will load from file when build_vocab() is called

        if update_vocab:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.build_vocab()

        self.train = self.vocab.encode_file(os.path.join(path, 'train.txt'))
        self.valid = self.vocab.encode_file(os.path.join(path, 'valid.txt'))
        # self.test = self.vocab.encode_file(
        #     os.path.join(path, 'test.txt'))

    def save(self, datadir):

        data = dict()

        data['train'] = self.train
        data['valid'] = self.valid
        data['vocab'] = self.vocab

        fn = os.path.join(datadir, 'cache.pt')
        torch.save(data, fn)

        vn = os.path.join(datadir, 'vocab.txt')
        self.vocab.write_to_file(vn)

    def load(self, datadir):

        fn = os.path.join(datadir, 'cache.pt')
        cache = torch.load(fn)

        self.train = cache['train']
        self.valid = cache['valid']
        self.vocab = cache['vocab']

    def get_iterator(self, split, *args, **kwargs):

        order = kwargs.get('order', True)

        if order:
            if split == 'train':
                data_iter = LMOrderedIterator(self.vocab, self.train, *args,
                                              **kwargs)
            elif split in ['valid', 'test']:
                data_iter = LMOrderedIterator(self.vocab, self.valid, *args,
                                              **kwargs)
        else:
            if split == 'train':
                data_iter = LMShuffledIterator(self.vocab, self.train, *args,
                                               **kwargs)
            elif split in ['valid', 'test']:
                data_iter = LMShuffledIterator(self.vocab, self.valid, *args,
                                               **kwargs)

        return data_iter
コード例 #27
0
class Corpus(object):
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8', 'wddev']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))
        elif self.dataset == 'wt103':
            self.vocab.count_file(os.path.join(path, 'train.txt'))
        elif self.dataset == 'wdtrain' or self.dataset == 'wdtrain-morph':
            train_path_pattern = os.path.join(path, 'train.txt')
            train_paths = glob.glob(train_path_pattern)
            # the vocab will load from file when build_vocab() is called

        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103', 'wddev']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True)
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = train_paths
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=False,
                                               add_double_eos=True)
        elif self.dataset == 'wdtrain' or self.dataset == 'wdtrain-morph':
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=False,
                                               add_double_eos=True)

    def get_iterator(self, split, *args, **kwargs):
        if split == 'train':
            if self.dataset in [
                    'ptb', 'wt2', 'wt103', 'enwik8', 'text8', 'wddev'
            ]:
                data_iter = LMOrderedIterator(self.train, *args, **kwargs)
            elif self.dataset == 'wdtrain' or self.dataset == 'wdtrain-morph':
                kwargs['shuffle'] = True
                #data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
                data_iter = LMShuffledIterator(self.train, *args, **kwargs)
                #data_iter = LMOrderedIterator(self.train, *args, **kwargs)
        elif split in ['valid', 'test']:
            data = self.valid if split == 'valid' else self.test
            if self.dataset in [
                    'ptb', 'wt2', 'wt103', 'enwik8', 'text8', 'wddev'
            ]:
                data_iter = LMOrderedIterator(data, *args, **kwargs)
            elif self.dataset == 'wdtrain' or self.dataset == 'wdtrain-morph':
                data_iter = LMShuffledIterator(data, *args, **kwargs)
                #data_iter = RescoreIter(data, *args, **kwargs)
                #data_iter = LMOrderedIterator(data, *args, **kwargs)
        return data_iter
コード例 #28
0
    def __init__(self, path, dataset, use_bpe, *args, **kwargs):
        self.dataset = dataset
        if use_bpe:
            self.vocab = OpenAIVocab(kwargs['max_size'],
                                     kwargs.get('vocab_file'))
        else:
            self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))
        elif self.dataset == 'wt103' or self.dataset == 'wt2':
            self.vocab.count_file(os.path.join(path, 'train.txt'))
        elif self.dataset == 'wt103-normal':
            self.vocab.count_file(os.path.join(path, 'wiki.train.tokens'))
        elif self.dataset == 'lm1b':
            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
        elif self.dataset == 'wiki':
            file_path_pattern = os.path.join(path, '*/wiki_*.txt')
            file_paths = glob.glob(file_path_pattern)
            assert file_paths, f'Nothing found at {file_path_pattern}'

        # the vocab will load from file when build_vocab() is called
        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True)
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = train_paths
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=False,
                                               add_double_eos=True)
        elif self.dataset == 'wiki':
            # Take the first and second file of each alphabetical directory for train and test.
            self.valid = [x for x in file_paths if x.endswith('00.txt')]
            self.test = [x for x in file_paths if x.endswith('01.txt')]
            self.train = list(
                set(file_paths) - set(self.valid) - set(self.test))
        elif self.dataset in ['wt103-normal']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'wiki.train.tokens'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'wiki.valid.tokens'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(
                path, 'wiki.test.tokens'),
                                               ordered=True,
                                               add_eos=False)
コード例 #29
0
class Corpus(object):
    #not called with args and kwards ever
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset  #the string storing name of dataset
        self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
            # Add words to the counter object of vocab
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))

        elif self.dataset == 'wt103':
            # Add words to the counter object of vocab
            self.vocab.count_file(os.path.join(path, 'train.txt'))

        elif self.dataset == 'lm1b':
            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
            # the vocab will load from file when build_vocab() is called

        # Add words to idx2sym and sym2idx of vocab
        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            # Add LongTensors of the full corpus consisting only of words.
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True)

        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True,
                                               add_eos=False)

        elif self.dataset == 'lm1b':
            self.train = train_paths
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=False,
                                               add_double_eos=True)

    def get_iterator(self, split, *args, **kwargs):
        # batch_size, bptt, device and extended context length are passed as args/kwargs
        if split == 'train':
            if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
                data_iter = LMOrderedIterator(self.train, *args, **kwargs)

            elif self.dataset == 'lm1b':
                kwargs['shuffle'] = True
                data_iter = LMMultiFileIterator(self.train, self.vocab, *args,
                                                **kwargs)

        elif split in ['valid', 'test']:
            data = self.valid if split == 'valid' else self.test

            if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
                data_iter = LMOrderedIterator(data, *args, **kwargs)

            elif self.dataset == 'lm1b':
                data_iter = LMShuffledIterator(data, *args, **kwargs)

        return data_iter
コード例 #30
0
class Corpus(object):
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
            self.vocab.count_file(os.path.join(path, 'train.txt'))
            self.vocab.count_file(os.path.join(path, 'valid.txt'))
            self.vocab.count_file(os.path.join(path, 'test.txt'))
        elif self.dataset == 'wt103':
            self.vocab.count_file(os.path.join(path, 'train.txt'))
        elif self.dataset == 'lm1b':
            train_path_pattern = os.path.join(
                path, '1-billion-word-language-modeling-benchmark-r13output',
                'training-monolingual.tokenized.shuffled', 'news.en-*')
            train_paths = glob.glob(train_path_pattern)
            # the vocab will load from file when build_vocab() is called
        elif self.dataset in ['nesmdb', 'nesmdb_emb']:
            train_paths = glob.glob(os.path.join(path, 'train', '*.txt'))
            valid_paths = glob.glob(os.path.join(path, 'valid', '*.txt'))
            test_paths = glob.glob(os.path.join(path, 'test', '*.txt'))

        self.vocab.build_vocab()

        if self.dataset in ['ptb', 'wt2', 'wt103']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True)
        elif self.dataset in ['enwik8', 'text8']:
            self.train = self.vocab.encode_file(os.path.join(
                path, 'train.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == 'lm1b':
            self.train = train_paths
            self.valid = self.vocab.encode_file(os.path.join(
                path, 'valid.txt'),
                                                ordered=False,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(os.path.join(path, 'test.txt'),
                                               ordered=False,
                                               add_double_eos=True)
        elif self.dataset in ['nesmdb', 'nesmdb_emb']:
            self.train = train_paths
            self.valid = valid_paths
            self.test = test_paths

    def get_iterator(self, split, *args, **kwargs):
        if split == 'train':
            if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
                data_iter = LMOrderedIterator(self.train, *args, **kwargs)
            elif self.dataset in ['lm1b', 'nesmdb']:
                kwargs['shuffle'] = True
                data_iter = LMMultiFileIterator(self.train, self.vocab, *args,
                                                **kwargs)
            elif self.dataset == 'nesmdb_emb':
                kwargs['shuffle'] = True
                data_iter = ConditionalLMMultiFileIterator(
                    self.train, self.vocab, *args, **kwargs)
        elif split in ['valid', 'test']:
            data = self.valid if split == 'valid' else self.test
            if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
                data_iter = LMOrderedIterator(data, *args, **kwargs)
            elif self.dataset == 'lm1b':
                data_iter = LMShuffledIterator(data, *args, **kwargs)
            elif self.dataset == 'nesmdb':
                kwargs['shuffle'] = False
                # I've decided to let these both always be true for evaluation
                kwargs['skip_short'] = True
                kwargs['trim_padding'] = True
                data_iter = LMMultiFileIterator(data, self.vocab, *args,
                                                **kwargs)
            elif self.dataset == 'nesmdb_emb':
                kwargs['shuffle'] = False
                kwargs['skip_short'] = True
                kwargs['trim_padding'] = True
                data_iter = ConditionalLMMultiFileIterator(
                    data, self.vocab, *args, **kwargs)

        return data_iter