コード例 #1
0
def get_cutoff(data, model, pred, h0=None, nratio=10):
    with torch.no_grad():
        zs, h0 = model(data.x,
                       data.eis,
                       data.all,
                       ew_fn=data.all_w,
                       h_0=h0,
                       include_h=True)

        # Get optimal cutoff point for LR
        if pred:
            p, n, z = g.dynamic_link_prediction(data,
                                                data.all,
                                                zs,
                                                include_tr=False,
                                                nratio=nratio)
        else:
            p, n, z = g.link_prediction(data,
                                        data.all,
                                        zs,
                                        include_tr=False,
                                        nratio=nratio)

        dt, df = model.score_fn(p, n, z)

    model.cutoff = get_optimal_cutoff(dt, df, fw=0.6)
    return model.cutoff, h0
コード例 #2
0
ファイル: evo_serial_train.py プロジェクト: zazyzaya/TGCN
def train(model, data, epochs=1500, dynamic=False, nratio=10, lr=0.01):
    print(lr)
    end_tr = data.T-TEST_TS

    tr_adjs, tr_xs = convert_to_dense(data, data.TR, end=end_tr)
    opt = Adam(model.parameters(), lr=lr)

    best = (0, None)
    no_improvement = 0
    for e in range(epochs):
        model.train()
        opt.zero_grad()

        # Get embedding   
        zs = model(tr_adjs, tr_xs)

        if not dynamic:
            p,n,z = g.link_prediction(data, data.tr, zs, include_tr=False, nratio=nratio)
            
        else:
            p,n,z = g.dynamic_link_prediction(data, data.tr, zs, include_tr=False, nratio=nratio)      
        
        loss = model.loss_fn(p,n,z)
        loss.backward()
        opt.step()

        # Done by VGRNN to improve convergence
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)

        trloss = loss.item() 
        with torch.no_grad():
            model.eval()
            zs = model(tr_adjs, tr_xs)

            if not dynamic:
                p,n,z = g.link_prediction(data, data.va, zs)
                st, sf = model.score_fn(p,n,z)
                sscores = get_score(st, sf)

                print(
                    '[%d] Loss: %0.4f  \n\tSt %s ' %
                    (e, trloss, fmt_score(sscores) )
                )

                avg = sscores[0] + sscores[1]

            else:
                dp,dn,dz = g.dynamic_link_prediction(data, data.va, zs)
                dt, df = model.score_fn(dp,dn,dz)
                dscores = get_score(dt, df)

                dp,dn,dz = g.dynamic_new_link_prediction(data, data.va, zs)
                dt, df = model.score_fn(dp,dn,dz)
                dnscores = get_score(dt, df)

                print(
                    '[%d] Loss: %0.4f  \n\tDet %s  \n\tNew %s' %
                    (e, trloss, fmt_score(dscores), fmt_score(dnscores) )
                )

                avg = (
                    dscores[0] + dscores[1] 
                )
            
            if e == KL_ANNEALING:
                model.kld_weight = KL_WEIGHT

            if avg > best[0]:
                best = (avg, deepcopy(model))
                no_improvement = 0
            else:
                # Though it's not reflected in the code, the authors for VGRNN imply in the
                # supplimental material that after 500 epochs, early stopping may kick in 
                if e > 100:
                    no_improvement += 1
                if no_improvement == PATIENCE:
                    print("Early stopping...\n")
                    break


    model = best[1]
    with torch.no_grad():
        model.eval()
        adjs, xs = convert_to_dense(data, data.TR)
        zs = model(adjs, xs)[end_tr:]

        if not dynamic:
            p,n,z = g.link_prediction(data, data.te, zs, start=end_tr)
            t, f = model.score_fn(p,n,z)
            sscores = get_score(t, f)

            print(
                '''
                Final scores: 
                    Static LP:  %s
                '''
            % fmt_score(sscores))

            return {'auc': sscores[0], 'ap': sscores[1]}

        else:    
            p,n,z = g.dynamic_link_prediction(data, data.te, zs, start=end_tr)
            t, f = model.score_fn(p,n,z)
            dscores = get_score(t, f)

            p,n,z = g.dynamic_new_link_prediction(data, data.te, zs, start=end_tr)
            t, f = model.score_fn(p,n,z)
            nscores = get_score(t, f)

            print(
                '''
                Final scores: 
                    Dynamic LP:     %s 
                    Dynamic New LP: %s 
                ''' %
                (fmt_score(dscores),
                 fmt_score(nscores))
            )

            return {
                'pred-auc': dscores[0],
                'pred-ap': dscores[1],
                'new-auc': nscores[0], 
                'new-ap': nscores[1],
            }
コード例 #3
0
ファイル: tedge_train.py プロジェクト: zazyzaya/TGCN
def train(model, data, epochs=1500, dynamic=False, nratio=10, lr=0.01):
    print(lr)
    end_tr = data.T - TEST_TS

    opt = Adam(model.parameters(), lr=lr)

    best = (0, None)
    no_improvement = 0
    for e in range(epochs):
        model.train()
        opt.zero_grad()
        zs = None

        # Get embedding
        zs, _ = model(data.x, data.eis[:end_tr], data.tr)

        if not dynamic:
            p, n, z = g.link_prediction(data,
                                        data.tr,
                                        zs,
                                        include_tr=False,
                                        nratio=nratio)

        else:
            p, n, z = g.dynamic_link_prediction(data,
                                                data.tr,
                                                zs,
                                                include_tr=False,
                                                nratio=nratio)

        loss = model.loss_fn(p, n, z)
        loss.backward()
        opt.step()

        trloss = loss.item()
        with torch.no_grad():
            model.eval()
            zs, hs = model(data.x, data.eis[:end_tr], data.tr)

            if not dynamic:
                p, n, h = g.link_prediction(data, data.va, hs)
                st, sf = model.score_fn(p, n, h)
                sscores = get_score(st, sf)

                print('[%d] Loss: %0.4f  \n\tSt %s ' %
                      (e, trloss, fmt_score(sscores)))

                avg = sscores[0] + sscores[1]

            elif type(model) == HybridProbTGCN:
                dp, dn, _ = g.dynamic_link_prediction(data, data.va, zs)
                dt, df = model.score_fn(dp, dn, hs)
                dscores = get_score(dt, df)

                dp, dn, _ = g.dynamic_new_link_prediction(data, data.va, zs)
                dt, df = model.score_fn(dp, dn, hs)
                dnscores = get_score(dt, df)

                print('[%d] Loss: %0.4f  \n\tDet %s  \n\tNew %s' %
                      (e, trloss, fmt_score(dscores), fmt_score(dnscores)))

                avg = (dscores[0] + dscores[1])

            else:
                dp, dn, dh = g.dynamic_link_prediction(data, data.va, hs)
                dt, df = model.score_fn(dp, dn, dh)
                dscores = get_score(dt, df)

                dp, dn, dh = g.dynamic_new_link_prediction(data, data.va, hs)
                dt, df = model.score_fn(dp, dn, dh)
                dnscores = get_score(dt, df)

                print('[%d] Loss: %0.4f  \n\tDet %s  \n\tNew %s' %
                      (e, trloss, fmt_score(dscores), fmt_score(dnscores)))

                avg = (dscores[0] + dscores[1])

            if avg > best[0]:
                best = (avg, deepcopy(model))
                no_improvement = 0
            else:
                # Though it's not reflected in the code, the authors for VGRNN imply in the
                # supplimental material that after 500 epochs, early stopping may kick in
                no_improvement += 1
                if no_improvement == PATIENCE:
                    print("Early stopping...\n")
                    break

    # Test the data that hasn't been touched
    model = best[1]
    with torch.no_grad():
        model.eval()
        zs, hs = model(data.x, data.eis, data.tr)
        hs = hs[end_tr:] if not dynamic else hs[end_tr - 1:]

        if not dynamic:
            p, n, h = g.link_prediction(data, data.te, hs, start=end_tr)
            t, f = model.score_fn(p, n, h)
            sscores = get_score(t, f)

            print('''
                Final scores: 
                    Static LP:  %s
                ''' % fmt_score(sscores))

            return {'auc': sscores[0], 'ap': sscores[1]}

        if type(model) == HybridProbTGCN:
            # H matrix is already correctly aligned, so ignore
            p, n, _ = g.dynamic_link_prediction(data,
                                                data.te,
                                                zs[end_tr - 1:],
                                                start=end_tr - 1)
            print(len(p))
            print(hs.size(0))
            t, f = model.score_fn(p, n, hs)
            dscores = get_score(t, f)

            p, n, _ = g.dynamic_new_link_prediction(data,
                                                    data.te,
                                                    zs[end_tr - 1:],
                                                    start=end_tr - 1)
            t, f = model.score_fn(p, n, hs)
            nscores = get_score(t, f)

            print('''
                Final scores: 
                    Dynamic LP:     %s 
                    Dynamic New LP: %s 
                ''' % (fmt_score(dscores), fmt_score(nscores)))

            return {
                'pred-auc': dscores[0],
                'pred-ap': dscores[1],
                'new-auc': nscores[0],
                'new-ap': nscores[1],
            }

        else:
            p, n, h = g.dynamic_link_prediction(data,
                                                data.te,
                                                hs,
                                                start=end_tr - 1)
            print(len(p))
            print(h.size(0))
            t, f = model.score_fn(p, n, h)
            dscores = get_score(t, f)

            p, n, h = g.dynamic_new_link_prediction(data,
                                                    data.te,
                                                    hs,
                                                    start=end_tr - 1)
            t, f = model.score_fn(p, n, h)
            nscores = get_score(t, f)

            print('''
                Final scores: 
                    Dynamic LP:     %s 
                    Dynamic New LP: %s 
                ''' % (fmt_score(dscores), fmt_score(nscores)))

            return {
                'pred-auc': dscores[0],
                'pred-ap': dscores[1],
                'new-auc': nscores[0],
                'new-ap': nscores[1],
            }
コード例 #4
0
def train(model,
          data,
          epochs=1500,
          nratio=1,
          dynamic=True,
          min_epochs=100,
          lr_nratio=1,
          single_prior=False,
          lp_epochs=100,
          no_test=False):

    TE_STARTS = data.T if no_test else data.te_starts
    WINDOW = TE_STARTS

    opt = Adam(model.parameters(), lr=LR)
    best = (0, None)
    no_improvement = 0

    for e in range(epochs):
        model.train()
        opt.zero_grad()

        st = 0  #e % (max(1, TE_STARTS-WINDOW))
        zs = model(data.x, data.eis[st:st + WINDOW], data.tr, start_idx=st)

        # TGCN uses the embeds of timestep t to predict t+1 if dynamic, using sparse
        # loss, requiring us to generate some neg samples, and timeshift them
        if model.__class__ == SerialTGCN:
            if dynamic:
                p, n, z = g.dynamic_link_prediction(data,
                                                    data.tr,
                                                    zs,
                                                    end=st + WINDOW,
                                                    start=st,
                                                    nratio=nratio,
                                                    include_tr=False)
            else:
                p, n, z = g.link_prediction(data,
                                            data.tr,
                                            zs,
                                            start=st,
                                            end=st + WINDOW,
                                            nratio=nratio,
                                            include_tr=False)

        # VGRNN uses dense loss, so no need to do neg sampling or timeshift
        else:
            if model.adj_loss:
                p = data.eis[:TE_STARTS]
                n = None
                z = zs
            else:
                p, n, z = g.link_prediction(data,
                                            data.tr,
                                            zs,
                                            end=TE_STARTS,
                                            nratio=nratio,
                                            include_tr=False)

        loss = model.loss_fn(p, n, z)
        loss.backward()
        opt.step()

        # Done by VGRNN to improve convergence
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        trloss = loss.item()

        with torch.no_grad():
            model.eval()
            zs = model(data.x, data.eis[:TE_STARTS], data.tr)

            if not dynamic:
                p, n, z = g.link_prediction(data, data.va, zs, end=TE_STARTS)
                st, sf = model.score_fn(p, n, z)
                sscores = get_score(st, sf)

                print('[%d] Loss: %0.4f  \n\tSt %s ' %
                      (e, trloss, fmt_score(sscores)))

                avg = sscores[1]

            else:
                # VGRNN is providing priors, which are built from the previous timestep
                # already, thus there is no need to shift the selected ei's as the
                # dynamic functions do
                if model.__class__ == VGRNN:
                    zs = zs
                    dp, dn, dz = g.link_prediction(data,
                                                   data.va,
                                                   zs,
                                                   end=TE_STARTS)
                else:
                    dp, dn, dz = g.dynamic_link_prediction(data,
                                                           data.va,
                                                           zs,
                                                           end=TE_STARTS)

                dt, df = model.score_fn(dp, dn, dz)
                dscores = get_score(dt, df)

                dp, dn, dz = g.dynamic_new_link_prediction(data,
                                                           data.all,
                                                           zs,
                                                           end=TE_STARTS)

                # Again, we don't need to shift the VGRNN embeds backward
                if model.__class__ == VGRNN:
                    dz = zs

                dt, df = model.score_fn(dp, dn, dz)
                dnscores = get_score(dt, df)

                print('[%d] Loss: %0.4f  \n\tDet %s  \n\tNew %s' %
                      (e, trloss, fmt_score(dscores), fmt_score(dnscores)))

                avg = (dscores[0] + dscores[1] + dnscores[0] + dnscores[1])

            if avg > best[0]:
                best = (avg, deepcopy(model))
                no_improvement = 0
            else:
                # Though it's not reflected in the code, the authors for VGRNN imply in the
                # supplimental material that after 500 epochs, early stopping may kick in
                if e > min_epochs:
                    no_improvement += 1
                if no_improvement > PATIENCE:
                    print("Early stopping...\n")
                    break

    model = best[1]
    if no_test:
        return model

    zs = None
    with torch.no_grad():
        model.eval()
        if model.__class__ == SerialTGCN:
            zs_all, h0 = model(data.x,
                               data.eis[:TE_STARTS],
                               data.all,
                               include_h=True)
        else:
            zs_all = model(data.x, data.eis[:TE_STARTS], data.all)

        # Generate all future embeds using prior from last normal state
        if single_prior:
            zs = torch.cat([
                model(data.x, [data.eis[i]], data.all, h_0=h0, start_idx=i)
                for i in range(TE_STARTS - 1, data.T)
            ],
                           dim=0)
        else:
            zs = model(data.x, data.eis, data.all)[TE_STARTS - 1:]

    if dynamic:
        if model.__class__ == VGRNN:
            zs = zs[1:]
        else:
            zs = zs[:-1]
    else:
        zs = zs[1:]

    # Train the link classifier
    lp = LP_Classifier()
    tr_mask = None
    X_pos_tr = None
    X_pos_va = None
    neg_val_size = None
    for _ in range(lp_epochs):
        if dynamic:
            if model.__class__ == VGRNN:
                p, n, z = g.link_prediction(data,
                                            data.all,
                                            zs_all,
                                            end=TE_STARTS,
                                            nratio=lr_nratio)
            else:
                p, n, z = g.dynamic_link_prediction(data,
                                                    data.all,
                                                    zs_all,
                                                    end=TE_STARTS,
                                                    nratio=lr_nratio)
        else:
            p, n, z = g.link_prediction(data,
                                        data.all,
                                        zs_all,
                                        end=TE_STARTS,
                                        nratio=lr_nratio)

        X_neg = cat_embeds(z, n)

        # Only need to do this once. Don't know the size of the tr set
        # until now. But partitions don't change
        if type(tr_mask) == type(None):
            X_pos = cat_embeds(z, p)

            tr_mask = torch.zeros(X_pos.size(0), dtype=torch.bool)
            prm = torch.randperm(X_pos.size(0))

            val_size = int(X_pos.size(0) * 0.05)
            tr_mask[prm[val_size:]] = True
            X_pos_tr, X_pos_va = X_pos[tr_mask], X_pos[~tr_mask]
            neg_val_size = val_size * lr_nratio

        # This is recalculated each time, because sending different randomly
        # generated negative samples each epoch
        X_neg_tr, X_neg_va = X_neg[neg_val_size:], X_neg[:neg_val_size]

        X_tr = torch.cat([X_pos_tr, X_neg_tr], dim=0)
        y_tr = torch.zeros(X_pos_tr.size(0) + X_neg_tr.size(0))
        y_tr[:X_neg_tr.size(0)] = 1

        X_va = torch.cat([X_pos_va, X_neg_va], dim=0)
        y_va = torch.zeros(val_size + neg_val_size)
        y_va[:neg_val_size] = 1

        lp.train_lp_step(X_tr, y_tr, X_va, y_va)

    likelihood = [
        model.decode(data.eis[TE_STARTS + i][0], data.eis[TE_STARTS + i][1],
                     zs[i]) for i in range(zs.size(0))
    ]
    likelihood = torch.cat(likelihood, dim=0)

    # Statistical measures on likelihood scores
    y = torch.cat(data.y[TE_STARTS:], dim=0).squeeze(-1)
    ap = average_precision_score(y, likelihood)
    auc = roc_auc_score(y, likelihood)

    X_te = cat_embeds(zs, data.eis[TE_STARTS:])
    lp.score(y, lp(X_te))
    lp.dumb_predict(likelihood, y)

    print('\n%s' % fmt_score([auc, ap]))
コード例 #5
0
def train(data, model, dynamic, epochs=1500, nratio=10):
    # Leave all params as default for now
    opt = Adam(model.parameters(), lr=LR)

    times = []

    best = (0, None)
    no_improvement = 0
    for e in range(epochs):
        model.train()
        opt.zero_grad()
        start = time.time()

        # Generate positive and negative samples from this and the next time step
        if dynamic:
            zs = model(data.x, data.eis, data.all, ew_fn=data.all_w)
            p, n, z = g.dynamic_link_prediction(data,
                                                data.tr,
                                                zs,
                                                include_tr=False,
                                                nratio=nratio)
        else:
            zs = model(data.x, data.eis, data.tr, ew_fn=data.tr_w)
            p, n, z = g.link_prediction(data,
                                        data.tr,
                                        zs,
                                        include_tr=False,
                                        nratio=nratio)

        loss = model.loss_fn(p, n, z)
        loss.backward()
        opt.step()

        elapsed = time.time() - start
        times.append(elapsed)

        # Done by VGRNN to improve convergence
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        trloss = loss.item()

        with torch.no_grad():
            model.eval()
            if not dynamic:
                zs = model(data.x, data.eis, data.tr, ew_fn=data.tr_w)
                p, n, z = g.link_prediction(data, data.va, zs)
                sp, sf = model.score_fn(p, n, z)
                sscores = get_score(sp, sf)
                #vloss = model.loss_fn(p,n,z).item()

                print('[%d] Loss: %0.4f \t%0.4fs  \n\tDet %s\n' %
                      (e, trloss, elapsed, fmt_score(sscores)))

                avg = sum(sscores)

            else:
                zs = model(data.x, data.eis, data.all, ew_fn=data.all_w)
                dp, dn, dz = g.dynamic_link_prediction(data, data.va, zs)
                dt, df = model.score_fn(dp, dn, dz)
                dscores = get_score(dt, df)

                #dp,dn,dz = g.dynamic_new_link_prediction(data, data.va, zs)
                #dt, df = model.score_fn(dp,dn,dz)
                #dnscores = get_score(dt, df)

                print('[%d] Loss: %0.4f \t%0.4fs  \n\tPred %s '
                      %  #\n\tNew %s\n' %
                      (e, trloss, elapsed, fmt_score(dscores)
                       )  #, fmt_score(dnscores) )
                      )

                avg = sum(dscores)

            if avg > best[0]:
                best = (avg, deepcopy(model))
                no_improvement = 0
            else:
                no_improvement += 1
                if no_improvement > PATIENCE:
                    print("Early stopping...\n")
                    break

    print("Avg. TPE: %0.4fs" % (sum(times) / len(times)))
    model = best[1]
    _, h0 = model(data.x, data.eis, data.all, ew_fn=data.all_w, include_h=True)

    return model, h0
コード例 #6
0
def train_cyber(data, model, dynamic, single_prior=False, epochs=1500):
    # Leave all params as default for now
    SKIP = data.te_starts
    opt = Adam(model.parameters(), lr=LR)

    best = (0, None)
    no_improvement = 0
    for e in range(epochs):
        model.train()
        opt.zero_grad()

        # Get embeddings
        zs = model(data.x, data.eis[:SKIP], data.tr)

        # Generate positive and negative samples from this and the next time step
        if dynamic:
            p, n, z = g.dynamic_link_prediction(data,
                                                data.tr,
                                                zs,
                                                include_tr=False,
                                                end=SKIP)
        else:
            p, n, z = g.link_prediction(data,
                                        data.tr,
                                        zs,
                                        include_tr=False,
                                        end=SKIP)

        loss = model.loss_fn(p, n, z)
        loss.backward()
        opt.step()

        # Done by VGRNN to improve convergence
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        trloss = loss.item()

        with torch.no_grad():
            model.eval()
            zs = model(data.x, data.eis[:SKIP], data.tr)

            if not dynamic:
                p, n, z = g.link_prediction(data, data.va, zs, end=SKIP)
                sp, sf = model.score_fn(p, n, z)
                sscores = get_score(sp, sf)

                print('[%d] Loss: %0.4f  \n\tDet %s' %
                      (e, trloss, fmt_score(sscores)))

                avg = sum(sscores)

            else:
                dp, dn, dz = g.dynamic_link_prediction(data,
                                                       data.va,
                                                       zs,
                                                       end=SKIP)
                dt, df = model.score_fn(dp, dn, dz)
                dscores = get_score(dt, df)

                dp, dn, dz = g.dynamic_new_link_prediction(data,
                                                           data.va,
                                                           zs,
                                                           end=SKIP)
                dt, df = model.score_fn(dp, dn, dz)
                dnscores = get_score(dt, df)

                print('[%d] Loss: %0.4f  \n\tPred %s  \n\tNew %s' %
                      (e, trloss, fmt_score(dscores), fmt_score(dnscores)))

                avg = sum(dscores) + sum(dnscores)

            if avg > best[0]:
                best = (avg, deepcopy(model))
                no_improvement = 0
            else:
                no_improvement += 1
                if no_improvement > PATIENCE:
                    print("Early stopping...\n")
                    break

    model = best[1]
    zs = None
    with torch.no_grad():
        model.eval()
        zs, h = model(data.x, data.eis, data.all,
                      include_h=True)[:data.te_starts]

        # Generate all future embeds using prior from last normal state
        if single_prior:
            zs = torch.cat([
                model(data.x, [data.eis[i]], data.all, h_0=h, start_idx=i)
                for i in range(SKIP - 1, data.T)
            ],
                           dim=0)
        else:
            zs = model(data.x, data.eis, data.all)[SKIP - 1:]

    zs = zs[:-1] if pred else zs[1:]

    # Scores all edges and matches them with name/timestamp
    edges = []
    data.node_map = pickle.load(open(ld.LANL_FOLDER + 'nmap.pkl', 'rb'))

    for i in range(zs.size(0)):
        idx = i + data.te_starts

        ei = data.eis[idx]
        scores = model.decode(ei[0], ei[1], zs[i])
        names = data.format_edgelist(idx)

        for i in range(len(names)):
            edges.append((scores[i].item(), names[i]))

    max_anom = (0, 0.0)
    edges.sort(key=lambda x: x[0])
    anoms = 0
    tot_anoms = 9  # Hardcoded for now
    with open('out.txt', 'w+') as f:
        for i in range(len(edges)):
            e = edges[i]
            f.write('%0.4f %s\n' % e)

            if 'ANOM' in e[1]:
                anoms += 1
                max_anom = (i, e[0])
                stats = tpr_fpr(i, anoms, len(edges), tot_anoms)
                print('[%d/%d] %0.4f %s  %s' %
                      (i, len(edges), e[0], e[1], stats))

    print('Maximum anomaly scored %d out of %d edges' %
          (max_anom[0], len(edges)))
コード例 #7
0
ファイル: serial_train.py プロジェクト: zazyzaya/TGCN
def train(model, data, epochs=1500, dynamic=False, nratio=10, lr=0.01):
    print(lr)
    end_tr = data.T - TEST_TS

    opt = Adam(model.parameters(), lr=lr)

    best = (0, None)
    no_improvement = 0
    for e in range(epochs):
        model.train()
        opt.zero_grad()
        zs = None

        # Get embedding
        zs = model(data.x, data.eis, data.tr)[:end_tr]

        if not dynamic or model.__class__ in uses_priors:
            p, n, z = g.link_prediction(data,
                                        data.tr,
                                        zs,
                                        include_tr=False,
                                        nratio=nratio)

        else:
            p, n, z = g.dynamic_link_prediction(data,
                                                data.tr,
                                                zs,
                                                include_tr=False,
                                                nratio=nratio)

        loss = model.loss_fn(p, n, z)
        loss.backward()
        opt.step()

        # Done by VGRNN to improve convergence
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)

        trloss = loss.item()
        with torch.no_grad():
            model.eval()
            zs = model(data.x, data.eis, data.tr)[:end_tr]

            if not dynamic:
                p, n, z = g.link_prediction(data, data.va, zs)
                st, sf = model.score_fn(p, n, z)
                sscores = get_score(st, sf)

                print('[%d] Loss: %0.4f  \n\tSt %s ' %
                      (e, trloss, fmt_score(sscores)))

                avg = sscores[0] + sscores[1]

            else:
                # VGRNN is providing priors, which are built from the previous timestep
                # already, thus there is no need to shift the selected ei's as the
                # dynamic functions do
                if model.__class__ in uses_priors:
                    zs = zs[1:]
                    dp, dn, dz = g.link_prediction(data, data.va, zs)
                else:
                    dp, dn, dz = g.dynamic_link_prediction(data, data.va, zs)

                dt, df = model.score_fn(dp, dn, dz)
                dscores = get_score(dt, df)

                dp, dn, dz = g.dynamic_new_link_prediction(data, data.va, zs)
                if model.__class__ in uses_priors:
                    dz = zs  # Again, we don't need to shift the VGRNN embeds backward

                dt, df = model.score_fn(dp, dn, dz)
                dnscores = get_score(dt, df)

                print('[%d] Loss: %0.4f  \n\tDet %s  \n\tNew %s' %
                      (e, trloss, fmt_score(dscores), fmt_score(dnscores)))

                avg = (dscores[0] + dscores[1])

            if e == KL_ANNEALING:
                model.kld_weight = KL_WEIGHT

            if avg > best[0]:
                best = (avg, deepcopy(model))
                no_improvement = 0
            else:
                # Though it's not reflected in the code, the authors for VGRNN imply in the
                # supplimental material that after 500 epochs, early stopping may kick in
                if e > 100:
                    no_improvement += 1
                if no_improvement == PATIENCE:
                    print("Early stopping...\n")
                    break

    model = best[1]
    with torch.no_grad():
        model.eval()
        zs = model(data.x, data.eis, data.tr)[end_tr - 1:]

        if not dynamic:
            zs = zs[1:]
            p, n, z = g.link_prediction(data, data.te, zs, start=end_tr)
            t, f = model.score_fn(p, n, z)
            sscores = get_score(t, f)

            print('''
                Final scores: 
                    Static LP:  %s
                ''' % fmt_score(sscores))

            return {'auc': sscores[0], 'ap': sscores[1]}

        else:
            if model.__class__ in uses_priors:
                zs = zs[1:]
                p, n, z = g.link_prediction(data, data.te, zs, start=end_tr)
            else:
                p, n, z = g.dynamic_link_prediction(data,
                                                    data.te,
                                                    zs,
                                                    start=end_tr - 1)
                print(len(p))
                print(z.size(0))

            t, f = model.score_fn(p, n, z)
            dscores = get_score(t, f)

            p, n, z = g.dynamic_new_link_prediction(data,
                                                    data.te,
                                                    zs,
                                                    start=end_tr - 1)
            if model.__class__ in uses_priors:
                z = zs

            print(z.size(0))

            t, f = model.score_fn(p, n, z)
            nscores = get_score(t, f)

            print('''
                Final scores: 
                    Dynamic LP:     %s 
                    Dynamic New LP: %s 
                ''' % (fmt_score(dscores), fmt_score(nscores)))

            return {
                'pred-auc': dscores[0],
                'pred-ap': dscores[1],
                'new-auc': nscores[0],
                'new-ap': nscores[1],
            }