Exemple #1
0
    def __init__(self):
        """
        Init whatever you need here
        """
        vocab_file = 'data/vocab.txt'
        with codecs.open(vocab_file, 'r', 'utf-8') as f:
            vocab = [i.strip() for i in f.readlines() if len(i.strip()) != 0]
        self.vocab = vocab
        self.freqs = dict(zip(self.vocab[::-1], range(len(self.vocab))))

        # Our code are as follows
        config = Config()
        torch.cuda.set_device(device=config.gpu)
        self.config = config

        # Data definition
        self.corpus = KnowledgeCorpus(data_dir=config.data_dir,
                                      data_prefix=config.data_prefix,
                                      min_freq=0,
                                      max_vocab_size=config.max_vocab_size,
                                      vocab_file=config.vocab_file,
                                      min_len=config.min_len,
                                      max_len=config.max_len,
                                      embed_file=config.embed_file,
                                      share_vocab=config.share_vocab)
        # Model definition
        self.model = Seq2Seq(src_vocab_size=self.corpus.SRC.vocab_size,
                             tgt_vocab_size=self.corpus.TGT.vocab_size,
                             embed_size=config.embed_size,
                             hidden_size=config.hidden_size,
                             padding_idx=self.corpus.padding_idx,
                             num_layers=config.num_layers,
                             bidirectional=config.bidirectional,
                             attn_mode=config.attn,
                             with_bridge=config.with_bridge,
                             tie_embedding=config.tie_embedding,
                             dropout=config.dropout,
                             use_gpu=config.use_gpu)
        print(self.model)
        self.model.load(config.ckpt)

        # Generator definition
        self.generator = TopKGenerator(model=self.model,
                                       src_field=self.corpus.SRC,
                                       tgt_field=self.corpus.TGT,
                                       cue_field=self.corpus.CUE,
                                       beam_size=config.beam_size,
                                       max_length=config.max_dec_len,
                                       ignore_unk=config.ignore_unk,
                                       length_average=config.length_average,
                                       use_gpu=config.use_gpu)
        self.BOS = self.generator.BOS
        self.EOS = self.generator.EOS
        self.stoi = self.corpus.SRC.stoi
        self.itos = self.corpus.SRC.itos
def main():

    config = model_config()
    if config.check:
        config.save_dir = "./tmp/"
    config.use_gpu = torch.cuda.is_available() and config.gpu >= 0
    device = config.gpu
    torch.cuda.set_device(device)
    # Data definition
    corpus = KnowledgeCorpus(data_dir=config.data_dir, data_prefix=config.data_prefix,
                             min_freq=3, max_vocab_size=config.max_vocab_size,
                             min_len=config.min_len, max_len=config.max_len,
                             embed_file=config.embed_file, with_label=config.with_label,
                             share_vocab=config.share_vocab)
    corpus.load()
    if config.test and config.ckpt:
        corpus.reload(data_type='test')
    train_iter = corpus.create_batches(
        config.batch_size, "train", shuffle=True, device=device)
    valid_iter = corpus.create_batches(
        config.batch_size, "valid", shuffle=False, device=device)
    test_iter = corpus.create_batches(
        config.batch_size, "test", shuffle=False, device=device)
    # Model definition
    model = KnowledgeSeq2Seq(src_vocab_size=corpus.SRC.vocab_size,
                             tgt_vocab_size=corpus.TGT.vocab_size,
                             embed_size=config.embed_size, hidden_size=config.hidden_size,
                             padding_idx=corpus.padding_idx,
                             num_layers=config.num_layers, bidirectional=config.bidirectional,
                             attn_mode=config.attn, with_bridge=config.with_bridge,
                             tie_embedding=config.tie_embedding, dropout=config.dropout,
                             use_gpu=config.use_gpu, 
                             use_bow=config.use_bow, use_dssm=config.use_dssm,
                             use_pg=config.use_pg, use_gs=config.use_gs,
                             pretrain_epoch=config.pretrain_epoch,
                             use_posterior=config.use_posterior,
                             weight_control=config.weight_control,
                             concat=config.decode_concat)
    model_name = model.__class__.__name__
    # Generator definition
    generator = TopKGenerator(model=model,
                              src_field=corpus.SRC, tgt_field=corpus.TGT, cue_field=corpus.CUE,
                              max_length=config.max_dec_len, ignore_unk=config.ignore_unk, 
			      length_average=config.length_average, use_gpu=config.use_gpu)
    # Interactive generation testing
    if config.interact and config.ckpt:
        model.load(config.ckpt)
        return generator
    # Testing
    elif config.test and config.ckpt:
        print(model)
        model.load(config.ckpt)
        print("Testing ...")
        metrics, scores = evaluate(model, test_iter)
        print(metrics.report_cum())
        print("Generating ...")
        evaluate_generation(generator, test_iter, save_file=config.gen_file, verbos=True)
    else:
        # Load word embeddings
        if config.use_embed and config.embed_file is not None:
            model.encoder.embedder.load_embeddings(
                corpus.SRC.embeddings, scale=0.03)
            model.decoder.embedder.load_embeddings(
                corpus.TGT.embeddings, scale=0.03)
        # Optimizer definition
        optimizer = getattr(torch.optim, config.optimizer)(
            model.parameters(), lr=config.lr)
        # Learning rate scheduler
        if config.lr_decay is not None and 0 < config.lr_decay < 1.0:
            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, 
                            factor=config.lr_decay, patience=1, verbose=True, min_lr=1e-5)
        else:
            lr_scheduler = None
        # Save directory
        date_str, time_str = datetime.now().strftime("%Y%m%d-%H%M%S").split("-")
        result_str = "{}-{}".format(model_name, time_str)
        if not os.path.exists(config.save_dir):
            os.makedirs(config.save_dir)
        # Logger definition
        logger = logging.getLogger(__name__)
        logging.basicConfig(level=logging.DEBUG, format="%(message)s")
        fh = logging.FileHandler(os.path.join(config.save_dir, "train.log"))
        logger.addHandler(fh)
        # Save config
        params_file = os.path.join(config.save_dir, "params.json")
        with open(params_file, 'w') as fp:
            json.dump(config.__dict__, fp, indent=4, sort_keys=True)
        print("Saved params to '{}'".format(params_file))
        logger.info(model)
        # Train
        logger.info("Training starts ...")
        trainer = Trainer(model=model, optimizer=optimizer, train_iter=train_iter,
                          valid_iter=valid_iter, logger=logger, generator=generator,
                          valid_metric_name="-loss", num_epochs=config.num_epochs,
                          save_dir=config.save_dir, log_steps=config.log_steps,
                          valid_steps=config.valid_steps, grad_clip=config.grad_clip,
                          lr_scheduler=lr_scheduler, save_summary=False)
        if config.ckpt is not None:
            trainer.load(file_prefix=config.ckpt)
        trainer.train()
        logger.info("Training done!")
        # Test
        logger.info("")
        trainer.load(os.path.join(config.save_dir, "best"))
        logger.info("Testing starts ...")
        metrics, scores = evaluate(model, test_iter)
        logger.info(metrics.report_cum())
        logger.info("Generation starts ...")
        test_gen_file = os.path.join(config.save_dir, "test.result")
        evaluate_generation(generator, test_iter, save_file=test_gen_file, verbos=True)
Exemple #3
0
class Model:
    """
    This is an example model. It reads predefined dictionary and predict a fixed distribution.
    For a correct evaluation, each team should implement 3 functions:
    next_word_probability
    gen_response
    """
    def __init__(self):
        """
        Init whatever you need here
        """
        vocab_file = 'data/vocab.txt'
        with codecs.open(vocab_file, 'r', 'utf-8') as f:
            vocab = [i.strip() for i in f.readlines() if len(i.strip()) != 0]
        self.vocab = vocab
        self.freqs = dict(zip(self.vocab[::-1], range(len(self.vocab))))

        # Our code are as follows
        config = Config()
        torch.cuda.set_device(device=config.gpu)
        self.config = config

        # Data definition
        self.corpus = KnowledgeCorpus(data_dir=config.data_dir,
                                      data_prefix=config.data_prefix,
                                      min_freq=0,
                                      max_vocab_size=config.max_vocab_size,
                                      vocab_file=config.vocab_file,
                                      min_len=config.min_len,
                                      max_len=config.max_len,
                                      embed_file=config.embed_file,
                                      share_vocab=config.share_vocab)
        # Model definition
        self.model = Seq2Seq(src_vocab_size=self.corpus.SRC.vocab_size,
                             tgt_vocab_size=self.corpus.TGT.vocab_size,
                             embed_size=config.embed_size,
                             hidden_size=config.hidden_size,
                             padding_idx=self.corpus.padding_idx,
                             num_layers=config.num_layers,
                             bidirectional=config.bidirectional,
                             attn_mode=config.attn,
                             with_bridge=config.with_bridge,
                             tie_embedding=config.tie_embedding,
                             dropout=config.dropout,
                             use_gpu=config.use_gpu)
        print(self.model)
        self.model.load(config.ckpt)

        # Generator definition
        self.generator = TopKGenerator(model=self.model,
                                       src_field=self.corpus.SRC,
                                       tgt_field=self.corpus.TGT,
                                       cue_field=self.corpus.CUE,
                                       beam_size=config.beam_size,
                                       max_length=config.max_dec_len,
                                       ignore_unk=config.ignore_unk,
                                       length_average=config.length_average,
                                       use_gpu=config.use_gpu)
        self.BOS = self.generator.BOS
        self.EOS = self.generator.EOS
        self.stoi = self.corpus.SRC.stoi
        self.itos = self.corpus.SRC.itos

    def next_word_probability(self, context, partial_out):
        """
        Return probability distribution over next words given a partial true output.
        This is used to calculate the per-word perplexity.

        :param context: dict, contexts containing the dialogue history and personal
                        profile of each speaker
                        this dict contains following keys:

                        context['dialog']: a list of string, dialogue histories (tokens in each utterances
                                           are separated using spaces).
                        context['uid']: a list of int, indices to the profile of each speaker
                        context['profile']: a list of dict, personal profiles for each speaker
                        context['responder_profile']: dict, the personal profile of the responder

        :param partial_out: list, previous "true" words
        :return: a list, the first element is a dict, where each key is a word and each value is a probability
                         score for that word. Unset keys assume a probability of zero.
                         the second element is the probability for the EOS token

        e.g.
        context:
        { "dialog": [ ["How are you ?"], ["I am fine , thank you . And you ?"] ],
          "uid": [0, 1],
          "profile":[ { "loc":"Beijing", "gender":"male", "tag":"" },
                      { "loc":"Shanghai", "gender":"female", "tag":"" } ],
          "responder_profile":{ "loc":"Beijing", "gender":"male", "tag":"" }
        }

        partial_out:
        ['I', 'am']

        ==>  {'fine': 0.9}, 0.1
        """
        test_raw = self.read_data(context)
        test_data = self.corpus.build_examples(test_raw, data_type='test')
        dataset = Dataset(test_data)
        data_iter = dataset.create_batches(batch_size=1,
                                           shuffle=False,
                                           device=self.config.gpu)
        inputs = None
        for batch in data_iter:
            inputs = batch
            break

        partial_out_idx = [
            self.stoi[s] if s in self.stoi.keys() else self.stoi['<unk>']
            for s in partial_out
        ]

        # switch the model to evaluate mode
        self.model.eval()
        with torch.no_grad():
            enc_outputs, dec_init_state = self.model.encode(inputs)
            long_tensor_type = torch.cuda.LongTensor if self.config.use_gpu else torch.LongTensor

            # Initialize the input vector
            input_var = long_tensor_type([self.BOS] * 1)
            # Inflate the initial hidden states to be of size: (1, H)
            dec_state = dec_init_state.inflate(1)

            for t in range(len(partial_out_idx)):
                # Run the RNN one step forward
                output, dec_state, attn = self.model.decode(
                    input_var, dec_state)
                input_var = long_tensor_type([partial_out_idx[t]])

            output, dec_state, attn = self.model.decode(input_var, dec_state)
            log_softmax_output = output.squeeze(1)
        log_softmax_output = log_softmax_output.cpu().numpy()
        prob_output = [math.exp(i) for i in log_softmax_output[0]]

        # The first 4 tokens are: '<pad>' '<unk>' '<bos>' '<eos>'
        freq_dict = {}
        for i in range(4, len(self.itos)):
            freq_dict[self.itos[i]] = prob_output[i]
        eos_prob = prob_output[3]
        return freq_dict, eos_prob

    def gen_response(self, contexts):
        """
        Return a list of responses to each context.

        :param contexts: list, a list of context, each context is a dict that contains the dialogue history and personal
                         profile of each speaker
                         this dict contains following keys:

                         context['dialog']: a list of string, dialogue histories (tokens in each utterances
                                            are separated using spaces).
                         context['uid']: a list of int, indices to the profile of each speaker
                         context['profile']: a list of dict, personal profiles for each speaker
                         context['responder_profile']: dict, the personal profile of the responder

        :return: list, responses for each context, each response is a list of tokens.

        e.g.
        contexts:
        [{ "dialog": [ ["How are you ?"], ["I am fine , thank you . And you ?"] ],
          "uid": [0, 1],
          "profile":[ { "loc":"Beijing", "gender":"male", "tag":"" },
                      { "loc":"Shanghai", "gender":"female", "tag":"" } ],
          "responder_profile":{ "loc":"Beijing", "gender":"male", "tag":"" }
        }]

        ==>  [['I', 'am', 'fine', 'too', '!']]
        """
        test_raw = self.read_data(contexts[0])
        test_data = self.corpus.build_examples(test_raw, data_type='test')
        dataset = Dataset(test_data)
        data_iter = dataset.create_batches(batch_size=1,
                                           shuffle=False,
                                           device=self.config.gpu)
        results = self.generator.generate(batch_iter=data_iter)
        res = [result.preds[0].split(" ") for result in results]
        return res

    @staticmethod
    def read_data(dialog):
        history = dialog["dialog"]
        uid = [int(i) for i in dialog["uid"]]
        if "responder_profile" in dialog.keys():
            responder_profile = dialog["responder_profile"]
        elif "response_profile" in dialog.keys():
            responder_profile = dialog["response_profile"]
        else:
            raise ValueError("No responder_profile or response_profile!")

        src = ""
        for idx, sent in zip(uid, history):
            sent_content = sent[0]
            src += sent_content
            src += ' '

        src = src.strip()
        tgt = "NULL"
        filter_knowledge = []
        if type(responder_profile["tag"]) is list:
            filter_knowledge.append(' '.join(
                responder_profile["tag"][0].split(';')))
        else:
            filter_knowledge.append(' '.join(
                responder_profile["tag"].split(';')))
        filter_knowledge.append(responder_profile["loc"])
        data = [{'src': src, 'tgt': tgt, 'cue': filter_knowledge}]
        return data
def main():
    """
    main
    """
    config = model_config()
    if config.check:
        config.save_dir = "./tmp/"
    config.use_gpu = torch.cuda.is_available() and config.gpu >= 0
    device = config.gpu
    torch.cuda.set_device(device)
    # Data definition
    if config.pos:
        corpus = Entity_Corpus_pos(data_dir=config.data_dir, data_prefix=config.data_prefix, entity_file=config.entity_file,
                                     min_freq=config.min_freq, max_vocab_size=config.max_vocab_size)

    else:
        corpus = Entity_Corpus(data_dir=config.data_dir, data_prefix=config.data_prefix,
                                   entity_file=config.entity_file,
                                   min_freq=config.min_freq, max_vocab_size=config.max_vocab_size)

    corpus.load()
    if config.test and config.ckpt:
        corpus.reload(data_type='test')
    train_iter = corpus.create_batches(
        config.batch_size, "train", shuffle=True, device=device)
    valid_iter = corpus.create_batches(
        config.batch_size, "valid", shuffle=False, device=device)
    if config.for_test:
        test_iter = corpus.create_batches(
            config.batch_size, "test", shuffle=False, device=device)
    else:
        test_iter = corpus.create_batches(
            config.batch_size, "valid", shuffle=False, device=device)
    if config.preprocess:
        print('预处理完毕')
        return

    if config.pos:
        if config.rnn_type == 'lstm':
            model = Entity_Seq2Seq_pos(src_vocab_size=corpus.SRC.vocab_size,
                                       pos_vocab_size=corpus.POS.vocab_size,
                                       embed_size=config.embed_size, hidden_size=config.hidden_size,
                                       padding_idx=corpus.padding_idx,
                                       num_layers=config.num_layers, bidirectional=config.bidirectional,
                                       attn_mode=config.attn, with_bridge=config.with_bridge,
                                       dropout=config.dropout,
                                       use_gpu=config.use_gpu,
                                       pretrain_epoch=config.pretrain_epoch)
        else:
            model = Entity_Seq2Seq_pos_gru(src_vocab_size=corpus.SRC.vocab_size,
                                           pos_vocab_size=corpus.POS.vocab_size,
                                           embed_size=config.embed_size, hidden_size=config.hidden_size,
                                           padding_idx=corpus.padding_idx,
                                           num_layers=config.num_layers, bidirectional=config.bidirectional,
                                           attn_mode=config.attn, with_bridge=config.with_bridge,
                                           dropout=config.dropout,
                                           use_gpu=config.use_gpu,
                                           pretrain_epoch=config.pretrain_epoch)
    else:
        if config.rnn_type == 'lstm':
            if config.elmo:
                model = Entity_Seq2Seq_elmo(src_vocab_size=corpus.SRC.vocab_size,
                                            embed_size=config.embed_size, hidden_size=config.hidden_size,
                                            padding_idx=corpus.padding_idx,
                                            num_layers=config.num_layers, bidirectional=config.bidirectional,
                                            attn_mode=config.attn, with_bridge=config.with_bridge,
                                            dropout=config.dropout,
                                            use_gpu=config.use_gpu,
                                            pretrain_epoch=config.pretrain_epoch,
                                            batch_size=config.batch_size)
            else:
                model = Entity_Seq2Seq(src_vocab_size=corpus.SRC.vocab_size,
                                       embed_size=config.embed_size, hidden_size=config.hidden_size,
                                       padding_idx=corpus.padding_idx,
                                       num_layers=config.num_layers, bidirectional=config.bidirectional,
                                       attn_mode=config.attn, with_bridge=config.with_bridge,
                                       dropout=config.dropout,
                                       use_gpu=config.use_gpu,
                                       pretrain_epoch=config.pretrain_epoch)
        else:  # GRU
            if config.elmo:
                model = Entity_Seq2Seq_elmo_gru(src_vocab_size=corpus.SRC.vocab_size,
                                                embed_size=config.embed_size, hidden_size=config.hidden_size,
                                                padding_idx=corpus.padding_idx,
                                                num_layers=config.num_layers, bidirectional=config.bidirectional,
                                                attn_mode=config.attn, with_bridge=config.with_bridge,
                                                dropout=config.dropout,
                                                use_gpu=config.use_gpu,
                                                pretrain_epoch=config.pretrain_epoch,
                                                batch_size=config.batch_size)

    # if config.pos:
    #     if config.rnn_type=='lstm':
    #         if config.elmo:
    #             model = Entity_Seq2Seq_elmo(src_vocab_size=corpus.SRC.vocab_size,
    #                                    embed_size=config.embed_size, hidden_size=config.hidden_size,
    #                                    padding_idx=corpus.padding_idx,
    #                                    num_layers=config.num_layers, bidirectional=config.bidirectional,
    #                                    attn_mode=config.attn, with_bridge=config.with_bridge,
    #                                    dropout=config.dropout,
    #                                    use_gpu=config.use_gpu,
    #                                    pretrain_epoch=config.pretrain_epoch,
    #                                    batch_size=config.batch_size)
    #         else:
    #             model = Entity_Seq2Seq_pos(src_vocab_size=corpus.SRC.vocab_size,
    #                                    pos_vocab_size=corpus.POS.vocab_size,
    #                                    embed_size=config.embed_size, hidden_size=config.hidden_size,
    #                                    padding_idx=corpus.padding_idx,
    #                                    num_layers=config.num_layers, bidirectional=config.bidirectional,
    #                                    attn_mode=config.attn, with_bridge=config.with_bridge,
    #                                    dropout=config.dropout,
    #                                    use_gpu=config.use_gpu,
    #                                    pretrain_epoch=config.pretrain_epoch)
    #     else:
    #         if config.elmo:
    #             model = Entity_Seq2Seq_elmo_gru(src_vocab_size=corpus.SRC.vocab_size,
    #                                    embed_size=config.embed_size, hidden_size=config.hidden_size,
    #                                    padding_idx=corpus.padding_idx,
    #                                    num_layers=config.num_layers, bidirectional=config.bidirectional,
    #                                    attn_mode=config.attn, with_bridge=config.with_bridge,
    #                                    dropout=config.dropout,
    #                                    use_gpu=config.use_gpu,
    #                                    pretrain_epoch=config.pretrain_epoch,
    #                                    batch_size=config.batch_size)
    #         else:
    #             model =Entity_Seq2Seq_pos_gru(src_vocab_size=corpus.SRC.vocab_size,
    #                                pos_vocab_size=corpus.POS.vocab_size,
    #                                embed_size=config.embed_size, hidden_size=config.hidden_size,
    #                                padding_idx=corpus.padding_idx,
    #                                num_layers=config.num_layers, bidirectional=config.bidirectional,
    #                                attn_mode=config.attn, with_bridge=config.with_bridge,
    #                                dropout=config.dropout,
    #                                use_gpu=config.use_gpu,
    #                                pretrain_epoch=config.pretrain_epoch)
    # else:
    #     model = Entity_Seq2Seq(src_vocab_size=corpus.SRC.vocab_size,
    #                              embed_size=config.embed_size, hidden_size=config.hidden_size,
    #                              padding_idx=corpus.padding_idx,
    #                              num_layers=config.num_layers, bidirectional=config.bidirectional,
    #                              attn_mode=config.attn, with_bridge=config.with_bridge,
    #                              dropout=config.dropout,
    #                              use_gpu=config.use_gpu,
    #                              pretrain_epoch=config.pretrain_epoch)



    model_name = model.__class__.__name__
    # Generator definition

    generator = TopKGenerator(model=model,
                                      src_field=corpus.SRC,
                                      max_length=config.max_dec_len, ignore_unk=config.ignore_unk,
                          length_average=config.length_average, use_gpu=config.use_gpu, beam_size=config.beam_size)
    # generator=None
    # Interactive generation testing
    if config.interact and config.ckpt:
        model.load(config.ckpt)
        return generator
    # Testing
    elif config.test and config.ckpt:
        print(model)
        model.load(config.ckpt)
        print("Testing ...")
        metrics = evaluate(model, valid_iter)
        print(metrics.report_cum())
        print("Generating ...")
        if config.for_test:
            evaluate_generation(generator, test_iter, save_file=config.gen_file, verbos=True, for_test=True)
        else:
            evaluate_generation(generator, test_iter, save_file=config.gen_file, verbos=True)
    else:
        # Load word embeddings
        if config.saved_embed is not None:
            model.encoder.embedder.load_embeddings(
                config.saved_embed, scale=0.03)
        # Optimizer definition
        # if config.saved_embed:
        #     embed=[]
        #     other=[]
        #     for name, v in model.named_parameters():
        #         if '.embedder' in name:
        #             print(name)
        #             embed.append(v)
        #         else:
        #             other.append(v)
        #     optimizer = getattr(torch.optim, config.optimizer)([{'params': other,
        #        'lr': config.lr,  'eps': 1e-8},
        #       {'params': embed,  'lr': config.lr/2, 'eps': 1e-8}])
        p=model.parameters()
        p=[x for x in p if x.requires_grad]
        optimizer = getattr(torch.optim, config.optimizer)(
            p, lr=config.lr, weight_decay=config.weight_decay)
        # Learning rate scheduler
        if config.lr_decay is not None and 0 < config.lr_decay < 1.0:
            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, 
                            factor=config.lr_decay, patience=1, verbose=True, min_lr=1e-5)
        else:
            lr_scheduler = None
        # Save directory
        date_str, time_str = datetime.now().strftime("%Y%m%d-%H%M%S").split("-")
        result_str = "{}-{}".format(model_name, time_str)
        if not os.path.exists(config.save_dir):
            os.makedirs(config.save_dir)
        # Logger definition
        logger = logging.getLogger(__name__)
        logging.basicConfig(level=logging.DEBUG, format="%(message)s")
        fh = logging.FileHandler(os.path.join(config.save_dir, "train.log"))
        logger.addHandler(fh)
        # Save config
        params_file = os.path.join(config.save_dir, "params.json")
        with open(params_file, 'w') as fp:
            json.dump(config.__dict__, fp, indent=4, sort_keys=True)
        print("Saved params to '{}'".format(params_file))
        logger.info(model)
        # Train
        logger.info("Training starts ...")
        trainer = Trainer(model=model, optimizer=optimizer, train_iter=train_iter,
                          valid_iter=valid_iter, logger=logger, generator=generator,
                          valid_metric_name="acc", num_epochs=config.num_epochs,
                          save_dir=config.save_dir, log_steps=config.log_steps,
                          valid_steps=config.valid_steps, grad_clip=config.grad_clip,
                          lr_scheduler=lr_scheduler, save_summary=False)
        if config.ckpt is not None:
            trainer.load(file_prefix=config.ckpt)
        trainer.train()
        logger.info("Training done!")
        # Test
        logger.info("")
        trainer.load(os.path.join(config.save_dir, "best"))
        logger.info("Testing starts ...")
        metrics, scores = evaluate(model, test_iter)
        logger.info(metrics.report_cum())
        logger.info("Generation starts ...")
        test_gen_file = os.path.join(config.save_dir, "test.result")
        evaluate_generation(generator, test_iter, save_file=test_gen_file, verbos=True)