示例#1
0
def stratified_sampler(train, test, target, text_field, label_field):
    shuffler = StratifiedShuffleSplit(n_splits=1,
                                      train_size=0.7,
                                      test_size=0.30)
    X = []
    y = []
    fields = [('text', text_field), (target[0], label_field)]

    for example in train:
        X.append(getattr(example, "text"))
        y.append(getattr(example, target[0]))

    for example in test:
        X.append(getattr(example, "text"))
        y.append(getattr(example, target[0]))

    train_idx, test_idx = list(shuffler.split(X, y))[0]

    trn = Dataset(
        examples=[Example.fromlist([X[i], y[i]], fields) for i in train_idx],
        fields=fields)
    tst = Dataset(
        examples=[Example.fromlist([X[i], y[i]], fields) for i in test_idx],
        fields=fields)

    return trn, tst
示例#2
0
def load_data(data_dir, emb_file):
    text_field = Field(sequential=True, tokenize=tokenize)
    label_field = Field(sequential=False, unk_token=None)
    fields = [('text', text_field), ('label', label_field)]
    examples = []

    for entry in os.scandir(data_dir):
        if entry.is_file():
            continue

        label = entry.name

        for doc_file in os.scandir(entry.path):
            if doc_file.name.startswith(label):
                with open(doc_file.path) as f:
                    text = '\n'.join(f.read().splitlines()[2:])
                    example = Example.fromlist([text, label], fields)
                    examples.append(example)

    data = Dataset(examples, fields)

    (train_data, test_data) = data.split(0.7)

    text_field.build_vocab(train_data)
    label_field.build_vocab(data)

    vectors = Vectors(emb_file)
    text_field.vocab.load_vectors(vectors)

    return train_data, test_data, text_field, label_field
def main():
    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-data_pkl',
                        required=True,
                        help='Pickle file with vocabulary.')
    parser.add_argument('-trg_data', default='PSLG-PC12/ENG-ASL_Test.en')
    parser.add_argument('-pred_data',
                        default='predictions.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    opt = parser.parse_args()

    data = pickle.load(open(opt.data_pkl, 'rb'))
    SRC, TRG = data['vocab']['src'], data['vocab']['trg']

    fields = [('src', SRC)]

    with open(opt.trg_data, 'r') as f:
        trg_loader = Dataset(
            examples=[Example.fromlist([x], fields) for x in f],
            fields={'src': SRC})
    trg_txt = [x.src for x in trg_loader]

    with open(opt.pred_data, 'r') as f:
        pred_loader = Dataset(
            examples=[Example.fromlist([x], fields) for x in f],
            fields={'src': SRC})
    pred_txt = [[x.src] for x in pred_loader]

    score = bleu_score(trg_txt, pred_txt)
    print('Bleu 4 score is {}'.format(str(score)))

    with open('bleu_score.txt', 'w') as f:
        f.write('Bleu 4 score is {}'.format(str(score)))
示例#4
0
    def __init__(
            self,
            question_path,
            paragraph_path,
            ratio,
            batch_size,
            vocab: Vocab = Ref("model.vocab"),
            batch_first=Ref("model.batch_first", True),
    ):
        self.vocab = vocab
        question = Field(include_lengths=True,
                         batch_first=batch_first,
                         pad_token=vocab.pad_token)
        question.vocab = vocab
        paragraph = Field(batch_first=batch_first, pad_token=vocab.pad_token)
        paragraph.vocab = vocab
        paragraphs = NestedField(paragraph, include_lengths=True)
        paragraphs.vocab = vocab
        target = Field(sequential=False, use_vocab=False, is_target=True)

        fields = [("question", question), ("paragraphs", paragraphs),
                  ("target", target)]
        examples = []
        with open(paragraph_path) as paragraph_file, open(
                question_path) as question_file:
            for q in question_file:
                q = q.strip()
                ps = [paragraph_file.readline().strip() for _ in range(ratio)]
                examples.append(Example.fromlist([q, ps, 0], fields))

        BaseIRDataset.__init__(self, ratio, batch_size, batch_first)
        TorchTextDataset.__init__(self, examples, fields)
def prepare_dataloaders(opt, device):
    batch_size = opt.batch_size
    data = pickle.load(open(opt.data_pkl, 'rb'))

    opt.max_token_seq_len = data['settings'].max_len
    opt.src_pad_idx = data['vocab']['src'].vocab.stoi[Constants.PAD_WORD]
    opt.trg_pad_idx = data['vocab']['trg'].vocab.stoi[Constants.PAD_WORD]

    opt.src_vocab_size = len(data['vocab']['src'].vocab)
    opt.trg_vocab_size = len(data['vocab']['trg'].vocab)

    #========= Preparing Model =========#
    if opt.embs_share_weight:
        assert data['vocab']['src'].vocab.stoi == data['vocab']['trg'].vocab.stoi, \
            'To sharing word embedding the src/trg word2idx table shall be the same.'

    fields = {'src': data['vocab']['src'], 'trg': data['vocab']['trg']}

    train = Dataset(examples=data['train'], fields=fields)
    val = Dataset(examples=data['valid'], fields=fields)

    train_iterator = BucketIterator(train,
                                    batch_size=batch_size,
                                    device=device,
                                    train=True)
    val_iterator = BucketIterator(val, batch_size=batch_size, device=device)

    return train_iterator, val_iterator
示例#6
0
def build_dataset(in_path, in_field, out_path = None, out_field = None):
    in_ = load_in_text(in_path)
    if out_path is not None:
        out_ = load_out_text(out_path)
        return Dataset(examples = [in_, out_], fields = [('src', in_field),('trg', out_field)])
    else:
        return Dataset(examples = [in_], fields = [('src', in_field)])
示例#7
0
    def __init__(self,
                 path,
                 batch_size,
                 vocab: Vocab = Ref("model.vocab"),
                 batch_first=Ref("model.batch_first", True)):
        self.vocab = vocab
        question = Field(include_lengths=True,
                         use_vocab=False,
                         pad_token=vocab.pad_index,
                         batch_first=batch_first)
        paragraph = Field(batch_first=batch_first,
                          pad_token=vocab.pad_index,
                          use_vocab=False)
        paragraphs = NestedField(paragraph, include_lengths=True)
        target = Field(sequential=False, use_vocab=False, is_target=True)

        fields = [("question", question), ("paragraphs", paragraphs),
                  ("target", target)]

        import h5py
        self.data = h5py.File(path, "r")
        ds = self.data["examples"]
        ratio = ds.attrs["ratio"]

        TorchTextDataset.__init__(self, self.ExampleWrapper(ds, ratio, fields),
                                  fields)
        BaseIRDataset.__init__(self, ratio, batch_size, batch_first)
示例#8
0
    def init_dataloaders(self):
        batch_size = self.config.get('batch_size', 8)
        project_path = self.config['firelab']['project_path']
        data_path = os.path.join(project_path, self.config['data'])

        with open(data_path) as f:
            lines = f.read().splitlines()

        text = Field(init_token='<bos>', eos_token='<eos>', batch_first=True)

        examples = [Example.fromlist([s], [('text', text)]) for s in lines]
        dataset = Dataset(examples, [('text', text)])
        # TODO: torchtext is insane. We pass split ratio for [train, val, test]
        # and it returns splits for [train, test, val]
        splits = dataset.split(split_ratio=[0.999, 0.0009, 0.0001])
        self.train_ds, self.test_ds, self.val_ds = splits
        text.build_vocab(self.train_ds)

        self.vocab = text.vocab
        self.train_dataloader = data.BucketIterator(self.train_ds,
                                                    batch_size,
                                                    repeat=False)
        self.val_dataloader = data.BucketIterator(self.val_ds,
                                                  batch_size,
                                                  train=False,
                                                  sort=False)
        self.test_dataloader = data.BucketIterator(self.test_ds,
                                                   batch_size,
                                                   train=False,
                                                   sort=False)
示例#9
0
def prepare_dataloaders(pkl, bs, device):
    batch_size = bs
    data = pickle.load(open(pkl, 'rb'))
    # PAD_WORD = '<PAD>'
    # UNK_WORD = '<UNK>'
    # BOS_WORD = '<SOS>'
    # EOS_WORD = '<EOS>'
    # max_token_seq_len = 100
    # src_pad_idx = data['vocab']['src'].vocab.stoi['<PAD>']
    # trg_pad_idx = data['vocab']['trg'].vocab.stoi['<PAD>']
    #
    vocab = data['vocab']['src'].vocab
    # trg_vocab_size = len(data['vocab']['trg'].vocab)

    fields = {'src': data['vocab']['src'], 'trg': data['vocab']['trg']}
    # fields = torchtext.data.Field(pad_token=PAD_WORD, init_token=BOS_WORD, eos_token=EOS_WORD, unk_token=UNK_WORD)
    train = Dataset(examples=data['train'], fields=fields)
    val = Dataset(examples=data['valid'], fields=fields)

    train_iterator = BucketIterator(train,
                                    batch_size=batch_size,
                                    device=device,
                                    train=True,
                                    shuffle=True)
    val_iterator = BucketIterator(val, batch_size=batch_size, device=device)

    return train_iterator, val_iterator, vocab
    def get_fold(self, fields, train_indexs, test_indexs):
        """
        get new batch
        :return:
        """
        examples = np.asarray(self.examples)

        return (Dataset(fields=fields, examples=examples[list(train_indexs)]),
                Dataset(fields=fields, examples=examples[list(test_indexs)]))
    def splits(self, fields, dev_ratio=.2, shuffle=True, **kwargs):
        examples = self.examples

        if shuffle: random.shuffle(examples)

        dev_index = -1 * int(dev_ratio * len(examples))

        return (Dataset(fields=fields, examples=examples[:dev_index]),
                Dataset(fields=fields, examples=examples[dev_index:]))
示例#12
0
 def insert_index(dataset: data.Dataset):
     examples = dataset.examples
     fields = dataset.fields
     for i, e in enumerate(examples):
         setattr(e, 'index', i)
     fields['index'] = data.Field(sequential=False, use_vocab=False)
     dataset.examples = examples
     dataset.fields = fields
     return dataset
示例#13
0
文件: dataset.py 项目: zhengxxn/NMT
def load_dataset_from_example(examples: Example, data_fields, max_len=None):
    if max_len is not None:
        dataset = Dataset(examples, data_fields, filter_pred=
        lambda x: len(x.src) <= max_len and len(x.trg) <= max_len)
        # lambda x: len(vars(x)['src']) <= max_len and len(vars(x)['trg']) <= max_len)
    else:
        dataset = Dataset(examples, data_fields)

    return dataset
def _load_dataset(field, path):
    fields = [("sentence", field)]

    with open(path, 'r', encoding='utf8') as f:
        examples = [Example.fromlist([line], fields) for line in f.readlines()]

        d = Dataset(examples, fields=fields)
        d.sort_key = lambda s: s.sentence

        return d
示例#15
0
def build_datasets(examples: List[Example], src_field: Field, dest_field: Field,
                   logger: Logger) -> Tuple[Dataset, Dataset, Dataset]:
    random.seed(GlobalConfig.SEED)
    logger.info('BUILD DATASETS')
    data = Dataset(examples, fields={'src': src_field, 'dest': dest_field})
    train_data, valid_data, test_data = data.split(split_ratio=[0.9, 0.05, 0.05])
    logger.info(f'train set size: {len(train_data.examples):,}')
    logger.info(f'valid set size: {len(valid_data.examples):,}')
    logger.info(f'test set size: {len(test_data.examples):,}')
    return train_data, valid_data, test_data
示例#16
0
 def load_dataset(self, mode, fold):
     with open(os.path.join(self.data_dir, '{}_{}.txt'.format(mode,
                                                              fold))) as f:
         ids = f.readlines()
     ids = set([idx.strip() for idx in ids])
     dset = Dataset(self.examples,
                    self.fields,
                    filter_pred=lambda x: x.id in ids)
     dset.sort_key = lambda x: len(x.text)
     return dset
示例#17
0
    def prepare_data(self):
        batch_size = self.data_cfg.batch_size
        data = pickle.load(open(self.data_cfg.data_path, "rb"))

        if self.model_cfg.emb_src_trg_weight_sharing:
            assert data['vocab']['src'].vocab.stoi == data['vocab']['trg'].vocab.stoi, \
                'To sharing word embedding the src/trg word2idx table shall be the same.'

        fields = {'src': data['vocab']['src'], 'trg': data['vocab']['trg']}
        self.train_dataset = Dataset(examples=data['train'], fields=fields)
        self.val_dataset = Dataset(examples=data['valid'], fields=fields)
示例#18
0
    def get_fold(self, fields, train_indexs, test_indexs, shuffle=True):
        """
        get new batch
        :return:
        """
        examples = np.asarray(self.examples)

        if shuffle: random.shuffle(examples)
        print list(train_indexs)
        return (Dataset(fields=fields, examples=examples[list(train_indexs)]),
                Dataset(fields=fields, examples=examples[list(test_indexs)]))
示例#19
0
def load_dataset(path, binary=True, vocab_path=None):
    print('Loading data from path {}...'.format(path))
    vocab_count = pickle.load(open(path + '/vocab', 'rb'))
    print('Constructing Fields...')
    fields_dict = make_fields(vocab_count)
    fields = convert_fields(fields_dict)
    print('Loading Examples...')
    train_examples = example_generator(path + '/train', fields_dict)
    train_data = Dataset(train_examples, fields)
    test_examples = example_generator(path + '/test', fields_dict)
    test_data = Dataset(test_examples, fields)
    return train_data, test_data
示例#20
0
文件: train_lstm.py 项目: marcwww/rlt
def build_iters(args):
    TXT = data.Field(lower=args.lower, include_lengths=True, batch_first=True)
    LBL = data.Field(sequential=False, unk_token=None)
    TREE = data.Field(sequential=True, use_vocab=False, pad_token=0)

    ftrain = 'data/sst/sst/trees/train.txt'
    fvalid = 'data/sst/sst/trees/dev.txt'
    ftest = 'data/sst/sst/trees/test.txt'

    examples_train, len_ave = load_examples(ftrain, subtrees=True)
    examples_valid, _ = load_examples(fvalid, subtrees=False)
    examples_test, _ = load_examples(ftest, subtrees=False)
    train = Dataset(examples_train,
                    fields=[('txt', TXT), ('tree', TREE), ('lbl', LBL)])
    TXT.build_vocab(train, vectors=args.pretrained)
    LBL.build_vocab(train)
    valid = Dataset(examples_valid,
                    fields=[('txt', TXT), ('tree', TREE), ('lbl', LBL)])
    test = Dataset(examples_test,
                   fields=[('txt', TXT), ('tree', TREE), ('lbl', LBL)])

    def batch_size_fn(new_example, current_count, ebsz):
        return ebsz + (len(new_example.txt) / len_ave)**0.3

    device = torch.device(args.gpu if args.gpu != -1 else 'cpu')
    train_iter = basic.BucketIterator(train,
                                      batch_size=args.batch_size,
                                      sort=True,
                                      shuffle=True,
                                      repeat=False,
                                      sort_key=lambda x: len(x.txt),
                                      batch_size_fn=batch_size_fn,
                                      device=device)

    valid_iter = basic.BucketIterator(valid,
                                      batch_size=args.batch_size,
                                      sort=True,
                                      shuffle=True,
                                      repeat=False,
                                      sort_key=lambda x: len(x.txt),
                                      batch_size_fn=batch_size_fn,
                                      device=device)

    test_iter = basic.BucketIterator(test,
                                     batch_size=args.batch_size,
                                     sort=True,
                                     shuffle=True,
                                     repeat=False,
                                     sort_key=lambda x: len(x.txt),
                                     batch_size_fn=batch_size_fn,
                                     device=device)

    return train_iter, valid_iter, test_iter, (TXT, TREE, LBL)
示例#21
0
    def __init__(self, corpus_path=None, split: Tuple[int, int] = None):
        self.fields, self.max_vocab_indexes = self.load_fields(self.VOCAB_PATH)

        if corpus_path:
            self.examples = self.load_corpus(corpus_path)
            if split:
                total = len(self.examples)
                pivot = int(total / sum(split) * split[0])
                self.datasets = [Dataset(self.examples[:pivot], fields=self.fields),
                                 Dataset(self.examples[pivot:], fields=self.fields)]
            else:
                self.datasets = [Dataset(self.examples, fields=self.fields)]
示例#22
0
 def _make_dataset(self, unlabeled, which=None) -> Dataset:
     if not unlabeled:
         sentences = self.train_sentences
         if which == "dev":
             sentences = self.dev_sentences
         elif which == "test":
             sentences = self.test_sentences
         examples = [self._make_example(s) for s in sentences]
         return Dataset(examples, self.labeled_fields)
     else:
         sentences = self.unlabeled_sentences
         examples = [self._make_example_unlabeled(s) for s in sentences]
         return Dataset(examples, self.unlabeled_fields)
示例#23
0
def processing_data(data_path, split_ratio=0.7):
    """
    数据处理
    :data_path:数据集路径
    :validation_split:划分为验证集的比重
    :return:train_iter,val_iter,TEXT.vocab 训练集、验证集和词典
    """
    # --------------- 已经实现好数据的读取,返回和训练集、验证集,可以根据需要自行修改函数 ------------------
    sentences = []  # 片段
    target = []  # 作者
    # 配置参数
    batch_size = 8
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 定义lebel到数字的映射关系
    labels = {'LX': 0, 'MY': 1, 'QZS': 2, 'WXB': 3, 'ZAL': 4}

    files = os.listdir(data_path)
    for file in files:
        if not os.path.isdir(file):
            f = open(data_path + "/" + file, 'r', encoding='UTF-8')  # 打开文件
            for index, line in enumerate(f.readlines()):
                sentences.append(line)
                target.append(labels[file[:-4]])

    mydata = list(zip(sentences, target))
    TEXT = Field(sequential=True,
                 tokenize=lambda x: jb.lcut(x),
                 lower=True,
                 use_vocab=True)
    LABEL = Field(sequential=False, use_vocab=False)
    FIELDS = [('text', TEXT), ('category', LABEL)]
    examples = list(
        map(lambda x: Example.fromlist(list(x), fields=FIELDS), mydata))
    dataset = Dataset(examples, fields=FIELDS)
    TEXT.build_vocab(dataset)
    # save vocab
    with open('vocab.pkl', 'wb') as vocab:
        pickle.dump(TEXT.vocab, vocab)
    # 划分数据集
    train, val = dataset.split(split_ratio=split_ratio)
    # BucketIterator可以针对文本长度产生batch,有利于训练
    train_iter, val_iter = BucketIterator.splits(
        (train, val),  # 数据集
        batch_sizes=(batch_size, batch_size),
        device=device,  # 如果使用gpu,此处将-1更换为GPU的编号
        sort_key=lambda x: len(x.text),
        sort_within_batch=False,
        repeat=False)
    # --------------------------------------------------------------------------------------------
    return train_iter, val_iter, TEXT.vocab
示例#24
0
    def init_dataloaders(self):
        project_path = self.config.firelab.project_path
        domain_x_data_path = os.path.join(project_path,
                                          self.config.data.domain_x)
        domain_y_data_path = os.path.join(project_path,
                                          self.config.data.domain_y)

        with open(domain_x_data_path) as f:
            domain_x = f.read().splitlines()
        with open(domain_y_data_path) as f:
            domain_y = f.read().splitlines()

        print('Dataset sizes:', len(domain_x), len(domain_y))
        domain_x = [
            s for s in domain_x
            if self.config.hp.min_len <= len(s) <= self.config.hp.max_len
        ]
        domain_y = [
            s for s in domain_y
            if self.config.hp.min_len <= len(s) <= self.config.hp.max_len
        ]
        print('Dataset sizes after filtering:', len(domain_x), len(domain_y))

        field = Field(init_token='<bos>',
                      eos_token='|',
                      batch_first=True,
                      tokenize=char_tokenize)
        fields = [('domain_x', field), ('domain_y', field)]

        examples = [
            Example.fromlist([x, y, x], fields)
            for x, y in zip(domain_x, domain_y)
        ]
        train_exs, val_exs = train_test_split(
            examples,
            test_size=self.config.val_set_size,
            random_state=self.config.random_seed)

        train_ds, val_ds = Dataset(train_exs, fields), Dataset(val_exs, fields)
        field.build_vocab(train_ds,
                          max_size=self.config.hp.get('max_vocab_size'))

        self.vocab = field.vocab
        self.train_dataloader = data.BucketIterator(train_ds,
                                                    self.config.hp.batch_size,
                                                    repeat=False)
        self.val_dataloader = data.BucketIterator(val_ds,
                                                  self.config.hp.batch_size,
                                                  repeat=False,
                                                  shuffle=False)
示例#25
0
    def splits(self, fields, dev_ratio=.1, shuffle=True, **kwargs):
        """Create dataset objects for splits of the MR dataset.
        Arguments:
            fields: The field that will be used for the sentence.
            label_field: The field that will be used for label data.
            dev_ratio: The ratio that will be used to get split validation dataset.
            shuffle: Whether to shuffle the data before split.
        """
        examples = self.examples
        if shuffle: random.shuffle(examples)

        dev_index = -1 * int(dev_ratio * len(examples))
        return (Dataset(fields=fields, examples=examples[:dev_index]),
                Dataset(fields=fields, examples=examples[dev_index:]))
示例#26
0
def build_iters(args):
    EXPR = torchtext.data.Field(sequential=True,
                                use_vocab=True,
                                batch_first=True,
                                include_lengths=True,
                                pad_token=PAD,
                                eos_token=None)
    VAL = torchtext.data.Field(sequential=False)
    ftrain = 'data/train_d20s.tsv'
    fvalid = 'data/test_d20s.tsv'
    # ftest = 'data/test_d20s.tsv'

    examples_train, len_ave = load_examples(ftrain)
    examples_valid, _ = load_examples(fvalid)
    train = Dataset(examples_train, fields=[('expr', EXPR),
                                            ('val', VAL)])
    EXPR.build_vocab(train)
    VAL.build_vocab(train)
    valid = Dataset(examples_valid, fields=[('expr', EXPR),
                                            ('val', VAL)])

    device = torch.device(args.gpu if args.gpu != -1 else 'cpu')
    def batch_size_fn(new_example, current_count, ebsz):
        return ebsz + (len(new_example.expr) / len_ave) ** 0.3

    splits, split_avels = split_examples(examples_train)
    train_iters = {srange:None for srange in splits}
    for srange, split in splits.items():
        train_split = Dataset(split, fields=[('expr', EXPR),
                               ('val', VAL)])
        data_iter = basic.BucketIterator(train_split,
                                         batch_size=args.bsz,
                                         sort=True,
                                         shuffle=True,
                                         repeat=False,
                                         sort_key=lambda x: len(x.expr),
                                         batch_size_fn=batch_size_fn,
                                         device=device)
        train_iters[srange] = data_iter

    valid_iter = basic.BucketIterator(valid,
                                      batch_size=args.bsz,
                                      sort=True,
                                      shuffle=True,
                                      repeat=False,
                                      sort_key=lambda x: len(x.expr),
                                      batch_size_fn=batch_size_fn,
                                      device=device)

    return train_iters, valid_iter, EXPR, VAL
示例#27
0
def load_data(c):
    """
    Load datasets, return a dictionary of datasets and fields
    """

    # TODO: add field for context

    spacy_src = spacy.load(c['src_lang'])
    spacy_trg = spacy.load(c['trg_lang'])

    def tokenize_src(text):
        return [tok.text for tok in spacy_src.tokenizer(text)]

    def tokenize_trg(text):
        return [tok.text for tok in spacy_trg.tokenizer(text)]

    src_field = Field(tokenize=tokenize_src,
                      include_lengths=True,
                      eos_token=EOS,
                      lower=True)
    trg_field = Field(tokenize=tokenize_trg,
                      include_lengths=True,
                      eos_token=EOS,
                      lower=True,
                      init_token=SOS)

    datasets = {}
    # load processed data
    for split in c['splits']:
        if os.path.isfile(c['root'] + split + '.pkl'):
            print('Loading {0}'.format(c['root'] + split + '.pkl'))
            examples = pickle.load(open(c['root'] + split + '.pkl', 'rb'))
            datasets[split] = Dataset(examples=examples,
                                      fields={
                                          'src': src_field,
                                          'trg': trg_field
                                      })
        else:
            src_path = c['root'] + split + '.src'
            trg_path = c['root'] + split + '.trg'
            examples = c['load'](src_path, trg_path, src_field, trg_field)
            datasets[split] = Dataset(examples=examples,
                                      fields={
                                          'src': src_field,
                                          'trg': trg_field
                                      })
            print('Saving to {0}'.format(c['root'] + split + '.pkl'))
            pickle.dump(examples, open(c['root'] + split + '.pkl', 'wb'))

    return datasets, src_field, trg_field
示例#28
0
    def __init__(
            self,
            path,
            batch_size,
            src_vocab: Vocab = Ref("model.src_vocab"),
            trg_vocab: Vocab = Ref("model.trg_vocab"),
            level=Ref("model.level"),
            sort_within_batch=False,
            batch_by_words=True,
            batch_first=Ref("model.batch_first", True),
            multiple: int = Ref("exp_global.multiple", 1),
    ):
        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        import h5py
        self.data = h5py.File(path, "r")
        self.src_lengths = self.data["src_len"][:]

        src = Field(batch_first=batch_first,
                    include_lengths=True,
                    postprocessing=self.postprocess_src,
                    use_vocab=False,
                    pad_token=src_vocab.pad_index)

        if "trg_len" in self.data:
            logger.info(f"Loading {path}")
            self.trg_lengths = self.data["trg_len"][:]
            trg = Field(batch_first=batch_first,
                        include_lengths=True,
                        init_token=trg_vocab.bos_index,
                        eos_token=trg_vocab.eos_index,
                        pad_token=trg_vocab.pad_index,
                        is_target=True,
                        postprocessing=self.postprocess_trg,
                        use_vocab=False)
            # trg.vocab = trg_vocab
            fields = [("src", src), ("trg", trg)]
            has_target = True
        else:
            logger.info(f"Loading monolingual {path}")
            fields = [("src", src)]
            has_target = False

        TorchTextDataset.__init__(
            self, self.ExampleWrapper(self.data["examples"], fields), fields)
        BaseTranslationDataset.__init__(self, batch_size, level, False,
                                        sort_within_batch, batch_by_words,
                                        batch_first, multiple, has_target)
示例#29
0
    def __call__(self, docs, progress=True, parallel=True):
        texts = [
            ' '.join([tok.lemma_ for tok in doc if not tok.is_stop])
            for doc in docs
        ]
        fields = [('index', RawField()),
                  ('context', SpacyBertField(self.tokenizer))]

        if parallel:
            with mp.Pool() as pool:
                examples = pool.map(Examplifier(fields),
                                    enumerate(tqdm(texts)))
        else:
            f = Examplifier(fields)
            examples = [f((i, t)) for (i, t) in enumerate(tqdm(texts))]

        ds = Dataset(examples, fields)
        buckets = BucketIterator(dataset=ds,
                                 batch_size=24,
                                 device=self.device,
                                 shuffle=False,
                                 sort=True,
                                 sort_key=lambda ex: -len(ex.context))

        embeds = np.zeros((len(texts), REDUCTION_DIMS), dtype=np.float32)
        for b in tqdm(buckets):
            with torch.no_grad():
                output = self.model.bert.embeddings(b.context)
                embeds[b.index] = reduce_embeds(b.context, output).cpu()

        return embeds
    def __iter__(self):
        text = self.dataset[0].text
        TEXT = self.dataset.fields["text"]
        TEXT.eos_token = None
        text = text + ([TEXT.pad_token] * int(
            math.ceil(len(text) / self.batch_size) * self.batch_size -
            len(text)))
        data = TEXT.numericalize([text], device=self.device)
        data = (data.stack(
            ("seqlen", "batch"),
            "flat").split("flat", ("batch", "seqlen"),
                          batch=self.batch_size).transpose("seqlen", "batch"))

        dataset = Dataset(examples=self.dataset.examples,
                          fields=[("text", TEXT), ("target", TEXT)])
        while True:
            for i in range(0, len(self) * self.bptt_len, self.bptt_len):
                self.iterations += 1
                seq_len = min(self.bptt_len, len(data) - i - 1)
                yield Batch.fromvars(
                    dataset,
                    self.batch_size,
                    text=data.narrow("seqlen", i, seq_len),
                    target=data.narrow("seqlen", i + 1, seq_len),
                )

            if not self.repeat:
                return