Ejemplo n.º 1
0
def get(ddir: str, ft_path: str, split: str):
    random.seed(1111)
    ddir = Path(ddir)

    ft_model = fastText.load_model(ft_path)
    swem = SWEM(ft_model)

    quality = lf.TextDataset(str(ddir / (f'quality.{split}.txt'))).map(int)
    sent1 = lf.TextDataset(str(ddir / (f'sent1.{split}.txt'))).map(sent_preprocess(swem))
    sent2 = lf.TextDataset(str(ddir / (f'sent2.{split}.txt'))).map(sent_preprocess(swem))

    ds = lf.zip(quality, sent1, sent2)
    return ds
Ejemplo n.º 2
0
def test_get(ddir: str, savedir: str, bsize: int, ft_path: str):
    ddir = Path(ddir)
    savedir = Path(savedir)

    ft_model = fastText.load_model(ft_path)
    swem = SWEM(ft_model)

    quality = lf.TextDataset(str(ddir / ('quality.test.txt'))).map(int)
    sent1 = lf.TextDataset(str(ddir / ('sent1.test.txt'))).map(sent_preprocess(swem))
    sent2 = lf.TextDataset(str(ddir / ('sent2.test.txt'))).map(sent_preprocess(swem))

    ds = lf.zip(quality, sent1, sent2)

    test_dataloader = DataLoader(
            ds.save(savedir / 'swem.test.cache'),
            batch_size=bsize,
            shuffle=False,
            num_workers=4,
            collate_fn=get_collate_fn()
            )

    return test_dataloader
Ejemplo n.º 3
0
def build(datapath='./data/example.txt', savedir='./'):
    datapath = Path(datapath)
    savedir = Path(savedir)

    docs = lf.TextDataset(str(datapath))
    ids = lf.Dataset(range(len(docs)))
    docs = docs.map(preprocess)
    ds = lf.zip(ids, docs)

    tokens = lf.flat_map(lambda x: x[1], ds, lazy=True)
    t2i, words = build_vocab(tokens, str(savedir / 'vocab.pkl'))

    unk_index = t2i[UNK_TOKEN]

    ds.map(postprocess(t2i, unk_index))\
        .save(str(savedir / 'dataset.token.pkl'))
def run(fpath: str,
        translator_path: str,
        src_vocab_path: str,
        tgt_vocab_path: str,
        bsize: int = 32,
        savedir: str = './test'):
    fpath = Path(fpath)
    fname = fpath.stem
    savedir = Path(savedir)

    dataset = lf.TextDataset(str(fpath)).map(preprocess)

    src_t2i: Dict
    tgt_t2i: Dict
    with open(src_vocab_path, 'rb') as f:
        src_t2i, _ = pickle.load(f)
    with open(tgt_vocab_path, 'rb') as f:
        tgt_t2i, _ = pickle.load(f)
    translator: Translator = torch.load(translator_path)

    src_unk_idx = src_t2i[UNK_TOKEN]
    tgt_unk_idx = tgt_t2i[UNK_TOKEN]
    src_pad_idx = src_t2i[PAD_TOKEN]

    dataloader = DataLoader(dataset.map(
        postprocess(src_t2i, src_unk_idx, tgt_t2i, tgt_unk_idx)).save(
            (savedir / fname).with_suffix('.pred.cache')),
                            batch_size=bsize,
                            shuffle=False,
                            num_workers=4,
                            collate_fn=get_collate_fn(src_pad_idx))

    tgt_i2t = {v: k for k, v in tgt_t2i.items()}

    pred_sents = []
    for batch in dataloader:
        src, src_lens = batch
        pred_seqs = translator.translate(src, src_lens, max_target_len=30)
        for pids in pred_seqs:
            ptokens = [tgt_i2t.get(int(pid)) for pid in pids]
            ptokens = trim_special_tokens(ptokens)
            pred_sents.append(' '.join(ptokens))

    with open(savedir / 'pred.txt', 'w') as f:
        f.write('\n'.join(pred_sents))
Ejemplo n.º 5
0
 def test_dataloader(self):
     test = lf.TextDataset(self.config["data"]["test_file"]).map(self._preprocessor)
     return DataLoader(test, sampler=SequentialSampler(test), batch_size=self.config["model"]["batch_size"], num_workers=32)
Ejemplo n.º 6
0
 def val_dataloader(self):
     # val = lf.Dataset(lf.TextDataset(self.config["data"]["val_file"]).take(100)).map(self._preprocessor).save(self.config["data"]["val_cache"])
     val = lf.TextDataset(self.config["data"]["val_file"]).map(self._preprocessor)
     return DataLoader(val, sampler=SequentialSampler(val), batch_size=self.config["model"]["batch_size"], num_workers=32)
Ejemplo n.º 7
0
 def train_dataloader(self):
     train = lf.TextDataset(self.config["data"]["train_file"]).map(self._preprocessor)
     return DataLoader(train, sampler=RandomSampler(train), batch_size=self.config["model"]["batch_size"], num_workers=32)