Exemple #1
0
    def val(self,is_test=False):
        self.metrics_gen={"ppl":0,"dist1":0,"dist2":0,"dist3":0,"dist4":0,"bleu1":0,"bleu2":0,"bleu3":0,"bleu4":0,"count":0}
        self.metrics_rec={"recall@1":0,"recall@10":0,"recall@50":0,"loss":0,"gate":0,"count":0,'gate_count':0}
        self.model.eval()
        if is_test:
            val_dataset = dataset('data/test_data.jsonl', self.opt)
        else:
            val_dataset = dataset('data/valid_data.jsonl', self.opt)
        val_set=CRSdataset(val_dataset.data_process(),self.opt['n_entity'],self.opt['n_concept'])
        val_dataset_loader = torch.utils.data.DataLoader(dataset=val_set,
                                                           batch_size=self.batch_size,
                                                           shuffle=False)
        recs=[]
        for context, c_lengths, response, r_length, mask_response, mask_r_length, entity, entity_vector, movie, concept_mask, dbpedia_mask, concept_vec, db_vec, rec in tqdm(val_dataset_loader):
            with torch.no_grad():
                seed_sets = []
                batch_size = context.shape[0]
                for b in range(batch_size):
                    seed_set = entity[b].nonzero().view(-1).tolist()
                    seed_sets.append(seed_set)
                scores, preds, rec_scores, rec_loss, _, mask_loss, info_db_loss, info_con_loss = self.model(context.cuda(), response.cuda(), mask_response.cuda(), concept_mask, dbpedia_mask, seed_sets, movie, concept_vec, db_vec, entity_vector.cuda(), rec, test=True, maxlen=20, bsz=batch_size)

            recs.extend(rec.cpu())
            #print(losses)
            #exit()
            self.metrics_cal_rec(rec_loss, rec_scores, movie)

        output_dict_rec={key: self.metrics_rec[key] / self.metrics_rec['count'] for key in self.metrics_rec}
        print(output_dict_rec)

        return output_dict_rec
Exemple #2
0
    def train(self):
        self.model.load_model()
        losses = []
        best_val_gen = 1000
        gen_stop = False
        for i in range(self.epoch * 3):
            train_set = CRSdataset(self.train_dataset.data_process(True),
                                   self.opt['n_entity'], self.opt['n_concept'])
            train_dataset_loader = torch.utils.data.DataLoader(
                dataset=train_set, batch_size=self.batch_size, shuffle=False)
            num = 0
            for context, c_lengths, response, r_length, mask_response, mask_r_length, entity, entity_vector, movie, concept_mask, dbpedia_mask, concept_vec, db_vec, rec in tqdm(
                    train_dataset_loader):
                seed_sets = []
                batch_size = context.shape[0]
                for b in range(batch_size):
                    seed_set = entity[b].nonzero().view(-1).tolist()
                    seed_sets.append(seed_set)
                self.model.train()
                self.zero_grad()

                # Commented here
                # scores, preds, rec_scores, rec_loss, gen_loss, mask_loss, info_db_loss, info_con_loss=self.model(context.cuda(), response.cuda(), mask_response.cuda(), concept_mask, dbpedia_mask, seed_sets, movie, concept_vec, db_vec, entity_vector.cuda(), rec, test=False)
                scores, preds, rec_scores, rec_loss, gen_loss, mask_loss, info_db_loss, info_con_loss = self.model(
                    context,
                    response,
                    mask_response,
                    concept_mask,
                    dbpedia_mask,
                    seed_sets,
                    movie,
                    concept_vec,
                    db_vec,
                    entity_vector,
                    rec,
                    test=False)

                joint_loss = gen_loss

                losses.append([gen_loss])
                self.backward(joint_loss)
                self.update_params()
                if num % 50 == 0:
                    print('gen loss is %f' %
                          (sum([l[0] for l in losses]) / len(losses)))
                    losses = []
                num += 1

            output_metrics_gen = self.val(True)
            if best_val_gen < output_metrics_gen["dist4"]:
                pass
            else:
                best_val_gen = output_metrics_gen["dist4"]
                self.model.save_model()
                print(
                    "generator model saved once------------------------------------------------"
                )

        _ = self.val(is_test=True)
Exemple #3
0
    def val(self,is_test=False):
        self.metrics_gen={"ppl":0,"dist1":0,"dist2":0,"dist3":0,"dist4":0,"bleu1":0,"bleu2":0,"bleu3":0,"bleu4":0,"count":0}
        self.metrics_rec={"recall@1":0,"recall@10":0,"recall@50":0,"loss":0,"gate":0,"count":0,'gate_count':0}
        self.model.eval()
        if is_test:
            val_dataset = dataset('data/test_data.jsonl', self.opt)
        else:
            val_dataset = dataset('data/valid_data.jsonl', self.opt)
        val_set=CRSdataset(val_dataset.data_process(True),self.opt['n_entity'],self.opt['n_concept'])
        val_dataset_loader = torch.utils.data.DataLoader(dataset=val_set,
                                                           batch_size=self.batch_size,
                                                           shuffle=False)
        inference_sum=[]
        golden_sum=[]
        context_sum=[]
        losses=[]
        recs=[]
        for context, c_lengths, response, r_length, mask_response, mask_r_length, entity, entity_vector, movie, concept_mask, dbpedia_mask, concept_vec, db_vec, rec in tqdm(val_dataset_loader):
            with torch.no_grad():
                seed_sets = []
                batch_size = context.shape[0]
                for b in range(batch_size):
                    seed_set = entity[b].nonzero().view(-1).tolist()
                    seed_sets.append(seed_set)
                _, _, _, _, gen_loss, mask_loss, info_db_loss, info_con_loss = self.model(context.cuda(), response.cuda(), mask_response.cuda(), concept_mask, dbpedia_mask, seed_sets, movie, concept_vec, db_vec, entity_vector.cuda(), rec, test=False)
                scores, preds, rec_scores, rec_loss, _, mask_loss, info_db_loss, info_con_loss = self.model(context.cuda(), response.cuda(), mask_response.cuda(), concept_mask, dbpedia_mask, seed_sets, movie, concept_vec, db_vec, entity_vector.cuda(), rec, test=True, maxlen=20, bsz=batch_size)

            golden_sum.extend(self.vector2sentence(response.cpu()))
            inference_sum.extend(self.vector2sentence(preds.cpu()))
            context_sum.extend(self.vector2sentence(context.cpu()))
            recs.extend(rec.cpu())
            losses.append(torch.mean(gen_loss))
            #print(losses)
            #exit()

        self.metrics_cal_gen(losses,inference_sum,golden_sum,recs)

        output_dict_gen={}
        for key in self.metrics_gen:
            if 'bleu' in key:
                output_dict_gen[key]=self.metrics_gen[key]/self.metrics_gen['count']
            else:
                output_dict_gen[key]=self.metrics_gen[key]
        print(output_dict_gen)

        f=open('context_test.txt','w',encoding='utf-8')
        f.writelines([' '.join(sen)+'\n' for sen in context_sum])
        f.close()

        f=open('output_test.txt','w',encoding='utf-8')
        f.writelines([' '.join(sen)+'\n' for sen in inference_sum])
        f.close()
        return output_dict_gen
Exemple #4
0
 def build_data(self):
     # 初始化分词器
     # self.tokenizer = BertTokenizer(vocab_file=self.opt['vocab_path'])  # 初始化分词器
     # build and save self.dataset
     self.dataset = {'train': None, 'valid': None, 'test': None}
     self.dataset_loader = {'train': None, 'valid': None, 'test': None}
     for subset in self.dataset:
         self.dataset[subset] = CRSdataset(subset, self.opt[f'{subset}_data_file'], self.opt, self.args, \
             self.opt['save_build_data'], self.opt['load_builded_data'])
         self.dataset_loader[subset] =  torch.utils.data.DataLoader(dataset=self.dataset[subset],
                                                         batch_size=self.batch_size,
                                                         shuffle=True,
                                                         collate_fn = collate_fn) # todo
         args.topic_class_num = self.dataset[subset].topic_class_num
Exemple #5
0
 def build_data(self):
     # 初始化分词器
     self.tokenizer = BertTokenizer(vocab_file=self.opt['vocab_path'])  # 初始化分词器
     # build and save self.dataset
     self.dataset = {'train': None, 'valid': None, 'test': None}
     self.dataset_loader = {'train': None, 'valid': None, 'test': None}
     for subset in self.dataset:
         self.dataset[subset] = CRSdataset(subset, self.opt[f'{subset}_data_file'], self.opt, self.args, self.tokenizer, self.opt['use_size'])
         self.dataset_loader[subset] =  torch.utils.data.DataLoader(dataset=self.dataset[subset],
                                                         batch_size=self.batch_size,
                                                         shuffle=True) # todo
     # self.args.item_size += 1
     self.movie_num = self.dataset['train'].movie_num + 1
     self.args.item_size = self.dataset['train'].movie_num + 1
     self.item_size = self.dataset['train'].movie_num + 1
Exemple #6
0
    def build_data(self):
        # 初始化分词器
        self.tokenizer = BertTokenizer(
            vocab_file=self.args.vocab_path)  # 初始化分词器
        # build and save self.dataset
        self.dataset = {'train': None, 'valid': None, 'test': None}
        self.dataset_loader = {'train': None, 'valid': None, 'test': None}
        for subset in self.dataset:
            self.dataset[subset] = CRSdataset(logger, subset,
                                              self.opt[f'{subset}_data_file'],
                                              self.args, self.tokenizer)
            self.dataset_loader[subset] = torch.utils.data.DataLoader(
                dataset=self.dataset[subset],
                batch_size=self.batch_size,
                shuffle=True)

        # self.dataset['train'].movie_num 是增加了unk之后的电影数量,+1是他们提高1位,增加0的电影总数
        self.item_size = self.dataset['train'].movie_num + 1
        self.args.item_size = self.item_size
Exemple #7
0
    def build_data(self):
        # 初始化分词器
        self.tokenizer = lambda x: [y for y in x]
        # vocab
        self.vocab = pickle.load(open(args.vocab_path, 'rb'))
        self.id2token = {id: token for token, id in self.vocab.items()}
        self.id2token['<SENT>'] = len(self.id2token)

        # build and save self.dataset
        self.dataset = {'train': None, 'valid': None, 'test': None}
        self.dataset_loader = {'train': None, 'valid': None, 'test': None}
        for subset in self.dataset:
            self.dataset[subset] = CRSdataset(subset, self.opt[f'{subset}_data_file'], self.opt, \
                                                self.args, self.tokenizer, self.vocab, \
                                                self.opt['save_build_data'], self.opt['load_builded_data'], self.opt['use_size'])
            self.dataset_loader[subset] = torch.utils.data.DataLoader(
                dataset=self.dataset[subset],
                batch_size=self.batch_size,
                shuffle=True)  # todo
            self.movie_num = self.dataset[subset].movie_num
Exemple #8
0
    def train(self):
        #self.model.load_model()
        losses=[]
        best_val_rec=0
        rec_stop=False
        for i in range(3):
            train_set=CRSdataset(self.train_dataset.data_process(),self.opt['n_entity'],self.opt['n_concept'])
            train_dataset_loader = torch.utils.data.DataLoader(dataset=train_set,
                                                            batch_size=self.batch_size,
                                                            shuffle=False)
            num=0
            for context,c_lengths,response,r_length,mask_response,mask_r_length,entity,entity_vector,movie,concept_mask,dbpedia_mask,concept_vec, db_vec,rec in tqdm(train_dataset_loader):
                seed_sets = []
                batch_size = context.shape[0]
                for b in range(batch_size):
                    seed_set = entity[b].nonzero().view(-1).tolist()
                    seed_sets.append(seed_set)
                self.model.train()
                self.zero_grad()

                scores, preds, rec_scores, rec_loss, gen_loss, mask_loss, info_db_loss, _=self.model(context.cuda(), response.cuda(), mask_response.cuda(),
                                                                                                                            concept_mask, dbpedia_mask, seed_sets, movie, concept_vec, db_vec, entity_vector.cuda(), rec, test=False)

                joint_loss=info_db_loss#+info_con_loss

                losses.append([info_db_loss])
                self.backward(joint_loss)
                self.update_params()
                if num%50==0:
                    print('info db loss is %f'%(sum([l[0] for l in losses])/len(losses)))
                    #print('info con loss is %f'%(sum([l[1] for l in losses])/len(losses)))
                    losses=[]
                num+=1

        print("masked loss pre-trained")
        losses=[]

        for i in range(self.epoch):
            train_set=CRSdataset(self.train_dataset.data_process(),self.opt['n_entity'],self.opt['n_concept'])
            train_dataset_loader = torch.utils.data.DataLoader(dataset=train_set,
                                                            batch_size=self.batch_size,
                                                            shuffle=False)
            num=0
            for context,c_lengths,response,r_length,mask_response,mask_r_length,entity,entity_vector,movie,concept_mask,dbpedia_mask,concept_vec, db_vec,rec in tqdm(train_dataset_loader):
                seed_sets = []
                batch_size = context.shape[0]
                for b in range(batch_size):
                    seed_set = entity[b].nonzero().view(-1).tolist()
                    seed_sets.append(seed_set)
                self.model.train()
                self.zero_grad()

                scores, preds, rec_scores, rec_loss, gen_loss, mask_loss, info_db_loss, _=self.model(context.cuda(), response.cuda(), mask_response.cuda(), concept_mask, dbpedia_mask, seed_sets, movie,concept_vec, db_vec, entity_vector.cuda(), rec, test=False)

                joint_loss=rec_loss+0.025*info_db_loss#+0.0*info_con_loss#+mask_loss*0.05

                losses.append([rec_loss,info_db_loss])
                self.backward(joint_loss)
                self.update_params()
                if num%50==0:
                    print('rec loss is %f'%(sum([l[0] for l in losses])/len(losses)))
                    print('info db loss is %f'%(sum([l[1] for l in losses])/len(losses)))
                    losses=[]
                num+=1

            output_metrics_rec = self.val()

            if best_val_rec > output_metrics_rec["recall@50"]+output_metrics_rec["recall@1"]:
                rec_stop=True
            else:
                best_val_rec = output_metrics_rec["recall@50"]+output_metrics_rec["recall@1"]
                self.model.save_model()
                print("recommendation model saved once------------------------------------------------")

            if rec_stop==True:
                break

        _=self.val(is_test=True)
Exemple #9
0
    def train(self):
        self.model.load_model()
        best_val_gen = 0
        step = 0
        gen_stop = False
        for i in range(self.epoch * 3):
            # get dataloader
            train_set = CRSdataset(self.train_dataset.data_process(True),
                                   self.opt['n_entity'], self.opt['n_concept'])
            train_dataset_loader = torch.utils.data.DataLoader(
                dataset=train_set, batch_size=self.batch_size, shuffle=False)

            # train
            num, loss_epoch = 0, 0
            for context, c_lengths, response, r_length, mask_response, mask_r_length, entity, entity_vector, \
                    movie, concept_mask, dbpedia_mask, concept_vec, db_vec, rec in tqdm(train_dataset_loader):

                # get movies appeared in context
                seed_sets = []
                batch_size = context.shape[0]
                for b in range(batch_size):
                    seed_set = entity[b].nonzero().view(-1).tolist()
                    seed_sets.append(seed_set)

                # set mode
                self.model.train()
                self.zero_grad()

                # forward
                scores, preds, rec_scores, rec_loss, gen_loss, mask_loss, info_db_loss, info_con_loss = self.model(
                    context.cuda(), concept_mask, dbpedia_mask, concept_vec,
                    db_vec, seed_sets, entity_vector.cuda(), TrainType.TRAIN,
                    response.cuda(), mask_response.cuda(), movie, rec)

                # get loss and update model
                joint_loss = gen_loss
                self.backward(joint_loss)
                self.update_params()
                loss_epoch += gen_loss.item()

                # monitor loss on training set
                num += 1
                if num % 100 == 1:
                    self.writer.add_scalar('Gen/Loss/Gen/Train',
                                           loss_epoch / num, step)
                    step += 1

            # validate
            output_metrics_gen = self.val()
            if best_val_gen > output_metrics_gen["dist4"]:
                pass
            else:
                best_val_gen = output_metrics_gen["dist4"]
                self.model.save_model()
                print(
                    "generator model saved once------------------------------------------------"
                )

            # monitor perfomance on validation set
            self.writer.add_scalar('Gen/Dist/1/Valid',
                                   output_metrics_gen['dist1'], i)
            self.writer.add_scalar('Gen/Dist/2/Valid',
                                   output_metrics_gen['dist2'], i)
            self.writer.add_scalar('Gen/Dist/3/Valid',
                                   output_metrics_gen['dist3'], i)
            self.writer.add_scalar('Gen/Dist/4/Valid',
                                   output_metrics_gen['dist4'], i)
            self.writer.add_scalar('Gen/Loss/Gen/Valid',
                                   output_metrics_gen['loss'], i)

        # testing
        _ = self.val(is_test=True)
Exemple #10
0
    def train(self):
        losses = []
        best_val_rec, step = 0, 0
        rec_stop = False

        # train for 3 epochs with MIM loss
        for i in range(3):
            train_set = CRSdataset(self.train_dataset.data_process(),
                                   self.opt['n_entity'], self.opt['n_concept'])
            train_dataset_loader = torch.utils.data.DataLoader(
                dataset=train_set, batch_size=self.batch_size, shuffle=False)

            for i, (context, c_lengths, response, r_length, mask_response,
                    mask_r_length, entity, entity_vector, movie, concept_mask,
                    dbpedia_mask, concept_vec, db_vec,
                    rec) in enumerate(tqdm(train_dataset_loader)):

                seed_sets = [en.nonzero().view(-1).tolist() for en in entity]
                # batch_size = context.shape[0]
                # for b in range(batch_size):
                #     seed_set = entity[b].nonzero().view(-1).tolist()
                #     seed_sets.append(seed_set)

                self.model.train()
                self.zero_grad()

                # forward
                scores, preds, rec_scores, rec_loss, gen_loss, mask_loss, info_db_loss, _ = self.model(
                    context.cuda(), concept_mask, dbpedia_mask, concept_vec,
                    db_vec, seed_sets, entity_vector.cuda(), TrainType.TRAIN,
                    response.cuda(), mask_response.cuda(), movie, rec)

                # update model
                joint_loss = info_db_loss
                losses.append([info_db_loss])
                self.backward(joint_loss)
                self.update_params()

                # monitor
                if i % 100 == 0:
                    self.writer.add_scalar(
                        'Pre/Loss/MIM/Train',
                        sum([l[0] for l in losses]) / len(losses), step)
                    step += 1
                    losses = []

        print("masked loss pre-trained")
        losses = []
        step = 0

        # train with recommendation task
        for i in range(self.epoch):
            train_set = CRSdataset(self.train_dataset.data_process(),
                                   self.opt['n_entity'], self.opt['n_concept'])
            train_dataset_loader = torch.utils.data.DataLoader(
                dataset=train_set, batch_size=self.batch_size, shuffle=False)

            for j, (context, c_lengths, response, r_length, mask_response,
                    mask_r_length, entity, entity_vector, movie, concept_mask,
                    dbpedia_mask, concept_vec, db_vec,
                    rec) in enumerate(tqdm(train_dataset_loader)):

                seed_sets = [en.nonzero().view(-1).tolist() for en in entity]

                self.model.train()
                self.zero_grad()

                # forward
                scores, preds, rec_scores, rec_loss, gen_loss, mask_loss, info_db_loss, info_con_loss = self.model(
                    context.cuda(), concept_mask, dbpedia_mask, concept_vec,
                    db_vec, seed_sets, entity_vector.cuda(), TrainType.TRAIN,
                    response.cuda(), mask_response.cuda(), movie, rec)

                # update model
                joint_loss = rec_loss + 0.025 * info_db_loss
                losses.append([rec_loss, info_db_loss])
                self.backward(joint_loss)
                self.update_params()

                # monitor
                if j % 100 == 0:
                    self.writer.add_scalar(
                        'Rec/Loss/Rec/Train',
                        sum([l[0] for l in losses]) / len(losses), step)
                    self.writer.add_scalar(
                        'Rec/Loss/MIM/Train',
                        sum([l[1] for l in losses]) / len(losses), step)
                    step += 1
                    losses = []

            # validation
            output_metrics_rec = self.val()

            if best_val_rec > output_metrics_rec[
                    "recall@50"] + output_metrics_rec["recall@1"]:
                rec_stop = True
            else:
                best_val_rec = output_metrics_rec[
                    "recall@50"] + output_metrics_rec["recall@1"]
                self.model.save_model()
                print(
                    "recommendation model saved once------------------------------------------------"
                )

            # monitor recall and loss
            self.writer.add_scalar('Rec/Recall/50/Valid',
                                   output_metrics_rec["recall@50"], i)
            self.writer.add_scalar('Rec/Recall/10/Valid',
                                   output_metrics_rec["recall@10"], i)
            self.writer.add_scalar('Rec/Recall/1/Valid',
                                   output_metrics_rec["recall@1"], i)
            self.writer.add_scalar('Rec/Loss/Rec/Valid',
                                   output_metrics_rec['loss'], i)

            if rec_stop == True:
                break

        _ = self.val(is_test=True)
Exemple #11
0
    def val(self, is_test=False, epoch=-1):
        # count是response数量
        self.model.eval()
        if is_test:
            valid_processed_set = self.test_processed_set
        else:
            valid_processed_set = self.valid_processed_set

        val_set = CRSdataset(valid_processed_set, self.opt['n_entity'],
                             self.opt['n_concept'])
        val_dataset_loader = torch.utils.data.DataLoader(
            dataset=val_set, batch_size=self.batch_size, shuffle=False)

        inference_sum = []
        tf_inference_sum = []
        golden_sum = []
        # context_sum = []
        losses = []
        recs = []

        for context, c_lengths, response, r_length, mask_response, mask_r_length, \
                entity, entity_vector, movie, concept_mask, dbpedia_mask, concept_vec, db_vec, rec \
                in tqdm(val_dataset_loader):
            with torch.no_grad():
                seed_sets = []
                batch_size = context.shape[0]
                for b in range(batch_size):
                    seed_set = entity[b].nonzero().view(-1).tolist()
                    seed_sets.append(seed_set)

                # 使用teacher force下的回复生成,
                _, tf_preds, _, _, gen_loss, mask_loss, info_db_loss, info_con_loss = \
                    self.model(context.to(self.device), response.to(self.device), mask_response.to(self.device), concept_mask, dbpedia_mask, \
                        seed_sets, movie, concept_vec, db_vec, entity_vector.to(self.device), rec, test=False)

                # 使用greedy模式下的回复生成,限定maxlen=20?
                # todo
                scores, preds, rec_scores, rec_loss, _, mask_loss, info_db_loss, info_con_loss = \
                    self.model(context.to(self.device), response.to(self.device), mask_response.to(self.device), concept_mask, dbpedia_mask, \
                        seed_sets, movie, concept_vec, db_vec, entity_vector.to(self.device), rec, test=True, maxlen=20, bsz=batch_size)

            golden_sum.extend(self.vector2sentence(response.cpu()))
            inference_sum.extend(self.vector2sentence(preds.cpu()))
            # tf_inference_sum.extend(self.vector2sentence(tf_preds.cpu()))
            # context_sum.extend(self.vector2sentence(context.cpu()))
            recs.extend(rec.cpu())
            losses.append(torch.mean(gen_loss))
            #logger.info(losses)
            #exit()

        subset = 'valid' if not is_test else 'test'

        # 原版: gen-loss来自teacher force,inference_sum来自greedy
        ppl = exp(sum(loss for loss in losses) / len(losses))
        output_dict_gen = {'ppl': ppl}
        logger.info(f"{subset} set metrics = {output_dict_gen}")
        # logger.info(f"{subset} set gt metrics = {self.metrics_gt}")

        # f=open('context_test.txt','w',encoding='utf-8')
        # f.writelines([' '.join(sen)+'\n' for sen in context_sum])
        # f.close()

        # 将生成的回复输出
        with open(f"output/output_{subset}_gen_epoch_{epoch}.txt",
                  'w',
                  encoding='utf-8') as f:
            f.writelines([
                '[Generated] ' + re.sub('@\d+', '__UNK__', ' '.join(sen)) +
                '\n' for sen in inference_sum
            ])

        # gt shuchu
        with open(f"output/output_{subset}_gt_epoch_{epoch}.txt",
                  'w',
                  encoding='utf-8') as f:
            for sen in golden_sum:
                mask_sen = re.sub('@\d+', '__UNK__', ' '.join(sen))
                mask_sen = re.sub(' ([!,.?])', '\\1', mask_sen)
                f.writelines(['[GT] ' + mask_sen + '\n'])

        # 将生成的回复与gt一起输出
        with open(f"output/output_{subset}_both_epoch_{epoch}.txt",
                  'w',
                  encoding='utf-8') as f:
            f.writelines(['[GroundTruth] ' + re.sub('@\d+', '__UNK__',' '.join(sen_gt))+'\n' \
                + '[Generated] ' + re.sub('@\d+', '__UNK__',' '.join(sen_gen))+'\n\n' \
                for sen_gt, sen_gen in zip(golden_sum, inference_sum)])

        self.save_embedding()

        return output_dict_gen
Exemple #12
0
    def train(self):
        losses = []
        best_val_gen = 1000
        gen_stop = False
        patience = 0
        max_patience = 5
        num = 0

        # file_temp = open('temp.txt', 'w')
        # train_output_file = open(f"output_train_tf.txt", 'w', encoding='utf-8')

        for i in range(self.epoch):
            train_set = CRSdataset(self.train_processed_set,
                                   self.opt['n_entity'], self.opt['n_concept'])
            train_dataset_loader = torch.utils.data.DataLoader(
                dataset=train_set, batch_size=self.batch_size,
                shuffle=True)  # shuffle

            for context,c_lengths,response,r_length,mask_response, \
                    mask_r_length,entity,entity_vector,movie,\
                    concept_mask,dbpedia_mask,concept_vec, \
                    db_vec,rec in tqdm(train_dataset_loader):
                ####################################### 检验输入输出ok
                # file_temp.writelines("[Context] ", self.vector2sentence(context))
                # file_temp.writelines("[Response] ", self.vector2sentence(response))
                # file_temp.writelines("\n")

                seed_sets = []
                batch_size = context.shape[0]
                for b in range(batch_size):
                    seed_set = entity[b].nonzero().view(-1).tolist()
                    seed_sets.append(seed_set)

                self.model.train()
                self.zero_grad()

                scores, preds, rec_scores, rec_loss, gen_loss, mask_loss, info_db_loss, info_con_loss= \
                    self.model(context.to(self.device), response.to(self.device), mask_response.to(self.device), concept_mask, dbpedia_mask, seed_sets, movie, \
                        concept_vec, db_vec, entity_vector.to(self.device), rec, test=False)

                ##########################################
                # train_output_file.writelines(
                #     ["Loss per batch = %f\n" % gen_loss.item()])
                # train_output_file.writelines(['[GroundTruth] ' + ' '.join(sen_gt)+'\n' \
                #     + '[Generated] ' + ' '.join(sen_gen)+'\n\n' \
                #     for sen_gt, sen_gen in zip(self.vector2sentence(response.cpu()), self.vector2sentence(preds.cpu()))])

                losses.append([gen_loss])
                self.backward(gen_loss)
                self.update_params()

                if num % 50 == 0:
                    loss = sum([l[0] for l in losses]) / len(losses)
                    ppl = exp(loss)
                    logger.info('gen loss is %f, ppl is %f' % (loss, ppl))
                    losses = []

                num += 1

            output_metrics_gen = self.val(epoch=i)
            _ = self.val(True, epoch=i)

            if best_val_gen < output_metrics_gen["ppl"]:
                patience += 1
                logger.info('Patience = ', patience)
                if patience >= 5:
                    gen_stop = True
            else:
                patience = 0
                best_val_gen = output_metrics_gen["ppl"]
                self.model.save_model(self.opt['model_save_path'])
                logger.info(
                    f"[generator model saved in {self.opt['model_save_path']}"
                    "------------------------------------------------]")

            if gen_stop:
                break