def create_reference_files():
        """
        Create reference files for BLEU evaluation.
        """
        params.ref_paths = {}

        for (lang1, lang2), v in data['para'].items():

            assert lang1 < lang2

            for data_set in ['valid', 'test']:

                # define data paths
                lang1_path = os.path.join(
                    params.hyp_path,
                    'ref.{0}-{1}.{2}.txt'.format(lang2, lang1, data_set))
                lang2_path = os.path.join(
                    params.hyp_path,
                    'ref.{0}-{1}.{2}.txt'.format(lang1, lang2, data_set))

                # store data paths
                params.ref_paths[(lang2, lang1, data_set)] = lang1_path
                params.ref_paths[(lang1, lang2, data_set)] = lang2_path

                # text sentences
                lang1_txt = []
                lang2_txt = []

                # convert to text
                for (sent1, len1), (sent2, len2) in get_iterator(
                        params, data, data_set, lang1, lang2):
                    lang1_txt.extend(
                        convert_to_text(sent1, len1, data['dico'], params))
                    lang2_txt.extend(
                        convert_to_text(sent2, len2, data['dico'], params))

                # replace <unk> by <<unk>> as these tokens cannot be counted in BLEU
                lang1_txt = [x.replace('<unk>', '<<unk>>') for x in lang1_txt]
                lang2_txt = [x.replace('<unk>', '<<unk>>') for x in lang2_txt]

                # export hypothesis
                with open(lang1_path, 'w', encoding='utf-8') as f:
                    f.write('\n'.join(lang1_txt) + '\n')
                with open(lang2_path, 'w', encoding='utf-8') as f:
                    f.write('\n'.join(lang2_txt) + '\n')

                # restore original segmentation
                restore_segmentation(lang1_path)
                restore_segmentation(lang2_path)
def run(model,
        params,
        dico,
        data,
        split,
        src_lang,
        trg_lang,
        gen_type="src2trg",
        alpha=1.,
        beta=1.,
        gamma=0.,
        uniform=False,
        iter_mult=1,
        mask_schedule="constant",
        constant_k=1,
        batch_size=8,
        gpu_id=0):
    #n_batches = math.ceil(len(srcs) / batch_size)
    if gen_type == "src2trg":
        ref_path = params.ref_paths[(src_lang, trg_lang, split)]
    elif gen_type == "trg2src":
        ref_path = params.ref_paths[(trg_lang, src_lang, split)]

    refs = [s.strip() for s in open(ref_path, encoding="utf-8").readlines()]
    hypothesis = []
    #hypothesis_selected_pos = []
    for batch_n, batch in enumerate(
            get_iterator(params, data, split, "de", "en")):
        (src_x, src_lens), (trg_x, trg_lens) = batch

        batches, batches_src_lens, batches_trg_lens, total_scores = [], [], [], []
        #batches_selected_pos = []
        for i_topk_length in range(params.num_topk_lengths):

            # overwrite source/target lengths according to dataset stats if necessary
            if params.de2en_lengths != None and params.en2de_lengths != None:
                src_lens_item = src_lens[0].item() - 2  # remove BOS, EOS
                trg_lens_item = trg_lens[0].item() - 2  # remove BOS, EOS
                if gen_type == "src2trg":
                    if len(params.de2en_lengths[src_lens_item].keys()
                           ) < i_topk_length + 1:
                        break
                    data_trg_lens = sorted(
                        params.de2en_lengths[src_lens_item].items(),
                        key=operator.itemgetter(1))
                    data_trg_lens_item = data_trg_lens[-1 -
                                                       i_topk_length][0] + 2
                    # overwrite trg_lens
                    trg_lens = torch.ones_like(trg_lens) * data_trg_lens_item
                elif gen_type == "trg2src":
                    if len(params.en2de_lengths[trg_lens_item].keys()
                           ) < i_topk_length + 1:
                        break
                    data_src_lens = sorted(
                        params.en2de_lengths[trg_lens_item].items(),
                        key=operator.itemgetter(1))
                    # take i_topk_length most likely length and add BOS, EOS
                    data_src_lens_item = data_src_lens[-1 -
                                                       i_topk_length][0] + 2
                    # overwrite src_lens
                    src_lens = torch.ones_like(src_lens) * data_src_lens_item

            if gen_type == "src2trg":
                sent1_input = src_x
                sent2_input = create_masked_batch(trg_lens, params, dico)
                dec_len = torch.max(trg_lens).item() - 2  # cut BOS, EOS
            elif gen_type == "trg2src":
                sent1_input = create_masked_batch(src_lens, params, dico)
                sent2_input = trg_x
                dec_len = torch.max(src_lens).item() - 2  # cut BOS, EOS

            batch, lengths, positions, langs = concat_batches(sent1_input, src_lens, params.lang2id[src_lang], \
                                                              sent2_input, trg_lens, params.lang2id[trg_lang], \
                                                              params.pad_index, params.eos_index, \
                                                              reset_positions=True,
                                                              assert_eos=True) # not sure about it

            if gpu_id >= 0:
                batch, lengths, positions, langs, src_lens, trg_lens = \
                    to_cuda(batch, lengths, positions, langs, src_lens, trg_lens)

            with torch.no_grad():
                batch, total_score_argmax_toks = \
                    _evaluate_batch(model, params, dico, batch,
                                    lengths, positions, langs, src_lens, trg_lens,
                                    gen_type, alpha, beta, gamma, uniform,
                                    dec_len, iter_mult, mask_schedule, constant_k)
            batches.append(batch.clone())
            batches_src_lens.append(src_lens.clone())
            batches_trg_lens.append(trg_lens.clone())
            total_scores.append(total_score_argmax_toks)
            #batches_selected_pos.append(selected_pos)

        best_score_idx = np.array(total_scores).argmax()
        batch, src_lens, trg_lens = batches[best_score_idx], batches_src_lens[
            best_score_idx], batches_trg_lens[best_score_idx]
        #selected_pos = batches_selected_pos[best_score_idx]

        #if gen_type == "src2trg":
        #    hypothesis_selected_pos.append([selected_pos, trg_lens.item()-2])
        #elif gen_type == "trg2src":
        #    hypothesis_selected_pos.append([selected_pos, src_lens.item()-2])

        for batch_idx in range(batch_size):
            src_len = src_lens[batch_idx].item()
            tgt_len = trg_lens[batch_idx].item()
            if gen_type == "src2trg":
                generated = batch[src_len:src_len + tgt_len, batch_idx]
            else:
                generated = batch[:src_len, batch_idx]
            # extra <eos>
            eos_pos = (generated == params.eos_index).nonzero()
            if eos_pos.shape[0] > 2:
                generated = generated[:(eos_pos[1, 0].item() + 1)]
            hypothesis.extend(convert_to_text(generated.unsqueeze(1), \
                                torch.Tensor([generated.shape[0]]).int(), \
                                dico, params))

        print("Ex {0}\nRef: {1}\nHyp: {2}\n".format(
            batch_n, refs[batch_n].encode("utf-8"),
            hypothesis[-1].encode("utf-8")))

    hyp_path = os.path.join(params.hyp_path, 'decoding.txt')
    hyp_path_tok = os.path.join(params.hyp_path, 'decoding.tok.txt')
    #hyp_selected_pos_path = os.path.join(params.hyp_path, "selected_pos.pkl")

    # export sentences to hypothesis file / restore BPE segmentation
    with open(hyp_path, 'w', encoding='utf-8') as f:
        f.write('\n'.join(hypothesis) + '\n')
    with open(hyp_path_tok, 'w', encoding='utf-8') as f:
        f.write('\n'.join(hypothesis) + '\n')
    #with open(hyp_selected_pos_path, 'wb') as f:
    #    pkl.dump(hypothesis_selected_pos, f)
    restore_segmentation(hyp_path)

    # evaluate BLEU score
    bleu = eval_moses_bleu(ref_path, hyp_path)
    print("BLEU %s-%s; %s %s : %f" %
          (src_lang, trg_lang, hyp_path, ref_path, bleu))
    # write BLEU score result to file
    result_path = os.path.join(params.hyp_path, "result.txt")
    with open(result_path, 'w', encoding='utf-8') as f:
        f.write("BLEU %s-%s; %s %s : %f\n" %
                (src_lang, trg_lang, hyp_path, ref_path, bleu))
Exemple #3
0
    def generate(self,
                 x,
                 lengths,
                 langs,
                 z,
                 z_prime=None,
                 log=True,
                 max_print=2):
        input_sent = convert_to_text(x, lengths, self.evaluator.dico,
                                     self.pre_trainer.params)
        with torch.no_grad():
            if z_prime is None:
                z_prime = self.deb('fwd', x=z, lengths=lengths, causal=False)
                z_prime = z_prime.transpose(0, 1)
            """
            #lang1, lang2 = self.pre_trainer.params.mt_steps[0]
            lang1, lang2 = self.pre_trainer.params.langs
            #lang1_id = self.pre_trainer.params.lang2id[lang1]
            lang2_id = self.pre_trainer.params.lang2id[lang2]
            """
            lang2_id = 1
            max_len = int(1.5 * lengths.max().item() + 10)
            seq_len = lengths.max()
            if seq_len >= max_len:
                scr_langs = langs[torch.arange(max_len)]
            else:
                #tgt_langs = torch.cat((langs, langs[torch.arange(max_len - seq_len)]), dim=0)
                scr_langs = torch.cat(
                    (langs, langs[0].repeat(max_len - seq_len, 1)), dim=0)
            tgt_langs = 1 - scr_langs  # the target langs is the opposite of the source lang

            self.pre_trainer.params.beam_size = 1
            if self.pre_trainer.params.beam_size == 1:
                generated_1, lengths_1 = self.pre_trainer.decoder.generate(
                    z,
                    lengths,
                    lang2_id,
                    max_len=max_len,
                    sample_temperature=None,
                    langs=tgt_langs)
                generated_2, lengths_2 = self.pre_trainer.decoder.generate(
                    z_prime,
                    lengths,
                    lang2_id,
                    max_len=max_len,
                    sample_temperature=None,
                    langs=tgt_langs)
            else:
                pass
                """
                beam_size = self.pre_trainer.params.beam_size
                tgt_langs = tgt_langs.repeat(1, beam_size) # (max_len, bs * beam_size)
                generated_1, lengths_1 = self.pre_trainer.decoder.generate_beam(
                        z, lengths, lang2_id, beam_size = beam_size,
                        length_penalty = self.pre_trainer.params.length_penalty,
                        early_stopping = self.pre_trainer.params.early_stopping,
                        max_len = max_len, langs = tgt_langs)
                generated_2, lengths_2 = self.pre_trainer.decoder.generate_beam(
                        z_prime, lengths, lang2_id, beam_size = beam_size,
                        length_penalty = self.pre_trainer.params.length_penalty,
                        early_stopping = self.pre_trainer.params.early_stopping,
                        max_len = max_len, langs = tgt_langs)
                #"""
            gen_text = convert_to_text(generated_1, lengths_1,
                                       self.evaluator.dico,
                                       self.pre_trainer.params)
            deb_sent = convert_to_text(generated_2, lengths_2,
                                       self.evaluator.dico,
                                       self.pre_trainer.params)
        if log:
            i = random.randint(0, len(z) - 1)
            max_print = min(i + max_print, len(x))
            self.logger.info("input : %s" %
                             restore_segmentation_py(input_sent[i:max_print]))
            self.logger.info("gen : %s" %
                             restore_segmentation_py(gen_text[i:max_print]))
            self.logger.info("deb : %s" %
                             restore_segmentation_py(deb_sent[i:max_print]))
        else:
            input_sent = restore_segmentation_py(input_sent)
            gen_text = restore_segmentation_py(gen_text)
            deb_sent = restore_segmentation_py(deb_sent)
            return {
                KEYS["input"]: input_sent,
                KEYS["gen"]: gen_text,
                KEYS["deb"]: deb_sent
            }