Exemplo n.º 1
0
def translate(args, net, src_vocab, tgt_vocab):
    "done"
    sentences = [l.split() for l in args.text]
    translated = []

    infer_dataset = ParallelDataset(args.text, args.ref_text, src_vocab,
                                    tgt_vocab)
    if args.batch_size is not None:
        infer_dataset.BATCH_SIZE = args.batch_size
    if args.max_batch_size is not None:
        infer_dataset.max_batch_size = args.max_batch_size
    if args.tokens_per_batch is not None:
        infer_dataset.tokens_per_batch = args.tokens_per_batch

    infer_dataiter = iter(infer_dataset.get_iterator(True, True))

    for raw_batch in infer_dataiter:
        src_mask = (raw_batch.src != src_vocab.stoi[config.PAD]).unsqueeze(-2)
        if args.use_cuda:
            src, src_mask = raw_batch.src.cuda(), src_mask.cuda()
        if args.greedy:
            generated, gen_len = greedy(args, net, src, src_mask, src_vocab,
                                        tgt_vocab)
        else:
            generated, gen_len = generate_beam(args, net, src, src_mask,
                                               src_vocab, tgt_vocab)
        new_translations = gen_batch2str(src, raw_batch.tgt, generated,
                                         gen_len, src_vocab, tgt_vocab)
        for res_sent in new_translations:
            print(res_sent)
        translated.extend(new_translations)

    return translated
Exemplo n.º 2
0
def translate(args, net, src_vocab, tgt_vocab, active_out=None):
    "done"
    sentences = [l.split() for l in args.text]
    translated = []

    infer_dataset = ParallelDataset(args.text, args.ref_text, src_vocab,
                                    tgt_vocab)
    if args.batch_size is not None:
        infer_dataset.BATCH_SIZE = args.batch_size
    if args.max_batch_size is not None:
        infer_dataset.max_batch_size = args.max_batch_size
    if args.tokens_per_batch is not None:
        infer_dataset.tokens_per_batch = args.tokens_per_batch

    infer_dataiter = iter(
        infer_dataset.get_iterator(shuffle=True,
                                   group_by_size=True,
                                   include_indices=True))

    for (raw_batch, indices) in infer_dataiter:
        src_mask = (raw_batch.src != src_vocab.stoi[config.PAD]).unsqueeze(-2)
        if args.use_cuda:
            src, src_mask = raw_batch.src.cuda(), src_mask.cuda()
        else:
            src = raw_batch.src
        generated, gen_len = greedy(args, net, src, src_mask, src_vocab,
                                    tgt_vocab)
        new_translations = gen_batch2str(src, raw_batch.tgt, generated,
                                         gen_len, src_vocab, tgt_vocab,
                                         indices, active_out)
        translated.extend(new_translations)

    return translated
Exemplo n.º 3
0
def translate(args, net, src_vocab, tgt_vocab):
    "done"
    sentences = [l.split() for l in args.text]
    translated = []

    if args.greedy:
        infer_dataset = ParallelDataset(args.text, args.ref_text, src_vocab,
                                        tgt_vocab)
        if args.batch_size is not None:
            infer_dataset.BATCH_SIZE = args.batch_size
        if args.max_batch_size is not None:
            infer_dataset.max_batch_size = args.max_batch_size
        if args.tokens_per_batch is not None:
            infer_dataset.tokens_per_batch = args.tokens_per_batch

        infer_dataiter = iter(infer_dataset.get_iterator(True, True))
        num_sents = 0
        for raw_batch in infer_dataiter:
            src_mask = (raw_batch.src !=
                        src_vocab.stoi[config.PAD]).unsqueeze(-2)
            if args.use_cuda:
                src, src_mask = raw_batch.src.cuda(), src_mask.cuda()
            generated, gen_len = greedy(args, net, src, src_mask, src_vocab,
                                        tgt_vocab)
            new_translations = gen_batch2str(src, raw_batch.tgt, generated,
                                             gen_len, src_vocab, tgt_vocab)
            print('src size : {}'.format(src.size()))
            '''
            for res_sent in new_translations:
                print(res_sent)
            translated.extend(new_translations)
            '''
    else:
        for i_s, sentence in enumerate(sentences):
            s_trans = translate_sentence(sentence, net, args, src_vocab,
                                         tgt_vocab)
            s_trans = remove_special_tok(remove_bpe(s_trans))
            translated.append(s_trans)
            print(translated[-1])

    return translated
Exemplo n.º 4
0
def load_test_data(data_path, vocab_path, batch_size, use_cuda=False):
    # Note: sequential=False, use_vocab=False, since we use preprocessed inputs.
    src_field = Field(
        sequential=True,
        use_vocab=False,
        include_lengths=True,
        batch_first=True,
        pad_token=PAD,
        unk_token=UNK,
        init_token=None,
        eos_token=None,
    )
    fields = (src_field, None)
    device = None if use_cuda else -1

    vocab = torch.load(vocab_path)
    _, src_word2idx, _ = vocab['src_dict']
    lower_case = vocab['lower_case']

    test_src = convert_text2idx(read_corpus(data_path, None, lower_case),
                                src_word2idx)
    test_data = ParallelDataset(
        test_src,
        None,
        fields=fields,
    )
    test_iter = Iterator(
        dataset=test_data,
        batch_size=batch_size,
        train=False,  # Variable(volatile=True)
        repeat=False,
        device=device,
        shuffle=False,
        sort=False)

    return src_field, test_iter
Exemplo n.º 5
0
from torch.utils.data import DataLoader
from dataset import ParallelDataset, BertTokenizer

print("Running unittests for XNLI dataset...")
batch_size = 12
seq_len = 128

tokenizer = BertTokenizer("data/bert-base-multilingual-uncased-vocab.txt")
dataset = ParallelDataset("data/xnli.15way.orig.tsv", tokenizer, seq_len)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

print("passed initialization and dataloader tests")

for i, batch in enumerate(data_loader):
    assert type(batch) is dict
    assert len(batch.keys()) == 15
    assert batch['en'].shape == (batch_size, seq_len)
    if i > 10:
        break

print("passed batch sampling tests")

languages = ('vi', 'en')
tokenizer = BertTokenizer("data/bert-base-multilingual-uncased-vocab.txt")
dataset = ParallelDataset("data/xnli.15way.orig.tsv",
                          tokenizer,
                          seq_len,
                          languages=languages)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

print("passed initialization of subset of columns")
Exemplo n.º 6
0
def load_para_data(params, data):
    """
    Load parallel data.
    """
    data['para'] = {}

    required_para_train = set(params.clm_steps + params.mlm_steps +
                              params.pc_steps + params.mt_steps)

    for src, tgt in params.para_dataset.keys():

        logger.info('============ Parallel data (%s-%s)' % (src, tgt))

        assert (src, tgt) not in data['para']
        data['para'][(src, tgt)] = {}

        for splt in ['train', 'valid', 'test']:

            # no need to load training data for evaluation
            if splt == 'train' and params.eval_only:
                continue

            # for back-translation, we can't load training data
            if splt == 'train' and (src, tgt) not in required_para_train and (
                    tgt, src) not in required_para_train:
                continue

            # load binarized datasets
            src_path, tgt_path = params.para_dataset[(src, tgt)][splt]
            src_data = load_binarized(src_path, params)
            tgt_data = load_binarized(tgt_path, params)

            # update dictionary parameters
            set_dico_parameters(params, data, src_data['dico'])
            set_dico_parameters(params, data, tgt_data['dico'])

            # create ParallelDataset
            dataset = ParallelDataset(src_data['sentences'],
                                      src_data['positions'],
                                      tgt_data['sentences'],
                                      tgt_data['positions'], params)

            # remove empty and too long sentences
            if splt == 'train':
                dataset.remove_empty_sentences()
                dataset.remove_long_sentences(params.max_len)

            # for validation and test set, enumerate sentence per sentence
            if splt != 'train':
                dataset.tokens_per_batch = -1

            # if there are several processes on the same machine, we can split the dataset
            if splt == 'train' and params.n_gpu_per_node > 1 and params.split_data:
                n_sent = len(dataset) // params.n_gpu_per_node
                a = n_sent * params.local_rank
                b = n_sent * params.local_rank + n_sent
                dataset.select_data(a, b)

            data['para'][(src, tgt)][splt] = dataset
            logger.info("")

    logger.info("")
Exemplo n.º 7
0
def load_train_data(data_path,
                    batch_size,
                    max_src_len,
                    max_trg_len,
                    use_cuda=False):
    # Note: sequential=False, use_vocab=False, since we use preprocessed inputs.
    src_field = Field(
        sequential=True,
        use_vocab=False,
        include_lengths=True,
        batch_first=True,
        pad_token=PAD,
        unk_token=UNK,
        init_token=None,
        eos_token=None,
    )
    trg_field = Field(
        sequential=True,
        use_vocab=False,
        include_lengths=True,
        batch_first=True,
        pad_token=PAD,
        unk_token=UNK,
        init_token=BOS,
        eos_token=EOS,
    )
    fields = (src_field, trg_field)
    device = None if use_cuda else -1

    def filter_pred(example):
        if len(example.src) <= max_src_len and len(example.trg) <= max_trg_len:
            return True
        return False

    dataset = torch.load(data_path)
    train_src, train_tgt = dataset['train_src'], dataset['train_tgt']
    dev_src, dev_tgt = dataset['dev_src'], dataset['dev_tgt']

    train_data = ParallelDataset(
        train_src,
        train_tgt,
        fields=fields,
        filter_pred=filter_pred,
    )
    train_iter = Iterator(
        dataset=train_data,
        batch_size=batch_size,
        train=True,  # Variable(volatile=False)
        sort_key=lambda x: data.interleave_keys(len(x.src), len(x.trg)),
        repeat=False,
        shuffle=True,
        device=device)
    dev_data = ParallelDataset(
        dev_src,
        dev_tgt,
        fields=fields,
    )
    dev_iter = Iterator(
        dataset=dev_data,
        batch_size=batch_size,
        train=False,  # Variable(volatile=True)
        repeat=False,
        device=device,
        shuffle=False,
        sort=False,
    )

    return src_field, trg_field, train_iter, dev_iter
Exemplo n.º 8
0
    def _create_datasets(self) -> tuple:
        """
        Creates the following dataset from the data paths in the config file.

        - a train generator that generates batches of src and tgt data
        - a dev generator that generates batches of src dev data
        - tgt_dev that denotes the raw target dev data
        """
        # add task prefix and EOS token as required by model
        src_train = [
            self.model.config.prefix + text.strip() + " </s>"
            for text in list(open(self.data_config["src_train"]))
        ]
        src_dev = [
            self.model.config.prefix + text.strip() + " </s>"
            for text in list(open(self.data_config["src_dev"]))
        ]

        tgt_train = [
            text.strip() + " </s>"
            for text in list(open(self.data_config["tgt_train"]))
        ]
        tgt_dev = [
            text.strip() for text in list(open(self.data_config["tgt_dev"]))
        ]

        # tokenize src and target data
        src_train_dict = self.tokenizer.batch_encode_plus(
            src_train,
            max_length=self.train_config["max_output_length"],
            return_tensors="pt",
            pad_to_max_length=True,
        )
        src_dev_dict = self.tokenizer.batch_encode_plus(
            src_dev,
            max_length=self.train_config["max_output_length"],
            return_tensors="pt",
            pad_to_max_length=True,
        )
        tgt_train_dict = self.tokenizer.batch_encode_plus(
            tgt_train,
            max_length=self.train_config["max_output_length"],
            return_tensors="pt",
            pad_to_max_length=True,
        )

        # obtain input tensors
        input_ids = src_train_dict["input_ids"]
        input_dev_ids = src_dev_dict["input_ids"]
        output_ids = tgt_train_dict["input_ids"]

        # specify data loader params and create train generator
        params = {
            "batch_size": self.train_config["batch_size"],
            "shuffle": self.train_config["shuffle_data"],
            "num_workers": self.train_config["num_workers_data_gen"],
        }
        train_generator = DataLoader(ParallelDataset(input_ids, output_ids),
                                     **params)
        self.logger.info(
            f"Created training dataset of {len(input_ids)} parallel sentences")

        dev_params = params
        dev_params["shuffle"] = False
        dev_generator = DataLoader(MonoDataset(input_dev_ids), **dev_params)

        all_data = (train_generator, dev_generator, tgt_dev)

        return all_data