예제 #1
0
    def end_eval(self, text_z_prime, references, hypothesis, hypothesis2):
        output_file, i = "output", 1
        while os.path.isfile(
                os.path.join(self.params.dump_path,
                             output_file + str(i) + '.txt')):
            i += 1
        output_file = os.path.join(self.params.dump_path,
                                   output_file + str(i) + '.txt')
        write_text_z_in_file(output_file, text_z_prime)
        restore_segmentation(output_file)

        # compute BLEU
        eval_bleu = True
        if eval_bleu and references:
            if False:
                bleu = multi_list_bleu(references, hypothesis)
                self.bleu = sum(bleu) / len(bleu)
                self.logger.info("average BLEU %s %s : %f" %
                                 ("input", "gen", self.bleu))
            else:
                # hypothesis / reference paths
                hyp_name, ref_name, i = "hyp", "ref", 1
                while os.path.isfile(
                        os.path.join(self.params.dump_path,
                                     hyp_name + str(i) + '.txt')):
                    i += 1
                ref_path = os.path.join(self.params.dump_path,
                                        ref_name + str(i) + '.txt')
                hyp_path = os.path.join(self.params.dump_path,
                                        hyp_name + str(i) + '.txt')
                hyp_path2 = os.path.join(self.params.dump_path,
                                         hyp_name + '_deb' + str(i) + '.txt')

                # export sentences to reference and hypothesis file
                with open(ref_path, 'w', encoding='utf-8') as f:
                    f.write('\n'.join(references) + '\n')
                with open(hyp_path, 'w', encoding='utf-8') as f:
                    f.write('\n'.join(hypothesis) + '\n')
                with open(hyp_path2, 'w', encoding='utf-8') as f:
                    f.write('\n'.join(hypothesis2) + '\n')

                restore_segmentation(ref_path)
                restore_segmentation(hyp_path)
                restore_segmentation(hyp_path2)

                # evaluate BLEU score
                self.bleu = eval_moses_bleu(ref_path, hyp_path)
                self.logger.info("BLEU input-gen : %f (%s, %s)" %
                                 (self.bleu, hyp_path, ref_path))
                bleu = eval_moses_bleu(ref_path, hyp_path2)
                self.logger.info("BLEU input-deb : %f (%s, %s)" %
                                 (bleu, hyp_path2, ref_path))
                bleu = eval_moses_bleu(hyp_path, hyp_path2)
                self.logger.info("BLEU gen-deb : %f (%s, %s)" %
                                 (bleu, hyp_path, hyp_path2))
예제 #2
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    torch.manual_seed(
        params.seed
    )  # Set random seed. NB: Multi-GPU also needs torch.cuda.manual_seed_all(params.seed)
    assert (params.sample_temperature
            == 0) or (params.beam_size == 1), 'Cannot sample with beam search.'
    assert params.amp <= 1, f'params.amp == {params.amp} not yet supported.'
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    encoder = TransformerModel(model_params,
                               dico,
                               is_encoder=True,
                               with_output=False).cuda().eval()
    decoder = TransformerModel(model_params,
                               dico,
                               is_encoder=False,
                               with_output=True).cuda().eval()
    if all([k.startswith('module.') for k in reloaded['encoder'].keys()]):
        reloaded['encoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['encoder'].items()
        }
    encoder.load_state_dict(reloaded['encoder'])
    if all([k.startswith('module.') for k in reloaded['decoder'].keys()]):
        reloaded['decoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['decoder'].items()
        }
    decoder.load_state_dict(reloaded['decoder'])

    if params.amp != 0:
        models = apex.amp.initialize([encoder, decoder],
                                     opt_level=('O%i' % params.amp))
        encoder, decoder = models

    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # read sentences from stdin
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." %
                len(src_sent))

    # f = io.open(params.output_path, 'w', encoding='utf-8')

    hypothesis = [[] for _ in range(params.beam_size)]
    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded = encoder('fwd',
                          x=batch.cuda(),
                          lengths=lengths.cuda(),
                          langs=langs.cuda(),
                          causal=False)
        encoded = encoded.transpose(0, 1)
        max_len = int(1.5 * lengths.max().item() + 10)
        if params.beam_size == 1:
            decoded, dec_lengths = decoder.generate(
                encoded,
                lengths.cuda(),
                params.tgt_id,
                max_len=max_len,
                sample_temperature=(None if params.sample_temperature == 0 else
                                    params.sample_temperature))
        else:
            decoded, dec_lengths, all_hyp_strs = decoder.generate_beam(
                encoded,
                lengths.cuda(),
                params.tgt_id,
                beam_size=params.beam_size,
                length_penalty=params.length_penalty,
                early_stopping=params.early_stopping,
                max_len=max_len,
                output_all_hyps=True)
        # hypothesis.extend(convert_to_text(decoded, dec_lengths, dico, params))

        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip().replace('<unk>', '<<unk>>')
            target = " ".join([dico[sent[k].item()] for k in range(len(sent))
                               ]).replace('<unk>', '<<unk>>')
            if params.beam_size == 1:
                hypothesis[0].append(target)
            else:
                for hyp_rank in range(params.beam_size):
                    print(
                        all_hyp_strs[j]
                        [hyp_rank if hyp_rank < len(all_hyp_strs[j]) else -1])
                    hypothesis[hyp_rank].append(
                        all_hyp_strs[j]
                        [hyp_rank if hyp_rank < len(all_hyp_strs[j]) else -1])

            sys.stderr.write("%i / %i: %s -> %s\n" %
                             (i + j, len(src_sent), source.replace(
                                 '@@ ', ''), target.replace('@@ ', '')))
            # f.write(target + "\n")

    # f.close()

    # export sentences to reference and hypothesis files / restore BPE segmentation
    save_dir, split = params.output_path.rsplit('/', 1)
    for hyp_rank in range(len(hypothesis)):
        hyp_name = f'hyp.st={params.sample_temperature}.bs={params.beam_size}.lp={params.length_penalty}.es={params.early_stopping}.seed={params.seed if (len(hypothesis) == 1) else str(hyp_rank)}.{params.src_lang}-{params.tgt_lang}.{split}.txt'
        hyp_path = os.path.join(save_dir, hyp_name)
        with open(hyp_path, 'w', encoding='utf-8') as f:
            f.write('\n'.join(hypothesis[hyp_rank]) + '\n')
        restore_segmentation(hyp_path)

        # evaluate BLEU score
        if params.ref_path:
            bleu = eval_moses_bleu(params.ref_path, hyp_path)
            logger.info("BLEU %s %s : %f" % (hyp_path, params.ref_path, bleu))
def run(model,
        params,
        dico,
        data,
        split,
        src_lang,
        trg_lang,
        gen_type="src2trg",
        alpha=1.,
        beta=1.,
        gamma=0.,
        uniform=False,
        iter_mult=1,
        mask_schedule="constant",
        constant_k=1,
        batch_size=8,
        gpu_id=0):
    #n_batches = math.ceil(len(srcs) / batch_size)
    if gen_type == "src2trg":
        ref_path = params.ref_paths[(src_lang, trg_lang, split)]
    elif gen_type == "trg2src":
        ref_path = params.ref_paths[(trg_lang, src_lang, split)]

    refs = [s.strip() for s in open(ref_path, encoding="utf-8").readlines()]
    hypothesis = []
    #hypothesis_selected_pos = []
    for batch_n, batch in enumerate(
            get_iterator(params, data, split, "de", "en")):
        (src_x, src_lens), (trg_x, trg_lens) = batch

        batches, batches_src_lens, batches_trg_lens, total_scores = [], [], [], []
        #batches_selected_pos = []
        for i_topk_length in range(params.num_topk_lengths):

            # overwrite source/target lengths according to dataset stats if necessary
            if params.de2en_lengths != None and params.en2de_lengths != None:
                src_lens_item = src_lens[0].item() - 2  # remove BOS, EOS
                trg_lens_item = trg_lens[0].item() - 2  # remove BOS, EOS
                if gen_type == "src2trg":
                    if len(params.de2en_lengths[src_lens_item].keys()
                           ) < i_topk_length + 1:
                        break
                    data_trg_lens = sorted(
                        params.de2en_lengths[src_lens_item].items(),
                        key=operator.itemgetter(1))
                    data_trg_lens_item = data_trg_lens[-1 -
                                                       i_topk_length][0] + 2
                    # overwrite trg_lens
                    trg_lens = torch.ones_like(trg_lens) * data_trg_lens_item
                elif gen_type == "trg2src":
                    if len(params.en2de_lengths[trg_lens_item].keys()
                           ) < i_topk_length + 1:
                        break
                    data_src_lens = sorted(
                        params.en2de_lengths[trg_lens_item].items(),
                        key=operator.itemgetter(1))
                    # take i_topk_length most likely length and add BOS, EOS
                    data_src_lens_item = data_src_lens[-1 -
                                                       i_topk_length][0] + 2
                    # overwrite src_lens
                    src_lens = torch.ones_like(src_lens) * data_src_lens_item

            if gen_type == "src2trg":
                sent1_input = src_x
                sent2_input = create_masked_batch(trg_lens, params, dico)
                dec_len = torch.max(trg_lens).item() - 2  # cut BOS, EOS
            elif gen_type == "trg2src":
                sent1_input = create_masked_batch(src_lens, params, dico)
                sent2_input = trg_x
                dec_len = torch.max(src_lens).item() - 2  # cut BOS, EOS

            batch, lengths, positions, langs = concat_batches(sent1_input, src_lens, params.lang2id[src_lang], \
                                                              sent2_input, trg_lens, params.lang2id[trg_lang], \
                                                              params.pad_index, params.eos_index, \
                                                              reset_positions=True,
                                                              assert_eos=True) # not sure about it

            if gpu_id >= 0:
                batch, lengths, positions, langs, src_lens, trg_lens = \
                    to_cuda(batch, lengths, positions, langs, src_lens, trg_lens)

            with torch.no_grad():
                batch, total_score_argmax_toks = \
                    _evaluate_batch(model, params, dico, batch,
                                    lengths, positions, langs, src_lens, trg_lens,
                                    gen_type, alpha, beta, gamma, uniform,
                                    dec_len, iter_mult, mask_schedule, constant_k)
            batches.append(batch.clone())
            batches_src_lens.append(src_lens.clone())
            batches_trg_lens.append(trg_lens.clone())
            total_scores.append(total_score_argmax_toks)
            #batches_selected_pos.append(selected_pos)

        best_score_idx = np.array(total_scores).argmax()
        batch, src_lens, trg_lens = batches[best_score_idx], batches_src_lens[
            best_score_idx], batches_trg_lens[best_score_idx]
        #selected_pos = batches_selected_pos[best_score_idx]

        #if gen_type == "src2trg":
        #    hypothesis_selected_pos.append([selected_pos, trg_lens.item()-2])
        #elif gen_type == "trg2src":
        #    hypothesis_selected_pos.append([selected_pos, src_lens.item()-2])

        for batch_idx in range(batch_size):
            src_len = src_lens[batch_idx].item()
            tgt_len = trg_lens[batch_idx].item()
            if gen_type == "src2trg":
                generated = batch[src_len:src_len + tgt_len, batch_idx]
            else:
                generated = batch[:src_len, batch_idx]
            # extra <eos>
            eos_pos = (generated == params.eos_index).nonzero()
            if eos_pos.shape[0] > 2:
                generated = generated[:(eos_pos[1, 0].item() + 1)]
            hypothesis.extend(convert_to_text(generated.unsqueeze(1), \
                                torch.Tensor([generated.shape[0]]).int(), \
                                dico, params))

        print("Ex {0}\nRef: {1}\nHyp: {2}\n".format(
            batch_n, refs[batch_n].encode("utf-8"),
            hypothesis[-1].encode("utf-8")))

    hyp_path = os.path.join(params.hyp_path, 'decoding.txt')
    hyp_path_tok = os.path.join(params.hyp_path, 'decoding.tok.txt')
    #hyp_selected_pos_path = os.path.join(params.hyp_path, "selected_pos.pkl")

    # export sentences to hypothesis file / restore BPE segmentation
    with open(hyp_path, 'w', encoding='utf-8') as f:
        f.write('\n'.join(hypothesis) + '\n')
    with open(hyp_path_tok, 'w', encoding='utf-8') as f:
        f.write('\n'.join(hypothesis) + '\n')
    #with open(hyp_selected_pos_path, 'wb') as f:
    #    pkl.dump(hypothesis_selected_pos, f)
    restore_segmentation(hyp_path)

    # evaluate BLEU score
    bleu = eval_moses_bleu(ref_path, hyp_path)
    print("BLEU %s-%s; %s %s : %f" %
          (src_lang, trg_lang, hyp_path, ref_path, bleu))
    # write BLEU score result to file
    result_path = os.path.join(params.hyp_path, "result.txt")
    with open(result_path, 'w', encoding='utf-8') as f:
        f.write("BLEU %s-%s; %s %s : %f\n" %
                (src_lang, trg_lang, hyp_path, ref_path, bleu))