예제 #1
0
def predict_bleu(weights,
                 model,
                 pattern,
                 seq_length,
                 device,
                 int_to_char,
                 fics,
                 character_level=False):
    """
    Generate text and compute BLEU
    """
    nb_classes = len(int_to_char)
    with torch.no_grad():
        generated_text = []
        for i in range(1000):
            x = np.reshape(pattern, (-1, seq_length, 1))
            x = torch.as_tensor(x, dtype=torch.int64).to(device=device)
            out = model(x).view(nb_classes)
            # index = np.argmax(out).item() # read value of 1d tensor
            # print(out.shape)
            # print(out)
            # top_indexes = torch.topk(out, 5, largest=True)
            probs = torch.nn.functional.softmax(out, 0)
            # index = np.random.choice(top_indexes[1])
            index = np.random.choice(np.arange(0, nb_classes),
                                     p=probs.to(device='cpu').numpy())
            result = int_to_char[index]
            generated_text.append(result)
            seq_in = [int_to_char[value.item()] for value in pattern]
            pattern.append(index)
            pattern = pattern[1:len(pattern)]
        generated_text = ''.join(generated_text)
        return generated_text, compute_bleu(weights, fics, generated_text,
                                            character_level)
예제 #2
0
def compute_bleu_from_tensors(scores_raw, decode_lengths, sort_ind, allcaps,
                              vocab):
    """
    Re-kludged from https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/train.py
    """
    references = []
    # References
    allcaps = allcaps[sort_ind]  # because images were sorted in the decoder
    for j in range(allcaps.shape[0]):
        img_caps = allcaps[j].tolist()
        img_captions = list(
            map(
                lambda c: [
                    w for w in c
                    if w not in {vocab["start_idx"], vocab["pad_idx"]}
                ],
                img_caps,
            ))  # remove <start> and pads
        references.append(img_captions)

    # Hypotheses
    _, preds = torch.max(scores_raw, dim=2)
    preds = preds.tolist()
    hypotheses = []
    for j, p in enumerate(preds):
        hypotheses.append(preds[j][:decode_lengths[j]])  # remove pads

    # Actually compute bleu
    return compute_bleu(references, hypotheses)[0] * 100
def _bleu(ref_file, trans_file, subword_option=None):
    """Compute BLEU scores and handling BPE."""
    max_order = 4
    smooth = False

    #   ref_files = [ref_file]
    #   reference_text = []
    #   for reference_filename in ref_files:
    #     with codecs.getreader("utf-8")(
    #         tf.gfile.GFile(reference_filename, "rb")) as fh:
    #       reference_text.append(fh.readlines())

    #   per_segment_references = []
    #   for references in zip(*reference_text):
    #     reference_list = []
    #     for reference in references:
    #       reference = _clean(reference, subword_option)
    #       reference_list.append(reference.split(" "))
    #     per_segment_references.append(reference_list)

    #   translations = []
    #   with codecs.getreader("utf-8")(tf.gfile.GFile(trans_file, "rb")) as fh:
    #     for line in fh:
    #       line = _clean(line, subword_option=None)
    #       translations.append(line.split(" "))

    # bleu_score, precisions, bp, ratio, translation_length, reference_length
    per_segment_references = [ref_file]
    translations = [trans_file]
    bleu_score, _, _, _, _, _ = bleu.compute_bleu(per_segment_references,
                                                  translations, max_order,
                                                  smooth)
    return 100 * bleu_score
예제 #4
0
 def eval(self, train_step):
     with self.eval_graph.as_default():
         self.eval_saver.restore(self.eval_session, self.model_file)
         bleu_score = 0
         target_results = []
         output_results = []
         for step in range(0, self.eval_reader.data_size):
             data = next(self.eval_data)
             in_seq = data['in_seq']
             in_seq_len = data['in_seq_len']
             target_seq = data['target_seq']
             target_seq_len = data['target_seq_len']
             outputs = self.eval_session.run(
                     self.eval_output,
                     feed_dict={
                         self.eval_in_seq: in_seq,
                         self.eval_in_seq_len: in_seq_len})
             for i in range(len(outputs)):
                 output = outputs[i]
                 target = target_seq[i]
                 output_text = reader.decode_text(output,
                         self.eval_reader.vocabs).split(' ')
                 target_text = reader.decode_text(target[1:],
                         self.eval_reader.vocabs).split(' ')
                 prob = int(self.eval_reader.data_size * self.batch_size / 10)
                 target_results.append([target_text])
                 output_results.append(output_text)
                 if random.randint(1, prob) == 1:
                     print('====================')
                     input_text = reader.decode_text(in_seq[i],
                             self.eval_reader.vocabs)
                     print('src:' + input_text)
                     print('output: ' + ' '.join(output_text))
                     print('target: ' + ' '.join(target_text))
         return bleu.compute_bleu(target_results, output_results)[0] * 100
예제 #5
0
    def eval(self):
        target_results = []
        output_results = []
        self.sess.run(self.transfer)
        for _ in range(
                int(self.valid_data_flow.data_number / dataflow.batch_size)):
            batch = self.valid_data_flow.prepare_data()
            list_dict = [{
                self.train_in_seq: batch[0],
                self.train_in_seq_len: batch[1]
            }]
            target_seq = batch[2]
            in_seq = batch[0]
            feed = {}
            for d in list_dict:
                feed.update(d)
            outputs = self.sess.run(self.model.infer_output, feed_dict=feed)
            for i in range(len(outputs)):
                output = outputs[i]
                target = target_seq[i]
                output_text = decode_text(
                    output, self.valid_data_flow.vocabs).split(' ')
                target_text = decode_text(
                    target[1:], self.valid_data_flow.vocabs).split(' ')

                target_results.append([target_text])
                output_results.append(output_text)
                if _ % 100 == 0 and i == 0:
                    print('====================')
                    input_text = decode_text(in_seq[i],
                                             self.valid_data_flow.vocabs)
                    print('src:' + input_text)
                    print('output: ' + ' '.join(output_text))
                    print('target: ' + ' '.join(target_text))
        return bleu.compute_bleu(target_results, output_results)[0] * 100
예제 #6
0
def compute_bleu(ref_file, trans_file):
    """Compute BLEU scores and handling BPE."""
    max_order = 4
    smooth = False

    ref_files = [ref_file]
    reference_text = []
    for reference_filename in ref_files:
        with codecs.getreader("utf-8")(tf.gfile.GFile(reference_filename,
                                                      "rb")) as fh:
            reference_text.append(fh.readlines())

    per_segment_references = []
    for references in zip(*reference_text):
        reference_list = []
        for reference in references:
            reference = reference.strip()
            reference_list.append(reference.split(" "))
        per_segment_references.append(reference_list)

    translations = []
    with codecs.getreader("utf-8")(tf.gfile.GFile(trans_file, "rb")) as fh:
        for line in fh:
            line = line.strip()
            translations.append(line.split(" "))

    # bleu_score, precisions, bp, ratio, translation_length, reference_length
    bleu_score, _, _, _, _, _ = bleu.compute_bleu(per_segment_references,
                                                  translations, max_order,
                                                  smooth)
    return 100 * bleu_score
예제 #7
0
def evaluate(model, loss_function, data_iter, max_len=MAX_LEN, epsilon=0.0005):
    """Computes loss on validation data.

    Returns:
        The loss on the dataset, a list of losses for each batch.
    """

    loss = 0
    loss_log = []
    num_instances = 0

    captions = []
    references = []

    with torch.no_grad():
        # set the network to evaluation mode
        model.eval()

        for batch in data_iter:
            i, t, batch_size = batch

            # process images
            i = i.to(DEVICE)
            y, att_w = model(i, None, max_len=max_len)
            y = y.permute(1, 2, 0)

            # compute loss
            if not isinstance(t, list):
                t = [t]
            tl = 0
            for j in range(len(t)):
                t[j] = t[j].to(DEVICE)
                t[j] = t[j].squeeze(2).permute(1, 0)

                l, _ = loss_func(loss_function, y, t[j], att_w, epsilon)
                tl += l.item()
                t[j] = t[j].detach()
                num_instances += batch_size

            loss_log.append(tl / (batch_size * len(t)))
            loss += tl

            # decode
            y = y.permute(0, 2, 1)
            _, topi = y.topk(1, dim=2)
            topi = topi.detach().squeeze(2)

            for j in range(batch_size):
                captions.append(data_iter.vocab.tensor_to_sentence(topi[j]))
                references.append([])
                for k in t:
                    references[-1].append(
                        data_iter.vocab.tensor_to_sentence(k[j]))

    bleu = compute_bleu(reference_corpus=references,
                        translation_corpus=captions)[0]
    return (loss / num_instances), loss_log, bleu
예제 #8
0
    def infer_and_eval(self, batches):
        a = np.random.random_integers(0, len(batches) - 1)
        b = np.random.random_integers(0, self.batch_size - 1)
        # inference
        reference_corpus = []
        generation_corpus = []

        recon_loss_l = []
        kl_loss_l = []
        bow_loss_l = []
        word_count = 0

        for index, batch in enumerate(batches):
            feed_dict = dict(zip(self.placeholders, batch))
            feed_dict[self.dropout] = 0.

            gen_digits, gen_len, recon_loss, kl_loss, bow_loss = self.sess.run(
                [
                    self.result, self.result_lengths, self.recon_loss,
                    self.kl_loss, self.bow_loss
                ],
                feed_dict=feed_dict)
            recon_loss_l.append(recon_loss)
            kl_loss_l.append(kl_loss)
            bow_loss_l.append(bow_loss)

            rep_m = batch[3]
            rep_len = batch[4]
            word_count += np.sum(rep_len)
            for i, leng in enumerate(rep_len):
                ref = rep_m[0:leng, i]
                reference_corpus.append([ref])
                out = gen_digits[:gen_len[
                    i] - 1, i, 0] if self.beam_width > 0 else gen_digits[:gen_len[
                        i] - 1, i]
                generation_corpus.append(out)
                if index == a and i == b:
                    ori_m = batch[1]
                    ori_len = batch[2]
                    ori = ori_m[:ori_len[i], i]
                    self.ori_sample = ori
                    self.rep_sample = ref
                    self.out_sample = out

        total_recon_loss = np.mean(recon_loss_l)
        total_kl_loss = np.mean(kl_loss_l)
        total_bow_loss_l = np.mean(bow_loss_l)
        perplexity = safe_exp(
            np.sum(recon_loss_l) * self.batch_size / word_count)

        bleu_score, precisions, bp, ratio, translation_length, reference_length = compute_bleu(
            reference_corpus, generation_corpus)
        for i in range(len(precisions)):
            precisions[i] *= 100

        return (total_recon_loss, total_kl_loss, total_bow_loss_l, perplexity,
                bleu_score * 100, precisions, generation_corpus)
예제 #9
0
 def get_evaluation(self, sess, batch):
     #decoder_logits_train is BxWxV , take max over V
     pred_dist = tf.nn.softmax(model.decoder_logits_train)
     #take max over the vocab to get the predicted words
     translation_corpus = tf.argmax(pred_dist, dimension=2)
     #reference corpus  BxW
     gold_dist = model.decoder_target
     reference_corpus = gold_dist
     #reference_corpus = tf.argmax(gold_dist, dimension = 2)
     e = bleu.compute_bleu(reference_corpus, translation_corpus)
     return e
예제 #10
0
    def test_compute_bleu_large(self):
        paranmt_path = 'data/paranmt.txt'
        reflists = []
        hyps = []
        with open(paranmt_path) as f:
            for line in f:
                p1, p2 = line.split('\t')
                reflists.append([p1.split()])
                hyps.append(p2.split())

        bleu = compute_bleu(reflists, hyps)
        gold = nltk.translate.bleu_score.corpus_bleu(reflists, hyps)
        self.assertAlmostEqual(bleu, gold, places=10)
예제 #11
0
 def __call__(self, trainer):
     with chainer.no_backprop_mode():
         refs = []
         hyps = []
         for i in range(0, len(self.test_data), self.batch):
             srcs, tgts = zip(*self.test_data[i:i + self.batch])
             refs.extend([[t.tolist()] for t in tgts])
             srcs = [
                 chainer.dataset.to_device(self.device, x) for x in srcs
             ]
             oys = self.model.translate(srcs, self.maxlen)
             hyps.extend(oys)
     sbleu = compute_bleu(refs, hyps, smooth=True)[0]
     chainer.report({self.key: sbleu})
def calc_bleu_scores(gold_sentences_file_name, rec_sentences_file_name, n=2):
    original_s = []
    restored_s = []
    with open(gold_sentences_file_name, 'r') as g_f, open(rec_sentences_file_name, 'r') as rec_f:
        for gold_rec_sent in zip(g_f, rec_f):
            gold_sent, rec_sent = gold_rec_sent
            gold_sent = gold_sent.strip().split(' ')
            rec_sent = rec_sent.strip().split(' ')
            
            original_s.append([gold_sent])
            restored_s.append(rec_sent)
    bleu_score = bleu.compute_bleu(original_s, restored_s, max_order=n, smooth=False)
    #b_score = nltk.translate.bleu_score.corpus_bleu(original_s, restored_s)
    #print(b_score)
    print('BLEU:', bleu_score[0]*100)
예제 #13
0
    def test_compute_bleu_large(self):
        paranmt_path = 'data/paranmt.txt'
        reflists = []
        hyps = []
        with open(paranmt_path, encoding="utf8") as f:
            for line in f:
                p1, p2 = line.split('\t')
                reflists.append([p1.split()])
                hyps.append(p2.split())

        print("Large test")
        bleu = compute_bleu(reflists, hyps)
        print("BLEU", bleu)
        gold = nltk.translate.bleu_score.corpus_bleu(reflists, hyps)
        print("Gold", gold)
        self.assertAlmostEqual(bleu, gold, places=10)
예제 #14
0
    def compute_internal_bleu_score(self, path_id):
        from bleu import compute_bleu
        insts = self.gt[str(path_id)]['instructions']
        num_insts = len(insts)
        bleus = []
        for i in range(0, num_insts):
            insts = self.gt[str(path_id)]['instructions'].copy()
            candidate = insts[i]
            insts.remove(insts[i])
            refs = [self.tok.split_sentence(inst) for inst in insts]
            tup = compute_bleu([refs],
                               self.tok.split_sentence(candidate),
                               smooth=False)

            bleus.append(tup[0])
        return np.mean(bleus)
예제 #15
0
 def test_bleu_multi_reference(self):
     hypothesis = ['It', 'is', 'a', 'guide', 'to', 'action', 'which',
                   'ensures', 'that', 'the', 'military', 'always',
                   'obeys', 'the', 'commands', 'of', 'the', 'party']
     refa = ['It', 'is', 'a', 'guide', 'to', 'action', 'that',
             'ensures', 'that', 'the', 'military', 'will', 'forever',
             'heed', 'Party', 'commands']
     refb = ['It', 'is', 'the', 'guiding', 'principle', 'which',
             'guarantees', 'the', 'military', 'forces', 'always',
             'being', 'under', 'the', 'command', 'of', 'the', 'Party']
     refc = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the',
             'army', 'always', 'to', 'heed', 'the', 'directions',
             'of', 'the', 'party']
     references = [refa, refb, refc]
     score = bleu.compute_bleu([references], [hypothesis])
     self.assertAlmostEqual(score, 0.50456667)
예제 #16
0
    def bleu_score(self, path2inst):
        from bleu import compute_bleu
        refs = []
        candidates = []
        for path_id, inst in path2inst.items():
            path_id = str(path_id)
            assert path_id in self.gt
            # There are three references
            refs.append([self.tok.split_sentence(sent) for sent in self.gt[path_id]['instructions']])
            candidates.append([self.tok.index_to_word[word_id] for word_id in inst])

        tuple = compute_bleu(refs, candidates, smooth=False)
        bleu_score = tuple[0]
        precisions = tuple[1]

        return bleu_score, precisions
예제 #17
0
def main():
    """ Computes the BLEU score for input and output corpora
    Args:
      sys.argv[1]: the reference corpus.
      sys.argv[2]: the translation output.
    Returns:
      prints the output of the `bleu.py` script.
    """
    ref_file = sys.argv[1]
    tra_file = sys.argv[2]
    print('Reference corpus: ', ref_file)
    print('Translation: ', tra_file)
    with open(ref_file, encoding='utf-8') as ref_fis, open(tra_file, encoding='utf-8') as tra_fis:
        ref_tokenized = list(map(lambda s: [list(word_tokenize(s))], ref_fis.readlines()))
        tra_tokenized = list(map(word_tokenize, tra_fis.readlines()))

        print(compute_bleu(ref_tokenized, tra_tokenized))
예제 #18
0
 def _computeBleu(self, generatedSequences, generatedSequenceLengths,
                  targetOutput, targetOutputLengths):
     init = tf.global_variables_initializer()
     translations = []
     with tf.Session() as sess:
         sess.run(init)
         try:
             while True:
                 translations.append(
                     sess.run([
                         targetOutput, targetOutputLengths,
                         generatedSequences, generatedSequenceLengths
                     ]))
         except tf.errors.OutOfRangeError:
             #this is how tensorFlow detects the end of file
             pass
     from bleu import compute_bleu
     print(compute_bleu(translations))
예제 #19
0
def evaluate(model, loss_function, data_iter, max_len=MAX_LEN, epsilon=0.0005):
    """Computes loss on validation data.

    Returns:
        The loss on the dataset, a list of losses for each batch.
    """

    loss = 0
    loss_log = []
    num_instances = 0

    captions = []
    references = []

    with torch.no_grad():
        # set the network to evaluation mode
        model.eval()
    
        for batch in data_iter:
            i, f, t, batch_size = batch
            i, f, t = i.to(DEVICE), f.to(DEVICE), t.to(DEVICE)
            y, att_w = model(i, f, None, max_len=max_len)
            y = y.permute(1, 2, 0)
            t = t.squeeze(2).permute(1, 0)
        
            l, _ = loss_func(loss_function, y, t, att_w, epsilon)

            loss += l.item()
            loss_log.append(l.item() / batch_size)
            num_instances += batch_size

            # decode
            y = y.permute(0, 2, 1)
            _, topi = y.topk(1, dim=2)
            topi = topi.detach().squeeze(2)
            t = t.detach()

            for j in range(batch_size):
                captions.append(data_iter.vocab_tgt.tensor_to_sentence(topi[j]))
                references.append([data_iter.vocab_tgt.tensor_to_sentence(t[j])])

    bleu = compute_bleu(reference_corpus=references, translation_corpus=captions)[0]
    return (loss / num_instances), loss_log, bleu
예제 #20
0
def compute_bleu_for_model(model,
                           sess,
                           inp_voc,
                           out_voc,
                           src_val,
                           dst_val,
                           model_name,
                           config,
                           max_len=200):
    src_val_ix = inp_voc.tokenize_many(src_val)

    inp = tf.placeholder(tf.int32, [None, None])
    translations = []

    if model_name == 'gnmt':
        sy_translations = model.symbolic_translate(inp, greedy=True)[0]
    elif model_name == 'transformer':
        sy_translations = model.symbolic_translate(inp,
                                                   mode='greedy',
                                                   max_len=max_len,
                                                   back_prop=False,
                                                   swap_memory=True).best_out
    else:
        raise NotImplemented("Unknown model")

    for batch in iterate_minibatches(src_val_ix,
                                     batchsize=config.get(
                                         'batch_size_for_inference', 64)):

        translations += sess.run([sy_translations],
                                 feed_dict={inp: batch[0][:, :max_len]
                                            })[0].tolist()

    outputs = out_voc.detokenize_many(translations, unbpe=True, deprocess=True)
    outputs = [out.split() for out in outputs]

    targets = out_voc.remove_bpe_many(dst_val)
    targets = [[t.split()] for t in targets]

    bleu = compute_bleu(targets, outputs)[0]

    return bleu
예제 #21
0
    def test_compute_bleu_small(self):
        h1 = ('It is a guide to action which ensures that the military always '
              'obeys the commands of the party').split()
        r1a = ('It is a guide to action that ensures that the military will '
               'forever heed Party commands').split()
        r1b = ('It is the guiding principle which guarantees the military '
               'forces always being under the command of the Party').split()
        r1c = ('It is the practical guide for the army always to heed the '
               'directions of the party').split()

        h2 = ('he read the book because he was interested in world '
              'history').split()
        r2a = ('he was interested in world history because he read the '
               'book').split()

        hyps = [h1, h2]
        reflists = [[r1a, r1b, r1c], [r2a]]

        bleu = compute_bleu(reflists, hyps)
        gold = nltk.translate.bleu_score.corpus_bleu(reflists, hyps)
        self.assertAlmostEqual(bleu, gold, places=10)
예제 #22
0
def main():
    opts = vars(parse_args())

    # Load model's parameters and a vocabulary
    prefix = os.path.join(os.path.dirname(opts['model']), 
                          os.path.basename(opts['model']).split('-')[0])
    params = pickle.load(open('{}.opts'.format(prefix), 'br'))
    svocab = [w.rstrip() for w in open('{}.svocab'.format(prefix), 'r')]
    tvocab = [w.rstrip() for w in open('{}.tvocab'.format(prefix), 'r')]

    # Setup models
    svocab_size = len(svocab) + 3 # 3 means number of special tags such as
    tvocab_size = len(tvocab) + 3 # "UNK", "BOS", and "EOS" 
    model = Seq2seq(svocab, tvocab, params)
    serializers.load_npz(opts['model'], model)
    if opts['gpuid'] >= 0:
        cuda.get_device(opts['gpuid']).use()
        model.to_gpu(opts['gpuid'])

    # Setup a data
    test_src = word2id(opts['src'], svocab)
    test_tgt = word2id(opts['tgt'], tvocab)
    test_data = [(s, t) for s, t in zip(test_src, test_tgt)]

    # Translating
    id2word = ['UNK', 'BOS', 'EOS'] + tvocab
    references = []
    translations = []
    if opts['beamsize'] < 2:
        for i in range(0, len(test_data), opts['batchsize']):
            srcs, _, refs = encdec_convert(test_data[i:i+opts['batchsize']], opts['gpuid'])
            hyps = model.translate(srcs, opts['maxlen'])
            for hyp in hyps:
                out = ' '.join([id2word[i] for i in hyp])
                print(out)
            references += [[[i for i in ref if i != IGNORE]] for ref in refs.tolist()]
            translations += [hyp for hyp in hyps]

    bleu = compute_bleu(references, translations, smooth=True)[0]
    print(bleu)
예제 #23
0
파일: utils.py 프로젝트: zouning68/RNN-NMT
def _bleu(ref_file, trans_file):
    max_order = 4
    smooth = False

    ref_files = [ref_file]
    reference_text = []
    for reference_filename in ref_files:
        with codecs.getreader("utf-8")(tf.gfile.GFile(reference_filename, "rb")) as fh:
            reference_text.append(fh.readlines())

    per_segment_references = []
    for references in zip(*reference_text):
        reference_list = []
        for reference in references:
            reference_list.append(reference.split(" "))
        per_segment_references.append(reference_list)

    translations = []
    with codecs.getreader("utf-8")(tf.gfile.GFile(trans_file, "rb")) as fh:
        for line in fh:
            translations.append(line.split(" "))

    bleu_score, _, _, _, _, _ = bleu.compute_bleu(per_segment_references, translations, max_order, smooth)
    return 100 * bleu_score
예제 #24
0
def inference():
    """inference function."""
    logging.info('Inference on test_dataset!')

    # data prepare
    test_data_loader = dataprocessor.get_dataloader(data_test, args,
                                                    dataset_type='test',
                                                    use_average_length=True)

    if args.bleu == 'tweaked':
        bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY')
        split_compound_word = bpe
        tokenized = True
    elif args.bleu == '13a' or args.bleu == 'intl':
        bpe = False
        split_compound_word = False
        tokenized = False
    else:
        raise NotImplementedError

    translation_out = []
    all_inst_ids = []
    total_wc = 0
    total_time = 0
    batch_total_blue = 0

    for batch_id, (src_seq, tgt_seq, src_test_length, tgt_test_length, inst_ids) \
            in enumerate(test_data_loader):

        total_wc += src_test_length.sum().asscalar() + tgt_test_length.sum().asscalar()

        src_seq = src_seq.as_in_context(ctx[0])
        tgt_seq = tgt_seq.as_in_context(ctx[0])
        src_test_length = src_test_length.as_in_context(ctx[0])
        tgt_test_length = tgt_test_length.as_in_context(ctx[0])
        all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist())

        start = time.time()
        # Translate to get a bleu score
        samples, _, sample_test_length = \
            translator.translate(src_seq=src_seq, src_valid_length=src_test_length)
        total_time += (time.time() - start)

        # generator the translator result for each batch
        max_score_sample = samples[:, 0, :].asnumpy()
        sample_test_length = sample_test_length[:, 0].asnumpy()
        translation_tmp = []
        translation_tmp_sentences = []
        for i in range(max_score_sample.shape[0]):
            translation_tmp.append([tgt_vocab.idx_to_token[ele] for ele in \
                                    max_score_sample[i][1:(sample_test_length[i] - 1)]])

        # detokenizer each translator result
        for _, sentence in enumerate(translation_tmp):
            if args.bleu == 'tweaked':
                translation_tmp_sentences.append(sentence)
                translation_out.append(sentence)
            elif args.bleu == '13a' or args.bleu == 'intl':
                translation_tmp_sentences.append(detokenizer(_bpe_to_words(sentence)))
                translation_out.append(detokenizer(_bpe_to_words(sentence)))
            else:
                raise NotImplementedError

        # generate tgt_sentence for bleu calculation of each batch
        tgt_sen_tmp = [test_tgt_sentences[index] for \
                         _, index in enumerate(inst_ids.asnumpy().astype(np.int32).tolist())]
        batch_test_bleu_score, _, _, _, _ = compute_bleu([tgt_sen_tmp], translation_tmp_sentences,
                                                         tokenized=tokenized, tokenizer=args.bleu,
                                                         split_compound_word=split_compound_word,
                                                         bpe=bpe)
        batch_total_blue += batch_test_bleu_score

        # log for every ten batchs
        if batch_id % 10 == 0 and batch_id != 0:
            batch_ave_bleu = batch_total_blue / 10
            batch_total_blue = 0
            logging.info('batch id={:d}, batch_bleu={:.4f}'
                         .format(batch_id, batch_ave_bleu * 100))

    # reorg translation sentences by inst_ids
    real_translation_out = [None for _ in range(len(all_inst_ids))]
    for ind, sentence in zip(all_inst_ids, translation_out):
        real_translation_out[ind] = sentence

    # get bleu score, n-gram precisions, brevity penalty,  reference length, and translation length
    test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], real_translation_out,
                                               tokenized=tokenized, tokenizer=args.bleu,
                                               split_compound_word=split_compound_word,
                                               bpe=bpe)

    logging.info('Inference at test dataset. \
                 inference bleu={:.4f}, throughput={:.4f}K wps'
                 .format(test_bleu_score * 100, total_wc / total_time / 1000))
예제 #25
0
def main(_):
  vocab = load_vocabulary(FLAGS.data_dir)
  if FLAGS.generating:
    data_reader = DataReader(FLAGS.data_dir, n_reviews=5, generating=True)
  else:
    data_reader = DataReader(FLAGS.data_dir)
  model = Model(total_users=data_reader.total_users, total_items=data_reader.total_items,
                global_rating=data_reader.global_rating, num_factors=FLAGS.num_factors,
                img_dims=[196, 512], vocab_size=len(vocab), word_dim=FLAGS.word_dim,
                lstm_dim=FLAGS.lstm_dim, max_length=FLAGS.max_length, dropout_rate=FLAGS.dropout_rate)

  saver = tf.compat.v1.train.Saver(max_to_keep=10)

  log_file = open('log.txt', 'w')
  test_step = 0

  config = tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement)
  config.gpu_options.allow_growth = True

  with tf.Session(config=config) as sess:
      saver.restore(sess, FLAGS.ckpt_dir)
      print('Model succesfully restored')
      # Testing
      review_gen_corpus = defaultdict(list)
      review_ref_corpus = defaultdict(list)

      photo_bleu_scores = defaultdict(list)
      photo_rouge_scores = defaultdict(list)

      review_bleu_scores = defaultdict(list)
      review_rouge_scores = defaultdict(list)

      sess.run(model.init_metrics)
      for users, items, ratings in data_reader.read_real_test_set(FLAGS.batch_size, rating_only=True):
        test_step += 1

        fd = model.feed_dict(users, items, ratings)
        sess.run(model.update_metrics, feed_dict=fd)

        review_users, review_items, review_ratings, photo_ids, reviews = get_review_data(users, items, ratings,
                                                                                         data_reader.real_test_review)
        img_idx = [data_reader.real_test_id2idx[photo_id] for photo_id in photo_ids]
        images = data_reader.real_test_img_features[img_idx]

        fd = model.feed_dict(users=review_users, items=review_items, images=images)
        _reviews, _alphas, _betas = sess.run([model.sampled_reviews, model.alphas, model.betas], feed_dict=fd)

        gen_reviews = decode_reviews(_reviews, vocab)
        ref_reviews = [decode_reviews(batch_review_normalize(ref), vocab) for ref in reviews]

        if FLAGS.generating:
          for gen, ref in zip(gen_reviews, ref_reviews):
            gen_str = "GENERATED:\n"+" ".join(gen)
            ref_str = "REFERENCE:\n"+" ".join([" ".join(sentence) for sentence in ref])+"\n"
            log_info(log_file,gen_str)
            log_info(log_file,ref_str)

        for user, item, gen, refs in zip(review_users, review_items, gen_reviews, ref_reviews):
          review_gen_corpus[(user, item)].append(gen)
          review_ref_corpus[(user, item)] += refs

          bleu_scores = compute_bleu([refs], [gen], max_order=4, smooth=True)
          for order, score in bleu_scores.items():
            photo_bleu_scores[order].append(score)

          rouge_scores = rouge([gen], refs)
          for metric, score in rouge_scores.items():
            photo_rouge_scores[metric].append(score)

      _mae, _rmse = sess.run([model.mae, model.rmse])
      log_info(log_file, '\nRating prediction results: MAE={:.3f}, RMSE={:.3f}'.format(_mae, _rmse))

      log_info(log_file, '\nReview generation results:')
      log_info(log_file, '- Photo level: BLEU-scores = {:.2f}, {:.2f}, {:.2f}, {:.2f}'.format(
        np.array(photo_bleu_scores[1]).mean() * 100, np.array(photo_bleu_scores[2]).mean() * 100,
        np.array(photo_bleu_scores[3]).mean() * 100, np.array(photo_bleu_scores[4]).mean() * 100))

      for user_item, gen_reviews in review_gen_corpus.items():
        references = [list(ref) for ref in set(tuple(ref) for ref in review_ref_corpus[user_item])]

        user_item_bleu_scores = defaultdict(list)
        for gen in gen_reviews:
          bleu_scores = compute_bleu([references], [gen], max_order=4, smooth=True)
          for order, score in bleu_scores.items():
            user_item_bleu_scores[order].append(score)
        for order, scores in user_item_bleu_scores.items():
          review_bleu_scores[order].append(np.array(scores).mean())

        user_item_rouge_scores = defaultdict(list)
        for gen in gen_reviews:
          rouge_scores = rouge([gen], references)
          for metric, score in rouge_scores.items():
            user_item_rouge_scores[metric].append(score)
        for metric, scores in user_item_rouge_scores.items():
          review_rouge_scores[metric].append(np.array(scores).mean())

      log_info(log_file, '- Review level: BLEU-scores = {:.2f}, {:.2f}, {:.2f}, {:.2f}'.format(
        np.array(review_bleu_scores[1]).mean() * 100, np.array(review_bleu_scores[2]).mean() * 100,
        np.array(review_bleu_scores[3]).mean() * 100, np.array(review_bleu_scores[4]).mean() * 100))

      for metric in ['rouge_1', 'rouge_2', 'rouge_l']:
        log_info(log_file, '- Photo level: {} = {:.2f}, {:.2f}, {:.2f}'.format(
          metric,
          np.array(photo_rouge_scores['{}/p_score'.format(metric)]).mean() * 100,
          np.array(photo_rouge_scores['{}/r_score'.format(metric)]).mean() * 100,
          np.array(photo_rouge_scores['{}/f_score'.format(metric)]).mean() * 100))
        log_info(log_file, '- Review level: {} = {:.2f}, {:.2f}, {:.2f}'.format(
          metric,
          np.array(review_rouge_scores['{}/p_score'.format(metric)]).mean() * 100,
          np.array(review_rouge_scores['{}/r_score'.format(metric)]).mean() * 100,
          np.array(review_rouge_scores['{}/f_score'.format(metric)]).mean() * 100))
예제 #26
0
def train():
    """Training function."""
    trainer = gluon.Trainer(model.collect_params(), args.optimizer,
                            {'learning_rate': args.lr})

    train_data_loader, val_data_loader, test_data_loader \
        = dataprocessor.make_dataloader(data_train, data_val, data_test, args)

    best_valid_bleu = 0.0
    for epoch_id in range(args.epochs):
        log_loss = 0
        log_denom = 0
        log_avg_gnorm = 0
        log_wc = 0
        log_start_time = time.time()
        for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length)\
                in enumerate(train_data_loader):
            # logging.info(src_seq.context) Context suddenly becomes GPU.
            src_seq = src_seq.as_in_context(ctx)
            tgt_seq = tgt_seq.as_in_context(ctx)
            src_valid_length = src_valid_length.as_in_context(ctx)
            tgt_valid_length = tgt_valid_length.as_in_context(ctx)
            with mx.autograd.record():
                out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length,
                               tgt_valid_length - 1)
                loss = loss_function(out, tgt_seq[:, 1:],
                                     tgt_valid_length - 1).mean()
                loss = loss * (tgt_seq.shape[1] - 1)
                log_loss += loss * tgt_seq.shape[0]
                log_denom += (tgt_valid_length - 1).sum()
                loss = loss / (tgt_valid_length - 1).mean()
                loss.backward()
            grads = [p.grad(ctx) for p in model.collect_params().values()]
            gnorm = gluon.utils.clip_global_norm(grads, args.clip)
            trainer.step(1)
            src_wc = src_valid_length.sum().asscalar()
            tgt_wc = (tgt_valid_length - 1).sum().asscalar()
            log_loss = log_loss.asscalar()
            log_denom = log_denom.asscalar()
            log_avg_gnorm += gnorm
            log_wc += src_wc + tgt_wc
            if (batch_id + 1) % args.log_interval == 0:
                wps = log_wc / (time.time() - log_start_time)
                logging.info(
                    '[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, '
                    'throughput={:.2f}K wps, wc={:.2f}K'.format(
                        epoch_id, batch_id + 1, len(train_data_loader),
                        log_loss / log_denom, np.exp(log_loss / log_denom),
                        log_avg_gnorm / args.log_interval, wps / 1000,
                        log_wc / 1000))
                log_start_time = time.time()
                log_loss = 0
                log_denom = 0
                log_avg_gnorm = 0
                log_wc = 0
        valid_loss, valid_translation_out = evaluate(val_data_loader)
        valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences],
                                                    valid_translation_out)
        logging.info(
            '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
            .format(epoch_id, valid_loss, np.exp(valid_loss),
                    valid_bleu_score * 100))
        dataprocessor.write_sentences(
            valid_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_valid_out.txt').format(epoch_id))
        if args.validate_on_test_data:
            test_loss, test_translation_out = evaluate(test_data_loader)
            test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences],
                                                       test_translation_out)
            logging.info(
                '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'
                .format(epoch_id, test_loss, np.exp(test_loss),
                        test_bleu_score * 100))

            dataprocessor.write_sentences(
                test_translation_out,
                os.path.join(args.save_dir,
                             'epoch{:d}_test_out.txt').format(epoch_id))
        if valid_bleu_score > best_valid_bleu:
            best_valid_bleu = valid_bleu_score
            save_path = os.path.join(args.save_dir, 'valid_best.params')
            logging.info('Save best parameters to {}'.format(save_path))
            model.save_parameters(save_path)
        if epoch_id + 1 >= (args.epochs * 2) // 3:
            new_lr = trainer.learning_rate * args.lr_update_factor
            logging.info('Learning rate change to {}'.format(new_lr))
            trainer.set_learning_rate(new_lr)
    if os.path.exists(os.path.join(args.save_dir, 'valid_best.params')):
        model.load_parameters(os.path.join(args.save_dir, 'valid_best.params'))
    valid_loss, valid_translation_out = evaluate(val_data_loader)
    valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences],
                                                valid_translation_out)
    logging.info(
        'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'.
        format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
    test_loss, test_translation_out = evaluate(test_data_loader)
    test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences],
                                               test_translation_out)
    logging.info(
        'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
        format(test_loss, np.exp(test_loss), test_bleu_score * 100))
    dataprocessor.write_sentences(
        valid_translation_out, os.path.join(args.save_dir,
                                            'best_valid_out.txt'))
    dataprocessor.write_sentences(
        test_translation_out, os.path.join(args.save_dir, 'best_test_out.txt'))
예제 #27
0
def train():
    """Training function."""
    trainer = gluon.Trainer(model.collect_params(), args.optimizer, {'learning_rate': args.lr})

    train_data_loader, val_data_loader, test_data_loader \
        = dataprocessor.make_dataloader(data_train, data_val, data_test, args)

    best_valid_bleu = 0.0
    for epoch_id in range(args.epochs):
        log_avg_loss = 0
        log_avg_gnorm = 0
        log_wc = 0
        log_start_time = time.time()
        for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length)\
                in enumerate(train_data_loader):
            # logging.info(src_seq.context) Context suddenly becomes GPU.
            src_seq = src_seq.as_in_context(ctx)
            tgt_seq = tgt_seq.as_in_context(ctx)
            src_valid_length = src_valid_length.as_in_context(ctx)
            tgt_valid_length = tgt_valid_length.as_in_context(ctx)
            with mx.autograd.record():
                out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1)
                loss = loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean()
                loss = loss * (tgt_seq.shape[1] - 1) / (tgt_valid_length - 1).mean()
                loss.backward()
            grads = [p.grad(ctx) for p in model.collect_params().values()]
            gnorm = gluon.utils.clip_global_norm(grads, args.clip)
            trainer.step(1)
            src_wc = src_valid_length.sum().asscalar()
            tgt_wc = (tgt_valid_length - 1).sum().asscalar()
            step_loss = loss.asscalar()
            log_avg_loss += step_loss
            log_avg_gnorm += gnorm
            log_wc += src_wc + tgt_wc
            if (batch_id + 1) % args.log_interval == 0:
                wps = log_wc / (time.time() - log_start_time)
                logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, '
                             'throughput={:.2f}K wps, wc={:.2f}K'
                             .format(epoch_id, batch_id + 1, len(train_data_loader),
                                     log_avg_loss / args.log_interval,
                                     np.exp(log_avg_loss / args.log_interval),
                                     log_avg_gnorm / args.log_interval,
                                     wps / 1000, log_wc / 1000))
                log_start_time = time.time()
                log_avg_loss = 0
                log_avg_gnorm = 0
                log_wc = 0
        valid_loss, valid_translation_out = evaluate(val_data_loader)
        valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out)
        logging.info('[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
                     .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
        test_loss, test_translation_out = evaluate(test_data_loader)
        test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out)
        logging.info('[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'
                     .format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100))
        dataprocessor.write_sentences(valid_translation_out,
                                      os.path.join(args.save_dir,
                                                   'epoch{:d}_valid_out.txt').format(epoch_id))
        dataprocessor.write_sentences(test_translation_out,
                                      os.path.join(args.save_dir,
                                                   'epoch{:d}_test_out.txt').format(epoch_id))
        if valid_bleu_score > best_valid_bleu:
            best_valid_bleu = valid_bleu_score
            save_path = os.path.join(args.save_dir, 'valid_best.params')
            logging.info('Save best parameters to {}'.format(save_path))
            model.save_parameters(save_path)
        if epoch_id + 1 >= (args.epochs * 2) // 3:
            new_lr = trainer.learning_rate * args.lr_update_factor
            logging.info('Learning rate change to {}'.format(new_lr))
            trainer.set_learning_rate(new_lr)
    if os.path.exists(os.path.join(args.save_dir, 'valid_best.params')):
        model.load_parameters(os.path.join(args.save_dir, 'valid_best.params'))
    valid_loss, valid_translation_out = evaluate(val_data_loader)
    valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out)
    logging.info('Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
                 .format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
    test_loss, test_translation_out = evaluate(test_data_loader)
    test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out)
    logging.info('Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'
                 .format(test_loss, np.exp(test_loss), test_bleu_score * 100))
    dataprocessor.write_sentences(valid_translation_out,
                                  os.path.join(args.save_dir, 'best_valid_out.txt'))
    dataprocessor.write_sentences(test_translation_out,
                                  os.path.join(args.save_dir, 'best_test_out.txt'))
예제 #28
0
def train():
    """Training function."""
    trainer = gluon.Trainer(model.collect_params(), args.optimizer,
                            {'learning_rate': args.lr, 'beta2': 0.98, 'epsilon': 1e-9})

    train_data_loader, val_data_loader, test_data_loader \
        = dataprocessor.make_dataloader(data_train, data_val, data_test, args,
                                        use_average_length=True, num_shards=len(ctx))

    if args.bleu == 'tweaked':
        bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY')
        split_compound_word = bpe
        tokenized = True
    elif args.bleu == '13a' or args.bleu == 'intl':
        bpe = False
        split_compound_word = False
        tokenized = False
    else:
        raise NotImplementedError

    best_valid_bleu = 0.0
    step_num = 0
    warmup_steps = args.warmup_steps
    grad_interval = args.num_accumulated
    model.collect_params().setattr('grad_req', 'add')
    average_start = (len(train_data_loader) // grad_interval) * (args.epochs - args.average_start)
    average_param_dict = None
    model.collect_params().zero_grad()
    for epoch_id in range(args.epochs):
        log_avg_loss = 0
        log_wc = 0
        loss_denom = 0
        step_loss = 0
        log_start_time = time.time()
        for batch_id, seqs \
                in enumerate(train_data_loader):
            if batch_id % grad_interval == 0:
                step_num += 1
                new_lr = args.lr / math.sqrt(args.num_units) \
                         * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5))
                trainer.set_learning_rate(new_lr)
            src_wc, tgt_wc, bs = np.sum([(shard[2].sum(), shard[3].sum(), shard[0].shape[0])
                                         for shard in seqs], axis=0)
            src_wc = src_wc.asscalar()
            tgt_wc = tgt_wc.asscalar()
            loss_denom += tgt_wc - bs
            seqs = [[seq.as_in_context(context) for seq in shard]
                    for context, shard in zip(ctx, seqs)]
            Ls = []
            with mx.autograd.record():
                for src_seq, tgt_seq, src_valid_length, tgt_valid_length in seqs:
                    out, _ = model(src_seq, tgt_seq[:, :-1],
                                   src_valid_length, tgt_valid_length - 1)
                    smoothed_label = label_smoothing(tgt_seq[:, 1:])
                    ls = loss_function(out, smoothed_label, tgt_valid_length - 1).sum()
                    Ls.append((ls * (tgt_seq.shape[1] - 1)) / args.batch_size / 100.0)
            for L in Ls:
                L.backward()
            if batch_id % grad_interval == grad_interval - 1 or\
                    batch_id == len(train_data_loader) - 1:
                if average_param_dict is None:
                    average_param_dict = {k: v.data(ctx[0]).copy() for k, v in
                                          model.collect_params().items()}
                trainer.step(float(loss_denom) / args.batch_size / 100.0)
                param_dict = model.collect_params()
                param_dict.zero_grad()
                if step_num > average_start:
                    alpha = 1. / max(1, step_num - average_start)
                    for name, average_param in average_param_dict.items():
                        average_param[:] += alpha * (param_dict[name].data(ctx[0]) - average_param)
            step_loss += sum([L.asscalar() for L in Ls])
            if batch_id % grad_interval == grad_interval - 1 or\
                    batch_id == len(train_data_loader) - 1:
                log_avg_loss += step_loss / loss_denom * args.batch_size * 100.0
                loss_denom = 0
                step_loss = 0
            log_wc += src_wc + tgt_wc
            if (batch_id + 1) % (args.log_interval * grad_interval) == 0:
                wps = log_wc / (time.time() - log_start_time)
                logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
                             'throughput={:.2f}K wps, wc={:.2f}K'
                             .format(epoch_id, batch_id + 1, len(train_data_loader),
                                     log_avg_loss / args.log_interval,
                                     np.exp(log_avg_loss / args.log_interval),
                                     wps / 1000, log_wc / 1000))
                log_start_time = time.time()
                log_avg_loss = 0
                log_wc = 0
        mx.nd.waitall()
        valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0])
        valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out,
                                                    tokenized=tokenized, tokenizer=args.bleu,
                                                    split_compound_word=split_compound_word,
                                                    bpe=bpe)
        logging.info('[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
                     .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
        test_loss, test_translation_out = evaluate(test_data_loader, ctx[0])
        test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out,
                                                   tokenized=tokenized, tokenizer=args.bleu,
                                                   split_compound_word=split_compound_word,
                                                   bpe=bpe)
        logging.info('[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'
                     .format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100))
        dataprocessor.write_sentences(valid_translation_out,
                                      os.path.join(args.save_dir,
                                                   'epoch{:d}_valid_out.txt').format(epoch_id))
        dataprocessor.write_sentences(test_translation_out,
                                      os.path.join(args.save_dir,
                                                   'epoch{:d}_test_out.txt').format(epoch_id))
        if valid_bleu_score > best_valid_bleu:
            best_valid_bleu = valid_bleu_score
            save_path = os.path.join(args.save_dir, 'valid_best.params')
            logging.info('Save best parameters to {}'.format(save_path))
            model.save_parameters(save_path)
        save_path = os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id))
        model.save_parameters(save_path)
    save_path = os.path.join(args.save_dir, 'average.params')
    mx.nd.save(save_path, average_param_dict)
    if args.average_checkpoint:
        for j in range(args.num_averages):
            params = mx.nd.load(os.path.join(args.save_dir,
                                             'epoch{:d}.params'.format(args.epochs - j - 1)))
            alpha = 1. / (j + 1)
            for k, v in model._collect_params_with_prefix().items():
                for c in ctx:
                    v.data(c)[:] += alpha * (params[k].as_in_context(c) - v.data(c))
        save_path = os.path.join(args.save_dir,
                                 'average_checkpoint_{}.params'.format(args.num_averages))
        model.save_parameters(save_path)
    elif args.average_start > 0:
        for k, v in model.collect_params().items():
            v.set_data(average_param_dict[k])
        save_path = os.path.join(args.save_dir, 'average.params')
        model.save_parameters(save_path)
    else:
        model.load_parameters(os.path.join(args.save_dir, 'valid_best.params'), ctx)
    valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0])
    valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out,
                                                tokenized=tokenized, tokenizer=args.bleu, bpe=bpe,
                                                split_compound_word=split_compound_word)
    logging.info('Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
                 .format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
    test_loss, test_translation_out = evaluate(test_data_loader, ctx[0])
    test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out,
                                               tokenized=tokenized, tokenizer=args.bleu, bpe=bpe,
                                               split_compound_word=split_compound_word)
    logging.info('Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'
                 .format(test_loss, np.exp(test_loss), test_bleu_score * 100))
    dataprocessor.write_sentences(valid_translation_out,
                                  os.path.join(args.save_dir, 'best_valid_out.txt'))
    dataprocessor.write_sentences(test_translation_out,
                                  os.path.join(args.save_dir, 'best_test_out.txt'))
예제 #29
0
def train(config, sample_validation_batches):
    source_language = config.get('src_language')
    target_language = config.get('trg_language')
    EOS_token = config.get('EOS_token')
    PAD_token = config.get('PAD_token')
    SOS_token = config.get('SOS_token')
    train_iter = config.get('train_iter')
    val_iter = config.get('val_iter')
    writer_path = config.get('writer_path')
    writer_train_path = get_or_create_dir(writer_path, 'train')
    writer_val_path = get_or_create_dir(writer_path, 'val')
    writer_train = SummaryWriter(log_dir=writer_train_path)
    writer_val = SummaryWriter(log_dir=writer_val_path)
    epochs = config.get('epochs')
    training = config.get('training')
    eval_every = training.get('eval_every')
    sample_every = training.get('sample_every')
    use_attention = config.get('use_attention')
    step = 1
    for epoch in range(epochs):
        print(f'Epoch: {epoch+1}/{epochs}')
        save_weights(config)
        for i, training_batch in enumerate(train_iter):
            loss = train_batch(config, training_batch)
            writer_train.add_scalar('loss', loss, step)

            if step == 1 or step % eval_every == 0:
                val_lengths = 0
                val_losses = 0
                reference_corpus = []
                translation_corpus = []
                for val_batch in val_iter:
                    val_loss, translations = evaluate_batch(config, val_batch)
                    val_lengths += 1
                    val_losses += val_loss
                    val_batch_trg, _ = val_batch.trg
                    _, batch_size = val_batch_trg.shape
                    references = map(
                        lambda i: torch2words(target_language,
                                              val_batch_trg[:, i]),
                        range(batch_size))
                    references = map(
                        lambda words: [
                            list(
                                filter_words(words, SOS_token, EOS_token,
                                             PAD_token))
                        ], references)
                    reference_corpus.extend(references)
                    translations = map(
                        lambda translation: list2words(
                            target_language, translation), translations)
                    translations = map(
                        lambda words: list(
                            filter_words(words, SOS_token, EOS_token, PAD_token
                                         )), translations)
                    translation_corpus.extend(translations)
                bleu = compute_bleu(reference_corpus, translation_corpus)
                val_loss = val_losses / val_lengths
                writer_val.add_scalar('bleu', bleu, step)
                writer_val.add_scalar('loss', val_loss, step)

            if step % sample_every == 0:
                val_batch = sample_validation_batches(1)
                val_batch_src, val_lengths_src = val_batch.src
                val_batch_trg, _ = val_batch.trg
                s0 = val_lengths_src[0].item()
                _, translations, attention_weights = evaluate_batch(
                    config, val_batch, True)
                source_words = torch2words(source_language, val_batch_src[:,
                                                                          0])
                target_words = torch2words(target_language, val_batch_trg[:,
                                                                          0])
                translation_words = list(
                    filter(lambda word: word != PAD_token,
                           list2words(target_language, translations[0])))
                if use_attention and sum(attention_weights.shape) != 0:
                    attention_figure = visualize_attention(
                        source_words[:s0], translation_words,
                        with_cpu(attention_weights))
                    writer_val.add_figure('attention', attention_figure, step)
                text = get_text(source_words, target_words, translation_words,
                                SOS_token, EOS_token, PAD_token)
                writer_val.add_text('translation', text, step)

            step += 1

    save_weights(config)
# TODO this is needed on Windows
# https://stackoverflow.com/questions/41117740/tensorflow-crashes-with-cublas-status-alloc-failed?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    tf.contrib.framework.assign_from_checkpoint_fn(
        args.textgan_filepath, tf.trainable_variables("generator"))(sess)

    z_value = np.random.normal(0, 1, size=(args.num_to_generate,
                                           z_prior_size)).astype(np.float32)

    out_sentence = sess.run(x_generated_ids, feed_dict={z_prior: z_value})

    sentences_as_ids, sentences_as_strs = [], []
    for sentence in out_sentence:
        try:
            take_len = 1 + np.where(sentence == end_of_sentence_id)[0][0]
        except IndexError:
            take_len = len(sentence)
        sentences_as_ids.append(sentence[:take_len].tolist())
        sentences_as_strs.append(" ".join(
            [reversed_dictionary[word_id] for word_id in sentence[:take_len]]))

for sentence in sentences_as_strs:
    print(sentence)
if args.compute_bleu:
    bleu_out = compute_bleu([test_data] * len(sentences_as_ids),
                            sentences_as_ids)
    print(bleu_out)
예제 #31
0
def train():
    """Training function."""
    trainer = gluon.Trainer(model.collect_params(), args.optimizer, {
        'learning_rate': args.lr,
        'beta2': 0.98,
        'epsilon': 1e-9
    })

    train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
                                  btf.Stack(dtype='float32'),
                                  btf.Stack(dtype='float32'))
    test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
                                 btf.Stack(dtype='float32'),
                                 btf.Stack(dtype='float32'), btf.Stack())
    target_val_lengths = list(map(lambda x: x[-1], data_val_lengths))
    target_test_lengths = list(map(lambda x: x[-1], data_test_lengths))
    if args.bucket_scheme == 'constant':
        bucket_scheme = ConstWidthBucket()
    elif args.bucket_scheme == 'linear':
        bucket_scheme = LinearWidthBucket()
    elif args.bucket_scheme == 'exp':
        bucket_scheme = ExpWidthBucket(bucket_len_step=1.2)
    else:
        raise NotImplementedError
    train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths,
                                             batch_size=args.batch_size,
                                             num_buckets=args.num_buckets,
                                             ratio=args.bucket_ratio,
                                             shuffle=True,
                                             use_average_length=True,
                                             num_shards=len(ctx),
                                             bucket_scheme=bucket_scheme)
    logging.info('Train Batch Sampler:\n{}'.format(
        train_batch_sampler.stats()))
    train_data_loader = ShardedDataLoader(data_train,
                                          batch_sampler=train_batch_sampler,
                                          batchify_fn=train_batchify_fn,
                                          num_workers=8)

    val_batch_sampler = FixedBucketSampler(lengths=target_val_lengths,
                                           batch_size=args.test_batch_size,
                                           num_buckets=args.num_buckets,
                                           ratio=args.bucket_ratio,
                                           shuffle=False,
                                           use_average_length=True,
                                           bucket_scheme=bucket_scheme)
    logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats()))
    val_data_loader = DataLoader(data_val,
                                 batch_sampler=val_batch_sampler,
                                 batchify_fn=test_batchify_fn,
                                 num_workers=8)
    test_batch_sampler = FixedBucketSampler(lengths=target_test_lengths,
                                            batch_size=args.test_batch_size,
                                            num_buckets=args.num_buckets,
                                            ratio=args.bucket_ratio,
                                            shuffle=False,
                                            use_average_length=True,
                                            bucket_scheme=bucket_scheme)
    logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats()))
    test_data_loader = DataLoader(data_test,
                                  batch_sampler=test_batch_sampler,
                                  batchify_fn=test_batchify_fn,
                                  num_workers=8)

    if args.bleu == 'tweaked':
        bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY')
        split_compound_word = bpe
        tokenized = True
    elif args.bleu == '13a' or args.bleu == 'intl':
        bpe = False
        split_compound_word = False
        tokenized = False
    else:
        raise NotImplementedError

    best_valid_bleu = 0.0
    step_num = 0
    warmup_steps = args.warmup_steps
    grad_interval = args.num_accumulated
    model.collect_params().setattr('grad_req', 'add')
    average_start = (len(train_data_loader) //
                     grad_interval) * (args.epochs - args.average_start)
    average_param_dict = None
    model.collect_params().zero_grad()
    for epoch_id in range(args.epochs):
        log_avg_loss = 0
        log_wc = 0
        loss_denom = 0
        step_loss = 0
        log_start_time = time.time()
        for batch_id, seqs \
                in enumerate(train_data_loader):
            if batch_id % grad_interval == 0:
                step_num += 1
                new_lr = args.lr / math.sqrt(args.num_units) \
                         * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5))
                trainer.set_learning_rate(new_lr)
            src_wc, tgt_wc, bs = np.sum(
                [(shard[2].sum(), shard[3].sum(), shard[0].shape[0])
                 for shard in seqs],
                axis=0)
            src_wc = src_wc.asscalar()
            tgt_wc = tgt_wc.asscalar()
            loss_denom += tgt_wc - bs
            seqs = [[seq.as_in_context(context) for seq in shard]
                    for context, shard in zip(ctx, seqs)]
            Ls = []
            with mx.autograd.record():
                for src_seq, tgt_seq, src_valid_length, tgt_valid_length in seqs:
                    out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length,
                                   tgt_valid_length - 1)
                    smoothed_label = label_smoothing(tgt_seq[:, 1:])
                    ls = loss_function(out, smoothed_label,
                                       tgt_valid_length - 1).sum()
                    Ls.append((ls * (tgt_seq.shape[1] - 1)) / args.batch_size /
                              100.0)
            for L in Ls:
                L.backward()
            if batch_id % grad_interval == grad_interval - 1 or\
                    batch_id == len(train_data_loader) - 1:
                if average_param_dict is None:
                    average_param_dict = {
                        k: v.data(ctx[0]).copy()
                        for k, v in model.collect_params().items()
                    }
                trainer.step(float(loss_denom) / args.batch_size / 100.0)
                param_dict = model.collect_params()
                param_dict.zero_grad()
                if step_num > average_start:
                    alpha = 1. / max(1, step_num - average_start)
                    for name, average_param in average_param_dict.items():
                        average_param[:] += alpha * (
                            param_dict[name].data(ctx[0]) - average_param)
            step_loss += sum([L.asscalar() for L in Ls])
            if batch_id % grad_interval == grad_interval - 1 or\
                    batch_id == len(train_data_loader) - 1:
                log_avg_loss += step_loss / loss_denom * args.batch_size * 100.0
                loss_denom = 0
                step_loss = 0
            log_wc += src_wc + tgt_wc
            if (batch_id + 1) % (args.log_interval * grad_interval) == 0:
                wps = log_wc / (time.time() - log_start_time)
                logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
                             'throughput={:.2f}K wps, wc={:.2f}K'.format(
                                 epoch_id, batch_id + 1,
                                 len(train_data_loader),
                                 log_avg_loss / args.log_interval,
                                 np.exp(log_avg_loss / args.log_interval),
                                 wps / 1000, log_wc / 1000))
                log_start_time = time.time()
                log_avg_loss = 0
                log_wc = 0
        mx.nd.waitall()
        valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0])
        valid_bleu_score, _, _, _, _ = compute_bleu(
            [val_tgt_sentences],
            valid_translation_out,
            tokenized=tokenized,
            tokenizer=args.bleu,
            split_compound_word=split_compound_word,
            bpe=bpe)
        logging.info(
            '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
            .format(epoch_id, valid_loss, np.exp(valid_loss),
                    valid_bleu_score * 100))
        test_loss, test_translation_out = evaluate(test_data_loader, ctx[0])
        test_bleu_score, _, _, _, _ = compute_bleu(
            [test_tgt_sentences],
            test_translation_out,
            tokenized=tokenized,
            tokenizer=args.bleu,
            split_compound_word=split_compound_word,
            bpe=bpe)
        logging.info(
            '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
            format(epoch_id, test_loss, np.exp(test_loss),
                   test_bleu_score * 100))
        write_sentences(
            valid_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_valid_out.txt').format(epoch_id))
        write_sentences(
            test_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_test_out.txt').format(epoch_id))
        if valid_bleu_score > best_valid_bleu:
            best_valid_bleu = valid_bleu_score
            save_path = os.path.join(args.save_dir, 'valid_best.params')
            logging.info('Save best parameters to {}'.format(save_path))
            model.save_parameters(save_path)
        save_path = os.path.join(args.save_dir,
                                 'epoch{:d}.params'.format(epoch_id))
        model.save_parameters(save_path)
    save_path = os.path.join(args.save_dir, 'average.params')
    mx.nd.save(save_path, average_param_dict)
    if args.average_checkpoint:
        for j in range(args.num_averages):
            params = mx.nd.load(
                os.path.join(args.save_dir,
                             'epoch{:d}.params'.format(args.epochs - j - 1)))
            alpha = 1. / (j + 1)
            for k, v in model._collect_params_with_prefix().items():
                for c in ctx:
                    v.data(c)[:] += alpha * (params[k].as_in_context(c) -
                                             v.data(c))
        save_path = os.path.join(
            args.save_dir,
            'average_checkpoint_{}.params'.format(args.num_averages))
        model.save_parameters(save_path)
    elif args.average_start > 0:
        for k, v in model.collect_params().items():
            v.set_data(average_param_dict[k])
        save_path = os.path.join(args.save_dir, 'average.params')
        model.save_parameters(save_path)
    else:
        model.load_parameters(os.path.join(args.save_dir, 'valid_best.params'),
                              ctx)
    valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0])
    valid_bleu_score, _, _, _, _ = compute_bleu(
        [val_tgt_sentences],
        valid_translation_out,
        tokenized=tokenized,
        tokenizer=args.bleu,
        bpe=bpe,
        split_compound_word=split_compound_word)
    logging.info(
        'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'.
        format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
    test_loss, test_translation_out = evaluate(test_data_loader, ctx[0])
    test_bleu_score, _, _, _, _ = compute_bleu(
        [test_tgt_sentences],
        test_translation_out,
        tokenized=tokenized,
        tokenizer=args.bleu,
        bpe=bpe,
        split_compound_word=split_compound_word)
    logging.info(
        'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
        format(test_loss, np.exp(test_loss), test_bleu_score * 100))
    write_sentences(valid_translation_out,
                    os.path.join(args.save_dir, 'best_valid_out.txt'))
    write_sentences(test_translation_out,
                    os.path.join(args.save_dir, 'best_test_out.txt'))
예제 #32
0
def train():
    """Training function."""
    trainer = gluon.Trainer(model.collect_params(), args.optimizer, {
        'learning_rate': args.lr,
        'beta2': 0.98,
        'epsilon': 1e-9
    })

    train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(),
                                  btf.Stack())
    test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(),
                                 btf.Stack(), btf.Stack())
    target_val_lengths = list(map(lambda x: x[-1], data_val_lengths))
    target_test_lengths = list(map(lambda x: x[-1], data_test_lengths))
    train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths,
                                             batch_size=args.batch_size,
                                             num_buckets=args.num_buckets,
                                             ratio=args.bucket_ratio,
                                             shuffle=True,
                                             use_average_length=True)
    logging.info('Train Batch Sampler:\n{}'.format(
        train_batch_sampler.stats()))
    train_data_loader = DataLoader(data_train,
                                   batch_sampler=train_batch_sampler,
                                   batchify_fn=train_batchify_fn,
                                   num_workers=8)

    val_batch_sampler = FixedBucketSampler(lengths=target_val_lengths,
                                           batch_size=args.test_batch_size,
                                           num_buckets=args.num_buckets,
                                           ratio=args.bucket_ratio,
                                           shuffle=False,
                                           use_average_length=True)
    logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats()))
    val_data_loader = DataLoader(data_val,
                                 batch_sampler=val_batch_sampler,
                                 batchify_fn=test_batchify_fn,
                                 num_workers=8)
    test_batch_sampler = FixedBucketSampler(lengths=target_test_lengths,
                                            batch_size=args.test_batch_size,
                                            num_buckets=args.num_buckets,
                                            ratio=args.bucket_ratio,
                                            shuffle=False,
                                            use_average_length=True)
    logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats()))
    test_data_loader = DataLoader(data_test,
                                  batch_sampler=test_batch_sampler,
                                  batchify_fn=test_batchify_fn,
                                  num_workers=8)
    best_valid_bleu = 0.0
    step_num = 0
    warmup_steps = args.warmup_steps
    grad_interval = args.num_accumulated
    model.collect_params().setattr('grad_req', 'add')
    average_start = (len(train_data_loader) //
                     grad_interval) * (args.epochs - args.average_start)
    average_param_dict = None
    model.collect_params().zero_grad()
    for epoch_id in range(args.epochs):
        log_avg_loss = 0
        log_wc = 0
        loss_denom = 0
        step_loss = 0
        log_start_time = time.time()
        for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length) \
                in enumerate(train_data_loader):
            if batch_id % grad_interval == 0:
                step_num += 1
                new_lr = args.lr / math.sqrt(args.num_units) \
                         * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5))
                trainer.set_learning_rate(new_lr)
            # logging.info(src_seq.context) Context suddenly becomes GPU.
            src_wc = src_valid_length.sum().asscalar()
            tgt_wc = tgt_valid_length.sum().asscalar()
            loss_denom += tgt_wc - tgt_valid_length.shape[0]
            if src_seq.shape[0] > len(ctx):
                src_seq_list, tgt_seq_list, src_valid_length_list, tgt_valid_length_list \
                    = [gluon.utils.split_and_load(seq, ctx, batch_axis=0, even_split=False)
                       for seq in [src_seq, tgt_seq, src_valid_length, tgt_valid_length]]
            else:
                src_seq_list = [src_seq.as_in_context(ctx[0])]
                tgt_seq_list = [tgt_seq.as_in_context(ctx[0])]
                src_valid_length_list = [
                    src_valid_length.as_in_context(ctx[0])
                ]
                tgt_valid_length_list = [
                    tgt_valid_length.as_in_context(ctx[0])
                ]

            Ls = []
            with mx.autograd.record():
                for src_seq, tgt_seq, src_valid_length, tgt_valid_length \
                        in zip(src_seq_list, tgt_seq_list,
                               src_valid_length_list, tgt_valid_length_list):
                    out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length,
                                   tgt_valid_length - 1)
                    smoothed_label = label_smoothing(tgt_seq[:, 1:])
                    ls = loss_function(out, smoothed_label,
                                       tgt_valid_length - 1).sum()
                    Ls.append((ls * (tgt_seq.shape[1] - 1)) / args.batch_size)
            for L in Ls:
                L.backward()
            if batch_id % grad_interval == grad_interval - 1 or\
                    batch_id == len(train_data_loader) - 1:
                if average_param_dict is None:
                    average_param_dict = {
                        k: v.data(ctx[0]).copy()
                        for k, v in model.collect_params().items()
                    }
                trainer.step(float(loss_denom) / args.batch_size)
                param_dict = model.collect_params()
                param_dict.zero_grad()
                if step_num > average_start:
                    alpha = 1. / max(1, step_num - average_start)
                    for name, average_param in average_param_dict.items():
                        average_param[:] += alpha * (
                            param_dict[name].data(ctx[0]) - average_param)
            step_loss += sum([L.asscalar() for L in Ls])
            if batch_id % grad_interval == grad_interval - 1 or\
                    batch_id == len(train_data_loader) - 1:
                log_avg_loss += step_loss / loss_denom * args.batch_size
                loss_denom = 0
                step_loss = 0
            log_wc += src_wc + tgt_wc
            if (batch_id + 1) % (args.log_interval * grad_interval) == 0:
                wps = log_wc / (time.time() - log_start_time)
                logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
                             'throughput={:.2f}K wps, wc={:.2f}K'.format(
                                 epoch_id, batch_id + 1,
                                 len(train_data_loader),
                                 log_avg_loss / args.log_interval,
                                 np.exp(log_avg_loss / args.log_interval),
                                 wps / 1000, log_wc / 1000))
                log_start_time = time.time()
                log_avg_loss = 0
                log_wc = 0
        mx.nd.waitall()
        valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0])
        valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences],
                                                    valid_translation_out,
                                                    bpe=True,
                                                    split_compound_word=True)
        logging.info(
            '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
            .format(epoch_id, valid_loss, np.exp(valid_loss),
                    valid_bleu_score * 100))
        test_loss, test_translation_out = evaluate(test_data_loader, ctx[0])
        test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences],
                                                   test_translation_out,
                                                   bpe=True,
                                                   split_compound_word=True)
        logging.info(
            '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
            format(epoch_id, test_loss, np.exp(test_loss),
                   test_bleu_score * 100))
        write_sentences(
            valid_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_valid_out.txt').format(epoch_id))
        write_sentences(
            test_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_test_out.txt').format(epoch_id))
        if valid_bleu_score > best_valid_bleu:
            best_valid_bleu = valid_bleu_score
            save_path = os.path.join(args.save_dir, 'valid_best.params')
            logging.info('Save best parameters to {}'.format(save_path))
            model.save_params(save_path)
        save_path = os.path.join(args.save_dir,
                                 'epoch{:d}.params'.format(epoch_id))
        model.save_params(save_path)
    save_path = os.path.join(args.save_dir, 'average.params')
    mx.nd.save(save_path, average_param_dict)
    if args.average_checkpoint:
        for j in range(args.num_averages):
            params = mx.nd.load(
                os.path.join(args.save_dir,
                             'epoch{:d}.params'.format(args.epochs - j - 1)))
            alpha = 1. / (j + 1)
            for k, v in model._collect_params_with_prefix().items():
                for c in ctx:
                    v.data(c)[:] += alpha * (params[k].as_in_context(c) -
                                             v.data(c))
    elif args.average_start > 0:
        for k, v in model.collect_params().items():
            v.set_data(average_param_dict[k])
    else:
        model.load_params(os.path.join(args.save_dir, 'valid_best.params'),
                          ctx)
    valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0])
    valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences],
                                                valid_translation_out,
                                                bpe=True,
                                                split_compound_word=True)
    logging.info(
        'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'.
        format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
    test_loss, test_translation_out = evaluate(test_data_loader, ctx[0])
    test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences],
                                               test_translation_out,
                                               bpe=True,
                                               split_compound_word=True)
    logging.info(
        'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
        format(test_loss, np.exp(test_loss), test_bleu_score * 100))
    write_sentences(valid_translation_out,
                    os.path.join(args.save_dir, 'best_valid_out.txt'))
    write_sentences(test_translation_out,
                    os.path.join(args.save_dir, 'best_test_out.txt'))