def _loop(self, iter, optimizer=None, clip=0, learn=False, re=None): context = torch.enable_grad if learn else torch.no_grad cum_loss = 0 cum_ntokens = 0 cum_rx = 0 cum_kl = 0 batch_loss = 0 batch_ntokens = 0 states = None with context(): titer = tqdm(iter) if learn else iter for i, batch in enumerate(titer): if learn: optimizer.zero_grad() text, lens = batch.text x = text[:-1] y = text[1:] lens = lens - 1 e, lene = batch.entities t, lent = batch.types v, lenv = batch.values #rlen, N = e.shape #r = torch.stack([e, t, v], dim=-1) r = [e, t, v] assert (lene == lent).all() lenr = lene # should i include <eos> in ppl? nwords = y.ne(1).sum() # assert nwords == lens.sum() T, N = y.shape #if states is None: states = self.init_state(N) logits, _ = self(x, states, lens, r, lenr) nll = self.loss(logits, y) kl = 0 nelbo = nll + kl if learn: nelbo.div(nwords.item()).backward() if clip > 0: gnorm = clip_(self.parameters(), clip) #for param in self.rnn_parameters(): #gnorm = clip_(param, clip) optimizer.step() cum_loss += nelbo.item() cum_ntokens += nwords.item() batch_loss += nelbo.item() batch_ntokens += nwords.item() if re is not None and i % re == -1 % re: titer.set_postfix(loss=batch_loss / batch_ntokens, gnorm=gnorm) batch_loss = 0 batch_ntokens = 0 return cum_loss, cum_ntokens
def _loop(self, diter, optimizer=None, clip=0, learn=False, re=None, once=False): context = torch.enable_grad if learn else torch.no_grad cum_loss = 0 cum_ntokens = 0 cum_rx = 0 cum_kl = 0 batch_loss = 0 batch_ntokens = 0 with context(): titer = tqdm(diter) if learn else diter for i, batch in enumerate( titer if not once else [next(iter(diter))]): if learn: optimizer.zero_grad() x, lens = batch.text lx, _ = batch.locations_text ax, _ = batch.aspects_text l = batch.locations a = batch.aspects y = batch.sentiments # keys k = [l, a] kx = [lx, ax] # N x y for dealing w imbalance logits = self(x, lens, k, kx) nll = self.loss(logits, y) nelbo = nll N = y.shape[0] if learn: nelbo.div(N).backward() if clip > 0: gnorm = clip_(self.parameters(), clip) optimizer.step() cum_loss += nelbo.item() cum_ntokens += N batch_loss += nelbo.item() batch_ntokens += N if re is not None and i % re == -1 % re: titer.set_postfix(loss=batch_loss / batch_ntokens, gnorm=gnorm) batch_loss = 0 batch_ntokens = 0 return cum_loss, cum_ntokens
def _loop(self, iter, learn, args): context = torch.enable_grad if learn else torch.no_grad loss = 0 ntokens = 0 rloss = 0 rntokens = 0 hidden_states = None with context(): t = tqdm(iter) for i, batch in enumerate(t): if learn: args.optimizer.zero_grad() x = batch.text y = batch.target nwords = y.ne(1).sum() T, N = y.shape if hidden_states is None: hidden_states = self.init_hidden(N) logits, hidden_states = self(x, hidden_states) if learn: hidden_states = ( [ tuple(x.detach() for x in tup) for tup in hidden_states[0] ], tuple(x.detach() for x in hidden_states[1]), hidden_states[2].detach(), ) logprobs = F.log_softmax(logits, dim=-1) logp = logprobs.view(T * N, -1).gather(-1, y.view(T * N, 1)) kl = 0 nll = -logp[y.view(-1, 1) != 1].sum() nelbo = nll + kl if learn: nelbo.div(nwords.item()).backward() if args.clip > 0: gnorm = clip_(self.parameters(), args.clip) #for param in self.rnn_parameters(): #gnorm = clip_(param, args.clip) args.optimizer.step() loss += nelbo.item() ntokens += nwords.item() rloss += nelbo.item() rntokens += nwords.item() if args is not None and i % args.report_interval == -1 % args.report_interval: t.set_postfix(loss=rloss / rntokens, gnorm=gnorm) rloss = 0 rntokens = 0 return loss, ntokens
def _loop_ie(self, iter, optimizer=None, clip=0, learn=False, re=None): context = torch.enable_grad if learn else torch.no_grad cum_loss = 0 cum_ntokens = 0 cum_rx = 0 cum_kl = 0 batch_loss = 0 batch_ntokens = 0 states = None with context(): t = tqdm(iter) if learn else iter for i, batch in enumerate(t): if learn: optimizer.zero_grad() text, lens = batch.text x = text[:-1] y = text[1:] lens = lens - 1 # should i include <eos> in ppl? nwords = y.ne(1).sum() # assert nwords == lens.sum() T, N = y.shape #if states is None: states = self.init_state(N) logits, _ = self(x, states, lens) #logits, states = self(x, states, lens) logprobs = F.log_softmax(logits, dim=-1) logp = logprobs.view(T * N, -1).gather(-1, y.view(T * N, 1)) kl = 0 nll = -logp[y.view(-1, 1) != 1].sum() nelbo = nll + kl if learn: nelbo.div(nwords.item()).backward() if clip > 0: gnorm = clip_(self.parameters(), clip) #for param in self.rnn_parameters(): #gnorm = clip_(param, clip) optimizer.step() cum_loss += nelbo.item() cum_ntokens += nwords.item() batch_loss += nelbo.item() batch_ntokens += nwords.item() if re is not None and i % re == -1 % re: t.set_postfix(loss=batch_loss / batch_ntokens, gnorm=gnorm) batch_loss = 0 batch_ntokens = 0 return cum_loss, cum_ntokens
def _loop( self, iter, optimizer=None, clip=0, learn=False, re=None, exact=False, elbo=True, T=64, E=128, supattn=False, supcopy=False, ): context = torch.enable_grad if learn else torch.no_grad # DBG self.copied = 0 couldve_copied = 0 cum_loss = 0 cum_ntokens = 0 cum_rx = 0 cum_kl = 0 batch_loss = 0 batch_ntokens = 0 batch_rx = 0 batch_kl = 0 states = None with context(): titer = tqdm(iter) if learn else iter for i, batch in enumerate(titer): if learn: optimizer.zero_grad() text, x_info = batch.text mask = x_info.mask lens = x_info.lengths L = text.shape["time"] x = text.narrow("time", 0, L - 1) y = text.narrow("time", 1, L - 1) #x = text[:-1] #y = text[1:] x_info.lengths.sub_(1) e, e_info = batch.entities t, t_info = batch.types v, v_info = batch.values lene = e_info.lengths lent = t_info.lengths lenv = v_info.lengths #rlen, N = e.shape #r = torch.stack([e, t, v], dim=-1) r = [e, t, v] assert (lene == lent).all() lenr = lene r_info = e_info # values text vt, vt_info = batch.values_text # DBG couldve_copied += (vt == y).sum().item() ue, ue_info = batch.uentities ut, ut_info = batch.utypes v2d = batch.v2d # should i include <eos> in ppl? nwords = y.ne(1).sum() # assert nwords == lens.sum() #T = y.shape["time"] N = y.shape["batch"] #if states is None: states = self.init_state(N) # should i include <eos> in ppl? no, should not. mask = y.ne(1) #* y.ne(3) nwords = mask.sum() # ugh.......refactor this if exact: rvinfo, states = self.marginal_nll(x, states, x_info, r, r_info, vt, y, x_info, T=T, E=E, learn=learn) nll = -rvinfo.log_py[mask].sum() cum_loss += nll.item() batch_loss += nll.item() else: rvinfo, _, nll, kl = self(x, states, x_info, r, r_info, vt, y, x_info, T=T, E=E, learn=learn) nelbo = nll + kl if learn: if clip > 0: gnorm = clip_(self.parameters(), clip) optimizer.step() cum_rx += nll.item() batch_rx += nll.item() cum_kl += kl.item() batch_kl += kl.item() cum_loss += nelbo.item() batch_loss += nelbo.item() cum_ntokens += nwords.item() batch_ntokens += nwords.item() if re is not None and i % re == -1 % re: titer.set_postfix( elbo=batch_loss / batch_ntokens, nll=batch_rx / batch_ntokens, kl=batch_kl / batch_ntokens, gnorm=gnorm, ) batch_loss = 0 batch_ntokens = 0 batch_rx = 0 batch_kl = 0 print( f"COPIED: {self.copied:,} vs {couldve_copied:,} / {cum_ntokens:,}") print(f"NLL: {cum_rx / cum_ntokens} || KL: {cum_kl / cum_ntokens}") return cum_loss, cum_ntokens
def _loop( self, iter, optimizer=None, clip=0, learn=False, re=None, exact=False, elbo=True, T=64, E=32, R=4, supattn=False, supcopy=False, ): context = torch.enable_grad if learn else torch.no_grad # DBG self.copied = 0 couldve_copied = 0 nthe = 0 n4 = 0 nfour = 0 n6 = 0 nsix = 0 n8 = 0 neight = 0 nchi = 0 nleb = 0 bthe = 0 b4 = 0 bfour = 0 b6 = 0 bsix = 0 b8 = 0 beight = 0 bchi = 0 bleb = 0 athe = 0 a4 = 0 afour = 0 a6 = 0 asix = 0 a8 = 0 aeight = 0 achi = 0 aleb = 0 abthe = 0 ab4 = 0 abfour = 0 ab6 = 0 absix = 0 ab8 = 0 abeight = 0 abchi = 0 ableb = 0 d1 = 0 d2 = 0 d4 = 0 dn = 0 cum_loss = 0 cum_ntokens = 0 cum_rx = 0 cum_kl = 0 batch_loss = 0 batch_ntokens = 0 batch_rx = 0 batch_kl = 0 states = None with context(): titer = tqdm(iter) if learn else iter for i, batch in enumerate(titer): if learn: optimizer.zero_grad() text, x_info = batch.text mask = x_info.mask lens = x_info.lengths L = text.shape["time"] x = text.narrow("time", 0, L - 1) y = text.narrow("time", 1, L - 1) #x = text[:-1] #y = text[1:] x_info.lengths.sub_(1) e, e_info = batch.entities t, t_info = batch.types v, v_info = batch.values lene = e_info.lengths lent = t_info.lengths lenv = v_info.lengths #rlen, N = e.shape #r = torch.stack([e, t, v], dim=-1) r = [e, t, v] assert (lene == lent).all() lenr = lene r_info = e_info # values text vt, vt_info = batch.values_text # DBG couldve_copied += (vt == y).transpose( "batch", "time", "els").any(-1).sum().item() ue, ue_info = batch.uentities ut, ut_info = batch.utypes v2d = batch.v2d vt2d = batch.vt2d # should i include <eos> in ppl? nwords = y.ne(1).sum() # assert nwords == lens.sum() #T = y.shape["time"] N = y.shape["batch"] #if states is None: states = self.init_state(N) # should i include <eos> in ppl? no, should not. mask = y.ne(1) #* y.ne(3) nwords = mask.sum() # ugh.......refactor this if exact: rvinfo, states = self.marginal_nll(x, states, x_info, r, r_info, vt, ue, ue_info, ut, ut_info, v2d, vt2d, y, x_info, T=T, E=E, R=R, learn=learn) nll = -rvinfo.log_py[mask].sum() cum_loss += nll.item() batch_loss += nll.item() else: rvinfo, _, nll, kl = self( x, states, x_info, r, r_info, vt, ue, ue_info, ut, ut_info, v2d, vt2d, y, x_info, T=T, E=E, R=R, learn=learn, supattn=supattn, # hacky idxs=batch.idxs, ) if rvinfo.log_qc_y is None: self.copied += (rvinfo.log_pc.exp().get("copy", 1) > 0.5).sum().item() else: self.copied += (rvinfo.log_qc_y.exp().get("copy", 1) > 0.5).sum().item() nelbo = nll + kl # check p(c|y) for subset of words # likelihood ratio of content vs noncontent pyn = rvinfo.log_py_c0 pyc = rvinfo.log_py_ac1 ythe = y.eq(self.Vx.stoi["the"]) # find "4" and "four" y4 = y.eq(self.Vx.stoi["4"]) yfour = y.eq(self.Vx.stoi["four"]) # find "6" and "six" y6 = y.eq(self.Vx.stoi["6"]) ysix = y.eq(self.Vx.stoi["six"]) # find "8" and "eight" y8 = y.eq(self.Vx.stoi["8"]) yeight = y.eq(self.Vx.stoi["eight"]) ychi = y.eq(self.Vx.stoi["chicago"]) yleb = y.eq(self.Vx.stoi["lebron"]) nthe += ythe.sum() n4 += y4.sum() nfour += yfour.sum() n6 += y6.sum() nsix += ysix.sum() n8 += y8.sum() neight += yeight.sum() nchi += ychi.sum() nleb += yleb.sum() vt2d_f = vt2d.stack(("e", "t"), "r") alignments = vt2d_f.gather("r", rvinfo.a_s, "k") a4 += alignments.eq( self.Vx.stoi["4"]).values.any(0)[y4.values].sum() afour += alignments.eq( self.Vx.stoi["4"]).values.any(0)[yfour.values].sum() a6 += alignments.eq( self.Vx.stoi["6"]).values.any(0)[y6.values].sum() asix += alignments.eq( self.Vx.stoi["6"]).values.any(0)[ysix.values].sum() a8 += alignments.eq( self.Vx.stoi["8"]).values.any(0)[y8.values].sum() aeight += alignments.eq( self.Vx.stoi["8"]).values.any(0)[yeight.values].sum() achi += alignments.eq(self.Vx.stoi["chicago"]).values.any( 0)[ychi.values].sum() aleb += alignments.eq(self.Vx.stoi["lebron"]).values.any( 0)[yleb.values].sum() if pyn is not None: better = (pyn < pyc).narrow("k", self.Kb, self.K) betterthe = (pyn.get("k", 0) < pyc.get("k", 0))[ythe] better4 = (pyn.get("k", 0) < pyc.get("k", 0))[y4] betterfour = (pyn.get("k", 0) < pyc.get("k", 0))[yfour] better6 = (pyn.get("k", 0) < pyc.get("k", 0))[y6] bettersix = (pyn.get("k", 0) < pyc.get("k", 0))[ysix] better8 = (pyn.get("k", 0) < pyc.get("k", 0))[y8] bettereight = (pyn.get("k", 0) < pyc.get("k", 0))[yeight] betterchi = (pyn.get("k", 0) < pyc.get("k", 0))[ychi] betterleb = (pyn.get("k", 0) < pyc.get("k", 0))[yleb] bthe += betterthe.sum() b4 += better4.sum() bfour += betterfour.sum() b6 += better6.sum() bsix += bettersix.sum() b8 += better8.sum() beight += bettereight.sum() bchi += betterchi.sum() bleb += betterleb.sum() # if any alignments have the right word AND are better ab4 += (alignments.eq(self.Vx.stoi["4"]) * better).values.any(0)[y4.values].sum() abfour += (alignments.eq(self.Vx.stoi["4"]) * better).values.any(0)[yfour.values].sum() ab6 += (alignments.eq(self.Vx.stoi["6"]) * better).values.any(0)[y6.values].sum() absix += (alignments.eq(self.Vx.stoi["6"]) * better).values.any(0)[ysix.values].sum() ab8 += (alignments.eq(self.Vx.stoi["8"]) * better).values.any(0)[y8.values].sum() abeight += (alignments.eq(self.Vx.stoi["8"]) * better).values.any(0)[yeight.values].sum() abchi += (alignments.eq(self.Vx.stoi["chicago"]) * better).values.any(0)[ychi.values].sum() ableb += (alignments.eq(self.Vx.stoi["lebron"]) * better).values.any(0)[yleb.values].sum() else: qc = rvinfo.log_qc_y.softmax("copy") better = qc.get("copy", 1) > 0.5 bthe += better[ythe].sum() b4 += better[y4].sum() bfour += better[yfour].sum() b6 += better[y6].sum() bsix += better[ysix].sum() b8 += better[y8].sum() beight += better[yeight].sum() bchi += better[ychi].sum() bleb += better[yleb].sum() # if any alignments have the right word AND are better ab4 += (alignments.eq(self.Vx.stoi["4"]) * better).values.any(0)[y4.values].sum() abfour += (alignments.eq(self.Vx.stoi["4"]) * better).values.any(0)[yfour.values].sum() ab6 += (alignments.eq(self.Vx.stoi["6"]) * better).values.any(0)[y6.values].sum() absix += (alignments.eq(self.Vx.stoi["6"]) * better).values.any(0)[ysix.values].sum() ab8 += (alignments.eq(self.Vx.stoi["8"]) * better).values.any(0)[y8.values].sum() abeight += (alignments.eq(self.Vx.stoi["8"]) * better).values.any(0)[yeight.values].sum() abchi += (alignments.eq(self.Vx.stoi["chicago"]) * better).values.any(0)[ychi.values].sum() ableb += (alignments.eq(self.Vx.stoi["lebron"]) * better).values.any(0)[yleb.values].sum() maxd, _ = self.delta.max("k") d1 += (maxd > 0.1).sum() d2 += (maxd > 0.2).sum() d4 += (maxd > 0.4).sum() dn += maxd.values.nelement() """ print(f"Proportion greater than .1: {float(d1.item()) / float(dn)}") print(f"Proportion greater than .2: {float(d2.item()) / float(dn)}") print(f"Proportion greater than .4: {float(d4.item()) / float(dn)}") """ if learn: if clip > 0: gnorm = clip_(self.parameters(), clip) optimizer.step() cum_rx += nll.item() batch_rx += nll.item() cum_kl += kl.item() batch_kl += kl.item() cum_loss += nelbo.item() if not hasattr( self, "nokl") else nll.item() batch_loss += nelbo.item() cum_ntokens += nwords.item() batch_ntokens += nwords.item() if re is not None and i % re == -1 % re: titer.set_postfix( elbo=batch_loss / batch_ntokens, nll=batch_rx / batch_ntokens, kl=batch_kl / batch_ntokens, gnorm=gnorm, #b4 = b4.item() / n4.item(), #bfour = bfour.item() / nfour.item(), #b6 = b6.item() / n6.item(), #bsix = bsix.item() / nsix.item(), #b8 = b8.item() / n8.item(), #beight = beight.item() / neight.item(), ) batch_loss = 0 batch_ntokens = 0 batch_rx = 0 batch_kl = 0 if re is not None and i % 100 == -1 % 100: #if re is not None and i % 10 == -1 % 10: print("better prob") print(f"the: {bthe.item()} / {nthe.item()}") print(f"4: {b4.item()} / {n4.item()}") print(f"four: {bfour.item()} / {nfour.item()}") print(f"6: {b6.item()} / {n6.item()}") print(f"six: {bsix.item()} / {nsix.item()}") print(f"8: {b8.item()} / {n8.item()}") print(f"eight: {beight.item()} / {neight.item()}") print(f"chicago: {bchi.item()} / {nchi.item()}") print(f"lebron: {bleb.item()} / {nleb.item()}") print("correct alignment") print(f"4: {a4.item()} / {n4.item()}") print(f"four: {afour.item()} / {nfour.item()}") print(f"6: {a6.item()} / {n6.item()}") print(f"six: {asix.item()} / {nsix.item()}") print(f"8: {a8.item()} / {n8.item()}") print(f"eight: {aeight.item()} / {neight.item()}") print(f"chi: {achi.item()} / {nchi.item()}") print(f"leb: {aleb.item()} / {nleb.item()}") print("correct alignment and better prob") print(f"4: {ab4.item()} / {n4.item()}") print(f"four: {abfour.item()} / {nfour.item()}") print(f"6: {ab6.item()} / {n6.item()}") print(f"six: {absix.item()} / {nsix.item()}") print(f"8: {ab8.item()} / {n8.item()}") print(f"eight: {abeight.item()} / {neight.item()}") print(f"chi: {abchi.item()} / {nchi.item()}") print(f"leb: {ableb.item()} / {nleb.item()}") print( f"Proportion greater than .1: {float(d1.item()) / float(dn)}" ) print( f"Proportion greater than .2: {float(d2.item()) / float(dn)}" ) print( f"Proportion greater than .4: {float(d4.item()) / float(dn)}" ) words = ["2", "4", "6", "8", "bulls", "lebron", "chicago"] for word in words: if self.v2d: input = self.lutv.weight[self.Vv.stoi[word]] elif self.untie: input = self.lutgx.weight[self.Vx.stoi[word]] else: input = self.lutx.weight[self.Vx.stoi[word]] weight = self.lutx.weight if not self.untie else self.lutgx.weight if self.mlp: probs, idx = (weight @ self.Wvy1( self.Wvy0( NamedTensor( input, names=("ctxt", ), )).tanh()).tanh().values ).softmax(0).topk(5) else: probs, idx = (weight @ input).softmax(0).topk(5) print(f"{word} probs " + " || ".join( f"{self.Vx.itos[x]}: {p:.2f}" for p, x in zip(probs.tolist(), idx.tolist()))) print("better prob") print(f"the: {bthe.item()} / {nthe.item()}") print(f"4: {b4.item()} / {n4.item()}") print(f"four: {bfour.item()} / {nfour.item()}") print(f"6: {b6.item()} / {n6.item()}") print(f"six: {bsix.item()} / {nsix.item()}") print(f"8: {b8.item()} / {n8.item()}") print(f"eight: {beight.item()} / {neight.item()}") print("correct alignment") print(f"4: {a4.item()} / {n4.item()}") print(f"four: {afour.item()} / {nfour.item()}") print(f"6: {a6.item()} / {n6.item()}") print(f"six: {asix.item()} / {nsix.item()}") print(f"8: {a8.item()} / {n8.item()}") print(f"eight: {aeight.item()} / {neight.item()}") print("correct alignment and better prob") print(f"4: {ab4.item()} / {n4.item()}") print(f"four: {abfour.item()} / {nfour.item()}") print(f"6: {ab6.item()} / {n6.item()}") print(f"six: {absix.item()} / {nsix.item()}") print(f"8: {ab8.item()} / {n8.item()}") print(f"eight: {abeight.item()} / {neight.item()}") print(f"Proportion greater than .1: {float(d1.item()) / float(dn)}") print(f"Proportion greater than .2: {float(d2.item()) / float(dn)}") print(f"Proportion greater than .4: {float(d4.item()) / float(dn)}") #print(f"COPIED: {self.copied:,} vs {couldve_copied:,} / {cum_ntokens:,}") print(f"NLL: {cum_rx / cum_ntokens} || KL: {cum_kl / cum_ntokens}") return cum_loss, cum_ntokens
def _ie_loop( self, iter, optimizer=None, clip=0, learn=False, re=None, T=256, E=None, R=None, ): self.train(learn) context = torch.enable_grad if learn else torch.no_grad cum_loss = 0 cum_ntokens = 0 cum_rx = 0 cum_kl = 0 batch_loss = 0 batch_ntokens = 0 states = None cum_e_correct = 0 cum_t_correct = 0 batch_e_correct = 0 batch_t_correct = 0 cum_correct = 0 batch_correct = 0 with context(): titer = tqdm(iter) if learn else iter for i, batch in enumerate(titer): if learn: optimizer.zero_grad() text, x_info = batch.text mask = x_info.mask lens = x_info.lengths L = text.shape["time"] x = text.narrow("time", 0, L - 1) y = text.narrow("time", 1, L - 1) #x = text[:-1] #y = text[1:] x_info.lengths.sub_(1) e, e_info = batch.entities t, t_info = batch.types v, v_info = batch.values lene = e_info.lengths lent = t_info.lengths lenv = v_info.lengths #rlen, N = e.shape #r = torch.stack([e, t, v], dim=-1) r = [e, t, v] assert (lene == lent).all() lenr = lene r_info = e_info # values text vt, vt_info = batch.values_text ue, ue_info = batch.uentities ut, ut_info = batch.utypes v2d = batch.v2d vt2d = batch.vt2d # should i include <eos> in ppl? nwords = y.ne(1).sum() # assert nwords == lens.sum() N = y.shape["batch"] #if states is None: self.K = 1 # whatever states = self.init_state(N) rvinfo, _, nll, kl = self( x, states, x_info, r, r_info, vt, ue, ue_info, ut, ut_info, v2d, vt2d, y=y, y_info=x_info, T=T, ) # only for crnnlma though """ cor, ecor, tcor, tot = Ie.pat1( self.ea.rename("els", "e"), self.ta.rename("els", "t"), batch.ie_d, ) """ #ds = batch.ie_d ets = batch.ie_etv #num_cells = batch.num_cells log_pe = rvinfo.log_qe_y.rename( "els", "e") if not self.evalp else rvinfo.log_pe.rename( "els", "e") log_pt = rvinfo.log_qt_y.rename( "els", "t") if not self.evalp else rvinfo.log_pt.rename( "els", "t") log_pe = log_pe.cpu() log_pt = log_pt.cpu() ue = ue.cpu() ut = ut.cpu() #import pdb; pdb.set_trace() # calculate accuracy #import pdb; pdb.set_trace() ds = batch.ie_et_d num_cells = float(sum(len(d) for d in ds)) # calculate accuracy #import pdb; pdb.set_trace() for batch, d in enumerate(ds): for t, (es, ts) in d.items(): #t = t + 1 _, e_max = log_pe.get("batch", batch).get("time", t).max("e") _, t_max = log_pt.get("batch", batch).get("time", t).max("t") e_preds = ue.get("batch", batch).get("els", e_max.item()) t_preds = ut.get("batch", batch).get("els", t_max.item()) import pdb pdb.set_trace() correct = (es.eq(e_preds) * ts.eq(t_preds)).any().float().item() e_correct = es.eq(e_preds).any().float().item() t_correct = ts.eq(t_preds).any().float().item() #import pdb; pdb.set_trace() cum_correct += correct batch_correct += correct cum_e_correct += e_correct cum_t_correct += t_correct batch_e_correct += e_correct batch_t_correct += t_correct #import pdb; pdb.set_trace() if learn: if clip > 0: gnorm = clip_(self.parameters(), clip) #for param in self.rnn_parameters(): #gnorm = clip_(param, clip) optimizer.step() cum_loss += nll.item() cum_ntokens += num_cells batch_loss += nll.item() batch_ntokens += num_cells if re is not None and i % re == -1 % re: titer.set_postfix( loss=batch_loss / batch_ntokens, gnorm=gnorm, acc=batch_correct / batch_ntokens, e=batch_e_correct / batch_ntokens, t=batch_t_correct / batch_ntokens, ) batch_loss = 0 batch_ntokens = 0 batch_correct = 0 batch_e_correct = 0 batch_t_correct = 0 print( f"acc: {cum_correct / cum_ntokens} | E acc: {cum_e_correct / cum_ntokens} | T acc: {cum_t_correct / cum_ntokens}" ) print(f"total supervised cells: {cum_ntokens}") return cum_loss, cum_ntokens
def _loop( self, iter, optimizer=None, clip=0, learn=False, re=None, supattn=False, supcopy=False, T=None, E=None, R=None, ): self.train(learn) context = torch.enable_grad if learn else torch.no_grad cum_loss = 0 cum_ntokens = 0 cum_rx = 0 cum_kl = 0 batch_loss = 0 batch_ntokens = 0 states = None with context(): titer = tqdm(iter) if learn else iter for i, batch in enumerate(titer): if learn: optimizer.zero_grad() text, x_info = batch.text mask = x_info.mask lens = x_info.lengths L = text.shape["time"] x = text.narrow("time", 0, L - 1) y = text.narrow("time", 1, L - 1) #x = text[:-1] #y = text[1:] x_info.lengths.sub_(1) e, e_info = batch.entities t, t_info = batch.types v, v_info = batch.values lene = e_info.lengths lent = t_info.lengths lenv = v_info.lengths #rlen, N = e.shape #r = torch.stack([e, t, v], dim=-1) r = [e, t, v] assert (lene == lent).all() lenr = lene r_info = e_info ue, ue_info = batch.uentities ut, ut_info = batch.utypes v2d = batch.v2d vt2d = batch.vt2d vt, vt_info = batch.values_text # should i include <eos> in ppl? nwords = y.ne(1).sum() # assert nwords == lens.sum() T = y.shape["time"] N = y.shape["batch"] #if states is None: states = self.init_state(N) logits, _ = self(x, states, x_info, r, r_info, vt, ue, ue_info, ut, ut_info, v2d, vt2d, y) nll = self.loss(logits, y) if self.maskedc: slist = [ "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "1", "2", "3", "4", "5", "6", "7", "8", "9", ] mask = sum(y.eq(self.Vx.stoi[x]) for x in slist) nllc = -self.log_pc[mask].sum() else: nllc = 0 kl = 0 nelbo = nll + kl if learn: (nelbo + nllc).div(nwords.item()).backward() if clip > 0: gnorm = clip_(self.parameters(), clip) #for param in self.rnn_parameters(): #gnorm = clip_(param, clip) optimizer.step() cum_loss += nelbo.item() cum_ntokens += nwords.item() batch_loss += nelbo.item() batch_ntokens += nwords.item() if re is not None and i % re == -1 % re: titer.set_postfix(loss=batch_loss / batch_ntokens, gnorm=gnorm) batch_loss = 0 batch_ntokens = 0 return cum_loss, cum_ntokens
def _ie_loop( self, iter, optimizer=None, clip=0, learn=False, re=None, T=None, E=None, R=None, ): self.train(learn) context = torch.enable_grad if learn else torch.no_grad if not self.v2d: self.copy_x_to_v() cum_loss = 0 cum_ntokens = 0 cum_rx = 0 cum_kl = 0 batch_loss = 0 batch_ntokens = 0 states = None cum_TP, cum_FP, cum_FN, cum_TM, cum_TG = 0, 0, 0, 0, 0 batch_TP, batch_FP, batch_FN, batch_TM, batch_TG = 0, 0, 0, 0, 0 # for debugging cum_e_correct = 0 cum_t_correct = 0 cum_correct = 0 with context(): titer = tqdm(iter) if learn else iter for i, batch in enumerate(titer): if learn: optimizer.zero_grad() text, x_info = batch.text mask = x_info.mask lens = x_info.lengths L = text.shape["time"] x = text.narrow("time", 0, L - 1) y = text.narrow("time", 1, L - 1) #x = text[:-1] #y = text[1:] x_info.lengths.sub_(1) e, e_info = batch.entities t, t_info = batch.types v, v_info = batch.values lene = e_info.lengths lent = t_info.lengths lenv = v_info.lengths #rlen, N = e.shape #r = torch.stack([e, t, v], dim=-1) r = [e, t, v] assert (lene == lent).all() lenr = lene r_info = e_info vt, vt_info = batch.values_text ue, ue_info = batch.uentities ut, ut_info = batch.utypes v2d = batch.v2d vt2d = batch.vt2d # should i include <eos> in ppl? nwords = y.ne(1).sum() # assert nwords == lens.sum() T = y.shape["time"] N = y.shape["batch"] #if states is None: states = self.init_state(N) logits, _ = self(x, states, x_info, r, r_info, vt, ue, ue_info, ut, ut_info, v2d, vt2d) nll = self.loss(logits, y) # only for crnnlma though """ cor, ecor, tcor, tot = Ie.pat1( self.ea.rename("els", "e"), self.ta.rename("els", "t"), batch.ie_d, ) """ #ds = batch.ie_d etvs = batch.ie_etv #num_cells = batch.num_cells num_cells = float(sum(len(d) for d in etvs)) log_pe = self.ea.cpu() log_pt = self.ta.cpu() log_pc = self.log_pc.cpu() ue = ue.cpu() ut = ut.cpu() # need log_pv, need to check if self.v2d # batch x time x hid h = self.lutx(y).values log_pv = torch.einsum( "nth,vh->ntv", [h, self.lutv.weight.data]).log_softmax(-1) log_pv = NamedTensor(log_pv, names=("batch", "time", "v")) tp, fp, fn, tm, tg = pr( etvs, ue, ut, log_pe, log_pt, log_pv, log_pc=log_pc, lens=lens, Ve=self.Ve, Vt=self.Vt, Vv=self.Vv, Vx=self.Vx, text=text, ) cum_TP += tp cum_FP += fp cum_FN += fn cum_TM += tm cum_TG += tg batch_TP += tp batch_FP += fp batch_FN += fn batch_TM += tm batch_TG += tg # For debugging ds = batch.ie_et_d for batch, d in enumerate(ds): for t, (es, ts) in d.items(): #t = t + 1 _, e_max = log_pe.get("batch", batch).get("time", t).max("e") _, t_max = log_pt.get("batch", batch).get("time", t).max("t") e_preds = ue.get("batch", batch).get("els", e_max.item()) t_preds = ut.get("batch", batch).get("els", t_max.item()) correct = (es.eq(e_preds) * ts.eq(t_preds)).any().float().item() e_correct = es.eq(e_preds).any().float().item() t_correct = ts.eq(t_preds).any().float().item() cum_correct += correct cum_e_correct += e_correct cum_t_correct += t_correct if learn: if clip > 0: gnorm = clip_(self.parameters(), clip) #for param in self.rnn_parameters(): #gnorm = clip_(param, clip) optimizer.step() cum_loss += nll.item() cum_ntokens += num_cells batch_loss += nll.item() batch_ntokens += num_cells if re is not None and i % re == -1 % re: titer.set_postfix( loss=batch_loss / batch_ntokens, gnorm=gnorm, p=batch_TP / (batch_TP + batch_FP), r=batch_TP / (batch_TP + batch_FN), ) batch_loss = 0 batch_ntokens = 0 batch_TP, batch_FP, batch_FN, batch_TM, batch_TG = 0, 0, 0, 0, 0 print( f"DBG acc: {cum_correct / cum_ntokens} | E acc: {cum_e_correct / cum_ntokens} | T acc: {cum_t_correct / cum_ntokens}" ) print( f"p: {cum_TP / (cum_TP + cum_FP)} || r: {cum_TP / (cum_TP + cum_FN)}" ) print(f"total supervised cells: {cum_ntokens}") return cum_loss, cum_ntokens
def _loop( self, iter, supiter=[], optimizer=None, clip=0, learn=False, re=None, elbo=True, T=64, E=32, R=4, supattn=False, supcopy=False, ): context = torch.enable_grad if learn else torch.no_grad # DBG self.copied = 0 self.couldve_copied = 0 # unsup cum_loss = 0. cum_ntokens = 0. cum_rx = 0. cum_klv = 0. cum_kla = 0. batch_loss = 0. batch_ntokens = 0. batch_rx = 0. batch_klv = 0. batch_kla = 0. cum_N = 0. batch_N = 0. # sup cum_nllpy = 0. cum_nllpv = 0. cum_nllqv = 0. cum_Ny = 0. cum_Nv = 0. batch_nllpy = 0. batch_nllpv = 0. batch_nllqv = 0. batch_Ny = 0. batch_Nv = 0. ziter = zip_longest(iter, supiter) states = None with context(): titer = tqdm(ziter) if learn else ziter for i, (batch, supbatch) in enumerate(titer): if learn: optimizer.zero_grad() if batch is not None: rvinfo, _, nll, klv, kla, Nv, Ny = self._batch( batch, T = T, E = E, R = R, # ? learn = learn, sup = False, ) nelbo = nll + klv + kla cum_rx += nll.item() batch_rx += nll.item() cum_klv += klv.item() batch_klv += klv.item() cum_kla += kla.item() batch_kla += kla.item() cum_loss += nelbo.item() if not hasattr(self, "nokl") else nll.item() batch_loss += nelbo.item() cum_ntokens += Ny.item() batch_ntokens += Ny.item() cum_N += Nv.item() batch_N += Nv.item() if supbatch is not None: ( rvinfosup, _, nll_pv, nll_py_v, nll_qv_y, v_total, y_total ) = self._batch( supbatch, T = T, E = E, R = R, learn = learn, sup = True, ) cum_nllpy += nll_py_v.item() cum_nllpv += nll_pv.item() cum_nllqv += nll_qv_y.item() cum_Ny += y_total.item() cum_Nv += v_total.item() batch_nllpy += nll_py_v.item() batch_nllpv += nll_pv.item() batch_nllqv += nll_qv_y.item() batch_Ny += y_total.item() batch_Nv += v_total.item() if learn: if clip > 0: gnorm = clip_(self.parameters(), clip) optimizer.step() if re is not None and i % re == -1 % re: titer.set_postfix( elbo = batch_loss / batch_ntokens if batch_loss != 0 else 0, nll = batch_rx / batch_ntokens if batch_rx != 0 else 0, klv = batch_klv / cum_N if batch_klv != 0 else 0, kla = batch_kla / batch_N if batch_kla != 0 else 0, py = batch_nllpy / batch_Ny if batch_nllpy != 0 else 0, pv = batch_nllpv / batch_Nv if batch_nllpv != 0 else 0, qv = batch_nllqv / batch_Nv if batch_nllqv != 0 else 0, gnorm = gnorm, ) batch_loss = 0. batch_ntokens = 0. batch_N = 0. batch_rx = 0. batch_klv = 0. batch_kla = 0. # sup batch_nllpy = 0. batch_nllpv = 0. batch_nllqv = 0. batch_Ny = 0. batch_Nv = 0. #print(f"COPIED: {self.copied:,} vs {couldve_copied:,} / {cum_ntokens:,}") if cum_N == 0: cum_N = 1 if cum_ntokens == 0: cum_ntokens = 1 if cum_Ny == 0: cum_Ny = 1 if cum_Nv == 0: cum_Nv = 1 print(f"NLL: {cum_rx / cum_ntokens} || KLv: {cum_klv / cum_N} || KLa: {cum_kla / cum_N}") print(f"py: {cum_nllpy / cum_Ny} || pv: {cum_nllpv / cum_Nv} || qv: {cum_nllqv / cum_Nv}") return cum_loss, cum_ntokens
def _loop(self, iter, optimizer=None, clip=0, learn=False, re=None): context = torch.enable_grad if learn else torch.no_grad cum_loss = 0 cum_ntokens = 0 cum_rx = 0 cum_kl = 0 batch_loss = 0 batch_ntokens = 0 states = None with context(): titer = tqdm(iter) if learn else iter for i, batch in enumerate(titer): if learn: optimizer.zero_grad() text, lens = batch.text x = text[:-1] y = text[1:] lens = lens - 1 e, lene = batch.entities t, lent = batch.types v, lenv = batch.values #rlen, N = e.shape #r = torch.stack([e, t, v], dim=-1) r = [e, t, v] assert (lene == lent).all() lenr = lene # should i include <eos> in ppl? no, should not. mask = y.ne(1) #* y.ne(3) nwords = mask.sum() # assert nwords == lens.sum() T, N = y.shape R = e.shape[0] #if states is None: states = self.init_state(N) logits, _, sampled_log_pa, log_pa, log_qay = self(x, states, lens, r, lenr, y) nll = self.loss(logits, y) B = nll[-1] reward = (nll[:-1] - B.unsqueeze(0)).detach() * sampled_log_pa reward = reward[mask.unsqueeze(0).expand_as(reward)].sum() nll = nll[:-1].sum(0)[mask].sum() # giving nans because of masking, sigh # hmm...this worked before? """ qa = log_qay.exp() qa.data[log_qay == float("-inf")] = 0 kl0 = qa * (log_qay - log_pa) kl0[log_qay == float("-inf")] = 0 kl = kl0.sum() #import pdb; pdb.set_trace() """ #""" kl = 0 qa = log_qay.exp() for i, l in enumerate(lenr.tolist()): #p = Categorical(logits=log_pa[:,i,:l]) #q = Categorical(logits=log_qay[:,i,:l]) #kl0 = kl_divergence(q, p).sum() kl0 = qa[:,i,:l] * (log_qay[:,i,:l] - log_pa[:,i,:l]) kl0[log_qay[:,i,:l] == float("-inf")] = 0 kl0 = kl0.sum() kl += kl0 #""" #p = Categorical(logits=log_pa[mask]) #q = Categorical(logits=log_qay[mask]) #kl = kl_divergence(q, p).sum() nelbo = nll + kl if learn: (nelbo - reward).div(nwords.item()).backward() if clip > 0: gnorm = clip_(self.parameters(), clip) #for param in self.rnn_parameters(): #gnorm = clip_(param, clip) optimizer.step() cum_loss += nelbo.item() cum_ntokens += nwords.item() batch_loss += nelbo.item() batch_ntokens += nwords.item() if re is not None and i % re == -1 % re: titer.set_postfix(loss = batch_loss / batch_ntokens, gnorm = gnorm) batch_loss = 0 batch_ntokens = 0 return cum_loss, cum_ntokens
def _ie_loop( self, iter, optimizer=None, clip=0, learn=False, re=None, T=None, E=None, R=None, ): self.train(learn) context = torch.enable_grad if learn else torch.no_grad cum_loss = 0 cum_ntokens = 0 cum_rx = 0 cum_kl = 0 batch_loss = 0 batch_ntokens = 0 states = None cum_TP, cum_FP, cum_FN, cum_TM, cum_TG = 0, 0, 0, 0, 0 batch_TP, batch_FP, batch_FN, batch_TM, batch_TG = 0, 0, 0, 0, 0 with context(): titer = tqdm(iter) if learn else iter for i, batch in enumerate(titer): if learn: optimizer.zero_grad() text, x_info = batch.text mask = x_info.mask lens = x_info.lengths x = text # ie_idx is off by one because of bos #ie_idx = batch.ie_idx #ie_e = batch.ie_e #ie_t = batch.ie_t etvs = batch.ie_etv #num_cells = batch.num_cells num_cells = float(sum(len(d) for d in etvs)) nll, log_pe, log_pt, log_pv = self(x, x_info, etvs, learn=learn) # gather predictions batch_positives = [] for batch, etv in enumerate(etvs): # Get model positives positives = {} T = lens.get("batch", batch).item() for t in range(T): e_p, e_max = log_pe.get("batch", batch).get("time", t).max("e") t_p, t_max = log_pt.get("batch", batch).get("time", t).max("t") v_p, v_preds = log_pv.get("batch", batch).get("time", t).max("v") e_pred = e_max.item() t_pred = t_max.item() v_pred = v_preds.item() if (e_pred != self.Ve.stoi[self.NONE] and t_pred != self.Vt.stoi[self.NONE] and v_pred != self.Vv.stoi[self.NONE]): key = (e_pred, t_pred, v_pred) score = (e_p + t_p + v_p).item() if key not in positives or score > positives[key]: positives[key] = score batch_positives.append(positives) # Compare against true positives true_positives = set() for es, ts, vs, _ in etv: true_positives |= set( zip(es.tolist(), ts.tolist(), vs.tolist())) total_m = len(positives) total_g = len(true_positives) tp = len(set(positives) & true_positives) fp = len(set(positives) - true_positives) fn = len(true_positives - set(positives)) cum_TP += tp cum_FP += fp cum_FN += fn cum_TM += total_m cum_TG += total_g batch_TP += tp batch_FP += fp batch_FN += fn batch_TM += total_m batch_TG += total_g if learn: if clip > 0: gnorm = clip_(self.parameters(), clip) #for param in self.rnn_parameters(): #gnorm = clip_(param, clip) optimizer.step() cum_loss += nll.item() cum_ntokens += num_cells batch_loss += nll.item() batch_ntokens += num_cells if re is not None and i % re == -1 % re: titer.set_postfix( loss=batch_loss / batch_ntokens, gnorm=gnorm, p=batch_TP / (batch_TP + batch_FP) if batch_TP > 0 else 0, r=batch_TP / (batch_TP + batch_FN) if batch_TP > 0 else 0, ) batch_loss = 0 batch_ntokens = 0 batch_TP, batch_FP, batch_FN, batch_TM, batch_TG = 0, 0, 0, 0, 0 print( f"p: {cum_TP / (cum_TP + cum_FP) if cum_TP > 0 else 0} || r: {cum_TP / (cum_TP + cum_FN) if cum_TP > 0 else 0}" ) print(f"total supervised cells: {cum_ntokens}") return cum_loss, cum_ntokens, cum_TP, cum_FP, cum_FN, cum_TM, cum_TG
def _vie_loop( self, iter, optimizer=None, clip=0, learn=False, re=None, T=None, E=None, R=None, ): self.train(learn) context = torch.enable_grad if learn else torch.no_grad cum_loss = 0 cum_ntokens = 0 cum_rx = 0 cum_kl = 0 batch_loss = 0 batch_ntokens = 0 states = None cum_e_correct = 0 cum_t_correct = 0 batch_e_correct = 0 batch_t_correct = 0 cum_correct = 0 batch_correct = 0 cum_copyable = 0 batch_copyable = 0 with context(): titer = tqdm(iter) if learn else iter for i, batch in enumerate(titer): if learn: optimizer.zero_grad() text, x_info = batch.text mask = x_info.mask lens = x_info.lengths x = text # ie_idx is off by one because of bos #ie_idx = batch.ie_idx #ie_e = batch.ie_e #ie_t = batch.ie_t #ds = batch.ie_d ets = batch.ie_et #num_cells = batch.num_cells #num_cells = float(sum(len(d) for d in ds)) e, r_info = batch.entities t, _ = batch.types v, _ = batch.values vt, _ = batch.values_text num_cells = r_info.mask.sum() nll, log_pv, attn = self(x, x_info, e=e, t=t, v=v, r_info=r_info, ets=ets, learn=learn) hv = log_pv.max("v")[1] correct = (hv == v)[r_info.mask].sum() batch_correct += correct cum_correct += correct vt_nopad = vt.clone() vt_nopad[vt_nopad == 1] = -1 num_copyable = vt_nopad.eq(x).sum() batch_copyable += num_copyable cum_copyable += num_copyable # Deal with n/a tokens na_v_idx = self.Vv.stoi["n/a"] #import pdb; pdb.set_trace() if learn: if clip > 0: gnorm = clip_(self.parameters(), clip) #for param in self.rnn_parameters(): #gnorm = clip_(param, clip) optimizer.step() cum_loss += nll cum_ntokens += num_cells batch_loss += nll batch_ntokens += num_cells if re is not None and i % re == -1 % re: titer.set_postfix( loss=batch_loss.item() / batch_ntokens.item(), gnorm=gnorm, acc=batch_correct.item() / batch_ntokens.item(), copyable=batch_copyable.item() / batch_ntokens.item(), ) batch_loss = 0 batch_ntokens = 0 batch_correct = 0 batch_copyable = 0 print( f"acc: {cum_correct.item() / cum_ntokens.item()} || copyable: {cum_copyable.item() / cum_ntokens.item()}" ) print(f"total supervised cells: {cum_ntokens.item()}") return cum_loss.item(), cum_ntokens.item()