def train_classifier(self, epochs): """ Classifier for calculating the classification accuracy metric of category text generation. Note: the train and test data for the classifier is opposite to the generator. Because the classifier is to calculate the classification accuracy of the generated samples where are trained on self.train_samples_list. Since there's no test data in synthetic data (oracle data), the synthetic data experiments doesn't need a classifier. """ import copy # Prepare data for Classifier clas_data = CatClasDataIter(self.clas_samples_list) eval_clas_data = CatClasDataIter(self.train_samples_list) max_acc = 0 best_clas = None for epoch in range(epochs): c_loss, c_acc = self.train_dis_epoch(self.clas, clas_data.loader, self.clas_criterion, self.clas_opt) _, eval_acc = self.eval_dis(self.clas, eval_clas_data.loader, self.clas_criterion) if eval_acc > max_acc: best_clas = copy.deepcopy( self.clas.state_dict()) # save the best classifier max_acc = eval_acc self.log.info( '[PRE-CLAS] epoch %d: c_loss = %.4f, c_acc = %.4f, eval_acc = %.4f, max_eval_acc = %.4f', epoch, c_loss, c_acc, eval_acc, max_acc) self.clas.load_state_dict( copy.deepcopy(best_clas)) # Reload the best classifier
def train_discriminator(self, d_step, d_epoch, phase='MLE'): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. """ # prepare loader for validate global d_loss, train_acc for step in range(d_step): # prepare loader for training real_samples = [] fake_samples = [] for i in range(cfg.k_label): real_samples.append(self.oracle_samples_list[i]) fake_samples.append(self.gen_list[i].sample( cfg.samples_num // cfg.k_label, 8 * cfg.batch_size)) dis_samples_list = [torch.cat(fake_samples, dim=0)] + real_samples dis_data = CatClasDataIter(dis_samples_list) for epoch in range(d_epoch): # ===Train=== d_loss, train_acc = self.train_dis_epoch( self.dis, dis_data.loader, self.dis_criterion, self.dis_opt) # ===Test=== self.log.info( '[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f' % (phase, step, d_loss, train_acc)) if cfg.if_save and not cfg.if_test and phase == 'MLE': torch.save(self.dis.state_dict(), cfg.pretrained_dis_path)
def cal_metrics_with_label(self, label_i): assert type(label_i) == int, 'missing label' with torch.no_grad(): # Prepare data for evaluation eval_samples = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) gen_data = GenDataIter(eval_samples) gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) gen_tokens_s = tensor_to_tokens( self.gen.sample(200, 200, label_i=label_i), self.idx2word_dict) clas_data = CatClasDataIter([eval_samples], label_i) # Reset metrics self.bleu.reset(test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens) self.nll_gen.reset(self.gen, self.train_data_list[label_i].loader, label_i) self.nll_div.reset(self.gen, gen_data.loader, label_i) self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) self.clas_acc.reset(self.clas, clas_data.loader) self.ppl.reset(gen_tokens) return [metric.get_score() for metric in self.all_metrics]