def main():

    config = model_config()
    if config.check:
        config.save_dir = "./tmp/"
    config.use_gpu = torch.cuda.is_available() and config.gpu >= 0
    device = config.gpu
    torch.cuda.set_device(device)
    # Data definition
    corpus = KnowledgeCorpus(data_dir=config.data_dir, data_prefix=config.data_prefix,
                             min_freq=3, max_vocab_size=config.max_vocab_size,
                             min_len=config.min_len, max_len=config.max_len,
                             embed_file=config.embed_file, with_label=config.with_label,
                             share_vocab=config.share_vocab)
    corpus.load()
    if config.test and config.ckpt:
        corpus.reload(data_type='test')
    train_iter = corpus.create_batches(
        config.batch_size, "train", shuffle=True, device=device)
    valid_iter = corpus.create_batches(
        config.batch_size, "valid", shuffle=False, device=device)
    test_iter = corpus.create_batches(
        config.batch_size, "test", shuffle=False, device=device)
    # Model definition
    model = KnowledgeSeq2Seq(src_vocab_size=corpus.SRC.vocab_size,
                             tgt_vocab_size=corpus.TGT.vocab_size,
                             embed_size=config.embed_size, hidden_size=config.hidden_size,
                             padding_idx=corpus.padding_idx,
                             num_layers=config.num_layers, bidirectional=config.bidirectional,
                             attn_mode=config.attn, with_bridge=config.with_bridge,
                             tie_embedding=config.tie_embedding, dropout=config.dropout,
                             use_gpu=config.use_gpu, 
                             use_bow=config.use_bow, use_dssm=config.use_dssm,
                             use_pg=config.use_pg, use_gs=config.use_gs,
                             pretrain_epoch=config.pretrain_epoch,
                             use_posterior=config.use_posterior,
                             weight_control=config.weight_control,
                             concat=config.decode_concat)
    model_name = model.__class__.__name__
    # Generator definition
    generator = TopKGenerator(model=model,
                              src_field=corpus.SRC, tgt_field=corpus.TGT, cue_field=corpus.CUE,
                              max_length=config.max_dec_len, ignore_unk=config.ignore_unk, 
			      length_average=config.length_average, use_gpu=config.use_gpu)
    # Interactive generation testing
    if config.interact and config.ckpt:
        model.load(config.ckpt)
        return generator
    # Testing
    elif config.test and config.ckpt:
        print(model)
        model.load(config.ckpt)
        print("Testing ...")
        metrics, scores = evaluate(model, test_iter)
        print(metrics.report_cum())
        print("Generating ...")
        evaluate_generation(generator, test_iter, save_file=config.gen_file, verbos=True)
    else:
        # Load word embeddings
        if config.use_embed and config.embed_file is not None:
            model.encoder.embedder.load_embeddings(
                corpus.SRC.embeddings, scale=0.03)
            model.decoder.embedder.load_embeddings(
                corpus.TGT.embeddings, scale=0.03)
        # Optimizer definition
        optimizer = getattr(torch.optim, config.optimizer)(
            model.parameters(), lr=config.lr)
        # Learning rate scheduler
        if config.lr_decay is not None and 0 < config.lr_decay < 1.0:
            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, 
                            factor=config.lr_decay, patience=1, verbose=True, min_lr=1e-5)
        else:
            lr_scheduler = None
        # Save directory
        date_str, time_str = datetime.now().strftime("%Y%m%d-%H%M%S").split("-")
        result_str = "{}-{}".format(model_name, time_str)
        if not os.path.exists(config.save_dir):
            os.makedirs(config.save_dir)
        # Logger definition
        logger = logging.getLogger(__name__)
        logging.basicConfig(level=logging.DEBUG, format="%(message)s")
        fh = logging.FileHandler(os.path.join(config.save_dir, "train.log"))
        logger.addHandler(fh)
        # Save config
        params_file = os.path.join(config.save_dir, "params.json")
        with open(params_file, 'w') as fp:
            json.dump(config.__dict__, fp, indent=4, sort_keys=True)
        print("Saved params to '{}'".format(params_file))
        logger.info(model)
        # Train
        logger.info("Training starts ...")
        trainer = Trainer(model=model, optimizer=optimizer, train_iter=train_iter,
                          valid_iter=valid_iter, logger=logger, generator=generator,
                          valid_metric_name="-loss", num_epochs=config.num_epochs,
                          save_dir=config.save_dir, log_steps=config.log_steps,
                          valid_steps=config.valid_steps, grad_clip=config.grad_clip,
                          lr_scheduler=lr_scheduler, save_summary=False)
        if config.ckpt is not None:
            trainer.load(file_prefix=config.ckpt)
        trainer.train()
        logger.info("Training done!")
        # Test
        logger.info("")
        trainer.load(os.path.join(config.save_dir, "best"))
        logger.info("Testing starts ...")
        metrics, scores = evaluate(model, test_iter)
        logger.info(metrics.report_cum())
        logger.info("Generation starts ...")
        test_gen_file = os.path.join(config.save_dir, "test.result")
        evaluate_generation(generator, test_iter, save_file=test_gen_file, verbos=True)
示例#2
0
文件: main.py 项目: siat-nlp/TTOS
def main():
    """
    main
    """
    config = model_config()

    config.use_gpu = torch.cuda.is_available() and config.gpu >= 0
    device = config.gpu
    torch.cuda.set_device(device)
    a = torch.tensor(1)
    a = a.cuda()
    print(a)

    # Special tokens definition
    special_tokens = ["<ENT>", "<NEN>"]

    # Data definition
    corpus = KnowledgeCorpus(data_dir=config.data_dir,
                             min_freq=0,
                             max_vocab_size=config.max_vocab_size,
                             min_len=config.min_len,
                             max_len=config.max_len,
                             embed_file=config.embed_file,
                             share_vocab=config.share_vocab,
                             special_tokens=special_tokens)

    corpus.load()

    # Model definition
    model_S = Seq2Seq(src_field=corpus.SRC,
                      tgt_field=corpus.TGT,
                      kb_field=corpus.KB,
                      embed_size=config.embed_size,
                      hidden_size=config.hidden_size,
                      padding_idx=corpus.padding_idx,
                      num_layers=config.num_layers,
                      bidirectional=config.bidirectional,
                      attn_mode=config.attn,
                      with_bridge=config.with_bridge,
                      tie_embedding=config.tie_embedding,
                      dropout=config.dropout,
                      max_hop=config.max_hop,
                      use_gpu=config.use_gpu)

    model_TB = Seq2Seq(src_field=corpus.SRC,
                       tgt_field=corpus.TGT,
                       kb_field=corpus.KB,
                       embed_size=config.embed_size,
                       hidden_size=config.hidden_size,
                       padding_idx=corpus.padding_idx,
                       num_layers=config.num_layers,
                       bidirectional=config.bidirectional,
                       attn_mode=config.attn,
                       with_bridge=config.with_bridge,
                       tie_embedding=config.tie_embedding,
                       dropout=config.dropout,
                       max_hop=config.max_hop,
                       use_gpu=config.use_gpu)

    model_TE = None if config.method == "1-1" else \
        Seq2Seq(src_field=corpus.SRC, tgt_field=corpus.TGT,
                kb_field=corpus.KB, embed_size=config.embed_size,
                hidden_size=config.hidden_size, padding_idx=corpus.padding_idx,
                num_layers=config.num_layers, bidirectional=config.bidirectional,
                attn_mode=config.attn, with_bridge=config.with_bridge,
                tie_embedding=config.tie_embedding, dropout=config.dropout,
                max_hop=config.max_hop, use_gpu=config.use_gpu)

    # Generator definition (note every generator only use single model generate here,
    # todo later can consider ensemble)
    generator_S = BeamGenerator(model=model_S,
                                src_field=corpus.SRC,
                                tgt_field=corpus.TGT,
                                kb_field=corpus.KB,
                                beam_size=config.beam_size,
                                max_length=config.max_dec_len,
                                ignore_unk=config.ignore_unk,
                                length_average=config.length_average,
                                use_gpu=config.use_gpu)

    generator_TB = BeamGenerator(model=model_TB,
                                 src_field=corpus.SRC,
                                 tgt_field=corpus.TGT,
                                 kb_field=corpus.KB,
                                 beam_size=config.beam_size,
                                 max_length=config.max_dec_len,
                                 ignore_unk=config.ignore_unk,
                                 length_average=config.length_average,
                                 use_gpu=config.use_gpu)

    generator_TE = None if config.method == "1-1" else \
                   BeamGenerator(model=model_TE, src_field=corpus.SRC, tgt_field=corpus.TGT,
                                 kb_field=corpus.KB, beam_size=config.beam_size, max_length=config.max_dec_len,
                                 ignore_unk=config.ignore_unk, length_average=config.length_average,
                                 use_gpu=config.use_gpu)

    # Discriminator definition
    discriminator_B = Discriminator(input_size=corpus.TGT.vocab_size,
                                    hidden_size=config.hidden_size,
                                    use_gpu=config.use_gpu)
    discriminator_E = Discriminator(input_size=corpus.TGT.vocab_size,
                                    hidden_size=config.hidden_size,
                                    use_gpu=config.use_gpu)

    # Muti-agent definition
    muti_agent = Muti_Agent(data_name=config.data_name,
                            ent_idx=corpus.ent_idx,
                            nen_idx=corpus.nen_idx,
                            model_S=model_S,
                            model_TB=model_TB,
                            model_TE=model_TE,
                            lambda_g=config.lambda_g,
                            lambda_s=config.lambda_s,
                            lambda_tb=config.lambda_tb,
                            lambda_te=config.lambda_te,
                            generator_S=generator_S,
                            generator_TB=generator_TB,
                            generator_TE=generator_TE,
                            discriminator_B=discriminator_B,
                            discriminator_E=discriminator_E,
                            use_gpu=config.use_gpu)

    # Testing (default only test model_S)
    if config.test and config.ckpt:
        test_iter = corpus.create_batches(config.batch_size,
                                          data_type="test",
                                          shuffle=False)

        model_path = os.path.join(config.save_dir, config.ckpt)
        muti_agent.load(model_path)
        print("Testing ...")
        if config.test_model == "S":
            test_model, generator = muti_agent.model_S, generator_S
        elif config.test_model == "TB":
            test_model, generator = muti_agent.model_TB, generator_TB
        elif config.test_model == "TE":
            test_model, generator = muti_agent.model_TE, generator_TE
        else:
            print("Invaild test model and generator!")
            sys.exit(0)
        metrics = Trainer.evaluate(test_model, test_iter)
        print(metrics.report_cum())
        print("Generating ...")
        generator.generate(data_iter=test_iter,
                           save_file=config.save_file,
                           verbos=True)

    else:
        train_iter = corpus.create_batches(config.batch_size,
                                           data_type="train",
                                           shuffle=True)
        valid_iter = corpus.create_batches(config.batch_size,
                                           data_type="valid",
                                           shuffle=False)

        # Optimizer definition
        optimizerG = getattr(torch.optim,
                             config.optimizer)(model_S.parameters(),
                                               lr=config.lr)
        optimizerDB = getattr(torch.optim,
                              config.optimizer)(discriminator_B.parameters(),
                                                lr=config.lr)
        optimizerDE = getattr(torch.optim,
                              config.optimizer)(discriminator_E.parameters(),
                                                lr=config.lr)

        if config.lr_decay is not None and 0 < config.lr_decay < 1.0:
            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer=optimizerG,
                mode='min',
                factor=config.lr_decay,
                patience=config.patience,
                verbose=True,
                min_lr=1e-6)
        else:
            lr_scheduler = None

        # Save directory
        if not os.path.exists(config.save_dir):
            os.makedirs(config.save_dir)

        # Logger definition
        logger = logging.getLogger(__name__)
        logging.basicConfig(level=logging.DEBUG, format="%(message)s")
        fh = logging.FileHandler(os.path.join(config.save_dir, "train.log"))
        logger.addHandler(fh)
        params_file = os.path.join(config.save_dir, "params.json")
        with open(params_file, 'w') as fp:
            json.dump(config.__dict__, fp, indent=4, sort_keys=True)
        logger.info("Saved params to '{}'".format(params_file))
        logger.info(muti_agent)

        # Training
        logger.info("Training starts ...")
        logger.info("Learning Approach: " + config.method)
        trainer = Trainer(model=muti_agent,
                          optimizer=(optimizerG, optimizerDB, optimizerDE),
                          train_iter=train_iter,
                          valid_iter=valid_iter,
                          logger=logger,
                          method=config.method,
                          valid_metric_name="-loss",
                          num_epochs=config.num_epochs,
                          pre_epochs=config.pre_epochs,
                          save_dir=config.save_dir,
                          pre_train_dir=config.pre_train_dir,
                          log_steps=config.log_steps,
                          valid_steps=config.valid_steps,
                          grad_clip=config.grad_clip,
                          lr_scheduler=lr_scheduler)

        if config.ckpt:
            trainer.load(file_ckpt=config.ckpt)
        else:
            # The whole pre_train model doesn't exist means we will train from scratch,
            # therefore load the three single pre_train model in the whole model (muti-agent)
            trainer.load_per_agent(S_ckpt=config.s_ckpt,
                                   TE_ckpt=config.te_ckpt,
                                   TB_ckpt=config.tb_ckpt)
            # close_train((model_TE, model_TB))

        trainer.train()

        logger.info("Training done!")
示例#3
0
def main():
    """
    main
    """
    config = model_config()

    config.use_gpu = torch.cuda.is_available() and config.gpu >= 0
    device = config.gpu
    torch.cuda.set_device(device)

    # Data definition
    corpus = KnowledgeCorpus(data_dir=config.data_dir,
                             min_freq=0,
                             max_vocab_size=config.max_vocab_size,
                             min_len=config.min_len,
                             max_len=config.max_len,
                             embed_file=config.embed_file,
                             share_vocab=config.share_vocab)
    corpus.load()

    # Model definition
    model = Seq2Seq(src_field=corpus.SRC,
                    tgt_field=corpus.TGT,
                    kb_field=corpus.KB,
                    embed_size=config.embed_size,
                    hidden_size=config.hidden_size,
                    padding_idx=corpus.padding_idx,
                    num_layers=config.num_layers,
                    bidirectional=config.bidirectional,
                    attn_mode=config.attn,
                    with_bridge=config.with_bridge,
                    tie_embedding=config.tie_embedding,
                    dropout=config.dropout,
                    max_hop=config.max_hop,
                    use_gpu=config.use_gpu)

    # Generator definition
    generator = BeamGenerator(model=model,
                              src_field=corpus.SRC,
                              tgt_field=corpus.TGT,
                              kb_field=corpus.KB,
                              beam_size=config.beam_size,
                              max_length=config.max_dec_len,
                              ignore_unk=config.ignore_unk,
                              length_average=config.length_average,
                              use_gpu=config.use_gpu,
                              mode=config.mode)

    # Testing
    if config.test and config.ckpt:
        data_iter = corpus.create_batches(
            batch_size=1 if config.mode == 'test' else config.batch_size,
            data_type=config.mode,
            shuffle=False)
        model_path = os.path.join(config.save_dir, config.ckpt)
        model.load(model_path)
        print("Testing ...")
        if config.mode != 'test':
            metrics = Trainer.evaluate(model, data_iter)
            print(metrics.report_cum())
        print("Generating ...")
        generator.generate(data_iter=data_iter,
                           save_file=config.save_file,
                           verbos=True)

    else:
        train_iter = corpus.create_batches(config.batch_size,
                                           data_type="train",
                                           shuffle=True)
        valid_iter = corpus.create_batches(config.batch_size,
                                           data_type="valid",
                                           shuffle=False)

        # Load word embeddings if possible
        if config.use_embed and config.embed_file is not None:
            model.encoder.embedder.load_embeddings(
                corpus.SRC.embeddings,
                scale=0.03,
                trainable=config.train_embed)
            model.decoder.embedder.load_embeddings(
                corpus.TGT.embeddings,
                scale=0.03,
                trainable=config.train_embed)

        # Optimizer definition
        optimizer = getattr(torch.optim, config.optimizer)(model.parameters(),
                                                           lr=config.lr)

        if config.lr_decay is not None and 0 < config.lr_decay < 1.0:
            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer=optimizer,
                mode='min',
                factor=config.lr_decay,
                patience=config.patience,
                verbose=True,
                min_lr=1e-6)
        else:
            lr_scheduler = None

        # Save directory
        if not os.path.exists(config.save_dir):
            os.makedirs(config.save_dir)

        # Logger definition
        logger = logging.getLogger(__name__)
        logging.basicConfig(level=logging.DEBUG, format="%(message)s")
        fh = logging.FileHandler(os.path.join(config.save_dir, "train.log"))
        logger.addHandler(fh)
        params_file = os.path.join(config.save_dir, "params.json")
        with open(params_file, 'w') as fp:
            json.dump(config.__dict__, fp, indent=4, sort_keys=True)
        logger.info("Saved params to '{}'".format(params_file))
        logger.info(model)

        # Training
        logger.info("Training starts ...")
        trainer = Trainer(model=model,
                          optimizer=optimizer,
                          train_iter=train_iter,
                          valid_iter=valid_iter,
                          logger=logger,
                          valid_metric_name="-loss",
                          num_epochs=config.num_epochs,
                          pre_epochs=config.pre_epochs,
                          save_dir=config.save_dir,
                          log_steps=config.log_steps,
                          valid_steps=config.valid_steps,
                          grad_clip=config.grad_clip,
                          lr_scheduler=lr_scheduler)
        if config.ckpt is not None:
            trainer.load(file_ckpt=config.ckpt)
        trainer.train()

        logger.info("Training done!")
def main():
    """
    main
    """
    config = model_config()
    if config.check:
        config.save_dir = "./tmp/"
    config.use_gpu = torch.cuda.is_available() and config.gpu >= 0
    device = config.gpu
    torch.cuda.set_device(device)
    # Data definition
    if config.pos:
        corpus = Entity_Corpus_pos(data_dir=config.data_dir, data_prefix=config.data_prefix, entity_file=config.entity_file,
                                     min_freq=config.min_freq, max_vocab_size=config.max_vocab_size)

    else:
        corpus = Entity_Corpus(data_dir=config.data_dir, data_prefix=config.data_prefix,
                                   entity_file=config.entity_file,
                                   min_freq=config.min_freq, max_vocab_size=config.max_vocab_size)

    corpus.load()
    if config.test and config.ckpt:
        corpus.reload(data_type='test')
    train_iter = corpus.create_batches(
        config.batch_size, "train", shuffle=True, device=device)
    valid_iter = corpus.create_batches(
        config.batch_size, "valid", shuffle=False, device=device)
    if config.for_test:
        test_iter = corpus.create_batches(
            config.batch_size, "test", shuffle=False, device=device)
    else:
        test_iter = corpus.create_batches(
            config.batch_size, "valid", shuffle=False, device=device)
    if config.preprocess:
        print('预处理完毕')
        return

    if config.pos:
        if config.rnn_type == 'lstm':
            model = Entity_Seq2Seq_pos(src_vocab_size=corpus.SRC.vocab_size,
                                       pos_vocab_size=corpus.POS.vocab_size,
                                       embed_size=config.embed_size, hidden_size=config.hidden_size,
                                       padding_idx=corpus.padding_idx,
                                       num_layers=config.num_layers, bidirectional=config.bidirectional,
                                       attn_mode=config.attn, with_bridge=config.with_bridge,
                                       dropout=config.dropout,
                                       use_gpu=config.use_gpu,
                                       pretrain_epoch=config.pretrain_epoch)
        else:
            model = Entity_Seq2Seq_pos_gru(src_vocab_size=corpus.SRC.vocab_size,
                                           pos_vocab_size=corpus.POS.vocab_size,
                                           embed_size=config.embed_size, hidden_size=config.hidden_size,
                                           padding_idx=corpus.padding_idx,
                                           num_layers=config.num_layers, bidirectional=config.bidirectional,
                                           attn_mode=config.attn, with_bridge=config.with_bridge,
                                           dropout=config.dropout,
                                           use_gpu=config.use_gpu,
                                           pretrain_epoch=config.pretrain_epoch)
    else:
        if config.rnn_type == 'lstm':
            if config.elmo:
                model = Entity_Seq2Seq_elmo(src_vocab_size=corpus.SRC.vocab_size,
                                            embed_size=config.embed_size, hidden_size=config.hidden_size,
                                            padding_idx=corpus.padding_idx,
                                            num_layers=config.num_layers, bidirectional=config.bidirectional,
                                            attn_mode=config.attn, with_bridge=config.with_bridge,
                                            dropout=config.dropout,
                                            use_gpu=config.use_gpu,
                                            pretrain_epoch=config.pretrain_epoch,
                                            batch_size=config.batch_size)
            else:
                model = Entity_Seq2Seq(src_vocab_size=corpus.SRC.vocab_size,
                                       embed_size=config.embed_size, hidden_size=config.hidden_size,
                                       padding_idx=corpus.padding_idx,
                                       num_layers=config.num_layers, bidirectional=config.bidirectional,
                                       attn_mode=config.attn, with_bridge=config.with_bridge,
                                       dropout=config.dropout,
                                       use_gpu=config.use_gpu,
                                       pretrain_epoch=config.pretrain_epoch)
        else:  # GRU
            if config.elmo:
                model = Entity_Seq2Seq_elmo_gru(src_vocab_size=corpus.SRC.vocab_size,
                                                embed_size=config.embed_size, hidden_size=config.hidden_size,
                                                padding_idx=corpus.padding_idx,
                                                num_layers=config.num_layers, bidirectional=config.bidirectional,
                                                attn_mode=config.attn, with_bridge=config.with_bridge,
                                                dropout=config.dropout,
                                                use_gpu=config.use_gpu,
                                                pretrain_epoch=config.pretrain_epoch,
                                                batch_size=config.batch_size)

    # if config.pos:
    #     if config.rnn_type=='lstm':
    #         if config.elmo:
    #             model = Entity_Seq2Seq_elmo(src_vocab_size=corpus.SRC.vocab_size,
    #                                    embed_size=config.embed_size, hidden_size=config.hidden_size,
    #                                    padding_idx=corpus.padding_idx,
    #                                    num_layers=config.num_layers, bidirectional=config.bidirectional,
    #                                    attn_mode=config.attn, with_bridge=config.with_bridge,
    #                                    dropout=config.dropout,
    #                                    use_gpu=config.use_gpu,
    #                                    pretrain_epoch=config.pretrain_epoch,
    #                                    batch_size=config.batch_size)
    #         else:
    #             model = Entity_Seq2Seq_pos(src_vocab_size=corpus.SRC.vocab_size,
    #                                    pos_vocab_size=corpus.POS.vocab_size,
    #                                    embed_size=config.embed_size, hidden_size=config.hidden_size,
    #                                    padding_idx=corpus.padding_idx,
    #                                    num_layers=config.num_layers, bidirectional=config.bidirectional,
    #                                    attn_mode=config.attn, with_bridge=config.with_bridge,
    #                                    dropout=config.dropout,
    #                                    use_gpu=config.use_gpu,
    #                                    pretrain_epoch=config.pretrain_epoch)
    #     else:
    #         if config.elmo:
    #             model = Entity_Seq2Seq_elmo_gru(src_vocab_size=corpus.SRC.vocab_size,
    #                                    embed_size=config.embed_size, hidden_size=config.hidden_size,
    #                                    padding_idx=corpus.padding_idx,
    #                                    num_layers=config.num_layers, bidirectional=config.bidirectional,
    #                                    attn_mode=config.attn, with_bridge=config.with_bridge,
    #                                    dropout=config.dropout,
    #                                    use_gpu=config.use_gpu,
    #                                    pretrain_epoch=config.pretrain_epoch,
    #                                    batch_size=config.batch_size)
    #         else:
    #             model =Entity_Seq2Seq_pos_gru(src_vocab_size=corpus.SRC.vocab_size,
    #                                pos_vocab_size=corpus.POS.vocab_size,
    #                                embed_size=config.embed_size, hidden_size=config.hidden_size,
    #                                padding_idx=corpus.padding_idx,
    #                                num_layers=config.num_layers, bidirectional=config.bidirectional,
    #                                attn_mode=config.attn, with_bridge=config.with_bridge,
    #                                dropout=config.dropout,
    #                                use_gpu=config.use_gpu,
    #                                pretrain_epoch=config.pretrain_epoch)
    # else:
    #     model = Entity_Seq2Seq(src_vocab_size=corpus.SRC.vocab_size,
    #                              embed_size=config.embed_size, hidden_size=config.hidden_size,
    #                              padding_idx=corpus.padding_idx,
    #                              num_layers=config.num_layers, bidirectional=config.bidirectional,
    #                              attn_mode=config.attn, with_bridge=config.with_bridge,
    #                              dropout=config.dropout,
    #                              use_gpu=config.use_gpu,
    #                              pretrain_epoch=config.pretrain_epoch)



    model_name = model.__class__.__name__
    # Generator definition

    generator = TopKGenerator(model=model,
                                      src_field=corpus.SRC,
                                      max_length=config.max_dec_len, ignore_unk=config.ignore_unk,
                          length_average=config.length_average, use_gpu=config.use_gpu, beam_size=config.beam_size)
    # generator=None
    # Interactive generation testing
    if config.interact and config.ckpt:
        model.load(config.ckpt)
        return generator
    # Testing
    elif config.test and config.ckpt:
        print(model)
        model.load(config.ckpt)
        print("Testing ...")
        metrics = evaluate(model, valid_iter)
        print(metrics.report_cum())
        print("Generating ...")
        if config.for_test:
            evaluate_generation(generator, test_iter, save_file=config.gen_file, verbos=True, for_test=True)
        else:
            evaluate_generation(generator, test_iter, save_file=config.gen_file, verbos=True)
    else:
        # Load word embeddings
        if config.saved_embed is not None:
            model.encoder.embedder.load_embeddings(
                config.saved_embed, scale=0.03)
        # Optimizer definition
        # if config.saved_embed:
        #     embed=[]
        #     other=[]
        #     for name, v in model.named_parameters():
        #         if '.embedder' in name:
        #             print(name)
        #             embed.append(v)
        #         else:
        #             other.append(v)
        #     optimizer = getattr(torch.optim, config.optimizer)([{'params': other,
        #        'lr': config.lr,  'eps': 1e-8},
        #       {'params': embed,  'lr': config.lr/2, 'eps': 1e-8}])
        p=model.parameters()
        p=[x for x in p if x.requires_grad]
        optimizer = getattr(torch.optim, config.optimizer)(
            p, lr=config.lr, weight_decay=config.weight_decay)
        # Learning rate scheduler
        if config.lr_decay is not None and 0 < config.lr_decay < 1.0:
            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, 
                            factor=config.lr_decay, patience=1, verbose=True, min_lr=1e-5)
        else:
            lr_scheduler = None
        # Save directory
        date_str, time_str = datetime.now().strftime("%Y%m%d-%H%M%S").split("-")
        result_str = "{}-{}".format(model_name, time_str)
        if not os.path.exists(config.save_dir):
            os.makedirs(config.save_dir)
        # Logger definition
        logger = logging.getLogger(__name__)
        logging.basicConfig(level=logging.DEBUG, format="%(message)s")
        fh = logging.FileHandler(os.path.join(config.save_dir, "train.log"))
        logger.addHandler(fh)
        # Save config
        params_file = os.path.join(config.save_dir, "params.json")
        with open(params_file, 'w') as fp:
            json.dump(config.__dict__, fp, indent=4, sort_keys=True)
        print("Saved params to '{}'".format(params_file))
        logger.info(model)
        # Train
        logger.info("Training starts ...")
        trainer = Trainer(model=model, optimizer=optimizer, train_iter=train_iter,
                          valid_iter=valid_iter, logger=logger, generator=generator,
                          valid_metric_name="acc", num_epochs=config.num_epochs,
                          save_dir=config.save_dir, log_steps=config.log_steps,
                          valid_steps=config.valid_steps, grad_clip=config.grad_clip,
                          lr_scheduler=lr_scheduler, save_summary=False)
        if config.ckpt is not None:
            trainer.load(file_prefix=config.ckpt)
        trainer.train()
        logger.info("Training done!")
        # Test
        logger.info("")
        trainer.load(os.path.join(config.save_dir, "best"))
        logger.info("Testing starts ...")
        metrics, scores = evaluate(model, test_iter)
        logger.info(metrics.report_cum())
        logger.info("Generation starts ...")
        test_gen_file = os.path.join(config.save_dir, "test.result")
        evaluate_generation(generator, test_iter, save_file=test_gen_file, verbos=True)