def analysis_evaluation(self): self.logger.info("Start Analyzing ...") start_time = time.time() test_batches = self.data.test self.logger.info("Total {} batches to analyze".format(len(test_batches))) acc_loss = 0 acc_kl_loss = 0 acc_aux_loss = 0 acc_avg_cos = 0 acc_avg_norm = 0 batch_cnt = 0 all_cnt = 0 cnt = 0 sample_bag = [] try: for idx, batch in enumerate(test_batches): if idx % 10 == 0: print("Idx: {}".format(idx)) seq_len, batch_sz = batch.size() if self.data.condition: seq_len -= 1 bit = batch[0, :] batch = batch[1:, :] bit = GVar(bit) else: bit = None feed = self.data.get_feed(batch) if self.args.swap > 0.00001: feed = swap_by_batch(feed, self.args.swap) if self.args.replace > 0.00001: feed = replace_by_batch(feed, self.args.replace, self.model.ntoken) target = GVar(batch) recon_loss, kld, aux_loss, tup, vecs, decoded = self.model(feed, target, bit) # target: seq_len, batchsz # decoded: seq_len, batchsz, dict_sz # tup: 'mean' 'logvar' for Gaussian # 'mu' for vMF # vecs bag = self.analyze_batch(target, kld, tup, vecs, decoded) sample_bag += bag acc_loss += recon_loss.data * seq_len * batch_sz acc_kl_loss += torch.sum(kld).data acc_aux_loss += torch.sum(aux_loss).data acc_avg_cos += tup['avg_cos'].data acc_avg_norm += tup['avg_norm'].data cnt += 1 batch_cnt += batch_sz all_cnt += batch_sz * seq_len except KeyboardInterrupt: print("early stop") self.write_samples(sample_bag) cur_loss = acc_loss[0] / all_cnt cur_kl = acc_kl_loss[0] / all_cnt cur_real_loss = cur_loss + cur_kl return cur_loss, cur_kl, cur_real_loss
def evaluate(self, args, model, dev_batches): # Turn on training mode which enables dropout. model.eval() model.FLAG_train = False acc_loss = 0 acc_kl_loss = 0 acc_aux_loss = 0 acc_avg_cos = 0 acc_avg_norm = 0 batch_cnt = 0 all_cnt = 0 cnt = 0 start_time = time.time() for idx, batch in enumerate(dev_batches): seq_len, batch_sz = batch.size() if self.data.condition: seq_len -= 1 bit = batch[0, :] batch = batch[1:, :] bit = GVar(bit) else: bit = None feed = self.data.get_feed(batch) if self.args.swap > 0.00001: feed = swap_by_batch(feed, self.args.swap) if self.args.replace > 0.00001: feed = replace_by_batch(feed, self.args.replace, self.model.ntoken) target = GVar(batch) recon_loss, kld, aux_loss, tup, vecs, _ = model(feed, target, bit) acc_loss += recon_loss.data * seq_len * batch_sz acc_kl_loss += torch.sum(kld).data acc_aux_loss += torch.sum(aux_loss).data acc_avg_cos += tup['avg_cos'].data acc_avg_norm += tup['avg_norm'].data cnt += 1 batch_cnt += batch_sz all_cnt += batch_sz * seq_len cur_loss = acc_loss.item() / all_cnt cur_kl = acc_kl_loss.item() / all_cnt cur_aux_loss = acc_aux_loss.item() / all_cnt cur_avg_cos = acc_avg_cos.item() / cnt cur_avg_norm = acc_avg_norm.item() / cnt cur_real_loss = cur_loss + cur_kl # Runner.log_eval(print_ppl) # print('loss {:5.2f} | KL {:5.2f} | ppl {:8.2f}'.format( cur_loss, cur_kl, math.exp(print_ppl))) return cur_loss, cur_kl, cur_real_loss
def analysis_eval_order(self, feed, batch, bit): assert 0.33 > self.args.swap > 0.0001 origin_feed = feed.clone() feed_1x = swap_by_batch(feed.clone(), self.args.swap) feed_2x = swap_by_batch(feed.clone(), self.args.swap * 2) feed_3x = swap_by_batch(feed.clone(), self.args.swap * 3) feed_4x = swap_by_batch(feed.clone(), self.args.swap * 4) feed_5x = swap_by_batch(feed.clone(), self.args.swap * 5) feed_6x = swap_by_batch(feed.clone(), self.args.swap * 6) target = GVar(batch) # recon_loss, kld, aux_loss, tup, vecs, decoded = self.model(feed, target, bit) original_recon_loss, kld, _, original_tup, original_vecs, _ = self.model(origin_feed, target, bit) if 'Distnor' in self.instance_name: key_name = "mean" elif 'vmf' in self.instance_name: key_name = "mu" else: raise NotImplementedError original_mu = original_tup[key_name] recon_loss_1x, _, _, tup_1x, vecs_1x, _ = self.model(feed_1x, target, bit) recon_loss_2x, _, _, tup_2x, vecs_2x, _ = self.model(feed_2x, target, bit) recon_loss_3x, _, _, tup_3x, vecs_3x, _ = self.model(feed_3x, target, bit) recon_loss_4x, _, _, tup_4x, vecs_4x, _ = self.model(feed_4x, target, bit) recon_loss_5x, _, _, tup_5x, vecs_5x, _ = self.model(feed_5x, target, bit) recon_loss_6x, _, _, tup_6x, vecs_6x, _ = self.model(feed_6x, target, bit) # target: seq_len, batchsz # decoded: seq_len, batchsz, dict_sz # tup: 'mean' 'logvar' for Gaussian # 'mu' for vMF # vecs # cos_1x = self.analyze_batch_order(original_vecs, vecs_1x).data # cos_2x = self.analyze_batch_order(original_vecs, vecs_2x).data # cos_3x = self.analyze_batch_order(original_vecs, vecs_3x).data cos_1x = torch.mean(cos(original_mu, tup_1x[key_name])).data cos_2x = torch.mean(cos(original_mu, tup_2x[key_name])).data cos_3x = torch.mean(cos(original_mu, tup_3x[key_name])).data cos_4x = torch.mean(cos(original_mu, tup_4x[key_name])).data cos_5x = torch.mean(cos(original_mu, tup_5x[key_name])).data cos_6x = torch.mean(cos(original_mu, tup_6x[key_name])).data # print(cos_1x, cos_2x, cos_3x) return [ [original_recon_loss.data, recon_loss_1x.data, recon_loss_2x.data, recon_loss_3x.data, recon_loss_4x.data, recon_loss_5x.data, recon_loss_6x.data] , [cos_1x, cos_2x, cos_3x, cos_4x, cos_5x, cos_6x]]
def train_epo(self, args, model, train_batches, epo, epo_start_time, glob_iter): model.train() model.FLAG_train = True start_time = time.time() if self.args.optim == 'sgd': self.optim = torch.optim.SGD(model.parameters(), lr=self.args.cur_lr) acc_loss = 0 acc_kl_loss = 0 acc_aux_loss = 0 acc_avg_cos = 0 acc_avg_norm = 0 batch_cnt = 0 all_cnt = 0 cnt = 0 random.shuffle(train_batches) for idx, batch in enumerate(train_batches): self.optim.zero_grad() seq_len, batch_sz = batch.size() if self.data.condition: seq_len -= 1 if self.model.input_cd_bit > 1: bit = batch[0, :] bit = GVar(bit) else: bit = None batch = batch[1:, :] else: bit = None feed = self.data.get_feed(batch) if self.args.swap > 0.00001: feed = swap_by_batch(feed, self.args.swap) if self.args.replace > 0.00001: feed = replace_by_batch(feed, self.args.replace, self.model.ntoken) self.glob_iter += 1 target = GVar(batch) recon_loss, kld, aux_loss, tup, vecs, _ = model(feed, target, bit) total_loss = recon_loss * seq_len + torch.mean( kld) * self.args.kl_weight + torch.mean( aux_loss) * args.aux_weight total_loss.backward() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. # torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) # Upgrade to pytorch 0.4.1 torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip, norm_type=2) self.optim.step() acc_loss += recon_loss.data * seq_len * batch_sz acc_kl_loss += torch.sum(kld).data acc_aux_loss += torch.sum(aux_loss).data acc_avg_cos += tup['avg_cos'].data acc_avg_norm += tup['avg_norm'].data cnt += 1 batch_cnt += batch_sz all_cnt += batch_sz * seq_len if idx % args.log_interval == 0 and idx > 0: cur_loss = acc_loss.item() / all_cnt cur_kl = acc_kl_loss.item() / all_cnt # if cur_kl < 0.03: # raise KeyboardInterrupt # if cur_kl > 0.7: # raise KeyboardInterrupt cur_aux_loss = acc_aux_loss.item() / all_cnt cur_avg_cos = acc_avg_cos.item() / cnt cur_avg_norm = acc_avg_norm.item() / cnt cur_real_loss = cur_loss + cur_kl Runner.log_instant(self.writer, self.args, self.glob_iter, epo, start_time, cur_avg_cos, cur_avg_norm, cur_loss, cur_kl, cur_aux_loss, cur_real_loss)