def eval_loop( args, V, iter, model, ): total_ll = 0 total_elbo = 0 n = 0 lpz, last_states = None, None with th.no_grad(): for i, batch in enumerate(iter): model.train(False) if hasattr(model, "noise_scale"): model.noise_scale = 0 mask, lengths, n_tokens = get_mask_lengths(batch.text, V) if args.iterator != "bptt": lpz, last_states = None, None losses, lpz, _ = model.score( batch.text, lpz=lpz, last_states=last_states, mask=mask, lengths=lengths, ) total_ll += losses.evidence.detach() if losses.elbo is not None: total_elbo += losses.elbo.detach() n += n_tokens return Pack(evidence=total_ll, elbo=total_elbo), n
def cached_eval_loop( args, V, iter, model, ): total_ll = 0 total_elbo = 0 n = 0 with th.no_grad(): model.train(False) lpz = None start, transition, emission = model.compute_parameters( model.word2state) word2state = model.word2state for i, batch in enumerate(iter): if hasattr(model, "noise_scale"): model.noise_scale = 0 text = batch.text mask, lengths, n_tokens = get_mask_lengths(text, V) N, T = text.shape if lpz is not None and args.iterator == "bptt": start = (lpz[:, :, None] + transition[last_states, :]).logsumexp(1) log_potentials = (model.clamp( text, start, transition, emission, word2state) if model.eval_shorter else model.clamp2( text, start, transition, emission, word2state)) losses, lpz = model.compute_loss(log_potentials, mask, lengths) if word2state is not None: idx = th.arange(N, device=model.device) last_words = text[idx, lengths - 1] last_states = model.word2state[last_words] total_ll += losses.evidence.detach() if losses.elbo is not None: total_elbo += losses.elbo.detach() n += n_tokens return Pack(evidence=total_ll, elbo=total_elbo), n
def train_loop( args, V, iter, model, parameters, optimizer, scheduler, valid_iter=None, verbose=False, ): global WANDB_STEP noise_scales = np.linspace(1, 0, args.noise_anneal_steps) total_ll = 0 total_elbo = 0 n = 0 # check is performed at end of epoch outside loop as well checkpoint = len(iter) // (args.num_checks - 1) with th.enable_grad(): lpz = None last_states = None for i, batch in enumerate(iter): model.train(True) WANDB_STEP += 1 optimizer.zero_grad() text = batch.textp1 if "lstm" in args.model else batch.text if args.iterator == "bucket": lpz = None last_states = None mask, lengths, n_tokens = get_mask_lengths(text, V) if model.timing: start_forward = timep.time() # check if iterator == bptt losses, lpz, last_states = model.score(text, lpz=lpz, last_states=last_states, mask=mask, lengths=lengths) if model.timing: print(f"forward time: {timep.time() - start_forward}") total_ll += losses.evidence.detach() if losses.elbo is not None: total_elbo += losses.elbo.detach() n += n_tokens loss = -losses.loss / n_tokens if model.timing: start_backward = timep.time() loss.backward() if model.timing: print(f"backward time: {timep.time() - start_backward}") clip_grad_norm_(parameters, args.clip) if args.schedule not in valid_schedules: # sched before opt since we want step = 1? # this is how huggingface does it scheduler.step() optimizer.step() #import pdb; pdb.set_trace() #wandb.log({ #"running_training_loss": total_ll / n, #"running_training_ppl": math.exp(min(-total_ll / n, 700)), #}, step=WANDB_STEP) if verbose and i % args.report_every == args.report_every - 1: report( Pack(evidence=total_ll, elbo=total_elbo), n, f"Train batch {i}", ) if valid_iter is not None and i % checkpoint == checkpoint - 1: v_start_time = time.time() #eval_fn = cached_eval_loop if args.model == "mshmm" else eval_loop #valid_losses, valid_n = eval_loop( #valid_losses, valid_n = cached_eval_loop( if args.model == "mshmm" or args.model == "factoredhmm": if args.num_classes > 2**15: eval_fn = mixed_cached_eval_loop else: eval_fn = cached_eval_loop elif args.model == "hmm": eval_fn = cached_eval_loop else: eval_fn = eval_loop valid_losses, valid_n = eval_fn( args, V, valid_iter, model, ) report(valid_losses, valid_n, "Valid eval", v_start_time) #wandb.log({ #"valid_loss": valid_losses.evidence / valid_n, #"valid_ppl": math.exp(-valid_losses.evidence / valid_n), #}, step=WANDB_STEP) update_best_valid(valid_losses, valid_n, model, optimizer, scheduler, args.name) #wandb.log({ #"lr": optimizer.param_groups[0]["lr"], #}, step=WANDB_STEP) scheduler.step(valid_losses.evidence) # remove this later? if args.log_counts > 0 and args.keep_counts > 0: # TODO: FACTOR OUT counts = (model.counts / model.counts.sum(0, keepdim=True))[:, 4:] c, v = counts.shape #cg4 = counts > 1e-4 #cg3 = counts > 1e-3 cg2 = counts > 1e-2 #wandb.log({ #"avgcounts@1e-4": cg4.sum().item() / float(v), #"avgcounts@1e-3": cg3.sum().item() / float(v), #"avgcounts@1e-2": cg2.sum().item() / float(v), #"maxcounts@1e-4": cg4.sum(0).max().item() / float(v), #"maxcounts@1e-3": cg3.sum(0).max().item() / float(v), #"maxcounts@1e-2": cg2.sum(0).max().item(), #"mincounts@1e-4": cg4.sum(0).min().item() / float(v), #"mincounts@1e-3": cg3.sum(0).min().item() / float(v), #"mincounts@1e-2": cg2.sum(0).min().item(), #"maxcounts": counts.sum(0).max().item(), #"mincounts": counts.sum(0).min().item(), #}, step=WANDB_STEP) del cg2 del counts return Pack(evidence=total_ll, elbo=total_elbo), n
def mixed_cached_eval_loop( args, V, iter, model, ): total_ll = 0 total_elbo = 0 n = 0 with th.no_grad(): model.train(False) lpz = None start = model.start().cpu() emission = model.mask_emission(model.emission_logits(), model.word2state).cpu() # blocked transition num_blocks = 128 block_size = model.C // num_blocks next_state_proj = (model.next_state_proj.weight if hasattr( model, "next_state_proj") else model.next_state_emb()) transition = th.empty(model.C, model.C, device=th.device("cpu"), dtype=emission.dtype) for s in range(0, model.C, block_size): states = range(s, s + block_size) x = model.trans_mlp( model.dropout(model.state_emb.weight[states] if hasattr( model.state_emb, "weight") else model.state_emb( th.LongTensor(states).to(model.device)))) y = (x @ next_state_proj.t()).log_softmax(-1) transition[states] = y.to(transition.device) #start0, transition0, emission0 = model.compute_parameters(model.word2state) # th.allclose(transition, transition0) word2state = model.word2state for i, batch in enumerate(iter): if hasattr(model, "noise_scale"): model.noise_scale = 0 text = batch.text mask, lengths, n_tokens = get_mask_lengths(text, V) N, T = text.shape if lpz is not None and args.iterator == "bptt": # hopefully this isn't too slow on cpu start = (lpz[:, :, None] + transition[last_states, :]).logsumexp(1) log_potentials = model.clamp(text, start, transition, emission, word2state).to(model.device) losses, lpz = model.compute_loss(log_potentials, mask, lengths) lpz = lpz.cpu() idx = th.arange(N, device=model.device) last_words = text[idx, lengths - 1] last_states = model.word2state[last_words] total_ll += losses.evidence.detach() if losses.elbo is not None: total_elbo += losses.elbo.detach() n += n_tokens return Pack(evidence=total_ll, elbo=total_elbo), n