def sample_cell(self, mu, norm, kappa): """ :param mu: z_dir (batchsz, lat_dim) . ALREADY normed. :param norm: z_norm (batchsz, lat_dim). :param kappa: scalar :return: """ """vMF sampler in pytorch. http://stats.stackexchange.com/questions/156729/sampling-from-von-mises-fisher-distribution-in-python Args: mu (Tensor): of shape (batch_size, 2*word_dim) kappa (Float): controls dispersion. kappa of zero is no dispersion. """ batch_sz, lat_dim = mu.size() # Unif VMF norm_with_noise = self.add_norm_noise_batch(norm, self.norm_eps) # Unif VMF w = self._sample_weight_batch(kappa, lat_dim, batch_sz) w = w.unsqueeze(1) w_var = GVar(w * torch.ones(batch_sz, lat_dim)) v = self._sample_ortho_batch(mu, lat_dim) scale_factr = torch.sqrt( GVar(torch.ones(batch_sz, lat_dim)) - torch.pow(w_var, 2)) orth_term = v * scale_factr muscale = mu * w_var sampled_vec = (orth_term + muscale) * norm_with_noise return sampled_vec.unsqueeze(0)
def init_hidden(self, bsz): weight = next(self.parameters()).data if self.rnn_type == 'LSTM': return (GVar(weight.new(self.nlayers, bsz, self.nhid).zero_()), GVar(weight.new(self.nlayers, bsz, self.nhid).zero_())) else: return GVar(weight.new(self.nlayers, bsz, self.nhid).zero_())
def evaluate(self, dev_batches): self.learner.eval() print("Test start") acc_loss = 0 acc_accuracy = 0 all_cnt = 0 cnt = 0 random.shuffle(dev_batches) for idx, batch in enumerate(dev_batches): self.optim.zero_grad() bit, vec = batch bit = GVar(bit) vec = GVar(vec) # print(bit) loss, pred = self.learner(vec, bit) _, argmax = torch.max(pred, dim=1) loss.backward() self.optim.step() argmax = argmax.data bit = bit.data for idx, num in enumerate(argmax): gt = bit[idx] all_cnt += 1 if gt == num: acc_accuracy += 1 acc_loss += loss.data[0] cnt += 1 # print("===============test===============") # print(acc_loss / cnt) print("Loss {} \tAccuracy {}".format(acc_loss / cnt, acc_accuracy / all_cnt)) return float(acc_accuracy / all_cnt)
def analysis_evaluation(self): self.logger.info("Start Analyzing ...") start_time = time.time() test_batches = self.data.test self.logger.info("Total {} batches to analyze".format(len(test_batches))) acc_loss = 0 acc_kl_loss = 0 acc_aux_loss = 0 acc_avg_cos = 0 acc_avg_norm = 0 batch_cnt = 0 all_cnt = 0 cnt = 0 sample_bag = [] try: for idx, batch in enumerate(test_batches): if idx % 10 == 0: print("Idx: {}".format(idx)) seq_len, batch_sz = batch.size() if self.data.condition: seq_len -= 1 bit = batch[0, :] batch = batch[1:, :] bit = GVar(bit) else: bit = None feed = self.data.get_feed(batch) if self.args.swap > 0.00001: feed = swap_by_batch(feed, self.args.swap) if self.args.replace > 0.00001: feed = replace_by_batch(feed, self.args.replace, self.model.ntoken) target = GVar(batch) recon_loss, kld, aux_loss, tup, vecs, decoded = self.model(feed, target, bit) # target: seq_len, batchsz # decoded: seq_len, batchsz, dict_sz # tup: 'mean' 'logvar' for Gaussian # 'mu' for vMF # vecs bag = self.analyze_batch(target, kld, tup, vecs, decoded) sample_bag += bag acc_loss += recon_loss.data * seq_len * batch_sz acc_kl_loss += torch.sum(kld).data acc_aux_loss += torch.sum(aux_loss).data acc_avg_cos += tup['avg_cos'].data acc_avg_norm += tup['avg_norm'].data cnt += 1 batch_cnt += batch_sz all_cnt += batch_sz * seq_len except KeyboardInterrupt: print("early stop") self.write_samples(sample_bag) cur_loss = acc_loss[0] / all_cnt cur_kl = acc_kl_loss[0] / all_cnt cur_real_loss = cur_loss + cur_kl return cur_loss, cur_kl, cur_real_loss
def evaluate(self, args, model, dev_batches): # Turn on training mode which enables dropout. model.eval() model.FLAG_train = False acc_loss = 0 acc_kl_loss = 0 acc_aux_loss = 0 acc_avg_cos = 0 acc_avg_norm = 0 batch_cnt = 0 all_cnt = 0 cnt = 0 start_time = time.time() for idx, batch in enumerate(dev_batches): seq_len, batch_sz = batch.size() if self.data.condition: seq_len -= 1 bit = batch[0, :] batch = batch[1:, :] bit = GVar(bit) else: bit = None feed = self.data.get_feed(batch) if self.args.swap > 0.00001: feed = swap_by_batch(feed, self.args.swap) if self.args.replace > 0.00001: feed = replace_by_batch(feed, self.args.replace, self.model.ntoken) target = GVar(batch) recon_loss, kld, aux_loss, tup, vecs, _ = model(feed, target, bit) acc_loss += recon_loss.data * seq_len * batch_sz acc_kl_loss += torch.sum(kld).data acc_aux_loss += torch.sum(aux_loss).data acc_avg_cos += tup['avg_cos'].data acc_avg_norm += tup['avg_norm'].data cnt += 1 batch_cnt += batch_sz all_cnt += batch_sz * seq_len cur_loss = acc_loss.item() / all_cnt cur_kl = acc_kl_loss.item() / all_cnt cur_aux_loss = acc_aux_loss.item() / all_cnt cur_avg_cos = acc_avg_cos.item() / cnt cur_avg_norm = acc_avg_norm.item() / cnt cur_real_loss = cur_loss + cur_kl # Runner.log_eval(print_ppl) # print('loss {:5.2f} | KL {:5.2f} | ppl {:8.2f}'.format( cur_loss, cur_kl, math.exp(print_ppl))) return cur_loss, cur_kl, cur_real_loss
def __init__(self, hid_dim, lat_dim, kappa=1): super().__init__() self.hid_dim = hid_dim self.lat_dim = lat_dim self.kappa = kappa # self.func_kappa = torch.nn.Linear(hid_dim, lat_dim) self.func_mu = torch.nn.Linear(hid_dim, lat_dim) self.kld = GVar(torch.from_numpy(vMF._vmf_kld(kappa, lat_dim)).float()) print('KLD: {}'.format(self.kld.data[0]))
def evaluate(self, args, model, corpus_dev, corpus_dev_cnt, dev_batches): """ Standard evaluation function on dev or test set. :param args: :param model: :param dev_batches: :return: """ # Turn on training mode which enables dropout. model.eval() acc_loss = 0 acc_kl_loss = 0 acc_real_loss = 0 word_cnt = 0 doc_cnt = 0 start_time = time.time() ntokens = self.data.vocab_size for idx, batch in enumerate(dev_batches): data_batch, count_batch = self.data.fetch_data( corpus_dev, corpus_dev_cnt, batch, ntokens) data_batch = GVar(torch.FloatTensor(data_batch)) recon_loss, kld, aux_loss, tup, vecs = model(data_batch) count_batch = GVar(torch.FloatTensor(count_batch)) # real_loss = torch.div((recon_loss + kld).data, count_batch) doc_num = len(count_batch) # remove nan # for n in real_loss: # if n == n: # acc_real_loss += n # acc_real_ppl += torch.sum(real_ppl) acc_loss += torch.sum(recon_loss).item() # acc_kl_loss += torch.sum(kld).item() count_batch = count_batch + 1e-12 word_cnt += torch.sum(count_batch) doc_cnt += doc_num # word ppl cur_loss = acc_loss / word_cnt # word loss cur_kl = acc_kl_loss / word_cnt # cur_real_loss = acc_real_loss / doc_cnt cur_real_loss = cur_loss + cur_kl elapsed = time.time() - start_time # Runner.log_eval(print_ppl) # print('loss {:5.2f} | KL {:5.2f} | ppl {:8.2f}'.format( cur_loss, cur_kl, math.exp(print_ppl))) return cur_loss, cur_kl, cur_real_loss
def dropword(self, emb, drop_rate=0.3): """ Mix the ground truth word with UNK. If drop rate = 1, no ground truth info is used. (Fly mode) :param emb: :param drop_rate: 0 - no drop; 1 - full drop, all UNK :return: mixed embedding """ UNKs = GVar(torch.ones(emb.size()[0], emb.size()[1]).long() * 2) UNKs = self.emb(UNKs) # print(UNKs, emb) masks = numpy.random.binomial(1, drop_rate, size=(emb.size()[0], emb.size()[1])) masks = GVar(torch.FloatTensor(masks)).unsqueeze(2).expand_as(UNKs) emb = emb * (1 - masks) + UNKs * masks return emb
def __init__(self, hid_dim, lat_dim, kappa=1): """ von Mises-Fisher distribution class with batch support and manual tuning kappa value. Implementation follows description of my paper and Guu's. """ super().__init__() self.hid_dim = hid_dim self.lat_dim = lat_dim self.kappa = kappa # self.func_kappa = torch.nn.Linear(hid_dim, lat_dim) self.func_mu = torch.nn.Linear(hid_dim, lat_dim) self.kld = GVar(torch.from_numpy(vMF._vmf_kld(kappa, lat_dim)).float()) print('KLD: {}'.format(self.kld.data[0]))
def forward(self, x): batch_sz = x.size()[0] linear_x = self.enc_vec(x) linear_x = self.dropout(linear_x) active_x = self.active(linear_x) linear_x_2 = self.enc_vec_2(active_x) tup, kld, vecs = self.dist.build_bow_rep(linear_x_2, self.n_sample) # vecs: n_samples, batch_sz, lat_dim if 'redundant_norm' in tup: aux_loss = tup['redundant_norm'].view(batch_sz) else: aux_loss = GVar(torch.zeros(batch_sz)) # stat avg_cos = BowVAE.check_dispersion(vecs) avg_norm = torch.mean(tup['norm']) tup['avg_cos'] = avg_cos tup['avg_norm'] = avg_norm flatten_vecs = vecs.view(self.n_sample * batch_sz, self.n_lat) flatten_vecs = self.dec_act(self.dec_linear(flatten_vecs)) logit = self.dropout(self.out(flatten_vecs)) logit = torch.nn.functional.log_softmax(logit, dim=1) logit = logit.view(self.n_sample, batch_sz, self.vocab_size) flatten_x = x.unsqueeze(0).expand(self.n_sample, batch_sz, self.vocab_size) error = torch.mul(flatten_x, logit) error = torch.mean(error, dim=0) recon_loss = -torch.sum(error, dim=1, keepdim=False) return recon_loss, kld, aux_loss, tup, vecs
def sample_cell(self, mu, norm, kappa): batch_sz, lat_dim = mu.size() # mu = GVar(mu) mu = mu / torch.norm(mu, p=2, dim=1, keepdim=True) w = self._sample_weight_batch(kappa, lat_dim, batch_sz) w = w.unsqueeze(1) # batch version w_var = GVar(w * torch.ones(batch_sz, lat_dim).to(device)) v = self._sample_ortho_batch(mu, lat_dim) scale_factr = torch.sqrt( GVar(torch.ones(batch_sz, lat_dim)) - torch.pow(w_var, 2)) orth_term = v * scale_factr muscale = mu * w_var sampled_vec = orth_term + muscale return sampled_vec.unsqueeze(0).to(device)
def _sample_orthonormal_to(self, mu, dim): """Sample point on sphere orthogonal to mu. """ v = GVar(torch.randn(dim)) rescale_value = mu.dot(v) / mu.norm() proj_mu_v = mu * rescale_value.expand(dim) ortho = v - proj_mu_v ortho_norm = torch.norm(ortho) return ortho / ortho_norm.expand_as(ortho)
def add_norm_noise(self, munorm, eps): """ KL loss is - log(maxvalue/eps) cut at maxvalue-eps, and add [0,eps] noise. """ # if np.random.rand()<0.05: # print(munorm[0]) trand = torch.rand(1).expand(munorm.size()) * eps return munorm + GVar(trand)
def forward(self, inp, target, bit=None): """ Forward with ground truth (maybe mixed with UNK) as input. :param inp: seq_len, batch_sz :param target: seq_len, batch_sz :param bit: 1, batch_sz :return: """ seq_len, batch_sz = inp.size() emb = self.drop(self.emb(inp)) if self.input_cd_bow > 1: bow = self.enc_bow(emb) else: bow = None if self.input_cd_bit > 1: bit = self.enc_bit(bit) else: bit = None h = self.forward_enc(emb, bit) tup, kld, vecs = self.forward_build_lat(h, self.args.nsample) # batchsz, lat dim if 'redundant_norm' in tup: aux_loss = tup['redundant_norm'].view(batch_sz) else: aux_loss = GVar(torch.zeros(batch_sz)) if 'norm' not in tup: tup['norm'] = GVar(torch.zeros(batch_sz)) # stat avg_cos = check_dispersion(vecs) tup['avg_cos'] = avg_cos avg_norm = torch.mean(tup['norm']) tup['avg_norm'] = avg_norm vec = torch.mean(vecs, dim=0) decoded = self.forward_decode_ground(emb, vec, bit, bow) # (seq_len, batch, dict sz) flatten_decoded = decoded.view(-1, self.ntoken) flatten_target = target.view(-1) loss = self.criterion(flatten_decoded, flatten_target) return loss, kld, aux_loss, tup, vecs, decoded
def play_eval(self, args, model, train_batches, epo, epo_start_time, glob_iter): # reveal the relation between latent space and length and loss # reveal the distribution of latent space model.eval() model.FLAG_train = False start_time = time.time() acc_loss = 0 acc_kl_loss = 0 acc_aux_loss = 0 acc_avg_cos = 0 acc_avg_norm = 0 batch_cnt = 0 all_cnt = 0 cnt = 0 random.shuffle(train_batches) if self.args.dist == 'nor': vs = visual_gauss(self.data.dictionary) elif self.args.dist == 'vmf': vs = visual_vmf(self.data.dictionary) for idx, batch in enumerate(train_batches): seq_len, batch_sz = batch.size() feed = self.data.get_feed(batch) glob_iter += 1 target = GVar(batch) recon_loss, kld, aux_loss, tup, vecs = model(feed, target) acc_loss += recon_loss.data * seq_len * batch_sz acc_kl_loss += torch.sum(kld).data acc_aux_loss += torch.sum(aux_loss).data acc_avg_cos += tup['avg_cos'].data acc_avg_norm += tup['avg_norm'].data cnt += 1 batch_cnt += batch_sz all_cnt += batch_sz * seq_len vs.add_batch(target.data, tup, kld.data) cur_loss = acc_loss[0] / all_cnt cur_kl = acc_kl_loss[0] / all_cnt cur_aux_loss = acc_aux_loss[0] / all_cnt cur_avg_cos = acc_avg_cos[0] / cnt cur_avg_norm = acc_avg_norm[0] / cnt cur_real_loss = cur_loss + cur_kl Runner.log_instant(None, self.args, glob_iter, epo, start_time, cur_avg_cos, cur_avg_norm, cur_loss, cur_kl, cur_aux_loss, cur_real_loss) vs.write_log()
def play_eval(self, args, model, train_batches, epo, epo_start_time, glob_iter): # reveal the relation between latent space and length and loss # reveal the distribution of latent space model.eval() start_time = time.time() acc_loss = 0 acc_kl_loss = 0 acc_real_loss = 0 word_cnt = 0 doc_cnt = 0 random.shuffle(train_batches) if self.args.dist == 'nor': vs = visual_gauss() elif self.args.dist == 'vmf': vs = visual_vmf() for idx, batch in enumerate(train_batches): # seq_len, batch_sz = batch.size() data_batch, count_batch = DataNg.fetch_data( self.data.test[0], self.data.test[1], batch) data_batch = GVar(torch.FloatTensor(data_batch)) recon_loss, kld, total_loss, tup, vecs = model(data_batch) vs.add_batch(data_batch, tup, kld.data, vecs) count_batch = torch.FloatTensor(count_batch).cuda() real_loss = torch.div((recon_loss + kld).data, count_batch) doc_num = len(count_batch) # remove nan for n in real_loss: if n == n: acc_real_loss += n # acc_real_ppl += torch.sum(real_ppl) acc_loss += torch.sum(recon_loss).item() # acc_kl_loss += torch.sum(kld.item()) count_batch = count_batch + 1e-12 word_cnt += torch.sum(count_batch) doc_cnt += doc_num cur_loss = acc_loss[0] / word_cnt # word loss cur_kl = acc_kl_loss / word_cnt # cur_real_loss = acc_real_loss / doc_cnt cur_real_loss = cur_loss + cur_kl Runner.log_instant(None, self.args, glob_iter, epo, start_time, cur_loss , cur_kl, cur_real_loss) vs.write_log()
def evaluate(self, dev_batches): self.learner.eval() print("Test start") acc_loss = 0 cnt = 0 random.shuffle(dev_batches) for idx, batch in enumerate(dev_batches): self.optim.zero_grad() seq_len, batch_sz = batch.size() if self.data.condition: seq_len -= 1 if self.model.input_cd_bit > 1: bit = batch[0, :] bit = GVar(bit) else: bit = None batch = batch[1:, :] else: bit = None feed = self.data.get_feed(batch) seq_len, batch_sz = feed.size() emb = self.model.drop(self.model.emb(feed)) if self.model.input_cd_bit > 1: bit = self.model.enc_bit(bit) else: bit = None h = self.model.forward_enc(emb, bit) tup, kld, vecs = self.model.forward_build_lat( h) # batchsz, lat dim if self.model.dist_type == 'vmf': code = tup['mu'] elif self.model.dist_type == 'nor': code = tup['mean'] else: raise NotImplementedError emb = torch.mean(emb, dim=0) if self.c2b: loss = self.learner(code, emb) else: loss = self.learner(code, emb) acc_loss += loss.data[0] cnt += 1 if idx % 400 == 0: acc_loss = 0 cnt = 0 # print("===============test===============") # print(acc_loss / cnt) print(acc_loss / cnt) return float(acc_loss / cnt)
def train_epo(self, train_batches): self.learner.train() print("Epo start") acc_loss = 0 cnt = 0 random.shuffle(train_batches) for idx, batch in enumerate(train_batches): self.optim.zero_grad() seq_len, batch_sz = batch.size() if self.data.condition: seq_len -= 1 if self.model.input_cd_bit > 1: bit = batch[0, :] bit = GVar(bit) else: bit = None batch = batch[1:, :] else: bit = None feed = self.data.get_feed(batch) seq_len, batch_sz = feed.size() emb = self.model.drop(self.model.emb(feed)) if self.model.input_cd_bit > 1: bit = self.model.enc_bit(bit) else: bit = None h = self.model.forward_enc(emb, bit) tup, kld, vecs = self.model.forward_build_lat( h) # batchsz, lat dim if self.model.dist_type == 'vmf': code = tup['mu'] elif self.model.dist_type == 'nor': code = tup['mean'] else: raise NotImplementedError emb = torch.mean(emb, dim=0) if self.c2b: loss = self.learner(code, emb) else: loss = self.learner(code, emb) loss.backward() self.optim.step() acc_loss += loss.data[0] cnt += 1 if idx % 400 == 0 and (idx > 0): print("Training {}".format(acc_loss / cnt)) acc_loss = 0 cnt = 0
def get_feed(data_patch): """ Given data patch, get the corresponding input of that data patch. Given: [A, B, C, D] Return: [SOS, A, B, C] :param data_patch: :return: """ # seq, batch bsz = data_patch.size()[1] sos = torch.LongTensor(1, bsz).fill_(1) input_data = GVar(torch.cat((sos, data_patch[:-1]))) return input_data
def __init__(self, hid_dim, lat_dim, kappa=1, norm_max=2, norm_func=True): super().__init__() self.hid_dim = hid_dim self.lat_dim = lat_dim self.kappa = kappa # self.func_kappa = torch.nn.Linear(hid_dim, lat_dim) self.func_mu = torch.nn.Linear(hid_dim, lat_dim) self.func_norm = torch.nn.Linear(hid_dim, 1) # self.noise_scaler = kappa self.norm_eps = 1 self.norm_max = norm_max self.norm_clip = torch.nn.Hardtanh(0.00001, self.norm_max - self.norm_eps) self.norm_func = norm_func # KLD accounts for both VMF and uniform parts kld_value = unif_vMF._vmf_kld(kappa, lat_dim) \ + unif_vMF._uniform_kld(0., self.norm_eps, 0., self.norm_max) self.kld = GVar(torch.from_numpy(np.array([kld_value])).float()) print('KLD: {}'.format(self.kld.data[0]))
def _sample_ortho_batch(self, mu, dim): """ :param mu: Variable, [batch size, latent dim] :param dim: scala. =latent dim :return: """ _batch_sz, _lat_dim = mu.size() assert _lat_dim == dim squeezed_mu = mu.unsqueeze(1) v = GVar(torch.randn(_batch_sz, dim, 1)) # TODO random # v = GVar(torch.linspace(-1, 1, steps=dim)) # v = v.expand(_batch_sz, dim).unsqueeze(2) rescale_val = torch.bmm(squeezed_mu, v).squeeze(2) proj_mu_v = mu * rescale_val ortho = v.squeeze() - proj_mu_v ortho_norm = torch.norm(ortho, p=2, dim=1, keepdim=True) y = ortho / ortho_norm return y
def sample_cell(self, mu, norm, kappa): batch_sz, lat_dim = mu.size() result = [] sampled_vecs = GVar(torch.FloatTensor(batch_sz, lat_dim)) for b in range(batch_sz): this_mu = mu[b] # kappa = np.linalg.norm(this_theta) this_mu = this_mu / torch.norm(this_mu, p=2) w = self._sample_weight(kappa, lat_dim) w_var = GVar(w * torch.ones(lat_dim)) v = self._sample_orthonormal_to(this_mu, lat_dim) scale_factr = torch.sqrt(GVar(torch.ones(lat_dim)) - torch.pow(w_var, 2)) orth_term = v * scale_factr muscale = this_mu * w_var sampled_vec = orth_term + muscale sampled_vecs[b] = sampled_vec # sampled_vec = torch.FloatTensor(sampled_vec) # result.append(sampled_vec) return sampled_vecs.unsqueeze(0)
def analysis_evaluation_order_and_importance(self): """ Measure the change of cos sim given different encoding sequence :return: """ self.logger.info("Start Analyzing ... Picking up 100 batches to analyze") start_time = time.time() test_batches = self.data.test random.shuffle(test_batches) test_batches = test_batches[:100] self.logger.info("Total {} batches to analyze".format(len(test_batches))) acc_loss = 0 acc_kl_loss = 0 batch_cnt = 0 all_cnt = 0 cnt = 0 sample_bag = [] try: for idx, batch in enumerate(test_batches): if idx % 10 == 0: print("Now Idx: {}".format(idx)) seq_len, batch_sz = batch.size() if self.data.condition: seq_len -= 1 bit = batch[0, :] batch = batch[1:, :] bit = GVar(bit) else: bit = None feed = self.data.get_feed(batch) if self.args.swap > 0.0001: bag = self.analysis_eval_order(feed, batch, bit) elif self.args.replace > 0.0001: bag = self.analysis_eval_word_importance(feed, batch, bit) else: print("Maybe Wrong mode?") raise NotImplementedError sample_bag.append(bag) except KeyboardInterrupt: print("early stop") if self.args.swap > 0.0001: return self.unpack_bag_order(sample_bag) elif self.args.replace > 0.0001: return self.unpack_bag_word_importance(sample_bag) else: raise NotImplementedError
def forward_build_lat(self, hidden, nsample=3): """ :param hidden: :return: tup, kld [batch_sz], out [nsamples, batch_sz, lat_dim] """ # hidden: batch_sz, nhid if self.args.dist == 'nor': tup, kld, out = self.dist.build_bow_rep(hidden, nsample) # 2 for bidirect, 2 for h and elif self.args.dist == 'vmf': tup, kld, out = self.dist.build_bow_rep(hidden, nsample) elif self.args.dist == 'unifvmf': tup, kld, out = self.dist.build_bow_rep(hidden, nsample) elif self.args.dist == 'vmf_diff': tup, kld, out = self.dist.build_bow_rep(hidden, nsample) elif self.args.dist == 'sph': tup, kld, out = self.dist.build_bow_rep(hidden, nsample) elif self.args.dist == 'zero': out = GVar(torch.zeros(1, hidden.size()[0], self.lat_dim)) tup = {} kld = GVar(torch.zeros(1)) else: raise NotImplementedError return tup, kld, out
def train_epo(self, train_batches): self.learner.train() print("Epo start") acc_loss = 0 acc_accuracy = 0 all_cnt = 0 cnt = 0 random.shuffle(train_batches) for idx, batch in enumerate(train_batches): self.optim.zero_grad() bit, vec = batch bit = GVar(bit) vec = GVar(vec) # print(bit) loss, pred = self.learner(vec, bit) _, argmax = torch.max(pred, dim=1) loss.backward() self.optim.step() argmax = argmax.data bit = bit.data for jdx, num in enumerate(argmax): gt = bit[jdx] all_cnt += 1 if gt == num: acc_accuracy += 1 acc_loss += loss.data[0] cnt += 1 if idx % 400 == 0: print("Loss {} \tAccuracy {}".format(acc_loss / cnt, acc_accuracy / all_cnt)) acc_loss = 0 cnt = 0
def analysis_eval_order(self, feed, batch, bit): assert 0.33 > self.args.swap > 0.0001 origin_feed = feed.clone() feed_1x = swap_by_batch(feed.clone(), self.args.swap) feed_2x = swap_by_batch(feed.clone(), self.args.swap * 2) feed_3x = swap_by_batch(feed.clone(), self.args.swap * 3) feed_4x = swap_by_batch(feed.clone(), self.args.swap * 4) feed_5x = swap_by_batch(feed.clone(), self.args.swap * 5) feed_6x = swap_by_batch(feed.clone(), self.args.swap * 6) target = GVar(batch) # recon_loss, kld, aux_loss, tup, vecs, decoded = self.model(feed, target, bit) original_recon_loss, kld, _, original_tup, original_vecs, _ = self.model(origin_feed, target, bit) if 'Distnor' in self.instance_name: key_name = "mean" elif 'vmf' in self.instance_name: key_name = "mu" else: raise NotImplementedError original_mu = original_tup[key_name] recon_loss_1x, _, _, tup_1x, vecs_1x, _ = self.model(feed_1x, target, bit) recon_loss_2x, _, _, tup_2x, vecs_2x, _ = self.model(feed_2x, target, bit) recon_loss_3x, _, _, tup_3x, vecs_3x, _ = self.model(feed_3x, target, bit) recon_loss_4x, _, _, tup_4x, vecs_4x, _ = self.model(feed_4x, target, bit) recon_loss_5x, _, _, tup_5x, vecs_5x, _ = self.model(feed_5x, target, bit) recon_loss_6x, _, _, tup_6x, vecs_6x, _ = self.model(feed_6x, target, bit) # target: seq_len, batchsz # decoded: seq_len, batchsz, dict_sz # tup: 'mean' 'logvar' for Gaussian # 'mu' for vMF # vecs # cos_1x = self.analyze_batch_order(original_vecs, vecs_1x).data # cos_2x = self.analyze_batch_order(original_vecs, vecs_2x).data # cos_3x = self.analyze_batch_order(original_vecs, vecs_3x).data cos_1x = torch.mean(cos(original_mu, tup_1x[key_name])).data cos_2x = torch.mean(cos(original_mu, tup_2x[key_name])).data cos_3x = torch.mean(cos(original_mu, tup_3x[key_name])).data cos_4x = torch.mean(cos(original_mu, tup_4x[key_name])).data cos_5x = torch.mean(cos(original_mu, tup_5x[key_name])).data cos_6x = torch.mean(cos(original_mu, tup_6x[key_name])).data # print(cos_1x, cos_2x, cos_3x) return [ [original_recon_loss.data, recon_loss_1x.data, recon_loss_2x.data, recon_loss_3x.data, recon_loss_4x.data, recon_loss_5x.data, recon_loss_6x.data] , [cos_1x, cos_2x, cos_3x, cos_4x, cos_5x, cos_6x]]
def analysis_eval_word_importance(self, feed, batch, bit): """ Given a sentence, replace a certain word by UNK and see how lat code change from the origin one. :param feed: :param batch: :param bit: :return: """ seq_len, batch_sz = batch.size() target = GVar(batch) origin_feed = feed.clone() original_recon_loss, kld, _, original_tup, original_vecs, _ = self.model(origin_feed, target, bit) # original_vecs = torch.mean(original_vecs, dim=0).unsqueeze(2) original_mu = original_tup['mu'] # table_of_code = torch.FloatTensor(seq_len, batch_sz ) table_of_mu = torch.FloatTensor(seq_len, batch_sz) for t in range(seq_len): cur_feed = feed.clone() cur_feed[t, :] = 2 cur_recon, _, _, cur_tup, cur_vec, _ = self.model(cur_feed, target, bit) cur_mu = cur_tup['mu'] # cur_vec = torch.mean(cur_vec, dim=0).unsqueeze(2) # x = cos(original_vecs, cur_vec) # x= x.squeeze() y = cos(original_mu, cur_mu) y = y.squeeze() # table_of_code[t,:] = x.data table_of_mu[t, :] = y.data bag = [] for b in range(batch_sz): weight = table_of_mu[:, b] word_ids = feed[:, b] words = self.ids_to_words(word_ids.data.tolist()) seq_of_words = words.split(" ") s = "" for t in range(seq_len): if weight[t] < 0.98: s += "*" + seq_of_words[t] + "* " else: s += seq_of_words[t] + " " bag.append(s) return bag
class vMF(torch.nn.Module): def __init__(self, hid_dim, lat_dim, kappa=1): super().__init__() self.hid_dim = hid_dim self.lat_dim = lat_dim self.kappa = kappa # self.func_kappa = torch.nn.Linear(hid_dim, lat_dim) self.func_mu = torch.nn.Linear(hid_dim, lat_dim) self.kld = GVar(torch.from_numpy(vMF._vmf_kld(kappa, lat_dim)).float()) print('KLD: {}'.format(self.kld.data[0])) def estimate_param(self, latent_code): ret_dict = {} ret_dict['kappa'] = self.kappa # Only compute mu, use mu/mu_norm as mu, # use 1 as norm, use diff(mu_norm, 1) as redundant_norm mu = self.func_mu(latent_code) norm = torch.norm(mu, 2, 1, keepdim=True) mu_norm_sq_diff_from_one = torch.pow(torch.add(norm, -1), 2) redundant_norm = torch.sum(mu_norm_sq_diff_from_one, dim=1, keepdim=True) ret_dict['norm'] = torch.ones_like(mu) ret_dict['redundant_norm'] = redundant_norm mu = mu / torch.norm(mu, p=2, dim=1, keepdim=True) ret_dict['mu'] = mu return ret_dict def compute_KLD(self, tup, batch_sz): return self.kld.expand(batch_sz) @staticmethod def _vmf_kld(k, d): tmp = (k * ((sp.iv(d / 2.0 + 1.0, k) + sp.iv(d / 2.0, k) * d / (2.0 * k)) / sp.iv(d / 2.0, k) - d / (2.0 * k)) \ + d * np.log(k) / 2.0 - np.log(sp.iv(d / 2.0, k)) \ - sp.loggamma(d / 2 + 1) - d * np.log(2) / 2).real if tmp != tmp: exit() return np.array([tmp]) def build_bow_rep(self, lat_code, n_sample): batch_sz = lat_code.size()[0] tup = self.estimate_param(latent_code=lat_code) mu = tup['mu'] norm = tup['norm'] kappa = tup['kappa'] kld = self.compute_KLD(tup, batch_sz) vecs = [] if n_sample == 1: return tup, kld, self.sample_cell(mu, norm, kappa) for n in range(n_sample): sample = self.sample_cell(mu, norm, kappa) vecs.append(sample) vecs = torch.cat(vecs, dim=0) return tup, kld, vecs def sample_cell(self, mu, norm, kappa): batch_sz, lat_dim = mu.size() result = [] sampled_vecs = GVar(torch.FloatTensor(batch_sz, lat_dim)) for b in range(batch_sz): this_mu = mu[b] # kappa = np.linalg.norm(this_theta) this_mu = this_mu / torch.norm(this_mu, p=2) w = self._sample_weight(kappa, lat_dim) w_var = GVar(w * torch.ones(lat_dim)) v = self._sample_orthonormal_to(this_mu, lat_dim) scale_factr = torch.sqrt(GVar(torch.ones(lat_dim)) - torch.pow(w_var, 2)) orth_term = v * scale_factr muscale = this_mu * w_var sampled_vec = orth_term + muscale sampled_vecs[b] = sampled_vec # sampled_vec = torch.FloatTensor(sampled_vec) # result.append(sampled_vec) return sampled_vecs.unsqueeze(0) def _sample_weight(self, kappa, dim): """Rejection sampling scheme for sampling distance from center on surface of the sphere. """ dim = dim - 1 # since S^{n-1} b = dim / (np.sqrt(4. * kappa ** 2 + dim ** 2) + 2 * kappa) # b= 1/(sqrt(4.* kdiv**2 + 1) + 2 * kdiv) x = (1. - b) / (1. + b) c = kappa * x + dim * np.log(1 - x ** 2) # dim * (kdiv *x + np.log(1-x**2)) while True: z = np.random.beta(dim / 2., dim / 2.) # concentrates towards 0.5 as d-> inf w = (1. - (1. + b) * z) / (1. - (1. - b) * z) u = np.random.uniform(low=0, high=1) if kappa * w + dim * np.log(1. - x * w) - c >= np.log( u): # thresh is dim *(kdiv * (w-x) + log(1-x*w) -log(1-x**2)) return w def _sample_orthonormal_to(self, mu, dim): """Sample point on sphere orthogonal to mu. """ v = GVar(torch.randn(dim)) rescale_value = mu.dot(v) / mu.norm() proj_mu_v = mu * rescale_value.expand(dim) ortho = v - proj_mu_v ortho_norm = torch.norm(ortho) return ortho / ortho_norm.expand_as(ortho)
def train_epo(self, args, model, train_batches, epo, epo_start_time): model.train() start_time = time.time() if self.args.optim == 'sgd': self.optim = torch.optim.SGD(model.parameters(), lr=self.args.cur_lr) else: raise NotImplementedError acc_loss = 0 acc_kl_loss = 0 acc_aux_loss = 0 acc_avg_cos = 0 acc_avg_norm = 0 # acc_real_loss = 0 word_cnt = 0 doc_cnt = 0 cnt = 0 random.shuffle(train_batches) for idx, batch in enumerate(train_batches): self.optim.zero_grad() self.glob_iter += 1 data_batch, count_batch = DataNg.fetch_data( self.data.train[0], self.data.train[1], batch, self.data.vocab_size) model.zero_grad() data_batch = GVar(torch.FloatTensor(data_batch)) recon_loss, kld, aux_loss, tup, vecs = model(data_batch) # print("Recon: {}\t KL: {}".format(recon_loss,kld)) # total_loss = torch.mean(recon_loss + kld * args.kl_weight) total_loss = torch.mean(recon_loss + kld * args.kl_weight + aux_loss * args.aux_weight) total_loss.backward() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) self.optim.step() count_batch = GVar(torch.FloatTensor(count_batch)) doc_num = len(count_batch) # real_loss = torch.div((recon_loss + kld).data, count_batch) # acc_real_loss += torch.sum(real_loss) acc_loss += torch.sum(recon_loss).item() acc_kl_loss += torch.sum(kld).item() acc_aux_loss += torch.sum(aux_loss).item() acc_avg_cos += tup['avg_cos'].item() acc_avg_norm += tup['avg_norm'].item() cnt += 1 count_batch = count_batch + 1e-12 word_cnt += torch.sum(count_batch).item() doc_cnt += doc_num if idx % args.log_interval == 0 and idx > 0: cur_loss = acc_loss / word_cnt # word loss cur_kl = acc_kl_loss / word_cnt cur_aux_loss = acc_aux_loss / word_cnt cur_avg_cos = acc_avg_cos / cnt cur_avg_norm = acc_avg_norm / cnt # cur_real_loss = acc_real_loss / doc_cnt cur_real_loss = cur_loss + cur_kl # if cur_kl < 0.14 or cur_kl > 1.2: # raise KeyboardInterrupt Runner.log_instant(self.writer, self.args, self.glob_iter, epo, start_time, cur_avg_cos, cur_avg_norm, cur_loss, cur_kl, cur_aux_loss, cur_real_loss) acc_loss = 0 acc_kl_loss = 0 acc_aux_loss = 0 acc_avg_cos = 0 acc_avg_norm = 0 word_cnt = 0 doc_cnt = 0 cnt = 0 if idx % (3 * args.log_interval) == 0 and idx > 0: with torch.no_grad(): self.eval_interface()
def train_epo(self, args, model, train_batches, epo, epo_start_time, glob_iter): model.train() model.FLAG_train = True start_time = time.time() if self.args.optim == 'sgd': self.optim = torch.optim.SGD(model.parameters(), lr=self.args.cur_lr) acc_loss = 0 acc_kl_loss = 0 acc_aux_loss = 0 acc_avg_cos = 0 acc_avg_norm = 0 batch_cnt = 0 all_cnt = 0 cnt = 0 random.shuffle(train_batches) for idx, batch in enumerate(train_batches): self.optim.zero_grad() seq_len, batch_sz = batch.size() if self.data.condition: seq_len -= 1 if self.model.input_cd_bit > 1: bit = batch[0, :] bit = GVar(bit) else: bit = None batch = batch[1:, :] else: bit = None feed = self.data.get_feed(batch) if self.args.swap > 0.00001: feed = swap_by_batch(feed, self.args.swap) if self.args.replace > 0.00001: feed = replace_by_batch(feed, self.args.replace, self.model.ntoken) self.glob_iter += 1 target = GVar(batch) recon_loss, kld, aux_loss, tup, vecs, _ = model(feed, target, bit) total_loss = recon_loss * seq_len + torch.mean( kld) * self.args.kl_weight + torch.mean( aux_loss) * args.aux_weight total_loss.backward() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. # torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) # Upgrade to pytorch 0.4.1 torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip, norm_type=2) self.optim.step() acc_loss += recon_loss.data * seq_len * batch_sz acc_kl_loss += torch.sum(kld).data acc_aux_loss += torch.sum(aux_loss).data acc_avg_cos += tup['avg_cos'].data acc_avg_norm += tup['avg_norm'].data cnt += 1 batch_cnt += batch_sz all_cnt += batch_sz * seq_len if idx % args.log_interval == 0 and idx > 0: cur_loss = acc_loss.item() / all_cnt cur_kl = acc_kl_loss.item() / all_cnt # if cur_kl < 0.03: # raise KeyboardInterrupt # if cur_kl > 0.7: # raise KeyboardInterrupt cur_aux_loss = acc_aux_loss.item() / all_cnt cur_avg_cos = acc_avg_cos.item() / cnt cur_avg_norm = acc_avg_norm.item() / cnt cur_real_loss = cur_loss + cur_kl Runner.log_instant(self.writer, self.args, self.glob_iter, epo, start_time, cur_avg_cos, cur_avg_norm, cur_loss, cur_kl, cur_aux_loss, cur_real_loss)