Example #1
0
    def _loop(self, iter, optimizer=None, clip=0, learn=False, re=None):
        context = torch.enable_grad if learn else torch.no_grad

        cum_loss = 0
        cum_ntokens = 0
        cum_rx = 0
        cum_kl = 0
        batch_loss = 0
        batch_ntokens = 0
        states = None
        with context():
            titer = tqdm(iter) if learn else iter
            for i, batch in enumerate(titer):
                if learn:
                    optimizer.zero_grad()
                text, lens = batch.text
                x = text[:-1]
                y = text[1:]
                lens = lens - 1

                e, lene = batch.entities
                t, lent = batch.types
                v, lenv = batch.values
                #rlen, N = e.shape
                #r = torch.stack([e, t, v], dim=-1)
                r = [e, t, v]
                assert (lene == lent).all()
                lenr = lene

                # should i include <eos> in ppl?
                nwords = y.ne(1).sum()
                # assert nwords == lens.sum()
                T, N = y.shape
                #if states is None:
                states = self.init_state(N)
                logits, _ = self(x, states, lens, r, lenr)

                nll = self.loss(logits, y)
                kl = 0
                nelbo = nll + kl
                if learn:
                    nelbo.div(nwords.item()).backward()
                    if clip > 0:
                        gnorm = clip_(self.parameters(), clip)
                        #for param in self.rnn_parameters():
                        #gnorm = clip_(param, clip)
                    optimizer.step()
                cum_loss += nelbo.item()
                cum_ntokens += nwords.item()
                batch_loss += nelbo.item()
                batch_ntokens += nwords.item()
                if re is not None and i % re == -1 % re:
                    titer.set_postfix(loss=batch_loss / batch_ntokens,
                                      gnorm=gnorm)
                    batch_loss = 0
                    batch_ntokens = 0
        return cum_loss, cum_ntokens
Example #2
0
    def _loop(self,
              diter,
              optimizer=None,
              clip=0,
              learn=False,
              re=None,
              once=False):
        context = torch.enable_grad if learn else torch.no_grad

        cum_loss = 0
        cum_ntokens = 0
        cum_rx = 0
        cum_kl = 0
        batch_loss = 0
        batch_ntokens = 0
        with context():
            titer = tqdm(diter) if learn else diter
            for i, batch in enumerate(
                    titer if not once else [next(iter(diter))]):
                if learn:
                    optimizer.zero_grad()
                x, lens = batch.text

                lx, _ = batch.locations_text
                ax, _ = batch.aspects_text
                l = batch.locations
                a = batch.aspects
                y = batch.sentiments

                # keys
                k = [l, a]
                kx = [lx, ax]

                # N x y for dealing w imbalance
                logits = self(x, lens, k, kx)

                nll = self.loss(logits, y)
                nelbo = nll
                N = y.shape[0]
                if learn:
                    nelbo.div(N).backward()
                    if clip > 0:
                        gnorm = clip_(self.parameters(), clip)
                    optimizer.step()
                cum_loss += nelbo.item()
                cum_ntokens += N
                batch_loss += nelbo.item()
                batch_ntokens += N
                if re is not None and i % re == -1 % re:
                    titer.set_postfix(loss=batch_loss / batch_ntokens,
                                      gnorm=gnorm)
                    batch_loss = 0
                    batch_ntokens = 0
        return cum_loss, cum_ntokens
Example #3
0
File: lm.py Project: justinchiu/3
    def _loop(self, iter, learn, args):
        context = torch.enable_grad if learn else torch.no_grad

        loss = 0
        ntokens = 0
        rloss = 0
        rntokens = 0
        hidden_states = None
        with context():
            t = tqdm(iter)
            for i, batch in enumerate(t):
                if learn:
                    args.optimizer.zero_grad()
                x = batch.text
                y = batch.target
                nwords = y.ne(1).sum()
                T, N = y.shape
                if hidden_states is None:
                    hidden_states = self.init_hidden(N)
                logits, hidden_states = self(x, hidden_states)
                if learn:
                    hidden_states = (
                        [
                            tuple(x.detach() for x in tup)
                            for tup in hidden_states[0]
                        ],
                        tuple(x.detach() for x in hidden_states[1]),
                        hidden_states[2].detach(),
                    )
                logprobs = F.log_softmax(logits, dim=-1)
                logp = logprobs.view(T * N, -1).gather(-1, y.view(T * N, 1))
                kl = 0
                nll = -logp[y.view(-1, 1) != 1].sum()
                nelbo = nll + kl
                if learn:
                    nelbo.div(nwords.item()).backward()
                    if args.clip > 0:
                        gnorm = clip_(self.parameters(), args.clip)
                        #for param in self.rnn_parameters():
                        #gnorm = clip_(param, args.clip)
                    args.optimizer.step()
                loss += nelbo.item()
                ntokens += nwords.item()
                rloss += nelbo.item()
                rntokens += nwords.item()
                if args is not None and i % args.report_interval == -1 % args.report_interval:
                    t.set_postfix(loss=rloss / rntokens, gnorm=gnorm)
                    rloss = 0
                    rntokens = 0
        return loss, ntokens
Example #4
0
    def _loop_ie(self, iter, optimizer=None, clip=0, learn=False, re=None):
        context = torch.enable_grad if learn else torch.no_grad

        cum_loss = 0
        cum_ntokens = 0
        cum_rx = 0
        cum_kl = 0
        batch_loss = 0
        batch_ntokens = 0
        states = None
        with context():
            t = tqdm(iter) if learn else iter
            for i, batch in enumerate(t):
                if learn:
                    optimizer.zero_grad()
                text, lens = batch.text
                x = text[:-1]
                y = text[1:]
                lens = lens - 1
                # should i include <eos> in ppl?
                nwords = y.ne(1).sum()
                # assert nwords == lens.sum()
                T, N = y.shape
                #if states is None:
                states = self.init_state(N)
                logits, _ = self(x, states, lens)
                #logits, states = self(x, states, lens)
                logprobs = F.log_softmax(logits, dim=-1)
                logp = logprobs.view(T * N, -1).gather(-1, y.view(T * N, 1))
                kl = 0
                nll = -logp[y.view(-1, 1) != 1].sum()
                nelbo = nll + kl
                if learn:
                    nelbo.div(nwords.item()).backward()
                    if clip > 0:
                        gnorm = clip_(self.parameters(), clip)
                        #for param in self.rnn_parameters():
                        #gnorm = clip_(param, clip)
                    optimizer.step()
                cum_loss += nelbo.item()
                cum_ntokens += nwords.item()
                batch_loss += nelbo.item()
                batch_ntokens += nwords.item()
                if re is not None and i % re == -1 % re:
                    t.set_postfix(loss=batch_loss / batch_ntokens, gnorm=gnorm)
                    batch_loss = 0
                    batch_ntokens = 0
        return cum_loss, cum_ntokens
Example #5
0
    def _loop(
        self,
        iter,
        optimizer=None,
        clip=0,
        learn=False,
        re=None,
        exact=False,
        elbo=True,
        T=64,
        E=128,
        supattn=False,
        supcopy=False,
    ):
        context = torch.enable_grad if learn else torch.no_grad

        # DBG
        self.copied = 0
        couldve_copied = 0

        cum_loss = 0
        cum_ntokens = 0
        cum_rx = 0
        cum_kl = 0
        batch_loss = 0
        batch_ntokens = 0
        batch_rx = 0
        batch_kl = 0
        states = None
        with context():
            titer = tqdm(iter) if learn else iter
            for i, batch in enumerate(titer):
                if learn:
                    optimizer.zero_grad()
                text, x_info = batch.text
                mask = x_info.mask
                lens = x_info.lengths

                L = text.shape["time"]
                x = text.narrow("time", 0, L - 1)
                y = text.narrow("time", 1, L - 1)
                #x = text[:-1]
                #y = text[1:]
                x_info.lengths.sub_(1)

                e, e_info = batch.entities
                t, t_info = batch.types
                v, v_info = batch.values
                lene = e_info.lengths
                lent = t_info.lengths
                lenv = v_info.lengths
                #rlen, N = e.shape
                #r = torch.stack([e, t, v], dim=-1)
                r = [e, t, v]
                assert (lene == lent).all()
                lenr = lene
                r_info = e_info

                # values text
                vt, vt_info = batch.values_text
                # DBG
                couldve_copied += (vt == y).sum().item()

                ue, ue_info = batch.uentities
                ut, ut_info = batch.utypes

                v2d = batch.v2d

                # should i include <eos> in ppl?
                nwords = y.ne(1).sum()
                # assert nwords == lens.sum()
                #T = y.shape["time"]
                N = y.shape["batch"]
                #if states is None:
                states = self.init_state(N)

                # should i include <eos> in ppl? no, should not.
                mask = y.ne(1)  #* y.ne(3)
                nwords = mask.sum()

                # ugh.......refactor this
                if exact:
                    rvinfo, states = self.marginal_nll(x,
                                                       states,
                                                       x_info,
                                                       r,
                                                       r_info,
                                                       vt,
                                                       y,
                                                       x_info,
                                                       T=T,
                                                       E=E,
                                                       learn=learn)
                    nll = -rvinfo.log_py[mask].sum()
                    cum_loss += nll.item()
                    batch_loss += nll.item()
                else:
                    rvinfo, _, nll, kl = self(x,
                                              states,
                                              x_info,
                                              r,
                                              r_info,
                                              vt,
                                              y,
                                              x_info,
                                              T=T,
                                              E=E,
                                              learn=learn)

                    nelbo = nll + kl
                    if learn:
                        if clip > 0:
                            gnorm = clip_(self.parameters(), clip)
                        optimizer.step()
                    cum_rx += nll.item()
                    batch_rx += nll.item()
                    cum_kl += kl.item()
                    batch_kl += kl.item()
                    cum_loss += nelbo.item()
                    batch_loss += nelbo.item()
                cum_ntokens += nwords.item()
                batch_ntokens += nwords.item()
                if re is not None and i % re == -1 % re:
                    titer.set_postfix(
                        elbo=batch_loss / batch_ntokens,
                        nll=batch_rx / batch_ntokens,
                        kl=batch_kl / batch_ntokens,
                        gnorm=gnorm,
                    )
                    batch_loss = 0
                    batch_ntokens = 0
                    batch_rx = 0
                    batch_kl = 0
        print(
            f"COPIED: {self.copied:,} vs {couldve_copied:,} / {cum_ntokens:,}")
        print(f"NLL: {cum_rx / cum_ntokens} || KL: {cum_kl / cum_ntokens}")
        return cum_loss, cum_ntokens
Example #6
0
    def _loop(
        self,
        iter,
        optimizer=None,
        clip=0,
        learn=False,
        re=None,
        exact=False,
        elbo=True,
        T=64,
        E=32,
        R=4,
        supattn=False,
        supcopy=False,
    ):
        context = torch.enable_grad if learn else torch.no_grad

        # DBG
        self.copied = 0
        couldve_copied = 0

        nthe = 0
        n4 = 0
        nfour = 0
        n6 = 0
        nsix = 0
        n8 = 0
        neight = 0
        nchi = 0
        nleb = 0

        bthe = 0
        b4 = 0
        bfour = 0
        b6 = 0
        bsix = 0
        b8 = 0
        beight = 0
        bchi = 0
        bleb = 0

        athe = 0
        a4 = 0
        afour = 0
        a6 = 0
        asix = 0
        a8 = 0
        aeight = 0
        achi = 0
        aleb = 0

        abthe = 0
        ab4 = 0
        abfour = 0
        ab6 = 0
        absix = 0
        ab8 = 0
        abeight = 0
        abchi = 0
        ableb = 0

        d1 = 0
        d2 = 0
        d4 = 0
        dn = 0

        cum_loss = 0
        cum_ntokens = 0
        cum_rx = 0
        cum_kl = 0
        batch_loss = 0
        batch_ntokens = 0
        batch_rx = 0
        batch_kl = 0
        states = None
        with context():
            titer = tqdm(iter) if learn else iter
            for i, batch in enumerate(titer):
                if learn:
                    optimizer.zero_grad()
                text, x_info = batch.text
                mask = x_info.mask
                lens = x_info.lengths

                L = text.shape["time"]
                x = text.narrow("time", 0, L - 1)
                y = text.narrow("time", 1, L - 1)
                #x = text[:-1]
                #y = text[1:]
                x_info.lengths.sub_(1)

                e, e_info = batch.entities
                t, t_info = batch.types
                v, v_info = batch.values
                lene = e_info.lengths
                lent = t_info.lengths
                lenv = v_info.lengths
                #rlen, N = e.shape
                #r = torch.stack([e, t, v], dim=-1)
                r = [e, t, v]
                assert (lene == lent).all()
                lenr = lene
                r_info = e_info

                # values text
                vt, vt_info = batch.values_text
                # DBG
                couldve_copied += (vt == y).transpose(
                    "batch", "time", "els").any(-1).sum().item()

                ue, ue_info = batch.uentities
                ut, ut_info = batch.utypes

                v2d = batch.v2d
                vt2d = batch.vt2d

                # should i include <eos> in ppl?
                nwords = y.ne(1).sum()
                # assert nwords == lens.sum()
                #T = y.shape["time"]
                N = y.shape["batch"]
                #if states is None:
                states = self.init_state(N)

                # should i include <eos> in ppl? no, should not.
                mask = y.ne(1)  #* y.ne(3)
                nwords = mask.sum()

                # ugh.......refactor this
                if exact:
                    rvinfo, states = self.marginal_nll(x,
                                                       states,
                                                       x_info,
                                                       r,
                                                       r_info,
                                                       vt,
                                                       ue,
                                                       ue_info,
                                                       ut,
                                                       ut_info,
                                                       v2d,
                                                       vt2d,
                                                       y,
                                                       x_info,
                                                       T=T,
                                                       E=E,
                                                       R=R,
                                                       learn=learn)
                    nll = -rvinfo.log_py[mask].sum()
                    cum_loss += nll.item()
                    batch_loss += nll.item()
                else:
                    rvinfo, _, nll, kl = self(
                        x,
                        states,
                        x_info,
                        r,
                        r_info,
                        vt,
                        ue,
                        ue_info,
                        ut,
                        ut_info,
                        v2d,
                        vt2d,
                        y,
                        x_info,
                        T=T,
                        E=E,
                        R=R,
                        learn=learn,
                        supattn=supattn,  # hacky
                        idxs=batch.idxs,
                    )
                    if rvinfo.log_qc_y is None:
                        self.copied += (rvinfo.log_pc.exp().get("copy", 1) >
                                        0.5).sum().item()
                    else:
                        self.copied += (rvinfo.log_qc_y.exp().get("copy", 1) >
                                        0.5).sum().item()

                    nelbo = nll + kl

                    # check p(c|y) for subset of words
                    # likelihood ratio of content vs noncontent
                    pyn = rvinfo.log_py_c0
                    pyc = rvinfo.log_py_ac1

                    ythe = y.eq(self.Vx.stoi["the"])
                    # find "4" and "four"
                    y4 = y.eq(self.Vx.stoi["4"])
                    yfour = y.eq(self.Vx.stoi["four"])
                    # find "6" and "six"
                    y6 = y.eq(self.Vx.stoi["6"])
                    ysix = y.eq(self.Vx.stoi["six"])
                    # find "8" and "eight"
                    y8 = y.eq(self.Vx.stoi["8"])
                    yeight = y.eq(self.Vx.stoi["eight"])
                    ychi = y.eq(self.Vx.stoi["chicago"])
                    yleb = y.eq(self.Vx.stoi["lebron"])

                    nthe += ythe.sum()
                    n4 += y4.sum()
                    nfour += yfour.sum()
                    n6 += y6.sum()
                    nsix += ysix.sum()
                    n8 += y8.sum()
                    neight += yeight.sum()
                    nchi += ychi.sum()
                    nleb += yleb.sum()

                    vt2d_f = vt2d.stack(("e", "t"), "r")
                    alignments = vt2d_f.gather("r", rvinfo.a_s, "k")

                    a4 += alignments.eq(
                        self.Vx.stoi["4"]).values.any(0)[y4.values].sum()
                    afour += alignments.eq(
                        self.Vx.stoi["4"]).values.any(0)[yfour.values].sum()
                    a6 += alignments.eq(
                        self.Vx.stoi["6"]).values.any(0)[y6.values].sum()
                    asix += alignments.eq(
                        self.Vx.stoi["6"]).values.any(0)[ysix.values].sum()
                    a8 += alignments.eq(
                        self.Vx.stoi["8"]).values.any(0)[y8.values].sum()
                    aeight += alignments.eq(
                        self.Vx.stoi["8"]).values.any(0)[yeight.values].sum()
                    achi += alignments.eq(self.Vx.stoi["chicago"]).values.any(
                        0)[ychi.values].sum()
                    aleb += alignments.eq(self.Vx.stoi["lebron"]).values.any(
                        0)[yleb.values].sum()

                    if pyn is not None:
                        better = (pyn < pyc).narrow("k", self.Kb, self.K)

                        betterthe = (pyn.get("k", 0) < pyc.get("k", 0))[ythe]
                        better4 = (pyn.get("k", 0) < pyc.get("k", 0))[y4]
                        betterfour = (pyn.get("k", 0) < pyc.get("k", 0))[yfour]
                        better6 = (pyn.get("k", 0) < pyc.get("k", 0))[y6]
                        bettersix = (pyn.get("k", 0) < pyc.get("k", 0))[ysix]
                        better8 = (pyn.get("k", 0) < pyc.get("k", 0))[y8]
                        bettereight = (pyn.get("k", 0) < pyc.get("k",
                                                                 0))[yeight]
                        betterchi = (pyn.get("k", 0) < pyc.get("k", 0))[ychi]
                        betterleb = (pyn.get("k", 0) < pyc.get("k", 0))[yleb]

                        bthe += betterthe.sum()
                        b4 += better4.sum()
                        bfour += betterfour.sum()
                        b6 += better6.sum()
                        bsix += bettersix.sum()
                        b8 += better8.sum()
                        beight += bettereight.sum()
                        bchi += betterchi.sum()
                        bleb += betterleb.sum()

                        # if any alignments have the right word AND are better
                        ab4 += (alignments.eq(self.Vx.stoi["4"]) *
                                better).values.any(0)[y4.values].sum()
                        abfour += (alignments.eq(self.Vx.stoi["4"]) *
                                   better).values.any(0)[yfour.values].sum()
                        ab6 += (alignments.eq(self.Vx.stoi["6"]) *
                                better).values.any(0)[y6.values].sum()
                        absix += (alignments.eq(self.Vx.stoi["6"]) *
                                  better).values.any(0)[ysix.values].sum()
                        ab8 += (alignments.eq(self.Vx.stoi["8"]) *
                                better).values.any(0)[y8.values].sum()
                        abeight += (alignments.eq(self.Vx.stoi["8"]) *
                                    better).values.any(0)[yeight.values].sum()
                        abchi += (alignments.eq(self.Vx.stoi["chicago"]) *
                                  better).values.any(0)[ychi.values].sum()
                        ableb += (alignments.eq(self.Vx.stoi["lebron"]) *
                                  better).values.any(0)[yleb.values].sum()
                    else:
                        qc = rvinfo.log_qc_y.softmax("copy")
                        better = qc.get("copy", 1) > 0.5
                        bthe += better[ythe].sum()
                        b4 += better[y4].sum()
                        bfour += better[yfour].sum()
                        b6 += better[y6].sum()
                        bsix += better[ysix].sum()
                        b8 += better[y8].sum()
                        beight += better[yeight].sum()
                        bchi += better[ychi].sum()
                        bleb += better[yleb].sum()

                        # if any alignments have the right word AND are better
                        ab4 += (alignments.eq(self.Vx.stoi["4"]) *
                                better).values.any(0)[y4.values].sum()
                        abfour += (alignments.eq(self.Vx.stoi["4"]) *
                                   better).values.any(0)[yfour.values].sum()
                        ab6 += (alignments.eq(self.Vx.stoi["6"]) *
                                better).values.any(0)[y6.values].sum()
                        absix += (alignments.eq(self.Vx.stoi["6"]) *
                                  better).values.any(0)[ysix.values].sum()
                        ab8 += (alignments.eq(self.Vx.stoi["8"]) *
                                better).values.any(0)[y8.values].sum()
                        abeight += (alignments.eq(self.Vx.stoi["8"]) *
                                    better).values.any(0)[yeight.values].sum()
                        abchi += (alignments.eq(self.Vx.stoi["chicago"]) *
                                  better).values.any(0)[ychi.values].sum()
                        ableb += (alignments.eq(self.Vx.stoi["lebron"]) *
                                  better).values.any(0)[yleb.values].sum()

                    maxd, _ = self.delta.max("k")
                    d1 += (maxd > 0.1).sum()
                    d2 += (maxd > 0.2).sum()
                    d4 += (maxd > 0.4).sum()
                    dn += maxd.values.nelement()
                    """
                    print(f"Proportion greater than .1: {float(d1.item()) / float(dn)}")
                    print(f"Proportion greater than .2: {float(d2.item()) / float(dn)}")
                    print(f"Proportion greater than .4: {float(d4.item()) / float(dn)}")
                    """

                    if learn:
                        if clip > 0:
                            gnorm = clip_(self.parameters(), clip)
                        optimizer.step()
                    cum_rx += nll.item()
                    batch_rx += nll.item()
                    cum_kl += kl.item()
                    batch_kl += kl.item()
                    cum_loss += nelbo.item() if not hasattr(
                        self, "nokl") else nll.item()
                    batch_loss += nelbo.item()
                cum_ntokens += nwords.item()
                batch_ntokens += nwords.item()
                if re is not None and i % re == -1 % re:
                    titer.set_postfix(
                        elbo=batch_loss / batch_ntokens,
                        nll=batch_rx / batch_ntokens,
                        kl=batch_kl / batch_ntokens,
                        gnorm=gnorm,
                        #b4 = b4.item() / n4.item(),
                        #bfour = bfour.item() / nfour.item(),
                        #b6 = b6.item() / n6.item(),
                        #bsix = bsix.item() / nsix.item(),
                        #b8 = b8.item() / n8.item(),
                        #beight = beight.item() / neight.item(),
                    )
                    batch_loss = 0
                    batch_ntokens = 0
                    batch_rx = 0
                    batch_kl = 0
                if re is not None and i % 100 == -1 % 100:
                    #if re is not None and i % 10 == -1 % 10:
                    print("better prob")
                    print(f"the: {bthe.item()} / {nthe.item()}")
                    print(f"4: {b4.item()} / {n4.item()}")
                    print(f"four: {bfour.item()} / {nfour.item()}")
                    print(f"6: {b6.item()} / {n6.item()}")
                    print(f"six: {bsix.item()} / {nsix.item()}")
                    print(f"8: {b8.item()} / {n8.item()}")
                    print(f"eight: {beight.item()} / {neight.item()}")
                    print(f"chicago: {bchi.item()} / {nchi.item()}")
                    print(f"lebron: {bleb.item()} / {nleb.item()}")
                    print("correct alignment")
                    print(f"4: {a4.item()} / {n4.item()}")
                    print(f"four: {afour.item()} / {nfour.item()}")
                    print(f"6: {a6.item()} / {n6.item()}")
                    print(f"six: {asix.item()} / {nsix.item()}")
                    print(f"8: {a8.item()} / {n8.item()}")
                    print(f"eight: {aeight.item()} / {neight.item()}")
                    print(f"chi: {achi.item()} / {nchi.item()}")
                    print(f"leb: {aleb.item()} / {nleb.item()}")
                    print("correct alignment and better prob")
                    print(f"4: {ab4.item()} / {n4.item()}")
                    print(f"four: {abfour.item()} / {nfour.item()}")
                    print(f"6: {ab6.item()} / {n6.item()}")
                    print(f"six: {absix.item()} / {nsix.item()}")
                    print(f"8: {ab8.item()} / {n8.item()}")
                    print(f"eight: {abeight.item()} / {neight.item()}")
                    print(f"chi: {abchi.item()} / {nchi.item()}")
                    print(f"leb: {ableb.item()} / {nleb.item()}")

                    print(
                        f"Proportion greater than .1: {float(d1.item()) / float(dn)}"
                    )
                    print(
                        f"Proportion greater than .2: {float(d2.item()) / float(dn)}"
                    )
                    print(
                        f"Proportion greater than .4: {float(d4.item()) / float(dn)}"
                    )

                    words = ["2", "4", "6", "8", "bulls", "lebron", "chicago"]
                    for word in words:
                        if self.v2d:
                            input = self.lutv.weight[self.Vv.stoi[word]]
                        elif self.untie:
                            input = self.lutgx.weight[self.Vx.stoi[word]]
                        else:
                            input = self.lutx.weight[self.Vx.stoi[word]]
                        weight = self.lutx.weight if not self.untie else self.lutgx.weight
                        if self.mlp:
                            probs, idx = (weight @ self.Wvy1(
                                self.Wvy0(
                                    NamedTensor(
                                        input,
                                        names=("ctxt", ),
                                    )).tanh()).tanh().values
                                          ).softmax(0).topk(5)
                        else:
                            probs, idx = (weight @ input).softmax(0).topk(5)
                        print(f"{word} probs " + " || ".join(
                            f"{self.Vx.itos[x]}: {p:.2f}"
                            for p, x in zip(probs.tolist(), idx.tolist())))

        print("better prob")
        print(f"the: {bthe.item()} / {nthe.item()}")
        print(f"4: {b4.item()} / {n4.item()}")
        print(f"four: {bfour.item()} / {nfour.item()}")
        print(f"6: {b6.item()} / {n6.item()}")
        print(f"six: {bsix.item()} / {nsix.item()}")
        print(f"8: {b8.item()} / {n8.item()}")
        print(f"eight: {beight.item()} / {neight.item()}")
        print("correct alignment")
        print(f"4: {a4.item()} / {n4.item()}")
        print(f"four: {afour.item()} / {nfour.item()}")
        print(f"6: {a6.item()} / {n6.item()}")
        print(f"six: {asix.item()} / {nsix.item()}")
        print(f"8: {a8.item()} / {n8.item()}")
        print(f"eight: {aeight.item()} / {neight.item()}")
        print("correct alignment and better prob")
        print(f"4: {ab4.item()} / {n4.item()}")
        print(f"four: {abfour.item()} / {nfour.item()}")
        print(f"6: {ab6.item()} / {n6.item()}")
        print(f"six: {absix.item()} / {nsix.item()}")
        print(f"8: {ab8.item()} / {n8.item()}")
        print(f"eight: {abeight.item()} / {neight.item()}")
        print(f"Proportion greater than .1: {float(d1.item()) / float(dn)}")
        print(f"Proportion greater than .2: {float(d2.item()) / float(dn)}")
        print(f"Proportion greater than .4: {float(d4.item()) / float(dn)}")
        #print(f"COPIED: {self.copied:,} vs {couldve_copied:,} / {cum_ntokens:,}")
        print(f"NLL: {cum_rx / cum_ntokens} || KL: {cum_kl / cum_ntokens}")
        return cum_loss, cum_ntokens
Example #7
0
    def _ie_loop(
        self,
        iter,
        optimizer=None,
        clip=0,
        learn=False,
        re=None,
        T=256,
        E=None,
        R=None,
    ):
        self.train(learn)
        context = torch.enable_grad if learn else torch.no_grad

        cum_loss = 0
        cum_ntokens = 0
        cum_rx = 0
        cum_kl = 0
        batch_loss = 0
        batch_ntokens = 0
        states = None
        cum_e_correct = 0
        cum_t_correct = 0
        batch_e_correct = 0
        batch_t_correct = 0
        cum_correct = 0
        batch_correct = 0
        with context():
            titer = tqdm(iter) if learn else iter
            for i, batch in enumerate(titer):
                if learn:
                    optimizer.zero_grad()
                text, x_info = batch.text
                mask = x_info.mask
                lens = x_info.lengths

                L = text.shape["time"]
                x = text.narrow("time", 0, L - 1)
                y = text.narrow("time", 1, L - 1)
                #x = text[:-1]
                #y = text[1:]
                x_info.lengths.sub_(1)

                e, e_info = batch.entities
                t, t_info = batch.types
                v, v_info = batch.values
                lene = e_info.lengths
                lent = t_info.lengths
                lenv = v_info.lengths
                #rlen, N = e.shape
                #r = torch.stack([e, t, v], dim=-1)
                r = [e, t, v]
                assert (lene == lent).all()
                lenr = lene
                r_info = e_info

                # values text
                vt, vt_info = batch.values_text

                ue, ue_info = batch.uentities
                ut, ut_info = batch.utypes

                v2d = batch.v2d
                vt2d = batch.vt2d

                # should i include <eos> in ppl?
                nwords = y.ne(1).sum()
                # assert nwords == lens.sum()
                N = y.shape["batch"]
                #if states is None:
                self.K = 1  # whatever
                states = self.init_state(N)
                rvinfo, _, nll, kl = self(
                    x,
                    states,
                    x_info,
                    r,
                    r_info,
                    vt,
                    ue,
                    ue_info,
                    ut,
                    ut_info,
                    v2d,
                    vt2d,
                    y=y,
                    y_info=x_info,
                    T=T,
                )

                # only for crnnlma though
                """
                cor, ecor, tcor, tot = Ie.pat1(
                    self.ea.rename("els", "e"),
                    self.ta.rename("els", "t"),
                    batch.ie_d,
                )
                """

                #ds = batch.ie_d
                ets = batch.ie_etv

                #num_cells = batch.num_cells
                log_pe = rvinfo.log_qe_y.rename(
                    "els", "e") if not self.evalp else rvinfo.log_pe.rename(
                        "els", "e")
                log_pt = rvinfo.log_qt_y.rename(
                    "els", "t") if not self.evalp else rvinfo.log_pt.rename(
                        "els", "t")

                log_pe = log_pe.cpu()
                log_pt = log_pt.cpu()
                ue = ue.cpu()
                ut = ut.cpu()

                #import pdb; pdb.set_trace()

                # calculate accuracy
                #import pdb; pdb.set_trace()
                ds = batch.ie_et_d
                num_cells = float(sum(len(d) for d in ds))

                # calculate accuracy
                #import pdb; pdb.set_trace()
                for batch, d in enumerate(ds):
                    for t, (es, ts) in d.items():
                        #t = t + 1
                        _, e_max = log_pe.get("batch", batch).get("time",
                                                                  t).max("e")
                        _, t_max = log_pt.get("batch", batch).get("time",
                                                                  t).max("t")
                        e_preds = ue.get("batch",
                                         batch).get("els", e_max.item())
                        t_preds = ut.get("batch",
                                         batch).get("els", t_max.item())
                        import pdb
                        pdb.set_trace()

                        correct = (es.eq(e_preds) *
                                   ts.eq(t_preds)).any().float().item()
                        e_correct = es.eq(e_preds).any().float().item()
                        t_correct = ts.eq(t_preds).any().float().item()
                        #import pdb; pdb.set_trace()

                        cum_correct += correct
                        batch_correct += correct
                        cum_e_correct += e_correct
                        cum_t_correct += t_correct
                        batch_e_correct += e_correct
                        batch_t_correct += t_correct

                #import pdb; pdb.set_trace()
                if learn:
                    if clip > 0:
                        gnorm = clip_(self.parameters(), clip)
                        #for param in self.rnn_parameters():
                        #gnorm = clip_(param, clip)
                    optimizer.step()
                cum_loss += nll.item()
                cum_ntokens += num_cells
                batch_loss += nll.item()
                batch_ntokens += num_cells
                if re is not None and i % re == -1 % re:
                    titer.set_postfix(
                        loss=batch_loss / batch_ntokens,
                        gnorm=gnorm,
                        acc=batch_correct / batch_ntokens,
                        e=batch_e_correct / batch_ntokens,
                        t=batch_t_correct / batch_ntokens,
                    )
                    batch_loss = 0
                    batch_ntokens = 0
                    batch_correct = 0
                    batch_e_correct = 0
                    batch_t_correct = 0

        print(
            f"acc: {cum_correct / cum_ntokens} | E acc: {cum_e_correct / cum_ntokens} | T acc: {cum_t_correct / cum_ntokens}"
        )
        print(f"total supervised cells: {cum_ntokens}")
        return cum_loss, cum_ntokens
Example #8
0
    def _loop(
        self,
        iter,
        optimizer=None,
        clip=0,
        learn=False,
        re=None,
        supattn=False,
        supcopy=False,
        T=None,
        E=None,
        R=None,
    ):
        self.train(learn)
        context = torch.enable_grad if learn else torch.no_grad

        cum_loss = 0
        cum_ntokens = 0
        cum_rx = 0
        cum_kl = 0
        batch_loss = 0
        batch_ntokens = 0
        states = None
        with context():
            titer = tqdm(iter) if learn else iter
            for i, batch in enumerate(titer):
                if learn:
                    optimizer.zero_grad()
                text, x_info = batch.text
                mask = x_info.mask
                lens = x_info.lengths

                L = text.shape["time"]
                x = text.narrow("time", 0, L - 1)
                y = text.narrow("time", 1, L - 1)
                #x = text[:-1]
                #y = text[1:]
                x_info.lengths.sub_(1)

                e, e_info = batch.entities
                t, t_info = batch.types
                v, v_info = batch.values
                lene = e_info.lengths
                lent = t_info.lengths
                lenv = v_info.lengths
                #rlen, N = e.shape
                #r = torch.stack([e, t, v], dim=-1)
                r = [e, t, v]
                assert (lene == lent).all()
                lenr = lene
                r_info = e_info

                ue, ue_info = batch.uentities
                ut, ut_info = batch.utypes

                v2d = batch.v2d
                vt2d = batch.vt2d

                vt, vt_info = batch.values_text

                # should i include <eos> in ppl?
                nwords = y.ne(1).sum()
                # assert nwords == lens.sum()
                T = y.shape["time"]
                N = y.shape["batch"]
                #if states is None:
                states = self.init_state(N)
                logits, _ = self(x, states, x_info, r, r_info, vt, ue, ue_info,
                                 ut, ut_info, v2d, vt2d, y)

                nll = self.loss(logits, y)

                if self.maskedc:
                    slist = [
                        "one",
                        "two",
                        "three",
                        "four",
                        "five",
                        "six",
                        "seven",
                        "eight",
                        "nine",
                        "1",
                        "2",
                        "3",
                        "4",
                        "5",
                        "6",
                        "7",
                        "8",
                        "9",
                    ]
                    mask = sum(y.eq(self.Vx.stoi[x]) for x in slist)
                    nllc = -self.log_pc[mask].sum()
                else:
                    nllc = 0

                kl = 0
                nelbo = nll + kl
                if learn:
                    (nelbo + nllc).div(nwords.item()).backward()
                    if clip > 0:
                        gnorm = clip_(self.parameters(), clip)
                        #for param in self.rnn_parameters():
                        #gnorm = clip_(param, clip)
                    optimizer.step()
                cum_loss += nelbo.item()
                cum_ntokens += nwords.item()
                batch_loss += nelbo.item()
                batch_ntokens += nwords.item()
                if re is not None and i % re == -1 % re:
                    titer.set_postfix(loss=batch_loss / batch_ntokens,
                                      gnorm=gnorm)
                    batch_loss = 0
                    batch_ntokens = 0
        return cum_loss, cum_ntokens
Example #9
0
    def _ie_loop(
        self,
        iter,
        optimizer=None,
        clip=0,
        learn=False,
        re=None,
        T=None,
        E=None,
        R=None,
    ):
        self.train(learn)
        context = torch.enable_grad if learn else torch.no_grad

        if not self.v2d:
            self.copy_x_to_v()

        cum_loss = 0
        cum_ntokens = 0
        cum_rx = 0
        cum_kl = 0
        batch_loss = 0
        batch_ntokens = 0
        states = None
        cum_TP, cum_FP, cum_FN, cum_TM, cum_TG = 0, 0, 0, 0, 0
        batch_TP, batch_FP, batch_FN, batch_TM, batch_TG = 0, 0, 0, 0, 0
        # for debugging
        cum_e_correct = 0
        cum_t_correct = 0
        cum_correct = 0
        with context():
            titer = tqdm(iter) if learn else iter
            for i, batch in enumerate(titer):
                if learn:
                    optimizer.zero_grad()
                text, x_info = batch.text
                mask = x_info.mask
                lens = x_info.lengths

                L = text.shape["time"]
                x = text.narrow("time", 0, L - 1)
                y = text.narrow("time", 1, L - 1)
                #x = text[:-1]
                #y = text[1:]
                x_info.lengths.sub_(1)

                e, e_info = batch.entities
                t, t_info = batch.types
                v, v_info = batch.values
                lene = e_info.lengths
                lent = t_info.lengths
                lenv = v_info.lengths
                #rlen, N = e.shape
                #r = torch.stack([e, t, v], dim=-1)
                r = [e, t, v]
                assert (lene == lent).all()
                lenr = lene
                r_info = e_info

                vt, vt_info = batch.values_text

                ue, ue_info = batch.uentities
                ut, ut_info = batch.utypes

                v2d = batch.v2d
                vt2d = batch.vt2d

                # should i include <eos> in ppl?
                nwords = y.ne(1).sum()
                # assert nwords == lens.sum()
                T = y.shape["time"]
                N = y.shape["batch"]
                #if states is None:
                states = self.init_state(N)
                logits, _ = self(x, states, x_info, r, r_info, vt, ue, ue_info,
                                 ut, ut_info, v2d, vt2d)
                nll = self.loss(logits, y)

                # only for crnnlma though
                """
                cor, ecor, tcor, tot = Ie.pat1(
                    self.ea.rename("els", "e"),
                    self.ta.rename("els", "t"),
                    batch.ie_d,
                )
                """

                #ds = batch.ie_d
                etvs = batch.ie_etv

                #num_cells = batch.num_cells
                num_cells = float(sum(len(d) for d in etvs))
                log_pe = self.ea.cpu()
                log_pt = self.ta.cpu()
                log_pc = self.log_pc.cpu()
                ue = ue.cpu()
                ut = ut.cpu()

                # need log_pv, need to check if self.v2d
                # batch x time x hid
                h = self.lutx(y).values
                log_pv = torch.einsum(
                    "nth,vh->ntv", [h, self.lutv.weight.data]).log_softmax(-1)
                log_pv = NamedTensor(log_pv, names=("batch", "time", "v"))

                tp, fp, fn, tm, tg = pr(
                    etvs,
                    ue,
                    ut,
                    log_pe,
                    log_pt,
                    log_pv,
                    log_pc=log_pc,
                    lens=lens,
                    Ve=self.Ve,
                    Vt=self.Vt,
                    Vv=self.Vv,
                    Vx=self.Vx,
                    text=text,
                )
                cum_TP += tp
                cum_FP += fp
                cum_FN += fn
                cum_TM += tm
                cum_TG += tg

                batch_TP += tp
                batch_FP += fp
                batch_FN += fn
                batch_TM += tm
                batch_TG += tg

                # For debugging
                ds = batch.ie_et_d
                for batch, d in enumerate(ds):
                    for t, (es, ts) in d.items():
                        #t = t + 1
                        _, e_max = log_pe.get("batch", batch).get("time",
                                                                  t).max("e")
                        _, t_max = log_pt.get("batch", batch).get("time",
                                                                  t).max("t")
                        e_preds = ue.get("batch",
                                         batch).get("els", e_max.item())
                        t_preds = ut.get("batch",
                                         batch).get("els", t_max.item())

                        correct = (es.eq(e_preds) *
                                   ts.eq(t_preds)).any().float().item()
                        e_correct = es.eq(e_preds).any().float().item()
                        t_correct = ts.eq(t_preds).any().float().item()

                        cum_correct += correct
                        cum_e_correct += e_correct
                        cum_t_correct += t_correct

                if learn:
                    if clip > 0:
                        gnorm = clip_(self.parameters(), clip)
                        #for param in self.rnn_parameters():
                        #gnorm = clip_(param, clip)
                    optimizer.step()
                cum_loss += nll.item()
                cum_ntokens += num_cells
                batch_loss += nll.item()
                batch_ntokens += num_cells
                if re is not None and i % re == -1 % re:
                    titer.set_postfix(
                        loss=batch_loss / batch_ntokens,
                        gnorm=gnorm,
                        p=batch_TP / (batch_TP + batch_FP),
                        r=batch_TP / (batch_TP + batch_FN),
                    )
                    batch_loss = 0
                    batch_ntokens = 0
                    batch_TP, batch_FP, batch_FN, batch_TM, batch_TG = 0, 0, 0, 0, 0

        print(
            f"DBG acc: {cum_correct / cum_ntokens} | E acc: {cum_e_correct / cum_ntokens} | T acc: {cum_t_correct / cum_ntokens}"
        )
        print(
            f"p: {cum_TP / (cum_TP + cum_FP)} || r: {cum_TP / (cum_TP + cum_FN)}"
        )
        print(f"total supervised cells: {cum_ntokens}")
        return cum_loss, cum_ntokens
Example #10
0
    def _loop(
        self, iter, supiter=[], optimizer=None, clip=0,
        learn=False, re=None,
        elbo=True,
        T=64, E=32, R=4,
        supattn=False, supcopy=False,
    ):
        context = torch.enable_grad if learn else torch.no_grad

        # DBG
        self.copied = 0
        self.couldve_copied = 0

        # unsup
        cum_loss = 0.
        cum_ntokens = 0.
        cum_rx = 0.
        cum_klv = 0.
        cum_kla = 0.
        batch_loss = 0.
        batch_ntokens = 0.
        batch_rx = 0.
        batch_klv = 0.
        batch_kla = 0.
        cum_N = 0.
        batch_N = 0.

        # sup
        cum_nllpy = 0.
        cum_nllpv = 0.
        cum_nllqv = 0.
        cum_Ny = 0.
        cum_Nv = 0.
        batch_nllpy = 0.
        batch_nllpv = 0.
        batch_nllqv = 0.
        batch_Ny = 0.
        batch_Nv = 0.

        ziter = zip_longest(iter, supiter)

        states = None
        with context():
            titer = tqdm(ziter) if learn else ziter
            for i, (batch, supbatch) in enumerate(titer):
                if learn:
                    optimizer.zero_grad()
                if batch is not None:
                    rvinfo, _, nll, klv, kla, Nv, Ny = self._batch(
                        batch,
                        T = T, E = E, R = R, # ?
                        learn = learn,
                        sup = False,
                    )
                    nelbo = nll + klv + kla

                    cum_rx += nll.item()
                    batch_rx += nll.item()
                    cum_klv += klv.item()
                    batch_klv += klv.item()
                    cum_kla += kla.item()
                    batch_kla += kla.item()
                    cum_loss += nelbo.item() if not hasattr(self, "nokl") else nll.item()
                    batch_loss += nelbo.item()

                    cum_ntokens += Ny.item()
                    batch_ntokens += Ny.item()
                    cum_N += Nv.item()
                    batch_N += Nv.item()
                if supbatch is not None:
                    (
                        rvinfosup, _,
                        nll_pv, nll_py_v, nll_qv_y,
                        v_total, y_total
                    ) = self._batch(
                        supbatch,
                        T = T, E = E, R = R,
                        learn = learn,
                        sup = True,
                    )
                    cum_nllpy += nll_py_v.item()
                    cum_nllpv += nll_pv.item()
                    cum_nllqv += nll_qv_y.item()
                    cum_Ny += y_total.item()
                    cum_Nv += v_total.item()
                    batch_nllpy += nll_py_v.item()
                    batch_nllpv += nll_pv.item()
                    batch_nllqv += nll_qv_y.item()
                    batch_Ny += y_total.item()
                    batch_Nv += v_total.item()

                if learn:
                    if clip > 0:
                        gnorm = clip_(self.parameters(), clip)
                    optimizer.step()

                if re is not None and i % re == -1 % re:
                    titer.set_postfix(
                        elbo = batch_loss / batch_ntokens if batch_loss != 0 else 0,
                        nll = batch_rx / batch_ntokens if batch_rx != 0 else 0,
                        klv = batch_klv / cum_N if batch_klv != 0 else 0,
                        kla = batch_kla / batch_N if batch_kla != 0 else 0,
                        py = batch_nllpy / batch_Ny if batch_nllpy != 0 else 0,
                        pv = batch_nllpv / batch_Nv if batch_nllpv != 0 else 0,
                        qv = batch_nllqv / batch_Nv if batch_nllqv != 0 else 0,
                        gnorm = gnorm,
                    )
                    batch_loss = 0.
                    batch_ntokens = 0.
                    batch_N = 0.
                    batch_rx = 0.
                    batch_klv = 0.
                    batch_kla = 0.
                    # sup
                    batch_nllpy = 0.
                    batch_nllpv = 0.
                    batch_nllqv = 0.
                    batch_Ny = 0.
                    batch_Nv = 0.
        #print(f"COPIED: {self.copied:,} vs {couldve_copied:,} / {cum_ntokens:,}")
        if cum_N == 0:
            cum_N = 1
        if cum_ntokens == 0:
            cum_ntokens = 1
        if cum_Ny == 0:
            cum_Ny = 1
        if cum_Nv == 0:
            cum_Nv = 1
        print(f"NLL: {cum_rx / cum_ntokens} || KLv: {cum_klv / cum_N} || KLa: {cum_kla / cum_N}")
        print(f"py: {cum_nllpy / cum_Ny} || pv: {cum_nllpv / cum_Nv} || qv: {cum_nllqv / cum_Nv}")
        return cum_loss, cum_ntokens
Example #11
0
    def _loop(self, iter, optimizer=None, clip=0, learn=False, re=None):
        context = torch.enable_grad if learn else torch.no_grad

        cum_loss = 0
        cum_ntokens = 0
        cum_rx = 0
        cum_kl = 0
        batch_loss = 0
        batch_ntokens = 0
        states = None
        with context():
            titer = tqdm(iter) if learn else iter
            for i, batch in enumerate(titer):
                if learn:
                    optimizer.zero_grad()
                text, lens = batch.text
                x = text[:-1]
                y = text[1:]
                lens = lens - 1

                e, lene = batch.entities
                t, lent = batch.types
                v, lenv = batch.values
                #rlen, N = e.shape
                #r = torch.stack([e, t, v], dim=-1)
                r = [e, t, v]
                assert (lene == lent).all()
                lenr = lene

                # should i include <eos> in ppl? no, should not.
                mask = y.ne(1) #* y.ne(3)
                nwords = mask.sum()
                # assert nwords == lens.sum()
                T, N = y.shape
                R = e.shape[0]
                #if states is None:
                states = self.init_state(N)
                logits, _, sampled_log_pa, log_pa, log_qay = self(x, states, lens, r, lenr, y)

                nll = self.loss(logits, y)
                B = nll[-1]
                reward = (nll[:-1] - B.unsqueeze(0)).detach() * sampled_log_pa
                reward = reward[mask.unsqueeze(0).expand_as(reward)].sum()
                nll = nll[:-1].sum(0)[mask].sum()


                # giving nans because of masking, sigh
                # hmm...this worked before?
                """
                qa = log_qay.exp()
                qa.data[log_qay == float("-inf")] = 0
                kl0 = qa * (log_qay - log_pa)
                kl0[log_qay == float("-inf")] = 0
                kl = kl0.sum()
                #import pdb; pdb.set_trace()
                """
                #"""
                kl = 0
                qa = log_qay.exp()
                for i, l in enumerate(lenr.tolist()):
                    #p = Categorical(logits=log_pa[:,i,:l])
                    #q = Categorical(logits=log_qay[:,i,:l])
                    #kl0 = kl_divergence(q, p).sum()
                    kl0 = qa[:,i,:l] * (log_qay[:,i,:l] - log_pa[:,i,:l])
                    kl0[log_qay[:,i,:l] == float("-inf")] = 0
                    kl0 = kl0.sum()
                    kl += kl0
                    #"""


                #p = Categorical(logits=log_pa[mask])
                #q = Categorical(logits=log_qay[mask])
                #kl = kl_divergence(q, p).sum()

                nelbo = nll + kl
                if learn:
                    (nelbo - reward).div(nwords.item()).backward()
                    if clip > 0:
                        gnorm = clip_(self.parameters(), clip)
                        #for param in self.rnn_parameters():
                            #gnorm = clip_(param, clip)
                    optimizer.step()
                cum_loss += nelbo.item()
                cum_ntokens += nwords.item()
                batch_loss += nelbo.item()
                batch_ntokens += nwords.item()
                if re is not None and i % re == -1 % re:
                    titer.set_postfix(loss = batch_loss / batch_ntokens, gnorm = gnorm)
                    batch_loss = 0
                    batch_ntokens = 0
        return cum_loss, cum_ntokens
Example #12
0
    def _ie_loop(
        self,
        iter,
        optimizer=None,
        clip=0,
        learn=False,
        re=None,
        T=None,
        E=None,
        R=None,
    ):
        self.train(learn)
        context = torch.enable_grad if learn else torch.no_grad

        cum_loss = 0
        cum_ntokens = 0
        cum_rx = 0
        cum_kl = 0
        batch_loss = 0
        batch_ntokens = 0
        states = None
        cum_TP, cum_FP, cum_FN, cum_TM, cum_TG = 0, 0, 0, 0, 0
        batch_TP, batch_FP, batch_FN, batch_TM, batch_TG = 0, 0, 0, 0, 0
        with context():
            titer = tqdm(iter) if learn else iter
            for i, batch in enumerate(titer):
                if learn:
                    optimizer.zero_grad()
                text, x_info = batch.text
                mask = x_info.mask
                lens = x_info.lengths

                x = text

                # ie_idx is off by one because of bos
                #ie_idx = batch.ie_idx
                #ie_e = batch.ie_e
                #ie_t = batch.ie_t

                etvs = batch.ie_etv

                #num_cells = batch.num_cells
                num_cells = float(sum(len(d) for d in etvs))

                nll, log_pe, log_pt, log_pv = self(x,
                                                   x_info,
                                                   etvs,
                                                   learn=learn)

                # gather predictions
                batch_positives = []
                for batch, etv in enumerate(etvs):
                    # Get model positives
                    positives = {}
                    T = lens.get("batch", batch).item()
                    for t in range(T):
                        e_p, e_max = log_pe.get("batch", batch).get("time",
                                                                    t).max("e")
                        t_p, t_max = log_pt.get("batch", batch).get("time",
                                                                    t).max("t")
                        v_p, v_preds = log_pv.get("batch",
                                                  batch).get("time",
                                                             t).max("v")

                        e_pred = e_max.item()
                        t_pred = t_max.item()
                        v_pred = v_preds.item()
                        if (e_pred != self.Ve.stoi[self.NONE]
                                and t_pred != self.Vt.stoi[self.NONE]
                                and v_pred != self.Vv.stoi[self.NONE]):
                            key = (e_pred, t_pred, v_pred)
                            score = (e_p + t_p + v_p).item()
                            if key not in positives or score > positives[key]:
                                positives[key] = score
                    batch_positives.append(positives)

                    # Compare against true positives
                    true_positives = set()
                    for es, ts, vs, _ in etv:
                        true_positives |= set(
                            zip(es.tolist(), ts.tolist(), vs.tolist()))

                    total_m = len(positives)
                    total_g = len(true_positives)
                    tp = len(set(positives) & true_positives)
                    fp = len(set(positives) - true_positives)
                    fn = len(true_positives - set(positives))

                    cum_TP += tp
                    cum_FP += fp
                    cum_FN += fn
                    cum_TM += total_m
                    cum_TG += total_g

                    batch_TP += tp
                    batch_FP += fp
                    batch_FN += fn
                    batch_TM += total_m
                    batch_TG += total_g

                if learn:
                    if clip > 0:
                        gnorm = clip_(self.parameters(), clip)
                        #for param in self.rnn_parameters():
                        #gnorm = clip_(param, clip)
                    optimizer.step()
                cum_loss += nll.item()
                cum_ntokens += num_cells
                batch_loss += nll.item()
                batch_ntokens += num_cells
                if re is not None and i % re == -1 % re:
                    titer.set_postfix(
                        loss=batch_loss / batch_ntokens,
                        gnorm=gnorm,
                        p=batch_TP /
                        (batch_TP + batch_FP) if batch_TP > 0 else 0,
                        r=batch_TP /
                        (batch_TP + batch_FN) if batch_TP > 0 else 0,
                    )
                    batch_loss = 0
                    batch_ntokens = 0
                    batch_TP, batch_FP, batch_FN, batch_TM, batch_TG = 0, 0, 0, 0, 0

        print(
            f"p: {cum_TP / (cum_TP + cum_FP) if cum_TP > 0 else 0} || r: {cum_TP / (cum_TP + cum_FN) if cum_TP > 0 else 0}"
        )
        print(f"total supervised cells: {cum_ntokens}")
        return cum_loss, cum_ntokens, cum_TP, cum_FP, cum_FN, cum_TM, cum_TG
Example #13
0
    def _vie_loop(
        self,
        iter,
        optimizer=None,
        clip=0,
        learn=False,
        re=None,
        T=None,
        E=None,
        R=None,
    ):
        self.train(learn)
        context = torch.enable_grad if learn else torch.no_grad

        cum_loss = 0
        cum_ntokens = 0
        cum_rx = 0
        cum_kl = 0
        batch_loss = 0
        batch_ntokens = 0
        states = None
        cum_e_correct = 0
        cum_t_correct = 0
        batch_e_correct = 0
        batch_t_correct = 0
        cum_correct = 0
        batch_correct = 0
        cum_copyable = 0
        batch_copyable = 0
        with context():
            titer = tqdm(iter) if learn else iter
            for i, batch in enumerate(titer):
                if learn:
                    optimizer.zero_grad()
                text, x_info = batch.text
                mask = x_info.mask
                lens = x_info.lengths

                x = text

                # ie_idx is off by one because of bos
                #ie_idx = batch.ie_idx
                #ie_e = batch.ie_e
                #ie_t = batch.ie_t

                #ds = batch.ie_d
                ets = batch.ie_et

                #num_cells = batch.num_cells
                #num_cells = float(sum(len(d) for d in ds))

                e, r_info = batch.entities
                t, _ = batch.types
                v, _ = batch.values
                vt, _ = batch.values_text
                num_cells = r_info.mask.sum()

                nll, log_pv, attn = self(x,
                                         x_info,
                                         e=e,
                                         t=t,
                                         v=v,
                                         r_info=r_info,
                                         ets=ets,
                                         learn=learn)

                hv = log_pv.max("v")[1]
                correct = (hv == v)[r_info.mask].sum()
                batch_correct += correct
                cum_correct += correct

                vt_nopad = vt.clone()
                vt_nopad[vt_nopad == 1] = -1
                num_copyable = vt_nopad.eq(x).sum()

                batch_copyable += num_copyable
                cum_copyable += num_copyable

                # Deal with n/a tokens
                na_v_idx = self.Vv.stoi["n/a"]
                #import pdb; pdb.set_trace()

                if learn:
                    if clip > 0:
                        gnorm = clip_(self.parameters(), clip)
                        #for param in self.rnn_parameters():
                        #gnorm = clip_(param, clip)
                    optimizer.step()
                cum_loss += nll
                cum_ntokens += num_cells
                batch_loss += nll
                batch_ntokens += num_cells
                if re is not None and i % re == -1 % re:
                    titer.set_postfix(
                        loss=batch_loss.item() / batch_ntokens.item(),
                        gnorm=gnorm,
                        acc=batch_correct.item() / batch_ntokens.item(),
                        copyable=batch_copyable.item() / batch_ntokens.item(),
                    )
                    batch_loss = 0
                    batch_ntokens = 0
                    batch_correct = 0
                    batch_copyable = 0

        print(
            f"acc: {cum_correct.item() / cum_ntokens.item()} || copyable: {cum_copyable.item() / cum_ntokens.item()}"
        )
        print(f"total supervised cells: {cum_ntokens.item()}")
        return cum_loss.item(), cum_ntokens.item()