def validate(train_loader, epoch, margin):

    losses = AverageMeter()
    acces = AverageMeter()
    ce_loss_vals = []
    ce_acc_vals = []
    scores = []
    for i, (inputs, target) in enumerate(train_loader):
        # (inputs, target) = next(iter(train_loader))
        emb = inputs.to(device)
        scores.append(emb)
    return scores
Пример #2
0
def validate(train_loader, visae, attae, epoch):
    visae.eval()
    attae.eval()
    train_losses = AverageMeter()
    train_acces = AverageMeter()
    for i, (emb, target) in enumerate(train_loader):
        emb = emb.to(device)
        with torch.no_grad():
            seen_class_embeddings = attae(
                seen_s2v_labels.to(device).float())[2]
            vis_emb = torch.log(1 + emb)

        vis_emb.retain_grad()
        vis_emb.requires_grad_(True)

        vis_hidden, vis_out, vis_trans_out = visae(vis_emb)
        vis_recons_loss = contractive_loss(vis_out,
                                           vis_emb,
                                           vis_hidden,
                                           criterion2,
                                           device,
                                           gamma=0.0001)

        att_emb = get_text_data(target, s2v_labels.float()).to(device)
        att_hidden, att_out, att_trans_out = attae(att_emb)
        att_recons_loss = criterion2(att_out, att_emb)

        # mmd loss
        loss_mmd = MMDLoss(vis_hidden, att_hidden).to(device)

        # supervised binary prediction loss
        pred_loss = multi_class_hinge_loss(vis_trans_out, att_trans_out,
                                           target, seen_class_embeddings)

        loss = pred_loss + vis_recons_loss + att_recons_loss + loss_mmd

        acc, _ = accuracy(seen_class_embeddings, vis_trans_out.detach(),
                          target.to(device), seen_inds)
        train_losses.update(loss.item(), emb.size(0))
        train_acces.update(acc, emb.size(0))
        if i % 20 == 0:
            print('Epoch-{:<3d} {:3d} batches\t'
                  'loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'accu {acc.val:.3f} ({acc.avg:.3f})\t'.format(
                      epoch,
                      len(train_loader),
                      loss=train_losses,
                      acc=train_acces))
            print('Vis Reconstruction Loss {:.4f}\t'
                  'Att Reconstruction Loss {:.4f}\t'
                  'MMD Loss {:.4f}\t'
                  'Supervised Binary Prediction Loss {:.4f}'.format(
                      vis_recons_loss.item(), att_recons_loss.item(),
                      loss_mmd.item(), pred_loss.item()))

    return train_losses.avg, train_acces.avg
Пример #3
0
def zsl_validate(val_loader, visae, attae, epoch):
    with torch.no_grad():
        losses = AverageMeter()
        acces = AverageMeter()
        ce_loss_vals = []
        ce_acc_vals = []
        visae.eval()
        attae.eval()
        class_embeddings = attae(unseen_s2v_labels.to(device).float())[2]
        for i, (emb, target) in enumerate(val_loader):
            emb = emb.to(device)
            vis_emb = torch.log(1 + emb)

            vis_hidden, vis_out, vis_trans_out = visae(vis_emb)
            vis_recons_loss = criterion2(vis_out, vis_emb)

            att_emb = get_text_data(target, s2v_labels.float()).to(device)
            att_hidden, att_out, att_trans_out = attae(att_emb)
            att_recons_loss = criterion2(att_out, att_emb)

            # mmd loss
            loss_mmd = MMDLoss(vis_hidden, att_hidden).to(device)

            # supervised binary prediction loss
            pred_loss = multi_class_hinge_loss(vis_trans_out, att_trans_out,
                                               target, class_embeddings)

            loss = pred_loss + vis_recons_loss + att_recons_loss + loss_mmd
            acc, _ = accuracy(class_embeddings, vis_trans_out,
                              target.to(device), unseen_inds)
            losses.update(loss.item(), emb.size(0))
            acces.update(acc, emb.size(0))

            ce_loss_vals.append(loss.cpu().detach().numpy())
            ce_acc_vals.append(acc)
            if i % 20 == 0:
                print('ZSL Validation Epoch-{:<3d} {:3d}/{:3d} batches \t'
                      'loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'accu {acc.val:.3f} ({acc.avg:.3f})\t'.format(
                          epoch, i, len(val_loader), loss=losses, acc=acces))
                print('Vis Reconstruction Loss {:.4f}\t'
                      'Att Reconstruction Loss {:.4f}\t'
                      'MMD Loss {:.4f}\t'
                      'Supervised Binary Prediction Loss {:.4f}'.format(
                          vis_recons_loss.item(), att_recons_loss.item(),
                          loss_mmd.item(), pred_loss.item()))

    return losses.avg, acces.avg
def zsl_validate(val_loader, epoch, margin):

    with torch.no_grad():
        losses = AverageMeter()
        acces = AverageMeter()
        ce_loss_vals = []
        ce_acc_vals = []
        # trunk.eval()
        NounPosMmen.eval()
        VerbPosMmen.eval()
        JointMmen.eval()
        # PrpPosMmen.eval()
        tars = []
        preds = []
        scores = []
        noun_scores = []
        verb_scores = []
        for i, (inputs, target) in enumerate(val_loader):
            emb = inputs.to(device)
            verb_emb_target, noun_emb_target = get_text_data(
                target, verb_emb, noun_emb)
            verb_tar = []
            for p in verbs[target]:
                verb_tar.append(np.argwhere(verb_corp == p)[0][0])

            noun_tar = []
            for p in np.unique(nouns[target]):
                noun_tar.append(np.argwhere(noun_corp == p)[0][0])

            emb = emb / torch.norm(emb, dim=1).view([emb.size(0), 1]).repeat(
                [1, emb.shape[1]])
            verb_embeddings = VerbPosMmen.TextArch(
                unseen_verb_emb.to(device).float())
            noun_embeddings = NounPosMmen.TextArch(
                unseen_noun_emb.to(device).float())
            noun_embeddings = noun_embeddings / torch.norm(
                noun_embeddings, dim=1).view([
                    noun_embeddings.size(0), 1
                ]).repeat([1, noun_embeddings.shape[1]])
            verb_embeddings = verb_embeddings / torch.norm(
                verb_embeddings, dim=1).view([
                    verb_embeddings.size(0), 1
                ]).repeat([1, verb_embeddings.shape[1]])
            joint_text_embedding = torch.cat(
                [verb_embeddings, noun_embeddings], axis=1)
            fin_text_embedding = JointMmen.TextArch(joint_text_embedding)
            fin_text_embedding = fin_text_embedding / torch.norm(
                fin_text_embedding, dim=1).view([
                    fin_text_embedding.size(0), 1
                ]).repeat([1, fin_text_embedding.shape[1]])

            vis_verb_transform, verb_transform = VerbPosMmen(
                emb,
                verb_emb_target.to(device).float())
            vis_verb_transform = vis_verb_transform / torch.norm(
                vis_verb_transform, dim=1).view([
                    vis_verb_transform.size(0), 1
                ]).repeat([1, vis_verb_transform.shape[1]])
            verb_transform = verb_transform / torch.norm(
                verb_transform, dim=1).view([
                    verb_transform.size(0), 1
                ]).repeat([1, verb_transform.shape[1]])
            verb_vis_loss = multi_class_hinge_loss(vis_verb_transform,
                                                   verb_transform, target,
                                                   verb_embeddings,
                                                   margin).to(device)
            verb_verb_loss = triplet_loss(vis_verb_transform,
                                          torch.tensor(verb_tar), device,
                                          margin).to(device)

            vis_noun_transform, noun_transform = NounPosMmen(
                emb,
                noun_emb_target.to(device).float())
            vis_noun_transform = vis_noun_transform / torch.norm(
                vis_noun_transform, dim=1).view([
                    vis_noun_transform.size(0), 1
                ]).repeat([1, vis_noun_transform.shape[1]])
            noun_transform = noun_transform / torch.norm(
                noun_transform, dim=1).view([
                    noun_transform.size(0), 1
                ]).repeat([1, noun_transform.shape[1]])
            noun_vis_loss = multi_class_hinge_loss(vis_noun_transform,
                                                   noun_transform, target,
                                                   noun_embeddings,
                                                   margin).to(device)
            noun_noun_loss = triplet_loss(vis_noun_transform,
                                          torch.tensor(noun_tar), device,
                                          margin).to(device)

            joint_vis = torch.cat([vis_verb_transform, vis_noun_transform],
                                  axis=1)
            joint_text = torch.cat([verb_transform, noun_transform], axis=1)
            fin_vis, fin_text = JointMmen(joint_vis, joint_text)
            fin_text = fin_text / torch.norm(fin_text, dim=1).view(
                [fin_text.size(0), 1]).repeat([1, fin_text.shape[1]])
            fin_vis = fin_vis / torch.norm(fin_vis, dim=1).view(
                [fin_vis.size(0), 1]).repeat([1, fin_vis.shape[1]])
            joint_loss = multi_class_hinge_loss(fin_vis, fin_text, target,
                                                fin_text_embedding,
                                                margin).to(device)

            loss = verb_vis_loss + noun_vis_loss
            loss += 0.01 * (verb_verb_loss) + 0.01 * (noun_noun_loss)
            loss += joint_loss

            # ce acc
            ce_acc, score = accuracy(fin_text_embedding, fin_vis,
                                     target.to(device), unseen_inds)
            losses.update(loss.item(), inputs.size(0))
            acces.update(ce_acc, inputs.size(0))
            tars.append(target)
            scores.append(score)
            ce_loss_vals.append(loss.cpu().detach().numpy())
            ce_acc_vals.append(ce_acc)

            if i % 20 == 0:
                print('Epoch-{:<3d} {:3d}/{:3d} batches \t'
                      'loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'accu {acc.val:.3f} ({acc.avg:.3f})\t'.format(
                          epoch,
                          i + 1,
                          len(val_loader),
                          loss=losses,
                          acc=acces))
        return losses.avg, acces.avg, scores
Пример #5
0
def train_one_cycle(cycle_num, cycle_length):
    s_epoch = (cycle_num)*(cycle_length)
    e_epoch = (cycle_num+1)*(cycle_length)
    if cycle_length == 1700:
        cr_fact_epoch = 1400
    else:
        cr_fact_epoch = 1500
        
    for epoch in range(s_epoch, e_epoch):
        losses = AverageMeter()
        ce_loss_vals = []
        sequence_encoder.train()
        sequence_decoder.train()    
        text_encoder.train()
        text_decoder.train()
        k_trip = 0
        k_fact = max((0.1*(epoch - (s_epoch+1000))/3000), 0)
        k_fact2 = k_fact*(epoch > (s_epoch + cr_fact_epoch))
        cr_fact = 1*(epoch > (s_epoch + cr_fact_epoch))
        lw_fact = 0
        (inputs, target) = next(iter(train_loader))
        s = inputs.to(device)
        t = get_text_data(target, labels_emb).to(device)
        smu, slv = sequence_encoder(s)
        sz = reparameterize(smu, slv)
        sout = sequence_decoder(sz)

        tmu, tlv = text_encoder(t)
        tz = reparameterize(tmu, tlv)
        tout = text_decoder(tz)

        # cross reconstruction
        tfroms = text_decoder(sz)
        sfromt = sequence_decoder(tz)

        s_recons = criterion1(s, sout)
        t_recons = criterion1(t, tout)
        s_kld = KL_divergence(smu, slv).to(device) 
        t_kld = KL_divergence(tmu, tlv).to(device)
        s_crecons = criterion1(s, sfromt)
        t_crecons = criterion1(t, tfroms)
        l_wass = Wasserstein_distance(smu, slv, tmu, tlv)


        loss = s_recons + t_recons 
        loss -= k_fact*(s_kld)
        loss -= k_fact2*(t_kld)
        loss += cr_fact*(s_crecons + t_crecons)
        loss += lw_fact*(l_wass)

        # backward

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.update(loss.item(), inputs.size(0))
        ce_loss_vals.append(loss.cpu().detach().numpy())
        if epoch % 1 == 0:
            print('Epoch-{:<3d} \t'
                'loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                epoch, loss=losses))
            print('srecons {:.4f}\ttrecons {:.4f}\t'.format(s_recons.item(), t_recons.item()))
            print('skld {:.4f}\ttkld {:.4f}\t'.format(s_kld.item(), t_kld.item()))
            print('screcons {:.4f}\ttcrecons {:.4f}\t'.format(s_crecons.item(), t_crecons.item()))        
            print('lwass {:.4f}\t'.format(l_wass.item()))

    return   
Пример #6
0
def train_one_cycle(cycle_num, cycle_length):
    
    s_epoch = (cycle_num)*(cycle_length)
    e_epoch = (cycle_num+1)*(cycle_length)
    if cycle_length == 1700:
        cr_fact_epoch = 1400
    else:
        cr_fact_epoch = 1500
        
    for epoch in range(s_epoch, e_epoch):
        losses = AverageMeter()
        ce_loss_vals = []

        # verb models
        sequence_encoder.train()
        sequence_decoder.train()    
        v_text_encoder.train()
        v_text_decoder.train()

        # hyper params
        k_fact = max((0.1*(epoch- (s_epoch+1000))/3000), 0)
        cr_fact = 1*(epoch > (s_epoch + cr_fact_epoch))
        v_k_fact2 = max((0.1*(epoch - (s_epoch + cr_fact_epoch))/3000), 0)*(cycle_num>1)
        n_k_fact2 = max((0.1*(epoch - ((s_epoch + cr_fact_epoch)))/3000), 0)*(cycle_num>1)
        v_cr_fact = 1*(epoch > (s_epoch + cr_fact_epoch))
        n_cr_fact = 1*(epoch > (s_epoch + cr_fact_epoch))
        
        # noun models
        n_text_encoder.train()
        n_text_decoder.train()


        (inputs, target) = next(iter(train_loader))
        s = inputs.to(device)
        nt, vt = get_text_data(target)
        nt = nt.to(device)
        vt = vt.to(device)

        smu, slv = sequence_encoder(s)
        sz = reparameterize(smu, slv)
        sout = sequence_decoder(sz)

        # noun forward pass

        ntmu, ntlv = n_text_encoder(nt)
        ntz = reparameterize(ntmu, ntlv)
        ntout = n_text_decoder(ntz)

        ntfroms = n_text_decoder(sz[:,:latent_size//2])

        s_recons = criterion1(s, sout)
        nt_recons = criterion1(nt, ntout)
        s_kld = KL_divergence(smu, slv).to(device) 
        nt_kld = KL_divergence(ntmu, ntlv).to(device)
        nt_crecons = criterion1(nt, ntfroms)
        

        # verb forward pass
        vtmu, vtlv = v_text_encoder(vt)
        vtz = reparameterize(vtmu, vtlv)
        vtout = v_text_decoder(vtz)

        vtfroms = v_text_decoder(sz[:,latent_size//2:])
        vt_recons = criterion1(vt, vtout)
        vt_kld = KL_divergence(vtmu, vtlv).to(device)
        vt_crecons = criterion1(vt, vtfroms)
        
        sfromt = sequence_decoder(torch.cat([ntz, vtz], 1))
        s_crecons = criterion1(s, sfromt)

        loss = s_recons + vt_recons + nt_recons 
        loss -= k_fact*(s_kld) + v_k_fact2*(vt_kld) + n_k_fact2*(nt_kld)
        loss += n_cr_fact*(nt_crecons) + v_cr_fact*(vt_crecons) + cr_fact*(s_crecons)
        
        # backward

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.update(loss.item(), inputs.size(0))
        ce_loss_vals.append(loss.cpu().detach().numpy())
        if epoch % 100 == 0:
            print('---------------------')
            print('Epoch-{:<3d} \t'
                'loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                epoch, loss=losses))
            print('srecons {:.4f}\t ntrecons {:.4f}\t vtrecons {:.4f}\t'.format(s_recons.item(), nt_recons.item(), vt_recons.item()))
            print('skld {:.4f}\t ntkld {:.4f}\t vtkld {:.4f}\t'.format(s_kld.item(), nt_kld.item(), vt_kld.item()))
            print('screcons {:.4f}\t ntcrecons {:.4f}\t ntcrecons {:.4f}\t'.format(s_crecons.item(), nt_crecons.item(), vt_crecons.item()))        
#             print('nlwass {:.4f}\t vlwass {:.4f}\n'.format(nl_wass.item(), vl_wass.item()))

    return