Example #1
0
    def _eval_epoch(sess, epoch, mode):
        if mode == 'eval':
            eval_data = dev_data
        elif mode == 'test':
            eval_data = test_data
        else:
            raise ValueError('`mode` should be either "eval" or "test".')

        references, hypotheses = [], []
        bsize = config_data.test_batch_size
        for i in range(0, len(eval_data), bsize):
            #print("eval {}/{}".format(i, len(eval_data)))
            sources, targets = zip(*eval_data[i:i + bsize])
            x_block = data_utils.source_pad_concat_convert(sources)
            feed_dict = {
                encoder_input: x_block,
                tx.global_mode(): tf.estimator.ModeKeys.EVAL,
            }
            fetches = {
                'inferred_ids': inferred_ids,
            }
            fetches_ = sess.run(fetches, feed_dict=feed_dict)

            hypotheses.extend(h.tolist() for h in fetches_['inferred_ids'])
            references.extend(r.tolist() for r in targets)
            hypotheses = utils.list_strip_eos(hypotheses, eos_token_id)
            references = utils.list_strip_eos(references, eos_token_id)

        if mode == 'eval':
            # Writes results to files to evaluate BLEU
            # For 'eval' mode, the BLEU is based on token ids (rather than
            # text tokens) and serves only as a surrogate metric to monitor
            # the training process
            fname = os.path.join(FLAGS.model_dir, 'tmp.eval')
            hypotheses = tx.utils.str_join(hypotheses)
            references = tx.utils.str_join(references)
            hyp_fn, ref_fn = tx.utils.write_paired_text(hypotheses,
                                                        references,
                                                        fname,
                                                        mode='s')
            eval_bleu = bleu_wrapper(ref_fn, hyp_fn, case_sensitive=True)
            eval_bleu = 100. * eval_bleu
            logger.info('epoch: %d, eval_bleu %.4f', epoch, eval_bleu)
            print('epoch: %d, eval_bleu %.4f' % (epoch, eval_bleu))

            if eval_bleu > best_results['score']:
                logger.info('epoch: %d, best bleu: %.4f', epoch, eval_bleu)
                best_results['score'] = eval_bleu
                best_results['epoch'] = epoch
                model_path = os.path.join(FLAGS.model_dir, 'best-model.ckpt')
                logger.info('saving model to %s', model_path)
                print('saving model to %s' % model_path)
                saver.save(sess, model_path)

        elif mode == 'test':
            # For 'test' mode, together with the cmds in README.md, BLEU
            # is evaluated based on text tokens, which is the standard metric.
            fname = os.path.join(FLAGS.model_dir, 'test.output')
            hwords, rwords = [], []
            for hyp, ref in zip(hypotheses, references):
                hwords.append([id2w[y] for y in hyp])
                rwords.append([id2w[y] for y in ref])
            hwords = tx.utils.str_join(hwords)
            rwords = tx.utils.str_join(rwords)
            hyp_fn, ref_fn = tx.utils.write_paired_text(hwords,
                                                        rwords,
                                                        fname,
                                                        mode='s')
            logger.info('Test output writtn to file: %s', hyp_fn)
            print('Test output writtn to file: %s' % hyp_fn)
    def _eval_epoch(epoch, mode):
        torch.cuda.empty_cache()
        if mode == 'eval':
            eval_data = dev_data
        elif mode == 'test':
            eval_data = test_data
        else:
            raise ValueError("`mode` should be either \"eval\" or \"test\".")

        references, hypotheses = [], []
        bsize = config_data.test_batch_size
        for i in tqdm(range(0, len(eval_data), bsize)):
            sources, targets = zip(*eval_data[i:i + bsize])
            with torch.no_grad():
                x_block = data_utils.source_pad_concat_convert(
                    sources, device=device)
                predictions = model(
                    encoder_input=x_block,
                    is_train_mode=False,
                    beam_width=beam_width)
                if beam_width == 1:
                    decoded_ids = predictions[0].sample_id
                else:
                    decoded_ids = predictions["sample_id"][:, :, 0]

                hypotheses.extend(h.tolist() for h in decoded_ids)
                references.extend(r.tolist() for r in targets)
                hypotheses = utils.list_strip_eos(hypotheses, eos_token_id)
                references = utils.list_strip_eos(references, eos_token_id)

        if mode == 'eval':
            # Writes results to files to evaluate BLEU
            # For 'eval' mode, the BLEU is based on token ids (rather than
            # text tokens) and serves only as a surrogate metric to monitor
            # the training process
            # TODO: Use texar.evals.bleu
            fname = os.path.join(args.model_dir, 'tmp.eval')
            hwords, rwords = [], []
            for hyp, ref in zip(hypotheses, references):
                hwords.append([str(y) for y in hyp])
                rwords.append([str(y) for y in ref])
            hwords = tx.utils.str_join(hwords)
            rwords = tx.utils.str_join(rwords)
            hyp_fn, ref_fn = tx.utils.write_paired_text(
                hwords, rwords, fname, mode='s',
                src_fname_suffix='hyp', tgt_fname_suffix='ref')
            eval_bleu = bleu_wrapper(ref_fn, hyp_fn, case_sensitive=True)
            eval_bleu = 100. * eval_bleu
            logger.info("epoch: %d, eval_bleu %.4f", epoch, eval_bleu)
            print(f"epoch: {epoch:d}, eval_bleu {eval_bleu:.4f}")

            if eval_bleu > best_results['score']:
                logger.info("epoch: %d, best bleu: %.4f", epoch, eval_bleu)
                best_results['score'] = eval_bleu
                best_results['epoch'] = epoch
                model_path = os.path.join(args.model_dir, args.model_fn)
                logger.info("Saving model to %s", model_path)
                print(f"Saving model to {model_path}")

                states = {
                    'model': model.state_dict(),
                    'optimizer': optim.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }
                torch.save(states, model_path)

        elif mode == 'test':
            # For 'test' mode, together with the cmds in README.md, BLEU
            # is evaluated based on text tokens, which is the standard metric.
            fname = os.path.join(args.model_dir, 'test.output')
            hwords, rwords = [], []
            for hyp, ref in zip(hypotheses, references):
                hwords.append([id2w[y] for y in hyp])
                rwords.append([id2w[y] for y in ref])
            hwords = tx.utils.str_join(hwords)
            rwords = tx.utils.str_join(rwords)
            hyp_fn, ref_fn = tx.utils.write_paired_text(
                hwords, rwords, fname, mode='s',
                src_fname_suffix='hyp', tgt_fname_suffix='ref')
            logger.info("Test output written to file: %s", hyp_fn)
            print(f"Test output written to file: {hyp_fn}")