Exemple #1
0
def extract_sentiment_words():
    # create vocabulary using wikitext2
    train_txt, _, _ = torchtext.datasets.WikiText2.splits(TEXT)
    TEXT.build_vocab(train_txt)

    start = time.time()
    x_train, y_train, x_val, y_val, rtrain, rtest = preprocess()
    end = time.time()

    print("PREPROCESSING TIME: {}".format(end - start))
    
    ntokens = len(TEXT.vocab.stoi) # the size of vocabulary
    
    # FIXME set up batched examples for better generality
    # batch_size = 20
    # eval_batch_size = 10

    # configs
    emsize = 200 # embedding dimension
    nhid = 200 # feedforward dimension
    nlayers = 2 # n encoders
    nhead = 2 # multiattention heads
    dropout = 0.2 # the dropout value

    # initialize main torch vars
    model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)
    criterion = nn.CrossEntropyLoss().to(device)

    lr = 0.05 # learning rate

    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

    best_val_loss = float("inf")
    epochs = 50
    best_model = None
    
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train_model(x_train, y_train, model, criterion, optimizer, scheduler, epoch)
        val_loss = evaluate(x_val, y_val,rtest, model,criterion)
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
            'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                        val_loss, math.exp(val_loss)))
        print('-' * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = model

        scheduler.step()
    
    # test_loss = evaluate(best_model, criterion, test_data)

    # print('=' * 89)
    # print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    #     test_loss, math.exp(test_loss)))
    # print('=' * 89)
    return best_model
Exemple #2
0
def main():
    print("Generating data...", end="")
    voc_size = args.vocab_sz
    inp = np.arange(2, voc_size, 2)
    tgt = np.arange(3, voc_size, 2)
    data_x, data_y = get_numbers(inp, tgt)
    train_len = int(len(data_x) * 0.9)
    train_x, val_x = data_x[:train_len], data_x[train_len:]
    train_y, val_y = data_y[:train_len], data_y[train_len:]
    print("Done")

    print("Setting model...", end="")
    model = TransformerModel(
        input_sz=voc_size,
        output_sz=voc_size,
        d_model=args.d_model,
        nhead=args.n_head,
        num_encoder_layers=args.n_encoder_layers,
        num_decoder_layers=args.n_decoder_layers,
        dim_feedforward=args.dim_feedforward,
        dropout=args.dropout,
    )
    if args.load_dir != ".":
        model.load_state_dict(flow.load(args.load_dir))
    model = to_cuda(model)
    criterion = to_cuda(nn.CrossEntropyLoss())

    optimizer = flow.optim.Adam(model.parameters(), lr=args.lr)
    print("Done")

    print("Training...")

    min_loss = 100
    for i in range(1, args.n_epochs + 1):
        epoch_loss = train(model, criterion, optimizer, train_x, train_y)
        epoch_loss_val = validation(model, criterion, val_x, val_y)
        print("epoch: {} train loss: {}".format(i, epoch_loss))
        print("epoch: {} val loss: {}".format(i, epoch_loss_val))
        if epoch_loss < min_loss:
            if not os.path.exists(args.save_dir):
                os.mkdir(args.save_dir)
            else:
                shutil.rmtree(args.save_dir)
                assert not os.path.exists(args.save_dir)
                os.mkdir(args.save_dir)
            flow.save(model.state_dict(), args.save_dir)
        if i % 3 == 2:
            print(test(model, test_times=10))
Exemple #3
0
def main(model_name=None, hidden=64, nlayers=1):
    voc_size = 10000
    inp = arange(2, voc_size, 2)
    tgt = arange(3, voc_size, 2)
    batch_size = 128
    epochs = 30
    dataset = NumberLoader(inp, tgt)
    train_len = int(len(dataset) * 0.9)
    val_len = len(dataset) - train_len
    train_set, val_set = random_split(dataset, [train_len, val_len])
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1)
    val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=1)
    model = TransformerModel(voc_size,
                             voc_size,
                             hidden=hidden,
                             nlayers=nlayers)
    if model_name is not None:
        model.load_state_dict(load(model_name))
    model = model.cuda()
    # optimizer = optim.SGD(model.parameters(), lr=0.5)
    optimizer = optim.Adam(model.parameters())
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    criterion = nn.CrossEntropyLoss()
    best_loss = 100
    for i in range(epochs):
        epoch_loss = train(model, criterion, optimizer, train_loader)
        epoch_loss_val = validation(model, criterion, val_loader)
        # scheduler.step()
        print("epoch: {} train loss: {}".format(i, epoch_loss))
        print("epoch: {} val loss: {}".format(i, epoch_loss_val))
        if epoch_loss_val < best_loss:
            best_loss = epoch_loss_val
            model_name = "model/model_{0:.5f}.pt".format(epoch_loss_val)
            save(model.state_dict(), model_name)
    return model_name
Exemple #4
0
def main(args):
    random_seed(args.seed)
    device = torch.device("cuda" if args.cuda else "cpu")

    corpus = data.Corpus(args.data)
    train_data = batchify(corpus.train, args.batch_size)
    val_data = batchify(corpus.valid, args.batch_size)
    test_data = batchify(corpus.test, args.batch_size)
    print('loaded data')
    print(f'number of unique tokens: {len(corpus.dictionary)}')

    ntokens = len(corpus.dictionary)
    if args.model == 'Transformer':
        model = TransformerModel(
            ntokens,
            args.emsize,
            args.nhead,
            args.nhid,
            args.nlayers,
            args.dropout).to(device)
    else:
        model = RNNModel(
            args.model,
            ntokens,
            args.emsize,
            args.nhid,
            args.nlayers,
            args.dropout,
            args.tied).to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                    max_lr=0.001,
                                                    steps_per_epoch=len(list(range(0,
                                                                                   train_data.size(
                                                                                       0) - 1,
                                                                                   args.bptt))),
                                                    epochs=args.epochs,
                                                    anneal_strategy='linear')
    print('initialized model and optimizer')
    train(args, model, optimizer, train_data, val_data, scheduler)
def main():
    parser = argparse.ArgumentParser(description="Train GPT2 Model")
    parser.add_argument("--batch_size",
                        type=int,
                        default=4,
                        help="Specify batch size")
    parser.add_argument("--num_epoch",
                        type=int,
                        default=3,
                        help="Specify number of epochs")
    parser.add_argument("--learning_rate",
                        type=float,
                        default=5e-5,
                        help="Specify AdamW learning rate")

    args = parser.parse_args()

    setup = models.trav_trans.dataset.Setup("output", "output/train_dps.txt",
                                            "output/train_ids.txt")

    layers = [1, 3, 6, 9]

    for l in layers:
        model = TransformerModel(
            len(setup.vocab.idx2vocab),
            CrossEntropyLoss(ignore_index=setup.vocab.pad_idx), l, 300, 1000,
            6, 1e-05)

        training_args = TrainingArgs(batch_size=args.batch_size,
                                     num_epoch=args.num_epoch,
                                     output_dir="output",
                                     optimizer=AdamW(model.parameters(),
                                                     lr=args.learning_rate),
                                     save_model_on_epoch=False,
                                     suffix=f"{l}-layers")

        trainer = Trainer(model, setup, training_args)

        trainer.train()
Exemple #6
0
def main():
    parser = argparse.ArgumentParser(description="Train GPT2 Model")
    parser.add_argument("--batch_size", type=int, default=4, help="Specify batch size")
    parser.add_argument("--num_epoch", type=int, default=3, help="Specify number of epochs")
    parser.add_argument("--learning_rate", type=float, default=5e-5, help="Specify AdamW learning rate")

    args = parser.parse_args()

    tokenizer = Tokenizer.from_file("output/tokenizer.json")
    dataset = Dataset("output/train_rq4_dps.txt")

    model = TransformerModel(
        tokenizer.get_vocab_size(),
        CrossEntropyLoss(ignore_index=tokenizer.encode("[PAD]").ids[0]),
        6,
        300,
        1000,
        6,
        1e-05
    )

    training_args = TrainingArgs(
        batch_size = args.batch_size,
        num_epoch = args.num_epoch,
        output_dir = "output",
        optimizer = AdamW(model.parameters(), lr=args.learning_rate),
        save_model_on_epoch = False
    )

    trainer = Trainer(
        model,
        dataset,
        tokenizer,
        training_args
    )

    trainer.train()
Exemple #7
0
                             n_meds=n_meds,
                             n_covs=n_covs,
                             sequence_len=sequence_len,
                             emsize=emsize,
                             nhead=nhead,
                             nhid=nhid,
                             nlayers=nlayers,
                             n_mc_smps=n_mc_smps,
                             dropout=dropout).to(globals.device)

    print("data fully setup!")

    ### Training parameters
    criterion = nn.BCEWithLogitsLoss(reduction='sum')
    lr = 0.03
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

    ### Training
    best_val_loss = float("inf")
    epochs = 100
    best_model = None

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train()
        val_loss, _, _ = evaluate(model)
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} '.
              format(epoch, (time.time() - epoch_start_time), val_loss))
        print('-' * 89)
Exemple #8
0
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print("using GPU numbers {}".format(CONFIG.hyperparam.misc.gpu_ids))
    else:
        device = torch.device("cpu")
        print("using CPU")
    model = TransformerModel(
        CONFIG,
        vocab_size=len(tokenizer),
        bos_idx=tokenizer.bos_idx,
        pad_idx=tokenizer.pad_idx,
    )
    model = model.to(device)
    if CONFIG.hyperparam.optimization.name == "Adam":
        optimizer = optim.Adam(
            model.parameters(),
            lr=CONFIG.hyperparam.optimization.lr,
            betas=(
                CONFIG.hyperparam.optimization.beta1,
                CONFIG.hyperparam.optimization.beta2,
            ),
            weight_decay=CONFIG.hyperparam.optimization.weight_decay,
        )
    else:
        raise NotImplementedError("only Adam implemented")
    #########################################################

    ################# evaluator, saver ######################
    print("loading evaluator and model saver...")
    evaluator = NLGEval(no_skipthoughts=True, no_glove=True)
    # evaluator = NLGEval(metrics_to_omit=["METEOR"])
Exemple #9
0

if args.restart:
    # Resume training from checkpoint
    with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f:
        model = torch.load(f)
    if not args.fp16:
        model = model.float()
    model.apply(update_dropout)
    model.apply(update_dropatt)
else:
    # Train from the start
    model = TransformerModel(ntokens, args.d_model, args.n_head, args.d_inner,
                             args.n_layer, args.dropout)

    for p in model.parameters():
        p.requires_grad_(True)
    model.train()

    model.apply(weights_init)
args.n_all_param = sum([p.nelement() for p in model.parameters()])

if args.fp16:
    model = model.half()

if args.multi_gpu:
    model = model.to(device)
    if args.gpu0_bsz >= 0:
        para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
                                          model,
                                          dim=1).to(device)
Exemple #10
0
def run_trainer(config):
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    run_name_format = (
        f"data={data_name}-"
        f"range={range_name}-"
        "d_model={d_model}-"
        "layers_count={nlayers}-"
        "heads_count={nhead}-"
        "FC_size={nhid}-"
        "lr={lr}-"
        "{timestamp}"
    )

    run_name = run_name_format.format(**config, timestamp=datetime.now().strftime("%Y_%m_%d_%H_%M_%S"))

    logger = get_logger(run_name, save_log=config['save_log'])
    logger.info(f'Run name : {run_name}')
    logger.info(config)

    data_dir = config['data_dir'] + "-" + data_name + "-" + range_name
    logger.info(f'Constructing dictionaries from {data_dir}...')
    source_dictionary = IndexDictionary.load(data_dir, mode='source')
    target_dictionary = IndexDictionary.load(data_dir, mode='target')
    logger.info(f'Source dictionary vocabulary : {source_dictionary.vocabulary_size} tokens')
    logger.info(f'Target dictionary vocabulary : {target_dictionary.vocabulary_size} tokens')

    logger.info('Building model...')
    model = TransformerModel(source_dictionary.vocabulary_size, target_dictionary.vocabulary_size,
                             d_model=config['d_model'],
                             nhead=config['nhead'],
                             nhid=config['nhid'],
                             nlayers=config['nlayers'])
    logger.info(model)
    logger.info('Encoder : {parameters_count} parameters'.format(parameters_count=sum([p.nelement() for p in model.transformer_encoder.parameters()])))
    logger.info('Decoder : {parameters_count} parameters'.format(parameters_count=sum([p.nelement() for p in model.transformer_decoder.parameters()])))
    logger.info('Total : {parameters_count} parameters'.format(parameters_count=sum([p.nelement() for p in model.parameters()])))

    logger.info('Loading datasets...')
    train_dataset = IndexedInputTargetTranslationDataset(
        data_dir=data_dir,
        phase='train')

    val_dataset = IndexedInputTargetTranslationDataset(
        data_dir=data_dir,
        phase='val')

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        collate_fn=input_target_collate_fn,
        num_workers=5)

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        collate_fn=input_target_collate_fn,
        num_workers=5)

    loss_function = TokenCrossEntropyLoss()
    accuracy_function = AccuracyMetric()
    optimizer = Adam(model.parameters(), lr=config['lr'])

    logger.info('Start training...')
    trainer = EpochSeq2SeqTrainer(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        loss_function=loss_function,
        metric_function=accuracy_function,
        optimizer=optimizer,
        logger=logger,
        run_name=run_name,
        save_config=config['save_config'],
        save_checkpoint=config['save_checkpoint'],
        config=config,
        iter_num=args.iter_num
    )

    trainer.run(config['epochs'])

    return trainer
Exemple #11
0
def main():
    ### settings
    args = set_args()
    save_path = args.save_path
    if not os.path.isdir(save_path):
        os.makedirs(save_path)
    logger.info(args)

    ### prepare for data
    train_dataset = COCOMultiLabel(args,
                                   train=True,
                                   image_path=args.image_path)
    test_dataset = COCOMultiLabel(args,
                                  train=False,
                                  image_path=args.image_path)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers,
                              pin_memory=True,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=my_collate)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers,
                             pin_memory=True,
                             shuffle=False,
                             drop_last=False,
                             collate_fn=my_collate)

    ## prepare for models
    encoder = CNN_Encoder().cuda()
    decoder = TransformerModel(args).cuda()
    ## set different parameter for training or only evaluation'
    if args.use_eval:
        weights_dic = torch.load(args.use_model)
        encoder.load_state_dict(
            convert_weights(weights_dic['encoder_state_dict']))
        decoder.load_state_dict(
            convert_weights(weights_dic['decoder_state_dict']))
    else:
        encoder.load_state_dict(
            convert_weights(torch.load(args.encoder_weights)))
        encoder_optimizer = torch.optim.Adam(encoder.parameters(),
                                             lr=args.encoder_lr)
        decoder_optimizer = torch.optim.Adam(decoder.parameters(),
                                             lr=args.decoder_lr)

    ## whether using dataparallel'
    if torch.cuda.device_count() > 1:
        encoder = nn.DataParallel(encoder)
        decoder = nn.DataParallel(decoder)

    ## set hinge loss function'
    loss_hinge = torch.nn.HingeEmbeddingLoss(margin=args.C,
                                             size_average=None,
                                             reduce=None,
                                             reduction='mean')

    ## if only evaluation, return"
    if args.use_eval:
        f1 = test(args, encoder, decoder, test_loader, args.threshold, 1)
        return

    ##  training stage
    highest_f1 = 0
    epochs_without_improve = 0
    for epoch in range(args.epochs):
        ## train and test
        train(args, encoder, decoder, train_loader, encoder_optimizer,
              decoder_optimizer, epoch, loss_hinge)
        f1 = test(args, encoder, decoder, test_loader, args.threshold, epoch)

        ### save parameter
        save_dict = {
            'encoder_state_dict': encoder.state_dict(),
            'decoder_state_dict': decoder.state_dict(),
            'epoch': epoch,
            'f1': f1,
            'decoder_optimizer_state_dict': decoder_optimizer.state_dict(),
            'encoder_optimizer_state_dict': encoder_optimizer.state_dict(),
            'epochs_without_improve': epochs_without_improve
        }

        ### save models'
        torch.save(save_dict,
                   args.save_path + "/checkpoint_" + timestr + '.pt.tar')
        if f1 > highest_f1:
            torch.save(
                save_dict,
                args.save_path + "/BEST_checkpoint_" + timestr + '.pt.tar')
            logger.info("Now the highest f1 is {}, it was {}".format(
                100 * f1, 100 * highest_f1))
            highest_f1 = f1
            epochs_without_improve = 0
        else:
            epochs_without_improve += 1
            if epochs_without_improve == 3:
                adjust_learning_rate(decoder_optimizer, args.coeff)
                adjust_learning_rate(encoder_optimizer, args.coeff)
                epochs_without_imp = 0
Exemple #12
0
    parser.add_argument('--pretrain_model_path', type=str, default=hp.pretrain_model_path)
    args = parser.parse_args()
    for k, v in vars(args).items():
        setattr(hp, k, v)
    args = parser.parse_args()

    pretrain_emb = align_word_embedding(hp.word_dict_pickle_path, hp.pretrain_emb_path, hp.ntoken,
                                        hp.nhid) if hp.load_pretrain_emb else None
    pretrain_cnn = torch.load(hp.pretrain_cnn_path) if hp.load_pretrain_cnn else None

    model = TransformerModel(hp.ntoken, hp.ninp, hp.nhead, hp.nhid, hp.nlayers, hp.batch_size, dropout=0.2,
                             pretrain_cnn=pretrain_cnn, pretrain_emb=pretrain_emb, freeze_cnn=hp.freeze_cnn).to(device)
    if hp.load_pretrain_model:
        model.load_state_dict(torch.load(hp.pretrain_model_path))

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=hp.lr, weight_decay=1e-6)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, hp.scheduler_decay)
    if hp.label_smoothing:
        criterion = LabelSmoothingLoss(hp.ntoken, smoothing=0.1)
    else:
        criterion = nn.CrossEntropyLoss(ignore_index=hp.ntoken - 1)

    now_time = str(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime(time.time())))
    log_dir = 'models/{name}'.format(name=hp.name)

    writer = SummaryWriter(log_dir=log_dir)

    log_path = os.path.join(log_dir, 'train.log')

    logging.basicConfig(level=logging.DEBUG,
                        format=
Exemple #13
0
class TrainLoop_Transformer():
    def __init__(self, opt):
        self.opt = opt

        self.dict = json.load(open(args.bpe2index, encoding='utf-8'))
        self.index2word = {self.dict[key]: key for key in self.dict}

        self.batch_size = self.opt['batch_size']
        self.epoch = self.opt['epoch']
        self.use_cuda = opt['use_cuda']
        print('self.use_cuda:', self.use_cuda)

        self.device = 'cuda:{}'.format(
            self.opt['gpu']) if self.use_cuda else 'cpu'
        self.opt['device'] = self.device

        self.movie_ids = pkl.load(open("data/movie_ids.pkl", "rb"))

        # self.metrics_gen = {
        #     "ppl": 0,
        #     "dist1": 0,
        #     "dist2": 0,
        #     "dist3": 0,
        #     "dist4": 0,
        #     "bleu1": 0,
        #     "bleu2": 0,
        #     "bleu3": 0,
        #     "bleu4": 0,
        #     "count": 0
        # }

        self.build_data()
        self.build_model()

        # self.init_optim(
        #     [p for p in self.model.parameters() if p.requires_grad],
        #     optim_states=states.get('optimizer'),
        #     saved_optim_type=states.get('optimizer_type')
        # )
        self.init_optim(
            [p for p in self.model.parameters() if p.requires_grad])

    def build_data(self):
        if self.opt['process_data']:
            self.train_dataset = dataset(
                "../../data/data1030/output/train_cut.pkl", self.opt, 'train')
            self.valid_dataset = dataset(
                "../../data/data1030/output/valid_cut.pkl", self.opt, 'valid')
            self.test_dataset = dataset(
                "../../data/data1030/output/test_cut.pkl", self.opt, 'test')

            self.train_processed_set = self.train_dataset.data_process(True)
            self.valid_processed_set = self.valid_dataset.data_process(True)
            self.test_processed_set = self.test_dataset.data_process(True)

            pickle.dump(self.train_processed_set,
                        open('data/train_processed_set.pkl', 'wb'))
            pickle.dump(self.valid_processed_set,
                        open('data/valid_processed_set.pkl', 'wb'))
            pickle.dump(self.test_processed_set,
                        open('data/test_processed_set.pkl', 'wb'))
            logger.info("[Save processed data]")
        else:
            try:
                self.train_processed_set = pickle.load(
                    open('data/train_processed_set.pkl', 'rb'))
                self.valid_processed_set = pickle.load(
                    open('data/valid_processed_set.pkl', 'rb'))
                self.test_processed_set = pickle.load(
                    open('data/test_processed_set.pkl', 'rb'))
            except:
                assert 1 == 0, "No processed data"
            logger.info("[Load processed data]")

    def build_model(self):
        self.model = TransformerModel(self.opt, self.dict)
        # todo
        if self.opt['embedding_type'] != 'random':
            pass

        if self.opt['load_dict'] is not None:
            logger.info('[ Loading existing model params from {} ]'
                        ''.format(self.opt['load_dict']))
            self.model.load_model(self.opt['load_dict'])

        if self.use_cuda:
            self.model.to(self.device)

    def train(self):
        losses = []
        best_val_gen = 1000
        gen_stop = False
        patience = 0
        max_patience = 5
        num = 0

        # file_temp = open('temp.txt', 'w')
        # train_output_file = open(f"output_train_tf.txt", 'w', encoding='utf-8')

        for i in range(self.epoch):
            train_set = CRSdataset(self.train_processed_set,
                                   self.opt['n_entity'], self.opt['n_concept'])
            train_dataset_loader = torch.utils.data.DataLoader(
                dataset=train_set, batch_size=self.batch_size,
                shuffle=True)  # shuffle

            for context,c_lengths,response,r_length,mask_response, \
                    mask_r_length,entity,entity_vector,movie,\
                    concept_mask,dbpedia_mask,concept_vec, \
                    db_vec,rec in tqdm(train_dataset_loader):
                ####################################### 检验输入输出ok
                # file_temp.writelines("[Context] ", self.vector2sentence(context))
                # file_temp.writelines("[Response] ", self.vector2sentence(response))
                # file_temp.writelines("\n")

                seed_sets = []
                batch_size = context.shape[0]
                for b in range(batch_size):
                    seed_set = entity[b].nonzero().view(-1).tolist()
                    seed_sets.append(seed_set)

                self.model.train()
                self.zero_grad()

                scores, preds, rec_scores, rec_loss, gen_loss, mask_loss, info_db_loss, info_con_loss= \
                    self.model(context.to(self.device), response.to(self.device), mask_response.to(self.device), concept_mask, dbpedia_mask, seed_sets, movie, \
                        concept_vec, db_vec, entity_vector.to(self.device), rec, test=False)

                ##########################################
                # train_output_file.writelines(
                #     ["Loss per batch = %f\n" % gen_loss.item()])
                # train_output_file.writelines(['[GroundTruth] ' + ' '.join(sen_gt)+'\n' \
                #     + '[Generated] ' + ' '.join(sen_gen)+'\n\n' \
                #     for sen_gt, sen_gen in zip(self.vector2sentence(response.cpu()), self.vector2sentence(preds.cpu()))])

                losses.append([gen_loss])
                self.backward(gen_loss)
                self.update_params()

                if num % 50 == 0:
                    loss = sum([l[0] for l in losses]) / len(losses)
                    ppl = exp(loss)
                    logger.info('gen loss is %f, ppl is %f' % (loss, ppl))
                    losses = []

                num += 1

            output_metrics_gen = self.val(epoch=i)
            _ = self.val(True, epoch=i)

            if best_val_gen < output_metrics_gen["ppl"]:
                patience += 1
                logger.info('Patience = ', patience)
                if patience >= 5:
                    gen_stop = True
            else:
                patience = 0
                best_val_gen = output_metrics_gen["ppl"]
                self.model.save_model(self.opt['model_save_path'])
                logger.info(
                    f"[generator model saved in {self.opt['model_save_path']}"
                    "------------------------------------------------]")

            if gen_stop:
                break

        # train_output_file.close()
        # _ = self.val(is_test=True)

    def val(self, is_test=False, epoch=-1):
        # count是response数量
        self.model.eval()
        if is_test:
            valid_processed_set = self.test_processed_set
        else:
            valid_processed_set = self.valid_processed_set

        val_set = CRSdataset(valid_processed_set, self.opt['n_entity'],
                             self.opt['n_concept'])
        val_dataset_loader = torch.utils.data.DataLoader(
            dataset=val_set, batch_size=self.batch_size, shuffle=False)

        inference_sum = []
        tf_inference_sum = []
        golden_sum = []
        # context_sum = []
        losses = []
        recs = []

        for context, c_lengths, response, r_length, mask_response, mask_r_length, \
                entity, entity_vector, movie, concept_mask, dbpedia_mask, concept_vec, db_vec, rec \
                in tqdm(val_dataset_loader):
            with torch.no_grad():
                seed_sets = []
                batch_size = context.shape[0]
                for b in range(batch_size):
                    seed_set = entity[b].nonzero().view(-1).tolist()
                    seed_sets.append(seed_set)

                # 使用teacher force下的回复生成,
                _, tf_preds, _, _, gen_loss, mask_loss, info_db_loss, info_con_loss = \
                    self.model(context.to(self.device), response.to(self.device), mask_response.to(self.device), concept_mask, dbpedia_mask, \
                        seed_sets, movie, concept_vec, db_vec, entity_vector.to(self.device), rec, test=False)

                # 使用greedy模式下的回复生成,限定maxlen=20?
                # todo
                scores, preds, rec_scores, rec_loss, _, mask_loss, info_db_loss, info_con_loss = \
                    self.model(context.to(self.device), response.to(self.device), mask_response.to(self.device), concept_mask, dbpedia_mask, \
                        seed_sets, movie, concept_vec, db_vec, entity_vector.to(self.device), rec, test=True, maxlen=20, bsz=batch_size)

            golden_sum.extend(self.vector2sentence(response.cpu()))
            inference_sum.extend(self.vector2sentence(preds.cpu()))
            # tf_inference_sum.extend(self.vector2sentence(tf_preds.cpu()))
            # context_sum.extend(self.vector2sentence(context.cpu()))
            recs.extend(rec.cpu())
            losses.append(torch.mean(gen_loss))
            #logger.info(losses)
            #exit()

        subset = 'valid' if not is_test else 'test'

        # 原版: gen-loss来自teacher force,inference_sum来自greedy
        ppl = exp(sum(loss for loss in losses) / len(losses))
        output_dict_gen = {'ppl': ppl}
        logger.info(f"{subset} set metrics = {output_dict_gen}")
        # logger.info(f"{subset} set gt metrics = {self.metrics_gt}")

        # f=open('context_test.txt','w',encoding='utf-8')
        # f.writelines([' '.join(sen)+'\n' for sen in context_sum])
        # f.close()

        # 将生成的回复输出
        with open(f"output/output_{subset}_gen_epoch_{epoch}.txt",
                  'w',
                  encoding='utf-8') as f:
            f.writelines([
                '[Generated] ' + re.sub('@\d+', '__UNK__', ' '.join(sen)) +
                '\n' for sen in inference_sum
            ])

        # gt shuchu
        with open(f"output/output_{subset}_gt_epoch_{epoch}.txt",
                  'w',
                  encoding='utf-8') as f:
            for sen in golden_sum:
                mask_sen = re.sub('@\d+', '__UNK__', ' '.join(sen))
                mask_sen = re.sub(' ([!,.?])', '\\1', mask_sen)
                f.writelines(['[GT] ' + mask_sen + '\n'])

        # 将生成的回复与gt一起输出
        with open(f"output/output_{subset}_both_epoch_{epoch}.txt",
                  'w',
                  encoding='utf-8') as f:
            f.writelines(['[GroundTruth] ' + re.sub('@\d+', '__UNK__',' '.join(sen_gt))+'\n' \
                + '[Generated] ' + re.sub('@\d+', '__UNK__',' '.join(sen_gen))+'\n\n' \
                for sen_gt, sen_gen in zip(golden_sum, inference_sum)])

        self.save_embedding()

        return output_dict_gen

    def save_embedding(self):
        json.dump(loop.dict, open('output/tf_bpe2index.json', 'w'))

    def vector2sentence(self, batch_sen):
        # 一个batch的sentence 从id换成token
        sentences = []
        for sen in batch_sen.numpy().tolist():
            sentence = []
            for word in sen:
                if word > 3:
                    sentence.append(self.index2word[word])
                elif word == 3:
                    sentence.append('_UNK_')
            sentences.append(sentence)
        return sentences

    @classmethod
    def optim_opts(self):
        """
        Fetch optimizer selection.

        By default, collects everything in torch.optim, as well as importing:
        - qhm / qhmadam if installed from github.com/facebookresearch/qhoptim

        Override this (and probably call super()) to add your own optimizers.
        """
        # first pull torch.optim in
        optims = {
            k.lower(): v
            for k, v in optim.__dict__.items()
            if not k.startswith('__') and k[0].isupper()
        }
        try:
            import apex.optimizers.fused_adam as fused_adam
            optims['fused_adam'] = fused_adam.FusedAdam
        except ImportError:
            pass

        try:
            # https://openreview.net/pdf?id=S1fUpoR5FQ
            from qhoptim.pyt import QHM, QHAdam
            optims['qhm'] = QHM
            optims['qhadam'] = QHAdam
        except ImportError:
            # no QHM installed
            pass
        logger.info(optims)
        return optims

    def init_optim(self, params, optim_states=None, saved_optim_type=None):
        """
        Initialize optimizer with model parameters.

        :param params:
            parameters from the model

        :param optim_states:
            optional argument providing states of optimizer to load

        :param saved_optim_type:
            type of optimizer being loaded, if changed will skip loading
            optimizer states
        """

        opt = self.opt

        # set up optimizer args
        lr = opt['learningrate']
        kwargs = {'lr': lr}
        # kwargs['amsgrad'] = True
        # kwargs['betas'] = (0.9, 0.999)

        optim_class = self.optim_opts()[opt['optimizer']]
        logger.info(f'optim_class = {optim_class}')
        self.optimizer = optim_class(params, **kwargs)

    def backward(self, loss):
        """
        Perform a backward pass. It is recommended you use this instead of
        loss.backward(), for integration with distributed training and FP16
        training.
        """
        loss.backward()

    def update_params(self):
        """
        Perform step of optimization, clipping gradients and adjusting LR
        schedule if needed. Gradient accumulation is also performed if agent
        is called with --update-freq.

        It is recommended (but not forced) that you call this in train_step.
        """
        update_freq = 1
        if update_freq > 1:
            # we're doing gradient accumulation, so we don't only want to step
            # every N updates instead
            self._number_grad_accum = (self._number_grad_accum +
                                       1) % update_freq
            if self._number_grad_accum != 0:
                return
        #0.1是不是太小了,原版就是这样
        if self.opt['gradient_clip'] > 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.opt['gradient_clip'])

        self.optimizer.step()

    def zero_grad(self):
        """
        Zero out optimizer.

        It is recommended you call this in train_step. It automatically handles
        gradient accumulation if agent is called with --update-freq.
        """
        self.optimizer.zero_grad()