def exact_train(corpus, model, optim, args, device): """ this is exact """ model.train() K = args.K elbo, ntokens = 0.0, 0 perm = torch.randperm(len(corpus)) #exact_logmarg = batch_fwdalg if args.markov_order == 1 else batch_var_elim exact_logmarg = batch_fwdalg for i, idx in enumerate(perm): # if i > 1: # break optim.zero_grad() batch = corpus[idx.item()].to(device) emdist, transdist = model.get_emdist(), model.get_transdist() btrull = exact_logmarg(batch, emdist, transdist).sum() belbo = btrull.item() btrull.div(-batch.size(1)).backward() elbo += belbo ntokens += batch.nelement() clip_opt_params(optim, args.clip) optim.step() if (i+1) % args.log_interval == 0: print("{:5d}/{:5d} | lr {:02.4f} | ppl {:8.2f}".format( i+1, perm.size(0), args.lr, math.exp(-elbo/ntokens))) return elbo, ntokens
def train(corpus, model, infnet, optim, args, device): model.train() infnet.train() K = args.K elbo, ntokens = 0.0, 0 perm = torch.randperm(len(corpus)) mean, var = None, None for i, idx in enumerate(perm): optim.zero_grad() batch = corpus[idx.item()].to(device) if args.reinforce: if mean is None: mean, var, balph = 0, 0, 0 else: balph = args.alpha belbo, mean, var = reinforce_elbo( model, infnet, batch, mean, var, alph=balph, nsamps=args.nsamps, vimco=args.vimco) else: # gumbel + st belbo = gumbel_st_elbo(model, infnet, batch) elbo += belbo ntokens += batch.nelement() clip_opt_params(optim, args.clip) optim.step() if (i+1) % args.log_interval == 0: print("{:5d}/{:5d} | lr {:02.4f} | ppl {:8.2f}".format( i+1, perm.size(0), args.lr, math.exp(-elbo/ntokens))) return elbo, ntokens
def lbp_train(corpus, model, optimizer, edges, ne, device, args): """ opt is just gonna be for everyone ne - 1 x nvis*nhid """ model.train() total_out_loss, nexamples = 0.0, 0 niter = 0 for i, batchthing in enumerate(corpus): optimizer.zero_grad() batch = batchthing[0] # i guess it's a list batch = batch.view(batch.size(0), -1).to(device) # bsz x V bsz, nvis = batch.size() nedges = nvis * model.nhid # V*H edges ed_lpots = model.get_edge_scores().unsqueeze(0) # 1 x nvis x nhid x 4 un_lpots = model.get_unary_scores().unsqueeze( 0) # 1 x (nvis + nhid) x 2 with torch.no_grad(): exed_lpots = ed_lpots.view(nedges, 1, 2, 2) # V*H x 1 x 2 x 2 exun_lpots = un_lpots.transpose(0, 1) # V+H x 1 x 2 nodebs, facbs, ii, _, _ = dolbp(exed_lpots, edges, x=None, emlps=exun_lpots, niter=args.lbp_iter, renorm=True, tol=args.lbp_tol, randomize=args.randomize_lbp) niter += ii # reshape log unary marginals: V+H x 1 x 2 -> 1 x V+H x 2 tau_u = torch.stack([nodebs[t] for t in range(nvis + model.nhid) ]).transpose(0, 1) # reshape log fac marginals: nedges x 1 x 2 x 2 -> 1 x nedges x 2 x 2 tau_e = torch.stack([facbs[e] for e in range(nedges)]).transpose(0, 1) # exponentiate tau_u, tau_e = (tau_u.exp() + EPS), (tau_e.exp() + EPS).view( 1, nvis, model.nhid, -1) fz = rbm_bethe_fez2(tau_u, tau_e, un_lpots, ed_lpots, ne) out_loss = -model._neg_free_energy(batch).sum() - fz * bsz total_out_loss += out_loss.item() out_loss.div(bsz).backward() clip_opt_params(optimizer, args.clip) optimizer.step() nexamples += bsz if (i + 1) % args.log_interval == 0: print("{:5d}/{:5d} | its {:3.2f} | out_loss {:8.5f}".format( i + 1, len(corpus), niter / (i + 1), total_out_loss / nexamples)) return total_out_loss, nexamples
def train(corpus, model, infnet, moptim, ioptim, ne, penfunc, device, args): """ opt is just gonna be for everyone ne - 1 x nvis*nhid """ #import time model.train() infnet.train() total_out_loss, total_in_loss, nexamples = 0.0, 0.0, 0 total_pen_loss = 0.0 for i, batchthing in enumerate(corpus): batch = batchthing[0] # i guess it's a list batch = batch.view(batch.size(0), -1).to(device) # bsz x V bsz, nvis = batch.size() npenterms = 2*nvis*model.nhid # V for each H marginal and H for each V marginal # maximize wrt rho with torch.no_grad(): ed_lpots = model.get_edge_scores().unsqueeze(0) # 1 x nvis x nhid x 4 un_lpots = model.get_unary_scores().unsqueeze(0) # 1 x (nvis + nhid) x 2 if args.reset_adam: ioptim = torch.optim.Adam(infnet.parameters(), lr=args.ilr) for _ in range(args.inf_iter): ioptim.zero_grad() pred_rho = infnet.q() # V*H x K^2 logits in_loss, pen_loss = rbm_inner_loss(pred_rho, un_lpots, ed_lpots, ne, penfunc) total_in_loss += in_loss.item()*bsz total_pen_loss += args.pen_mult/npenterms * pen_loss.item()*bsz in_loss = in_loss + args.pen_mult/npenterms * pen_loss in_loss.backward() clip_opt_params(ioptim, args.clip) ioptim.step() pred_rho = pred_rho.detach() # min wrt params moptim.zero_grad() ed_lpots = model.get_edge_scores().unsqueeze(0) # 1 x nvis x nhid x 4 un_lpots = model.get_unary_scores().unsqueeze(0) # 1 x (nvis + nhid) x 2 out_loss, open_loss = rbm_outer_loss(batch, model, pred_rho, un_lpots, ed_lpots, ne, penfunc=None) total_out_loss += out_loss.item() out_loss.div(bsz).backward() clip_opt_params(moptim, args.clip) moptim.step() nexamples += bsz if (i+1) % args.log_interval == 0: print("{:5d}/{:5d} | out_loss {:8.5f} | in_loss {:8.5f} | pen_loss {:8.6f}".format( i+1, len(corpus), total_out_loss/nexamples, total_in_loss/nexamples, total_pen_loss/(nexamples*args.pen_mult))) return total_out_loss, total_in_loss, total_pen_loss, nexamples
def exact_train(corpus, model, optim, helper, device, args): model.train() ll, ntokens = 0.0, 0 perm = torch.randperm(len(corpus)) for i, idx in enumerate(perm): # if i > 1: # break optim.zero_grad() batch = corpus[idx.item()].to(device) T, bsz = batch.size() # get normalizer; depends only on edges if T > 1: edges = helper.get_edges(T).to( device) # symbolic edge representation edge_scores = model.get_edge_scores(edges, T) # nedges x K*K ending_at = helper.get_ending_at( T) # used to make infc less annoying lnZ = batch_ufwdalg(edge_scores, T, helper.K, helper.markov_order, edges, ending_at=ending_at) # 1 else: edges, edge_scores, ending_at, lnZ = None, None, 0, 0 # get unnormalized marginal obs_lpots = model.get_obs_lps(batch) # bsz x T x K lnZx = batch_ufwdalg(edge_scores, T, helper.K, helper.markov_order, edges, ulpots=obs_lpots, ending_at=ending_at) # bsz logmarg = (lnZx - lnZ).sum() ll += logmarg.item() logmarg.div(-bsz).backward() ntokens += batch.nelement() clip_opt_params(optim, args.clip) optim.step() if (i + 1) % args.log_interval == 0: print("{:5d}/{:5d} | lr {:02.4f} | ppl {:8.2f}".format( i + 1, perm.size(0), args.lr, math.exp(-ll / ntokens))) return ll, ntokens
def train_pcd(corpus, model, optimizer, device, args): model.train() total_loss, nexamples = 0.0, 0 for i, batchthing in enumerate(corpus): batch = batchthing[0] # i guess it's a list batch = batch.view(batch.size(0), -1).to(device) # bsz x V bsz, nvis = batch.size() optimizer.zero_grad() loss = model.rb_pcd_loss(batch) total_loss += loss.item() loss.div(bsz).backward() clip_opt_params(optimizer, args.clip) optimizer.step() nexamples += bsz if (i + 1) % args.log_interval == 0: print("{:5d}/{:5d} | loss {:8.5f}".format(i + 1, len(corpus), total_loss / nexamples)) return total_loss, nexamples
def main(args, trdat, valdat, nvis, ne): print("main args", args) torch.manual_seed(args.seed) if torch.cuda.is_available(): if not args.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) device = torch.device("cuda" if args.cuda else "cpu") if len(args.train_from) > 0: saved_stuff = torch.load(args.train_from) saved_args = saved_stuff["opt"] model = RBM(nvis, saved_args) model.load_state_dict(saved_stuff["mod_sd"]) model = model.to(device) print("running ais...") valnpll, vnexagain = moar_validate(valdat, model, device, args, do_ais=True) print("Epoch {:3d} | val npll {:.3f}".format(0, valnpll / vnexagain)) exit() else: model = RBM(nvis, args).to(device) infnet = SeqInfNet(nvis, args) infnet = infnet.to(device) ne = ne.to(device) if args.training == "lbp" or args.check_corr: edges = get_rbm_edges(nvis, args.nhid) #.to(device) bestmodel = RBM(nvis, args) bestinfnet = SeqInfNet(nvis, args) if args.penfunc == "l2": penfunc = lambda x, y: ((x - y) * (x - y)).sum(-1) elif args.penfunc == "l1": penfunc = lambda x, y: (x - y).abs().sum(-1) elif args.penfunc == "kl1": penfunc = lambda x, y: batch_kl(x, y) elif args.penfunc == "kl2": penfunc = lambda x, y: batch_kl(y, x) else: penfunc = None best_loss, prev_loss = float("inf"), float("inf") lrdecay, pendecay = False, False if args.optalg == "sgd": popt1 = torch.optim.SGD(model.parameters(), lr=args.lr) popt2 = torch.optim.SGD(infnet.parameters(), lr=args.ilr) else: popt1 = torch.optim.Adam(model.parameters(), lr=args.lr) popt2 = torch.optim.Adam(infnet.parameters(), lr=args.ilr) if args.check_corr: from utils import corr nedges = nvis * model.nhid # V*H edges npenterms = 2 * nvis * model.nhid V, H = nvis, model.nhid with torch.no_grad(): ed_lpots = model.get_edge_scores().unsqueeze( 0) # 1 x nvis x nhid x 4 un_lpots = model.get_unary_scores().unsqueeze( 0) # 1 x (nvis + nhid) x 2 exed_lpots = ed_lpots.view(nedges, 1, 2, 2) # V*H x 1 x 2 x 2 exun_lpots = un_lpots.transpose(0, 1) # V+H x 1 x 2 nodebs, facbs, _, _, _ = dolbp(exed_lpots, edges, x=None, emlps=exun_lpots, niter=args.lbp_iter, renorm=True, tol=args.lbp_tol, randomize=args.randomize_lbp) # reshape log unary marginals: V+H x 1 x 2 -> 1 x V+H x 2 tau_u = torch.stack([nodebs[t] for t in range(nvis + model.nhid) ]).transpose(0, 1) # reshape log fac marginals: nedges x 1 x 2 x 2 -> 1 x nedges x 2 x 2 tau_e = torch.stack([facbs[e] for e in range(nedges)]).transpose(0, 1) # exponentiate tau_u, tau_e = (tau_u.exp() + EPS), (tau_e.exp() + EPS) for i in range(args.inf_iter): with torch.no_grad( ): # these functions are used in calc'ing the loss below too pred_rho = infnet.q() # V*H x K^2 logits # should be 1 x nvis+nhid x 2 and 1 x V*H x K^2 predtau_u, predtau_e, _ = get_taus_and_pens(pred_rho, V, H, penfunc=penfunc) predtau_u, predtau_e = predtau_u.exp() + EPS, predtau_e.exp( ) + EPS # just pick one entry from each un_margs = tau_u[0][:, 1] # V+H bin_margs = tau_e[0][:, 1, 1] # nedges pred_un_margs = predtau_u[0][:, 1] # T pred_bin_margs = predtau_e[0].view(-1, 2, 2)[:, 1, 1] # nedges print( i, "unary corr: %.4f, binary corr: %.4f" % (corr( un_margs, pred_un_margs), corr(bin_margs, pred_bin_margs))) popt2.zero_grad() pred_rho = infnet.q() # V*H x K^2 logits in_loss, ipen_loss = rbm_inner_loss(pred_rho, un_lpots, ed_lpots, ne, penfunc) in_loss = in_loss + args.pen_mult / npenterms * ipen_loss print("in_loss", in_loss.item()) in_loss.backward() clip_opt_params(popt2, args.clip) popt2.step() exit() best_loss = float("inf") for ep in range(args.epochs): if args.training == "pcd": oloss, nex = train_pcd(trdat, model, popt1, device, args) print("Epoch {:3d} | train loss {:.3f}".format(ep, oloss / nex)) elif args.training == "lbp": oloss, nex = lbp_train(trdat, model, popt1, edges, ne, device, args) print("Epoch {:3d} | train out_loss {:.3f}".format( ep, oloss / nex)) else: oloss, iloss, ploss, nex = train(trdat, model, infnet, popt1, popt2, ne, penfunc, device, args) print( "Epoch {:3d} | train out_loss {:.3f} | train in_loss {:.3f} | pen {:.3f}" .format(ep, oloss / nex, iloss / nex, ploss / (nex * args.pen_mult))) with torch.no_grad(): if args.training == "bethe": voloss, vploss, vnex = validate(valdat, model, infnet, ne, penfunc, device) print("Epoch {:3d} | val out_loss {:.3f} | val pen {:.3f}". format(ep, voloss / vnex, vploss / (vnex * args.pen_mult))) elif args.training == "lbp": voloss, vnex = lbp_validate(valdat, model, edges, ne, device) print("Epoch {:3d} | val out_loss {:.3f}".format( ep, voloss / vnex)) trnpll, nexagain = moar_validate(trdat, model, device, args) print("Epoch {:3d} | train npll {:.3f}".format( ep, trnpll / nexagain)) valnpll, vnexagain = moar_validate(valdat, model, device, args, do_ais=args.do_ais) print("Epoch {:3d} | val npll {:.3f}".format( ep, valnpll / vnexagain)) voloss = valnpll if voloss < best_loss: best_loss = voloss bad_epochs = -1 print("updating best model") bestmodel.load_state_dict(model.state_dict()) bestinfnet.load_state_dict(infnet.state_dict()) if (voloss >= prev_loss or lrdecay) and args.optalg == "sgd": for group in popt1.param_groups: group['lr'] *= args.lrdecay for group in popt2.param_groups: group['lr'] *= args.lrdecay #decay = True if (voloss >= prev_loss or pendecay): args.pen_mult *= args.pendecay print("pen_mult now", args.pen_mult) pendecay = True prev_loss = voloss if args.lr < 1e-5: break print("") # if args.reset_adam: # print("resetting adam...") # popt2 = torch.optim.Adam(infnet.parameters(), lr=args.ilr) return bestmodel, bestinfnet, best_loss
): # these functions are used in calc'ing the loss below too pred_rho = infnet.q(edges, T) # nedges x K^2 logits # should be 1 x T x K and 1 x nedges x K^2 predtau_u, predtau_e, _ = get_taus_and_pens(pred_rho, nodeidxs, K, neginf, penfunc=penfunc) predtau_u, predtau_e = predtau_u.exp() + EPS, predtau_e.exp() + EPS # i guess we'll just pick one entry from each un_margs = tau_u[0][:, 0] # T bin_margs = tau_e[0][:, K - 1, K - 1] # nedges pred_un_margs = predtau_u[0][:, 0] # T pred_bin_margs = predtau_e[0].view(nedges, K, K)[:, K - 1, K - 1] # nedges print( i, "unary corr: %.4f, binary corr: %.4f" % (corr(un_margs, pred_un_margs), corr(bin_margs, pred_bin_margs))) popt.zero_grad() pred_rho = infnet.q(edges, T) # nedges x K^2 logits in_loss, ipen_loss = inner_lossz(pred_rho.view(1, -1), ed_lpots.view(1, -1), nodeidxs, K, ne, neginf, penfunc) in_loss = in_loss + args.pen_mult / npenterms * ipen_loss print("in_loss", in_loss.item()) in_loss.backward() clip_opt_params(popt, args.clip) popt.step()
def main(args, helper, cache, max_seqlen, max_verts, ntypes, trbatches, valbatches): print("main args", args) torch.manual_seed(args.seed) if torch.cuda.is_available(): if not args.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) device = torch.device("cuda" if args.cuda else "cpu") if args.infarch == "rnnnode": infctor = RNodeInfNet else: infctor = TNodeInfNet model = HybEdgeModel(ntypes, max_verts, args).to(device) if "exact" not in args.loss: infnet = infctor(ntypes, max_seqlen, args).to(device) bestmodel = HybEdgeModel(ntypes, max_verts, args) if "exact" not in args.loss: bestinfnet = infctor(ntypes, max_seqlen, args) else: bestinfnet = None if args.penfunc == "l2": penfunc = lambda x, y: ((x - y) * (x - y)).sum(-1) elif args.penfunc == "l1": penfunc = lambda x, y: (x - y).abs().sum(-1) elif args.penfunc == "js": penfunc = lambda x, y: 0.5 * (batch_kl(x, y) + batch_kl(y, x)) elif args.penfunc == "kl1": penfunc = lambda x, y: batch_kl(x, y) elif args.penfunc == "kl2": penfunc = lambda x, y: batch_kl(y, x) else: penfunc = None neginf = torch.Tensor(1, 1, 1).fill_(-1e18).to(device) best_loss, prev_loss = float("inf"), float("inf") lrdecay, pendecay = False, False if "exact" in args.loss: if args.optalg == "sgd": popt1 = torch.optim.SGD(model.parameters(), lr=args.lr) else: popt1 = torch.optim.Adam(model.parameters(), lr=args.lr) else: if args.optalg == "sgd": popt1 = torch.optim.SGD(model.parameters(), lr=args.lr) popt2 = torch.optim.SGD(infnet.parameters(), lr=args.ilr) else: popt1 = torch.optim.Adam(model.parameters(), lr=args.lr) popt2 = torch.optim.Adam(infnet.parameters(), lr=args.ilr) if args.check_corr: from utils import corr # pick a graph to check T, K = 10, args.K edges, nodeidxs, ne = get_hmm_stuff(T, args.markov_order, K) edges, ne = edges.to(device), ne.view(1, -1).to(device) nodeidxs = nodeidxs.to(device) npenterms = (nodeidxs != 2 * edges.size(0)).sum().float() nedges = edges.size(0) with torch.no_grad(): #un_lpots = model.get_obs_lps(batch) # bsz x T x K log unary potentials ed_lpots = model.get_edge_scores(edges, T) # nedges x K*K log potentials exed_lpots = ed_lpots.view(nedges, 1, K, K) # get approximate unclamped marginals nodebs, facbs, _, _, _ = dolbp(exed_lpots, edges, niter=args.lbp_iter, renorm=True, randomize=args.randomize_lbp, tol=args.lbp_tol) tau_u = torch.stack([nodebs[t] for t in range(T)]).transpose(0, 1) # 1 x T x K tau_e = torch.stack([facbs[e] for e in range(nedges) ]).transpose(0, 1) # 1 x nedge x K x K tau_u, tau_e = (tau_u.exp() + EPS), (tau_e.exp() + EPS) for i in range(args.z_iter): with torch.no_grad( ): # these functions are used in calc'ing the loss below too pred_rho = infnet.q(edges, T) # nedges x K^2 logits # should be 1 x T x K and 1 x nedges x K^2 predtau_u, predtau_e, _ = get_taus_and_pens(pred_rho, nodeidxs, K, neginf, penfunc=penfunc) predtau_u, predtau_e = predtau_u.exp() + EPS, predtau_e.exp( ) + EPS # i guess we'll just pick one entry from each un_margs = tau_u[0][:, 0] # T bin_margs = tau_e[0][:, K - 1, K - 1] # nedges pred_un_margs = predtau_u[0][:, 0] # T pred_bin_margs = predtau_e[0].view(nedges, K, K)[:, K - 1, K - 1] # nedges print( i, "unary corr: %.4f, binary corr: %.4f" % (corr( un_margs, pred_un_margs), corr(bin_margs, pred_bin_margs))) popt2.zero_grad() pred_rho = infnet.q(edges, T) # nedges x K^2 logits in_loss, ipen_loss = inner_lossz(pred_rho.view(1, -1), ed_lpots.view(1, -1), nodeidxs, K, ne, neginf, penfunc) in_loss = in_loss + args.pen_mult / npenterms * ipen_loss print("in_loss", in_loss.item()) in_loss.backward() clip_opt_params(popt2, args.clip) popt2.step() exit() bad_epochs = -1 for ep in range(args.epochs): if args.loss == "exact": ll, ntokes = exact_train(trbatches, model, popt1, helper, device, args) print("Epoch {:3d} | train tru-ppl {:8.3f}".format( ep, math.exp(-ll / ntokes))) with torch.no_grad(): vll, vntokes = exact_validate(valbatches, model, helper, device) print("Epoch {:3d} | val tru-ppl {:8.3f}".format( ep, math.exp(-vll / vntokes))) # if ep == 4 and math.exp(-vll/vntokes) >= 280: # break voloss = -vll elif args.loss == "lbp": oloss, nex = lbp_train(trbatches, model, popt1, helper, device, args) print("Epoch {:3d} | train out_loss {:8.3f}".format( ep, oloss / nex)) with torch.no_grad(): voloss, vnex = lbp_validate(valbatches, model, helper, device) print("Epoch {:3d} | val out_loss {:8.3f} ".format( ep, voloss / vnex)) else: # infnet oloss, iloss, ploss, nex = train_unsup_am(trbatches, model, infnet, popt1, popt2, cache, penfunc, neginf, device, args) print("Epoch {:3d} | train out_loss {:.3f} | train in_loss {:.3f}". format(ep, oloss / nex, iloss / nex)) with torch.no_grad(): voloss, vploss, vnex = validate_unsup_am( valbatches, model, infnet, cache, penfunc, neginf, device, args) print( "Epoch {:3d} | val out_loss {:.3f} | val barr_loss {:.3f}". format(ep, voloss / vnex, vploss / vnex)) if args.loss != "exact": with torch.no_grad(): # trull, ntokes = exact_validate(trbatches, model, helper, device) # print("Epoch {:3d} | train tru-ppl {:.3f}".format( # ep, math.exp(-trull/ntokes))) vll, vntokes = exact_validate(valbatches, model, helper, device) print("Epoch {:3d} | val tru-ppl {:.3f}".format( ep, math.exp(-vll / vntokes))) voloss = -vll # trppl = math.exp(-trull/ntokes) # if (ep == 0 and trppl > 3000) or (ep > 0 and trppl > 1000): # break if voloss < best_loss: best_loss = voloss bad_epochs = -1 print("updating best model") bestmodel.load_state_dict(model.state_dict()) if bestinfnet is not None: bestinfnet.load_state_dict(infnet.state_dict()) if len(args.save) > 0 and not args.grid: print("saving model to", args.save) torch.save( { "opt": args, "mod_sd": bestmodel.state_dict(), "inf_sd": bestinfnet.state_dict() if bestinfnet is not None else None, "bestloss": bestloss }, args.save) if (voloss >= prev_loss or lrdecay) and args.optalg == "sgd": for group in popt1.param_groups: group['lr'] *= args.decay for group in popt2.param_groups: group['lr'] *= args.decay #decay = True if (voloss >= prev_loss or pendecay): args.pen_mult *= args.pendecay print("pen_mult now", args.pen_mult) pendecay = True prev_loss = voloss if ep >= 2 and math.exp(best_loss / vntokes) > 650: break print("") bad_epochs += 1 if bad_epochs >= 5: break if args.reset_adam: #bad_epochs == 1: print("resetting adam...") for group in popt2.param_groups: group['lr'] *= args.decay # not really decay # if args.reset_adam and ep == 1: #bad_epochs == 1: # print("resetting adam...") # popt2 = torch.optim.Adam(infnet.parameters(), lr=args.ilr) return bestmodel, bestinfnet, best_loss
def lbp_train(corpus, model, popt, helper, device, args): model.train() K, M = args.K, helper.markov_order total_loss, nexamples = 0.0, 0 niter, nxiter = 0, 0 perm = torch.randperm(len(corpus)) for i, idx in enumerate(perm): popt.zero_grad() batch = corpus[idx.item()].to(device) T, bsz = batch.size() # if T <= 1 or (M == 3 and T <= 3): # annoying # continue if T <= 1: continue if T not in cache: edges, nodeidxs, ne = get_hmm_stuff(T, M, K) cache[T] = (edges, nodeidxs, ne) edges, nodeidxs, ne = cache[T] nedges = edges.size(0) edges, ne = edges.to(device), ne.view(1, -1).to(device) un_lpots = model.get_obs_lps(batch) # bsz x T x K log unary potentials ed_lpots = model.get_edge_scores(edges, T) # nedges x K*K log potentials with torch.no_grad(): exed_lpots = ed_lpots.view(nedges, 1, K, K) # get approximate unclamped marginals nodebs, facbs, ii, _, _ = dolbp(exed_lpots, edges, niter=args.lbp_iter, renorm=True, randomize=args.randomize_lbp, tol=args.lbp_tol) xnodebs, xfacbs, iix, _, _ = dolbp(exed_lpots.expand( nedges, bsz, K, K), edges, x=batch, emlps=un_lpots.transpose(0, 1), niter=args.lbp_iter, renorm=True, randomize=args.randomize_lbp, tol=args.lbp_tol) niter += ii nxiter += iix # reshape log unary marginals: T x bsz x K -> bsz x T x K tau_u = torch.stack([nodebs[t] for t in range(T)]).transpose(0, 1) taux_u = torch.stack([xnodebs[t] for t in range(T)]).transpose(0, 1) # reshape log fac marginals: nedges x bsz x K x K -> bsz x nedges x K x K tau_e = torch.stack([facbs[e] for e in range(nedges)]).transpose(0, 1) taux_e = torch.stack([xfacbs[e] for e in range(nedges)]).transpose(0, 1) # exponentiate tau_u, tau_e = (tau_u.exp() + EPS).view(1, -1), (tau_e.exp() + EPS).view(1, -1) taux_u, taux_e = (taux_u.exp() + EPS).view( bsz, -1), (taux_e.exp() + EPS).view(bsz, -1) fx, _, _, _ = bethe_fex(taux_u, taux_e, un_lpots.view(bsz, -1), ed_lpots.view(1, -1).expand(bsz, -1), ne.expand(bsz, -1)) fz, _, _, _ = bethe_fez(tau_u, tau_e, ed_lpots.view(1, -1), ne) loss = fx - fz * bsz total_loss += loss.item() loss.div(bsz).backward() clip_opt_params(popt, args.clip) popt.step() nexamples += bsz if (i + 1) % args.log_interval == 0: print( "{:5d}/{:5d} | its {:3.2f}/{:3.2f} | out_loss {:8.3f}".format( i + 1, perm.size(0), niter / (i + 1), nxiter / (i + 1), total_loss / nexamples)) return total_loss, nexamples
def train_unsup_am(corpus, model, infnet, moptim, ioptim, cache, penfunc, neginf, device, args): """ opt is just gonna be for everyone """ model.train() infnet.train() K, M = args.K, args.markov_order total_out_loss, total_in_loss, nexamples = 0.0, 0.0, 0 total_pen_loss = 0.0 perm = torch.randperm(len(corpus)) for i, idx in enumerate(perm): batch = corpus[idx.item()].to(device) T, bsz = batch.size() if T <= 1: # annoying continue if T not in cache: edges, nodeidxs, ne = get_hmm_stuff(T, M, K) cache[T] = (edges, nodeidxs, ne) edges, nodeidxs, ne = cache[T] edges = edges.to(device) # symbolic edge representation ne, nodeidxs = ne.view(1, -1).to(device), nodeidxs.to( device) # 1 x T*K, # T x maxne npenterms = (nodeidxs != 2 * edges.size(0)).sum().float() # maximize wrt rho with torch.no_grad(): ed_lpots = model.get_edge_scores(edges, T) # nedges x K*K log potentials # if args.reset_adam: # ioptim = torch.optim.Adam(infnet.parameters(), lr=args.ilr) for _ in range(args.z_iter): ioptim.zero_grad() pred_rho = infnet.q(edges, T) # nedges x K^2 logits in_loss, ipen_loss = inner_lossz(pred_rho.view(1, -1), ed_lpots.view(1, -1), nodeidxs, K, ne, neginf, penfunc) total_in_loss += in_loss.item() * bsz total_pen_loss += args.pen_mult / npenterms * ipen_loss.item( ) * bsz in_loss = in_loss + args.pen_mult / npenterms * ipen_loss in_loss.backward() clip_opt_params(ioptim, args.clip) ioptim.step() pred_rho = pred_rho.detach() if args.loss == "alt3": # min wrt rho_x with torch.no_grad(): un_lpots = model.get_obs_lps( batch) # bsz x T x K log unary potentials for _ in range(args.zx_iter): ioptim.zero_grad() pred_rho_x = infnet.qx(batch, edges, T) out_loss1, open_loss1 = inner_lossx( pred_rho_x, un_lpots.view(bsz, -1), ed_lpots.view(1, -1).expand(bsz, -1), nodeidxs, K, ne.expand(bsz, -1), neginf, penfunc) out_loss1 = out_loss1 + args.pen_mult / npenterms * open_loss1 total_pen_loss += args.pen_mult / npenterms * open_loss1.item() out_loss1.div(bsz).backward() clip_opt_params(ioptim, args.clip) ioptim.step() pred_rho_x = pred_rho_x.detach() # min wrt params moptim.zero_grad() # even tho these don't change we needa do it again un_lpots = model.get_obs_lps(batch) # bsz x T x K log unary potentials ed_lpots = model.get_edge_scores(edges, T) # nedges x K*K log potentials if args.loss != "alt3": # jointly minimizing over rho_x pred_rho_x = infnet.qx(batch, edges, T) openfunc = penfunc else: openfunc = None out_loss, open_loss = outer_loss(pred_rho_x, pred_rho, un_lpots.view(bsz, -1), ed_lpots.view(1, -1), nodeidxs, K, ne, neginf, penfunc=openfunc) total_out_loss += out_loss.item() if args.loss != "alt3": total_pen_loss += args.pen_mult / npenterms * open_loss out_loss = out_loss + args.pen_mult / npenterms * open_loss out_loss.div(bsz).backward() clip_opt_params(moptim, args.clip) moptim.step() nexamples += bsz if (i + 1) % args.log_interval == 0: print( "{:5d}/{:5d} | out_loss {:8.5f} | in_loss {:8.5f} | pen_loss {:8.6f}" .format(i + 1, perm.size(0), total_out_loss / nexamples, total_in_loss / nexamples, total_pen_loss / (nexamples * args.pen_mult))) return total_out_loss, total_in_loss, total_pen_loss, nexamples