def main():

    config = model_config()
    if config.check:
        config.save_dir = "./tmp/"
    config.use_gpu = torch.cuda.is_available() and config.gpu >= 0
    device = config.gpu
    torch.cuda.set_device(device)
    # Data definition
    corpus = KnowledgeCorpus(data_dir=config.data_dir, data_prefix=config.data_prefix,
                             min_freq=3, max_vocab_size=config.max_vocab_size,
                             min_len=config.min_len, max_len=config.max_len,
                             embed_file=config.embed_file, with_label=config.with_label,
                             share_vocab=config.share_vocab)
    corpus.load()
    if config.test and config.ckpt:
        corpus.reload(data_type='test')
    train_iter = corpus.create_batches(
        config.batch_size, "train", shuffle=True, device=device)
    valid_iter = corpus.create_batches(
        config.batch_size, "valid", shuffle=False, device=device)
    test_iter = corpus.create_batches(
        config.batch_size, "test", shuffle=False, device=device)
    # Model definition
    model = KnowledgeSeq2Seq(src_vocab_size=corpus.SRC.vocab_size,
                             tgt_vocab_size=corpus.TGT.vocab_size,
                             embed_size=config.embed_size, hidden_size=config.hidden_size,
                             padding_idx=corpus.padding_idx,
                             num_layers=config.num_layers, bidirectional=config.bidirectional,
                             attn_mode=config.attn, with_bridge=config.with_bridge,
                             tie_embedding=config.tie_embedding, dropout=config.dropout,
                             use_gpu=config.use_gpu, 
                             use_bow=config.use_bow, use_dssm=config.use_dssm,
                             use_pg=config.use_pg, use_gs=config.use_gs,
                             pretrain_epoch=config.pretrain_epoch,
                             use_posterior=config.use_posterior,
                             weight_control=config.weight_control,
                             concat=config.decode_concat)
    model_name = model.__class__.__name__
    # Generator definition
    generator = TopKGenerator(model=model,
                              src_field=corpus.SRC, tgt_field=corpus.TGT, cue_field=corpus.CUE,
                              max_length=config.max_dec_len, ignore_unk=config.ignore_unk, 
			      length_average=config.length_average, use_gpu=config.use_gpu)
    # Interactive generation testing
    if config.interact and config.ckpt:
        model.load(config.ckpt)
        return generator
    # Testing
    elif config.test and config.ckpt:
        print(model)
        model.load(config.ckpt)
        print("Testing ...")
        metrics, scores = evaluate(model, test_iter)
        print(metrics.report_cum())
        print("Generating ...")
        evaluate_generation(generator, test_iter, save_file=config.gen_file, verbos=True)
    else:
        # Load word embeddings
        if config.use_embed and config.embed_file is not None:
            model.encoder.embedder.load_embeddings(
                corpus.SRC.embeddings, scale=0.03)
            model.decoder.embedder.load_embeddings(
                corpus.TGT.embeddings, scale=0.03)
        # Optimizer definition
        optimizer = getattr(torch.optim, config.optimizer)(
            model.parameters(), lr=config.lr)
        # Learning rate scheduler
        if config.lr_decay is not None and 0 < config.lr_decay < 1.0:
            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, 
                            factor=config.lr_decay, patience=1, verbose=True, min_lr=1e-5)
        else:
            lr_scheduler = None
        # Save directory
        date_str, time_str = datetime.now().strftime("%Y%m%d-%H%M%S").split("-")
        result_str = "{}-{}".format(model_name, time_str)
        if not os.path.exists(config.save_dir):
            os.makedirs(config.save_dir)
        # Logger definition
        logger = logging.getLogger(__name__)
        logging.basicConfig(level=logging.DEBUG, format="%(message)s")
        fh = logging.FileHandler(os.path.join(config.save_dir, "train.log"))
        logger.addHandler(fh)
        # Save config
        params_file = os.path.join(config.save_dir, "params.json")
        with open(params_file, 'w') as fp:
            json.dump(config.__dict__, fp, indent=4, sort_keys=True)
        print("Saved params to '{}'".format(params_file))
        logger.info(model)
        # Train
        logger.info("Training starts ...")
        trainer = Trainer(model=model, optimizer=optimizer, train_iter=train_iter,
                          valid_iter=valid_iter, logger=logger, generator=generator,
                          valid_metric_name="-loss", num_epochs=config.num_epochs,
                          save_dir=config.save_dir, log_steps=config.log_steps,
                          valid_steps=config.valid_steps, grad_clip=config.grad_clip,
                          lr_scheduler=lr_scheduler, save_summary=False)
        if config.ckpt is not None:
            trainer.load(file_prefix=config.ckpt)
        trainer.train()
        logger.info("Training done!")
        # Test
        logger.info("")
        trainer.load(os.path.join(config.save_dir, "best"))
        logger.info("Testing starts ...")
        metrics, scores = evaluate(model, test_iter)
        logger.info(metrics.report_cum())
        logger.info("Generation starts ...")
        test_gen_file = os.path.join(config.save_dir, "test.result")
        evaluate_generation(generator, test_iter, save_file=test_gen_file, verbos=True)
def main():
    """
    main
    """
    config = model_config()
    if config.check:
        config.save_dir = "./tmp/"
    config.use_gpu = torch.cuda.is_available() and config.gpu >= 0
    device = config.gpu
    torch.cuda.set_device(device)
    # Data definition
    if config.pos:
        corpus = Entity_Corpus_pos(data_dir=config.data_dir, data_prefix=config.data_prefix, entity_file=config.entity_file,
                                     min_freq=config.min_freq, max_vocab_size=config.max_vocab_size)

    else:
        corpus = Entity_Corpus(data_dir=config.data_dir, data_prefix=config.data_prefix,
                                   entity_file=config.entity_file,
                                   min_freq=config.min_freq, max_vocab_size=config.max_vocab_size)

    corpus.load()
    if config.test and config.ckpt:
        corpus.reload(data_type='test')
    train_iter = corpus.create_batches(
        config.batch_size, "train", shuffle=True, device=device)
    valid_iter = corpus.create_batches(
        config.batch_size, "valid", shuffle=False, device=device)
    if config.for_test:
        test_iter = corpus.create_batches(
            config.batch_size, "test", shuffle=False, device=device)
    else:
        test_iter = corpus.create_batches(
            config.batch_size, "valid", shuffle=False, device=device)
    if config.preprocess:
        print('预处理完毕')
        return

    if config.pos:
        if config.rnn_type == 'lstm':
            model = Entity_Seq2Seq_pos(src_vocab_size=corpus.SRC.vocab_size,
                                       pos_vocab_size=corpus.POS.vocab_size,
                                       embed_size=config.embed_size, hidden_size=config.hidden_size,
                                       padding_idx=corpus.padding_idx,
                                       num_layers=config.num_layers, bidirectional=config.bidirectional,
                                       attn_mode=config.attn, with_bridge=config.with_bridge,
                                       dropout=config.dropout,
                                       use_gpu=config.use_gpu,
                                       pretrain_epoch=config.pretrain_epoch)
        else:
            model = Entity_Seq2Seq_pos_gru(src_vocab_size=corpus.SRC.vocab_size,
                                           pos_vocab_size=corpus.POS.vocab_size,
                                           embed_size=config.embed_size, hidden_size=config.hidden_size,
                                           padding_idx=corpus.padding_idx,
                                           num_layers=config.num_layers, bidirectional=config.bidirectional,
                                           attn_mode=config.attn, with_bridge=config.with_bridge,
                                           dropout=config.dropout,
                                           use_gpu=config.use_gpu,
                                           pretrain_epoch=config.pretrain_epoch)
    else:
        if config.rnn_type == 'lstm':
            if config.elmo:
                model = Entity_Seq2Seq_elmo(src_vocab_size=corpus.SRC.vocab_size,
                                            embed_size=config.embed_size, hidden_size=config.hidden_size,
                                            padding_idx=corpus.padding_idx,
                                            num_layers=config.num_layers, bidirectional=config.bidirectional,
                                            attn_mode=config.attn, with_bridge=config.with_bridge,
                                            dropout=config.dropout,
                                            use_gpu=config.use_gpu,
                                            pretrain_epoch=config.pretrain_epoch,
                                            batch_size=config.batch_size)
            else:
                model = Entity_Seq2Seq(src_vocab_size=corpus.SRC.vocab_size,
                                       embed_size=config.embed_size, hidden_size=config.hidden_size,
                                       padding_idx=corpus.padding_idx,
                                       num_layers=config.num_layers, bidirectional=config.bidirectional,
                                       attn_mode=config.attn, with_bridge=config.with_bridge,
                                       dropout=config.dropout,
                                       use_gpu=config.use_gpu,
                                       pretrain_epoch=config.pretrain_epoch)
        else:  # GRU
            if config.elmo:
                model = Entity_Seq2Seq_elmo_gru(src_vocab_size=corpus.SRC.vocab_size,
                                                embed_size=config.embed_size, hidden_size=config.hidden_size,
                                                padding_idx=corpus.padding_idx,
                                                num_layers=config.num_layers, bidirectional=config.bidirectional,
                                                attn_mode=config.attn, with_bridge=config.with_bridge,
                                                dropout=config.dropout,
                                                use_gpu=config.use_gpu,
                                                pretrain_epoch=config.pretrain_epoch,
                                                batch_size=config.batch_size)

    # if config.pos:
    #     if config.rnn_type=='lstm':
    #         if config.elmo:
    #             model = Entity_Seq2Seq_elmo(src_vocab_size=corpus.SRC.vocab_size,
    #                                    embed_size=config.embed_size, hidden_size=config.hidden_size,
    #                                    padding_idx=corpus.padding_idx,
    #                                    num_layers=config.num_layers, bidirectional=config.bidirectional,
    #                                    attn_mode=config.attn, with_bridge=config.with_bridge,
    #                                    dropout=config.dropout,
    #                                    use_gpu=config.use_gpu,
    #                                    pretrain_epoch=config.pretrain_epoch,
    #                                    batch_size=config.batch_size)
    #         else:
    #             model = Entity_Seq2Seq_pos(src_vocab_size=corpus.SRC.vocab_size,
    #                                    pos_vocab_size=corpus.POS.vocab_size,
    #                                    embed_size=config.embed_size, hidden_size=config.hidden_size,
    #                                    padding_idx=corpus.padding_idx,
    #                                    num_layers=config.num_layers, bidirectional=config.bidirectional,
    #                                    attn_mode=config.attn, with_bridge=config.with_bridge,
    #                                    dropout=config.dropout,
    #                                    use_gpu=config.use_gpu,
    #                                    pretrain_epoch=config.pretrain_epoch)
    #     else:
    #         if config.elmo:
    #             model = Entity_Seq2Seq_elmo_gru(src_vocab_size=corpus.SRC.vocab_size,
    #                                    embed_size=config.embed_size, hidden_size=config.hidden_size,
    #                                    padding_idx=corpus.padding_idx,
    #                                    num_layers=config.num_layers, bidirectional=config.bidirectional,
    #                                    attn_mode=config.attn, with_bridge=config.with_bridge,
    #                                    dropout=config.dropout,
    #                                    use_gpu=config.use_gpu,
    #                                    pretrain_epoch=config.pretrain_epoch,
    #                                    batch_size=config.batch_size)
    #         else:
    #             model =Entity_Seq2Seq_pos_gru(src_vocab_size=corpus.SRC.vocab_size,
    #                                pos_vocab_size=corpus.POS.vocab_size,
    #                                embed_size=config.embed_size, hidden_size=config.hidden_size,
    #                                padding_idx=corpus.padding_idx,
    #                                num_layers=config.num_layers, bidirectional=config.bidirectional,
    #                                attn_mode=config.attn, with_bridge=config.with_bridge,
    #                                dropout=config.dropout,
    #                                use_gpu=config.use_gpu,
    #                                pretrain_epoch=config.pretrain_epoch)
    # else:
    #     model = Entity_Seq2Seq(src_vocab_size=corpus.SRC.vocab_size,
    #                              embed_size=config.embed_size, hidden_size=config.hidden_size,
    #                              padding_idx=corpus.padding_idx,
    #                              num_layers=config.num_layers, bidirectional=config.bidirectional,
    #                              attn_mode=config.attn, with_bridge=config.with_bridge,
    #                              dropout=config.dropout,
    #                              use_gpu=config.use_gpu,
    #                              pretrain_epoch=config.pretrain_epoch)



    model_name = model.__class__.__name__
    # Generator definition

    generator = TopKGenerator(model=model,
                                      src_field=corpus.SRC,
                                      max_length=config.max_dec_len, ignore_unk=config.ignore_unk,
                          length_average=config.length_average, use_gpu=config.use_gpu, beam_size=config.beam_size)
    # generator=None
    # Interactive generation testing
    if config.interact and config.ckpt:
        model.load(config.ckpt)
        return generator
    # Testing
    elif config.test and config.ckpt:
        print(model)
        model.load(config.ckpt)
        print("Testing ...")
        metrics = evaluate(model, valid_iter)
        print(metrics.report_cum())
        print("Generating ...")
        if config.for_test:
            evaluate_generation(generator, test_iter, save_file=config.gen_file, verbos=True, for_test=True)
        else:
            evaluate_generation(generator, test_iter, save_file=config.gen_file, verbos=True)
    else:
        # Load word embeddings
        if config.saved_embed is not None:
            model.encoder.embedder.load_embeddings(
                config.saved_embed, scale=0.03)
        # Optimizer definition
        # if config.saved_embed:
        #     embed=[]
        #     other=[]
        #     for name, v in model.named_parameters():
        #         if '.embedder' in name:
        #             print(name)
        #             embed.append(v)
        #         else:
        #             other.append(v)
        #     optimizer = getattr(torch.optim, config.optimizer)([{'params': other,
        #        'lr': config.lr,  'eps': 1e-8},
        #       {'params': embed,  'lr': config.lr/2, 'eps': 1e-8}])
        p=model.parameters()
        p=[x for x in p if x.requires_grad]
        optimizer = getattr(torch.optim, config.optimizer)(
            p, lr=config.lr, weight_decay=config.weight_decay)
        # Learning rate scheduler
        if config.lr_decay is not None and 0 < config.lr_decay < 1.0:
            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, 
                            factor=config.lr_decay, patience=1, verbose=True, min_lr=1e-5)
        else:
            lr_scheduler = None
        # Save directory
        date_str, time_str = datetime.now().strftime("%Y%m%d-%H%M%S").split("-")
        result_str = "{}-{}".format(model_name, time_str)
        if not os.path.exists(config.save_dir):
            os.makedirs(config.save_dir)
        # Logger definition
        logger = logging.getLogger(__name__)
        logging.basicConfig(level=logging.DEBUG, format="%(message)s")
        fh = logging.FileHandler(os.path.join(config.save_dir, "train.log"))
        logger.addHandler(fh)
        # Save config
        params_file = os.path.join(config.save_dir, "params.json")
        with open(params_file, 'w') as fp:
            json.dump(config.__dict__, fp, indent=4, sort_keys=True)
        print("Saved params to '{}'".format(params_file))
        logger.info(model)
        # Train
        logger.info("Training starts ...")
        trainer = Trainer(model=model, optimizer=optimizer, train_iter=train_iter,
                          valid_iter=valid_iter, logger=logger, generator=generator,
                          valid_metric_name="acc", num_epochs=config.num_epochs,
                          save_dir=config.save_dir, log_steps=config.log_steps,
                          valid_steps=config.valid_steps, grad_clip=config.grad_clip,
                          lr_scheduler=lr_scheduler, save_summary=False)
        if config.ckpt is not None:
            trainer.load(file_prefix=config.ckpt)
        trainer.train()
        logger.info("Training done!")
        # Test
        logger.info("")
        trainer.load(os.path.join(config.save_dir, "best"))
        logger.info("Testing starts ...")
        metrics, scores = evaluate(model, test_iter)
        logger.info(metrics.report_cum())
        logger.info("Generation starts ...")
        test_gen_file = os.path.join(config.save_dir, "test.result")
        evaluate_generation(generator, test_iter, save_file=test_gen_file, verbos=True)