def __init__(self, opt): self.log = create_logger(__name__, silent=False, to_disk=True, log_file=cfg.log_filename if cfg.if_test else [cfg.log_filename, cfg.save_root + 'log.txt']) self.sig = Signal(cfg.signal_file) self.opt = opt self.show_config() self.clas = None # load dictionary self.word2idx_dict, self.idx2word_dict = load_dict(cfg.dataset) # Dataloader try: self.train_data = GenDataIter(cfg.train_data) self.test_data = GenDataIter(cfg.test_data, if_test_data=True) except: pass try: self.train_data_list = [GenDataIter(cfg.cat_train_data.format(i)) for i in range(cfg.k_label)] self.test_data_list = [GenDataIter(cfg.cat_test_data.format(i), if_test_data=True) for i in range(cfg.k_label)] self.clas_data_list = [GenDataIter(cfg.cat_test_data.format(str(i)), if_test_data=True) for i in range(cfg.k_label)] self.train_samples_list = [self.train_data_list[i].target for i in range(cfg.k_label)] self.clas_samples_list = [self.clas_data_list[i].target for i in range(cfg.k_label)] except: pass # Criterion self.mle_criterion = nn.NLLLoss() self.dis_criterion = nn.CrossEntropyLoss() self.clas_criterion = nn.CrossEntropyLoss() # Optimizer self.clas_opt = None # Metrics self.bleu = BLEU('BLEU', gram=[2, 3, 4, 5], if_use=cfg.use_bleu) self.nll_gen = NLL('NLL_gen', if_use=cfg.use_nll_gen, gpu=cfg.CUDA) self.nll_div = NLL('NLL_div', if_use=cfg.use_nll_div, gpu=cfg.CUDA) self.self_bleu = BLEU('Self-BLEU', gram=[2, 3, 4], if_use=cfg.use_self_bleu) self.clas_acc = ACC(if_use=cfg.use_clas_acc) self.ppl = PPL(self.train_data, self.test_data, n_gram=5, if_use=cfg.use_ppl) self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ppl]
def create_many_oracle(from_a, to_b, num=1, save_path='../pretrain/'): for i in range(num): while True: oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) if cfg.CUDA: oracle = oracle.cuda() big_samples = oracle.sample(cfg.samples_num, 8 * cfg.batch_size) small_samples = oracle.sample(cfg.samples_num // 2, 8 * cfg.batch_size) oracle_data = GenDataIter(big_samples) mle_criterion = nn.NLLLoss() groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) if from_a <= groud_truth <= to_b: print('save ground truth: ', groud_truth) prefix = 'oracle_lstm' torch.save(oracle.state_dict(), save_path + '{}.pt'.format(prefix)) torch.save( big_samples, save_path + '{}_samples_{}.pt'.format(prefix, cfg.samples_num)) torch.save( small_samples, save_path + '{}_samples_{}.pt'.format(prefix, cfg.samples_num // 2)) break
def create_multi_oracle(number): for i in range(number): print('Creating Oracle %d...' % i) oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) if cfg.CUDA: oracle = oracle.cuda() large_samples = oracle.sample(cfg.samples_num, 4 * cfg.batch_size) small_samples = oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size) torch.save(oracle.state_dict(), cfg.multi_oracle_state_dict_path.format(i)) torch.save(large_samples, cfg.multi_oracle_samples_path.format(i, cfg.samples_num)) torch.save( small_samples, cfg.multi_oracle_samples_path.format(i, cfg.samples_num // 2)) oracle_data = GenDataIter(large_samples) mle_criterion = nn.NLLLoss() groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) print('Oracle %d Groud Truth: %.4f' % (i, groud_truth))
def create_oracle(): """Create a new Oracle model and Oracle's samples""" import config as cfg from models.Oracle import Oracle print('Creating Oracle...') oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) if cfg.CUDA: oracle = oracle.cuda() torch.save(oracle.state_dict(), cfg.oracle_state_dict_path) big_samples = oracle.sample(cfg.samples_num, 4 * cfg.batch_size) # large torch.save(big_samples, cfg.oracle_samples_path.format(cfg.samples_num)) # small torch.save(oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size), cfg.oracle_samples_path.format(cfg.samples_num // 2)) oracle_data = GenDataIter(big_samples) mle_criterion = nn.NLLLoss() groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) print('NLL_Oracle Groud Truth: %.4f' % groud_truth)
def __init__(self, opt): self.log = create_logger(__name__, silent=False, to_disk=True, log_file=cfg.log_filename if cfg.if_test else [cfg.log_filename, cfg.save_root + 'log.txt']) self.sig = Signal(cfg.signal_file) self.opt = opt # oracle, generator, discriminator self.oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) self.oracle_list = [Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) for _ in range(cfg.k_label)] self.dis = None self.clas = None self.show_config() self.check_oracle() # Create Oracle models if not exist # DataLoader self.oracle_samples = torch.load(cfg.oracle_samples_path.format(cfg.samples_num)) self.oracle_samples_list = [torch.load(cfg.multi_oracle_samples_path.format(i, cfg.samples_num)) for i in range(cfg.k_label)] self.oracle_data = GenDataIter(self.oracle_samples) self.oracle_data_list = [GenDataIter(self.oracle_samples_list[i]) for i in range(cfg.k_label)] # Criterion self.mle_criterion = nn.NLLLoss() self.dis_criterion = nn.CrossEntropyLoss() # Metrics self.nll_oracle = NLL('NLL_oracle', if_use=cfg.use_nll_oracle, gpu=cfg.CUDA) self.nll_gen = NLL('NLL_gen', if_use=cfg.use_nll_gen, gpu=cfg.CUDA) self.nll_div = NLL('NLL_div', if_use=cfg.use_nll_div, gpu=cfg.CUDA) self.all_metrics = [self.nll_oracle, self.nll_gen, self.nll_div]
def evaluation(self, eval_type): """Evaluation all children, update child score. Note that the eval data should be the same""" eval_samples = self.gen.sample(cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size) gen_data = GenDataIter(eval_samples) # Fd if cfg.lambda_fd != 0: Fd = NLL.cal_nll(self.gen, gen_data.loader, self.mle_criterion) # NLL_div else: Fd = 0 if eval_type == 'standard': Fq = self.eval_d_out_fake.mean().cpu().item() elif eval_type == 'rsgan': g_loss, d_loss = get_losses(self.eval_d_out_real, self.eval_d_out_fake, 'rsgan') Fq = d_loss.item() elif eval_type == 'nll': if cfg.lambda_fq != 0: Fq = -NLL.cal_nll(self.oracle, gen_data.loader, self.mle_criterion) # NLL_Oracle else: Fq = 0 elif eval_type == 'Ra': g_loss = torch.sigmoid(self.eval_d_out_fake - torch.mean(self.eval_d_out_real)).sum() Fq = g_loss.item() else: raise NotImplementedError("Evaluation '%s' is not implemented" % eval_type) score = cfg.lambda_fq * Fq + cfg.lambda_fd * Fd return Fq, Fd, score
def evaluation(self, eval_type): """Evaluation all children, update child score. Note that the eval data should be the same""" eval_samples = [ self.gen.sample(cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size, label_i=i) for i in range(cfg.k_label) ] # Fd if cfg.lambda_fd != 0: nll_div = [] for label_i in range(cfg.k_label): gen_data = GenDataIter(eval_samples[label_i]) nll_div.append( NLL.cal_nll_with_label(self.gen, gen_data.loader, label_i, self.mle_criterion)) Fd = sum(nll_div) else: Fd = 0 # Fq if 'bleu' in eval_type: bleu_score = [] for i in range(cfg.k_label): bleu_score.append( self.bleu[i].get_score(given_gram=int(eval_type[-1]))) Fq = sum(bleu_score) elif 'Ra' in eval_type: g_loss = 0 for i in range(cfg.k_label): g_loss += torch.sigmoid( self.eval_d_out_fake[i] - torch.mean(self.eval_d_out_real[i])).sum() Fq = g_loss.item() else: raise NotImplementedError("Evaluation '%s' is not implemented" % eval_type) score = cfg.lambda_fq * Fq + cfg.lambda_fd * Fd return Fq, Fd, score
def evaluation(self, eval_type): """Evaluation all children, update child score. Note that the eval data should be the same""" eval_samples = self.gen.sample(cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size) gen_data = GenDataIter(eval_samples) # Fd if cfg.lambda_fd != 0: Fd = NLL.cal_nll(self.gen, gen_data.loader, self.mle_criterion) # NLL_div else: Fd = 0 # Fq if eval_type == 'standard': Fq = self.eval_d_out_fake.mean().cpu().item() elif eval_type == 'rsgan': g_loss, d_loss = get_losses(self.eval_d_out_real, self.eval_d_out_fake, 'rsgan') Fq = d_loss.item() elif 'bleu' in eval_type: self.bleu.reset( test_text=tensor_to_tokens(eval_samples, self.idx2word_dict)) if cfg.lambda_fq != 0: Fq = self.bleu.get_score(given_gram=int(eval_type[-1])) else: Fq = 0 elif 'Ra' in eval_type: g_loss = torch.sigmoid(self.eval_d_out_fake - torch.mean(self.eval_d_out_real)).sum() Fq = g_loss.item() else: raise NotImplementedError("Evaluation '%s' is not implemented" % eval_type) score = cfg.lambda_fq * Fq + cfg.lambda_fd * Fd return Fq, Fd, score
def create_specific_oracle(from_a, to_b, num=1, save_path='../pretrain/'): for i in range(num): while True: oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) if cfg.CUDA: oracle = oracle.cuda() big_samples = oracle.sample(cfg.samples_num, 8 * cfg.batch_size) small_samples = oracle.sample(cfg.samples_num // 2, 8 * cfg.batch_size) oracle_data = GenDataIter(big_samples) mle_criterion = nn.NLLLoss() groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) if from_a <= groud_truth <= to_b: dir_path = save_path + 'oracle_data_gt{:.2f}_{}'.format( groud_truth, strftime("%m%d_%H%M%S", localtime())) if not os.path.exists(dir_path): os.mkdir(dir_path) print('save ground truth: ', groud_truth) # prefix = 'oracle{}_lstm_gt{:.2f}_{}'.format(i, groud_truth, strftime("%m%d", localtime())) prefix = dir_path + '/oracle_lstm' torch.save(oracle.state_dict(), '{}.pt'.format(prefix)) torch.save(big_samples, '{}_samples_{}.pt'.format(prefix, cfg.samples_num)) torch.save( small_samples, '{}_samples_{}.pt'.format(prefix, cfg.samples_num // 2)) break
class BasicInstructor: def __init__(self, opt): self.log = create_logger(__name__, silent=False, to_disk=True, log_file=cfg.log_filename if cfg.if_test else [cfg.log_filename, cfg.save_root + 'log.txt']) self.sig = Signal(cfg.signal_file) self.opt = opt # oracle, generator, discriminator self.oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) self.oracle_list = [Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) for _ in range(cfg.k_label)] self.dis = None self.clas = None self.show_config() self.check_oracle() # Create Oracle models if not exist # DataLoader self.oracle_samples = torch.load(cfg.oracle_samples_path.format(cfg.samples_num)) self.oracle_samples_list = [torch.load(cfg.multi_oracle_samples_path.format(i, cfg.samples_num)) for i in range(cfg.k_label)] self.oracle_data = GenDataIter(self.oracle_samples) self.oracle_data_list = [GenDataIter(self.oracle_samples_list[i]) for i in range(cfg.k_label)] # Criterion self.mle_criterion = nn.NLLLoss() self.dis_criterion = nn.CrossEntropyLoss() # Metrics self.nll_oracle = NLL('NLL_oracle', if_use=cfg.use_nll_oracle, gpu=cfg.CUDA) self.nll_gen = NLL('NLL_gen', if_use=cfg.use_nll_gen, gpu=cfg.CUDA) self.nll_div = NLL('NLL_div', if_use=cfg.use_nll_div, gpu=cfg.CUDA) self.all_metrics = [self.nll_oracle, self.nll_gen, self.nll_div] def _run(self): print('Nothing to run in Basic Instructor!') pass def _test(self): pass def init_model(self): if cfg.oracle_pretrain: if not os.path.exists(cfg.oracle_state_dict_path): create_oracle() self.oracle.load_state_dict(torch.load(cfg.oracle_state_dict_path)) if cfg.dis_pretrain: self.log.info( 'Load pretrained discriminator: {}'.format(cfg.pretrained_dis_path)) self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path)) if cfg.gen_pretrain: self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path)) self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device))) if cfg.CUDA: self.oracle = self.oracle.cuda() self.gen = self.gen.cuda() self.dis = self.dis.cuda() def train_gen_epoch(self, model, data_loader, criterion, optimizer): total_loss = 0 for i, data in enumerate(data_loader): inp, target = data['input'], data['target'] ''' print("inp.shape = ",inp.shape) -> [64,20] if (inp.numpy()[0][1:] == target.numpy()[0][:-1]).all: print("yes") 没错 exit() ''' if cfg.CUDA: inp, target = inp.cuda(), target.cuda() hidden = model.init_hidden(data_loader.batch_size) pred = model.forward(inp, hidden) #seqGAN:(batch_size * seq_len) * vocab_size loss = criterion(pred, target.view(-1)) #seqGAN:self.mle_criterion = nn.NLLLoss() self.optimize(optimizer, loss, model) total_loss += loss.item() return total_loss / len(data_loader) def train_dis_epoch(self, model, data_loader, criterion, optimizer): total_loss = 0 total_acc = 0 total_num = 0 for i, data in enumerate(data_loader): inp, target = data['input'], data['target'] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() pred = model.forward(inp) loss = criterion(pred, target) self.optimize(optimizer, loss, model) total_loss += loss.item() total_acc += torch.sum((pred.argmax(dim=-1) == target)).item() total_num += inp.size(0) total_loss /= len(data_loader) total_acc /= total_num return total_loss, total_acc @staticmethod def eval_dis(model, data_loader, criterion): total_loss = 0 total_acc = 0 total_num = 0 with torch.no_grad(): for i, data in enumerate(data_loader): inp, target = data['input'], data['target'] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() pred = model.forward(inp) loss = criterion(pred, target) total_loss += loss.item() total_acc += torch.sum((pred.argmax(dim=-1) == target)).item() total_num += inp.size(0) total_loss /= len(data_loader) total_acc /= total_num return total_loss, total_acc @staticmethod def optimize_multi(opts, losses): for i, (opt, loss) in enumerate(zip(opts, losses)): opt.zero_grad() loss.backward(retain_graph=True if i < len(opts) - 1 else False) opt.step() @staticmethod def optimize(opt, loss, model=None, retain_graph=False): opt.zero_grad() loss.backward(retain_graph=retain_graph) if model is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_norm) opt.step() def show_config(self): """Show parser parameters settings""" self.log.info(100 * '=') self.log.info('> training arguments:') for arg in vars(self.opt): self.log.info('>>> {0}: {1}'.format(arg, getattr(self.opt, arg))) self.log.info(100 * '=') def cal_metrics(self, fmt_str=False): """ Calculate metrics :param fmt_str: if return format string for logging """ with torch.no_grad(): # Prepare data for evaluation gen_data = GenDataIter(self.gen.sample(cfg.samples_num, 4 * cfg.batch_size)) # Reset metrics self.nll_oracle.reset(self.oracle, gen_data.loader) self.nll_gen.reset(self.gen, self.oracle_data.loader) self.nll_div.reset(self.gen, gen_data.loader) if fmt_str: return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) else: return [metric.get_score() for metric in self.all_metrics] def cal_metrics_with_label(self, label_i): assert type(label_i) == int, 'missing label' with torch.no_grad(): # Prepare data for evaluation eval_samples = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) gen_data = GenDataIter(eval_samples) # Reset metrics self.nll_oracle.reset(self.oracle_list[label_i], gen_data.loader, label_i) self.nll_gen.reset(self.gen, self.oracle_data_list[label_i].loader, label_i) self.nll_div.reset(self.gen, gen_data.loader, label_i) return [metric.get_score() for metric in self.all_metrics] def comb_metrics(self, fmt_str=False): all_scores = [self.cal_metrics_with_label(label_i) for label_i in range(cfg.k_label)] all_scores = np.array(all_scores).T.tolist() # each row for each metric if fmt_str: return ', '.join(['%s = %s' % (metric.get_name(), score) for (metric, score) in zip(self.all_metrics, all_scores)]) return all_scores def _save(self, phase, epoch): """Save model state dict and generator's samples""" if phase != 'ADV': torch.save(self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch)) save_sample_path = cfg.save_samples_root + 'samples_{}_{:05d}.txt'.format(phase, epoch) samples = self.gen.sample(cfg.batch_size, cfg.batch_size) write_tensor(save_sample_path, samples) def update_temperature(self, i, N): self.gen.temperature.data = torch.Tensor([get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt)]) if cfg.CUDA: self.gen.temperature.data = self.gen.temperature.data.cuda() def check_oracle(self): if not cfg.oracle_pretrain: create_oracle() create_multi_oracle(cfg.k_label) # General text generation Oracle model if not os.path.exists(cfg.oracle_samples_path.format(cfg.samples_num)) or not cfg.oracle_pretrain: create_oracle() # Category text generation Oracle models for i in range(cfg.k_label): if not os.path.exists(cfg.multi_oracle_samples_path.format(i, cfg.samples_num)): create_multi_oracle(cfg.k_label) break # Load Oracle state dict self.oracle.load_state_dict(torch.load(cfg.oracle_state_dict_path)) for i in range(cfg.k_label): oracle_path = cfg.multi_oracle_state_dict_path.format(i) self.oracle_list[i].load_state_dict(torch.load(oracle_path))
class BasicInstructor: def __init__(self, opt): self.log = create_logger(__name__, silent=False, to_disk=True, log_file=cfg.log_filename if cfg.if_test else [cfg.log_filename, cfg.save_root + 'log.txt']) self.sig = Signal(cfg.signal_file) self.opt = opt self.show_config() self.clas = None # load dictionary self.word2idx_dict, self.idx2word_dict = load_dict(cfg.dataset) # Dataloader try: self.train_data = GenDataIter(cfg.train_data) self.test_data = GenDataIter(cfg.test_data, if_test_data=True) except: pass try: self.train_data_list = [ GenDataIter(cfg.cat_train_data.format(i)) for i in range(cfg.k_label) ] self.test_data_list = [ GenDataIter(cfg.cat_test_data.format(i), if_test_data=True) for i in range(cfg.k_label) ] self.clas_data_list = [ GenDataIter(cfg.cat_test_data.format(str(i)), if_test_data=True) for i in range(cfg.k_label) ] self.train_samples_list = [ self.train_data_list[i].target for i in range(cfg.k_label) ] self.clas_samples_list = [ self.clas_data_list[i].target for i in range(cfg.k_label) ] except: pass # Criterion self.mle_criterion = nn.NLLLoss() self.dis_criterion = nn.CrossEntropyLoss() self.clas_criterion = nn.CrossEntropyLoss() # Optimizer self.clas_opt = None # Metrics self.bleu = BLEU('BLEU', gram=[2, 3, 4, 5], if_use=cfg.use_bleu) self.nll_gen = NLL('NLL_gen', if_use=cfg.use_nll_gen, gpu=cfg.CUDA) self.nll_div = NLL('NLL_div', if_use=cfg.use_nll_div, gpu=cfg.CUDA) self.self_bleu = BLEU('Self-BLEU', gram=[2, 3, 4], if_use=cfg.use_self_bleu) self.clas_acc = ACC(if_use=cfg.use_clas_acc) self.ppl = PPL(self.train_data, self.test_data, n_gram=5, if_use=cfg.use_ppl) self.all_metrics = [ self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ppl ] def _run(self): print('Nothing to run in Basic Instructor!') pass def _test(self): pass def init_model(self): if cfg.dis_pretrain: self.log.info('Load pre-trained discriminator: {}'.format( cfg.pretrained_dis_path)) self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path)) if cfg.gen_pretrain: self.log.info('Load MLE pre-trained generator: {}'.format( cfg.pretrained_gen_path)) self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path)) if cfg.CUDA: self.gen = self.gen.cuda() self.dis = self.dis.cuda() def train_gen_epoch(self, model, data_loader, criterion, optimizer): total_loss = 0 for i, data in enumerate(data_loader): inp, target = data['input'], data['target'] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() hidden = model.init_hidden(data_loader.batch_size) pred = model.forward(inp, hidden) loss = criterion(pred, target.view(-1)) self.optimize(optimizer, loss, model) total_loss += loss.item() return total_loss / len(data_loader) def train_dis_epoch(self, model, data_loader, criterion, optimizer): total_loss = 0 total_acc = 0 total_num = 0 for i, data in enumerate(data_loader): inp, target = data['input'], data['target'] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() pred = model.forward(inp) loss = criterion(pred, target) self.optimize(optimizer, loss, model) total_loss += loss.item() total_acc += torch.sum((pred.argmax(dim=-1) == target)).item() total_num += inp.size(0) total_loss /= len(data_loader) total_acc /= total_num return total_loss, total_acc def train_classifier(self, epochs): """ Classifier for calculating the classification accuracy metric of category text generation. Note: the train and test data for the classifier is opposite to the generator. Because the classifier is to calculate the classification accuracy of the generated samples where are trained on self.train_samples_list. Since there's no test data in synthetic data (oracle data), the synthetic data experiments doesn't need a classifier. """ import copy # Prepare data for Classifier clas_data = CatClasDataIter(self.clas_samples_list) eval_clas_data = CatClasDataIter(self.train_samples_list) max_acc = 0 best_clas = None for epoch in range(epochs): c_loss, c_acc = self.train_dis_epoch(self.clas, clas_data.loader, self.clas_criterion, self.clas_opt) _, eval_acc = self.eval_dis(self.clas, eval_clas_data.loader, self.clas_criterion) if eval_acc > max_acc: best_clas = copy.deepcopy( self.clas.state_dict()) # save the best classifier max_acc = eval_acc self.log.info( '[PRE-CLAS] epoch %d: c_loss = %.4f, c_acc = %.4f, eval_acc = %.4f, max_eval_acc = %.4f', epoch, c_loss, c_acc, eval_acc, max_acc) self.clas.load_state_dict( copy.deepcopy(best_clas)) # Reload the best classifier @staticmethod def eval_dis(model, data_loader, criterion): total_loss = 0 total_acc = 0 total_num = 0 with torch.no_grad(): for i, data in enumerate(data_loader): inp, target = data['input'], data['target'] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() pred = model.forward(inp) loss = criterion(pred, target) total_loss += loss.item() total_acc += torch.sum((pred.argmax(dim=-1) == target)).item() total_num += inp.size(0) total_loss /= len(data_loader) total_acc /= total_num return total_loss, total_acc @staticmethod def optimize_multi(opts, losses): for i, (opt, loss) in enumerate(zip(opts, losses)): opt.zero_grad() loss.backward(retain_graph=True if i < len(opts) - 1 else False) opt.step() @staticmethod def optimize(opt, loss, model=None, retain_graph=False): opt.zero_grad() loss.backward(retain_graph=retain_graph) if model is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_norm) opt.step() def show_config(self): self.log.info(100 * '=') self.log.info('> training arguments:') for arg in vars(self.opt): self.log.info('>>> {0}: {1}'.format(arg, getattr(self.opt, arg))) self.log.info(100 * '=') def cal_metrics(self, fmt_str=False): """ Calculate metrics :param fmt_str: if return format string for logging """ with torch.no_grad(): # Prepare data for evaluation eval_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size) gen_data = GenDataIter(eval_samples) gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200), self.idx2word_dict) # Reset metrics self.bleu.reset(test_text=gen_tokens, real_text=self.test_data.tokens) self.nll_gen.reset(self.gen, self.train_data.loader) self.nll_div.reset(self.gen, gen_data.loader) self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) self.ppl.reset(gen_tokens) if fmt_str: return ', '.join([ '%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics ]) else: return [metric.get_score() for metric in self.all_metrics] def cal_metrics_with_label(self, label_i): assert type(label_i) == int, 'missing label' with torch.no_grad(): # Prepare data for evaluation eval_samples = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) gen_data = GenDataIter(eval_samples) gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) gen_tokens_s = tensor_to_tokens( self.gen.sample(200, 200, label_i=label_i), self.idx2word_dict) clas_data = CatClasDataIter([eval_samples], label_i) # Reset metrics self.bleu.reset(test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens) self.nll_gen.reset(self.gen, self.train_data_list[label_i].loader, label_i) self.nll_div.reset(self.gen, gen_data.loader, label_i) self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) self.clas_acc.reset(self.clas, clas_data.loader) self.ppl.reset(gen_tokens) return [metric.get_score() for metric in self.all_metrics] def comb_metrics(self, fmt_str=False): all_scores = [ self.cal_metrics_with_label(label_i) for label_i in range(cfg.k_label) ] all_scores = np.array( all_scores).T.tolist() # each row for each metric if fmt_str: return ', '.join([ '%s = %s' % (metric.get_name(), score) for (metric, score) in zip(self.all_metrics, all_scores) ]) return all_scores def _save(self, phase, epoch): """Save model state dict and generator's samples""" if phase != 'ADV': torch.save( self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch)) save_sample_path = cfg.save_samples_root + 'samples_{}_{}_{:05d}.txt'.format( phase, cfg.samples_num, epoch) samples = self.gen.sample(5000, cfg.batch_size) write_tokens(save_sample_path, tensor_to_tokens(samples, self.idx2word_dict)) def update_temperature(self, i, N): self.gen.temperature.data = torch.Tensor( [get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt)]) if cfg.CUDA: self.gen.temperature.data = self.gen.temperature.data.cuda()
def evaluation(self, eval_type): """Evaluation all children, update child score. Note that the eval data should be the same""" eval_samples = [ self.gen.sample(cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size, label_i=i) for i in range(cfg.k_label) ] # Fd if cfg.lambda_fd != 0: nll_div = [] for label_i in range(cfg.k_label): gen_data = GenDataIter(eval_samples[label_i]) nll_div.append( NLL.cal_nll_with_label(self.gen, gen_data.loader, label_i, self.mle_criterion)) if 'f1' in eval_type: if cfg.k_label == 1: Fd = nll_div[0] if len(nll_div) > 0 else 0 elif cfg.k_label == 2: Fd = nll_div[0] * nll_div[1] / ( nll_div[0] + nll_div[1]) if len(nll_div) > 0 else 0 else: raise NotImplementedError("k_label = %d is not supported" % cfg.k_label) else: Fd = sum(nll_div) else: Fd = 0 # Fq if 'nll' in eval_type: nll_oracle = [] for label_i in range(cfg.k_label): gen_data = GenDataIter(eval_samples[label_i]) if cfg.lambda_fq != 0: nll_oracle.append(-NLL.cal_nll_with_label( self.oracle_list[label_i], gen_data.loader, label_i, self.mle_criterion)) if 'f1' in eval_type: if cfg.k_label == 1: Fq = nll_oracle[0] if len(nll_oracle) > 0 else 0 elif cfg.k_label == 2: Fq = nll_oracle[0] * nll_oracle[1] / ( nll_oracle[0] + nll_oracle[1]) if len(nll_oracle) > 0 else 0 else: raise NotImplementedError("k_label = %d is not supported" % cfg.k_label) else: # sum Fq = sum(nll_oracle) elif eval_type == 'Ra': g_loss = 0 for i in range(cfg.k_label): g_loss += torch.sigmoid( self.eval_d_out_fake[i] - torch.mean(self.eval_d_out_real[i])).sum() Fq = g_loss.item() else: raise NotImplementedError("Evaluation '%s' is not implemented" % eval_type) score = cfg.lambda_fq * Fq + cfg.lambda_fd * Fd return Fq, Fd, score