示例#1
0
def eval_loop(
    args,
    V,
    iter,
    model,
):
    total_ll = 0
    total_elbo = 0
    n = 0
    lpz, last_states = None, None
    with th.no_grad():
        for i, batch in enumerate(iter):
            model.train(False)
            if hasattr(model, "noise_scale"):
                model.noise_scale = 0
            mask, lengths, n_tokens = get_mask_lengths(batch.text, V)
            if args.iterator != "bptt":
                lpz, last_states = None, None
            losses, lpz, _ = model.score(
                batch.text,
                lpz=lpz,
                last_states=last_states,
                mask=mask,
                lengths=lengths,
            )
            total_ll += losses.evidence.detach()
            if losses.elbo is not None:
                total_elbo += losses.elbo.detach()
            n += n_tokens
    return Pack(evidence=total_ll, elbo=total_elbo), n
示例#2
0
def cached_eval_loop(
    args,
    V,
    iter,
    model,
):
    total_ll = 0
    total_elbo = 0
    n = 0
    with th.no_grad():
        model.train(False)
        lpz = None
        start, transition, emission = model.compute_parameters(
            model.word2state)
        word2state = model.word2state
        for i, batch in enumerate(iter):
            if hasattr(model, "noise_scale"):
                model.noise_scale = 0

            text = batch.text

            mask, lengths, n_tokens = get_mask_lengths(text, V)
            N, T = text.shape

            if lpz is not None and args.iterator == "bptt":
                start = (lpz[:, :, None] +
                         transition[last_states, :]).logsumexp(1)

            log_potentials = (model.clamp(
                text, start, transition, emission,
                word2state) if model.eval_shorter else model.clamp2(
                    text, start, transition, emission, word2state))
            losses, lpz = model.compute_loss(log_potentials, mask, lengths)

            if word2state is not None:
                idx = th.arange(N, device=model.device)
                last_words = text[idx, lengths - 1]
                last_states = model.word2state[last_words]

            total_ll += losses.evidence.detach()
            if losses.elbo is not None:
                total_elbo += losses.elbo.detach()
            n += n_tokens
    return Pack(evidence=total_ll, elbo=total_elbo), n
示例#3
0
def train_loop(
    args,
    V,
    iter,
    model,
    parameters,
    optimizer,
    scheduler,
    valid_iter=None,
    verbose=False,
):
    global WANDB_STEP

    noise_scales = np.linspace(1, 0, args.noise_anneal_steps)
    total_ll = 0
    total_elbo = 0
    n = 0
    # check is performed at end of epoch outside loop as well
    checkpoint = len(iter) // (args.num_checks - 1)
    with th.enable_grad():
        lpz = None
        last_states = None
        for i, batch in enumerate(iter):
            model.train(True)
            WANDB_STEP += 1
            optimizer.zero_grad()

            text = batch.textp1 if "lstm" in args.model else batch.text
            if args.iterator == "bucket":
                lpz = None
                last_states = None

            mask, lengths, n_tokens = get_mask_lengths(text, V)
            if model.timing:
                start_forward = timep.time()

            # check if iterator == bptt
            losses, lpz, last_states = model.score(text,
                                                   lpz=lpz,
                                                   last_states=last_states,
                                                   mask=mask,
                                                   lengths=lengths)

            if model.timing:
                print(f"forward time: {timep.time() - start_forward}")
            total_ll += losses.evidence.detach()
            if losses.elbo is not None:
                total_elbo += losses.elbo.detach()
            n += n_tokens

            loss = -losses.loss / n_tokens
            if model.timing:
                start_backward = timep.time()
            loss.backward()
            if model.timing:
                print(f"backward time: {timep.time() - start_backward}")
            clip_grad_norm_(parameters, args.clip)
            if args.schedule not in valid_schedules:
                # sched before opt since we want step = 1?
                # this is how huggingface does it
                scheduler.step()
            optimizer.step()
            #import pdb; pdb.set_trace()
            #wandb.log({
            #"running_training_loss": total_ll / n,
            #"running_training_ppl": math.exp(min(-total_ll / n, 700)),
            #}, step=WANDB_STEP)

            if verbose and i % args.report_every == args.report_every - 1:
                report(
                    Pack(evidence=total_ll, elbo=total_elbo),
                    n,
                    f"Train batch {i}",
                )

            if valid_iter is not None and i % checkpoint == checkpoint - 1:
                v_start_time = time.time()
                #eval_fn = cached_eval_loop if args.model == "mshmm" else eval_loop
                #valid_losses, valid_n  = eval_loop(
                #valid_losses, valid_n  = cached_eval_loop(
                if args.model == "mshmm" or args.model == "factoredhmm":
                    if args.num_classes > 2**15:
                        eval_fn = mixed_cached_eval_loop
                    else:
                        eval_fn = cached_eval_loop
                elif args.model == "hmm":
                    eval_fn = cached_eval_loop
                else:
                    eval_fn = eval_loop
                valid_losses, valid_n = eval_fn(
                    args,
                    V,
                    valid_iter,
                    model,
                )
                report(valid_losses, valid_n, "Valid eval", v_start_time)
                #wandb.log({
                #"valid_loss": valid_losses.evidence / valid_n,
                #"valid_ppl": math.exp(-valid_losses.evidence / valid_n),
                #}, step=WANDB_STEP)

                update_best_valid(valid_losses, valid_n, model, optimizer,
                                  scheduler, args.name)

                #wandb.log({
                #"lr": optimizer.param_groups[0]["lr"],
                #}, step=WANDB_STEP)
                scheduler.step(valid_losses.evidence)

                # remove this later?
                if args.log_counts > 0 and args.keep_counts > 0:
                    # TODO: FACTOR OUT
                    counts = (model.counts /
                              model.counts.sum(0, keepdim=True))[:, 4:]
                    c, v = counts.shape
                    #cg4 = counts > 1e-4
                    #cg3 = counts > 1e-3
                    cg2 = counts > 1e-2

                    #wandb.log({
                    #"avgcounts@1e-4": cg4.sum().item() / float(v),
                    #"avgcounts@1e-3": cg3.sum().item() / float(v),
                    #"avgcounts@1e-2": cg2.sum().item() / float(v),
                    #"maxcounts@1e-4": cg4.sum(0).max().item() / float(v),
                    #"maxcounts@1e-3": cg3.sum(0).max().item() / float(v),
                    #"maxcounts@1e-2": cg2.sum(0).max().item(),
                    #"mincounts@1e-4": cg4.sum(0).min().item() / float(v),
                    #"mincounts@1e-3": cg3.sum(0).min().item() / float(v),
                    #"mincounts@1e-2": cg2.sum(0).min().item(),
                    #"maxcounts": counts.sum(0).max().item(),
                    #"mincounts": counts.sum(0).min().item(),
                    #}, step=WANDB_STEP)
                    del cg2
                    del counts

    return Pack(evidence=total_ll, elbo=total_elbo), n
示例#4
0
def mixed_cached_eval_loop(
    args,
    V,
    iter,
    model,
):
    total_ll = 0
    total_elbo = 0
    n = 0
    with th.no_grad():
        model.train(False)
        lpz = None

        start = model.start().cpu()
        emission = model.mask_emission(model.emission_logits(),
                                       model.word2state).cpu()

        # blocked transition
        num_blocks = 128
        block_size = model.C // num_blocks
        next_state_proj = (model.next_state_proj.weight if hasattr(
            model, "next_state_proj") else model.next_state_emb())
        transition = th.empty(model.C,
                              model.C,
                              device=th.device("cpu"),
                              dtype=emission.dtype)
        for s in range(0, model.C, block_size):
            states = range(s, s + block_size)
            x = model.trans_mlp(
                model.dropout(model.state_emb.weight[states] if hasattr(
                    model.state_emb, "weight") else model.state_emb(
                        th.LongTensor(states).to(model.device))))
            y = (x @ next_state_proj.t()).log_softmax(-1)
            transition[states] = y.to(transition.device)

        #start0, transition0, emission0 = model.compute_parameters(model.word2state)
        # th.allclose(transition, transition0)
        word2state = model.word2state
        for i, batch in enumerate(iter):
            if hasattr(model, "noise_scale"):
                model.noise_scale = 0

            text = batch.text

            mask, lengths, n_tokens = get_mask_lengths(text, V)
            N, T = text.shape

            if lpz is not None and args.iterator == "bptt":
                # hopefully this isn't too slow on cpu
                start = (lpz[:, :, None] +
                         transition[last_states, :]).logsumexp(1)

            log_potentials = model.clamp(text, start, transition, emission,
                                         word2state).to(model.device)

            losses, lpz = model.compute_loss(log_potentials, mask, lengths)
            lpz = lpz.cpu()

            idx = th.arange(N, device=model.device)
            last_words = text[idx, lengths - 1]
            last_states = model.word2state[last_words]

            total_ll += losses.evidence.detach()
            if losses.elbo is not None:
                total_elbo += losses.elbo.detach()
            n += n_tokens
    return Pack(evidence=total_ll, elbo=total_elbo), n