Exemplo n.º 1
0
def validate(train_loader, visae, attae, epoch):
    visae.eval()
    attae.eval()
    train_losses = AverageMeter()
    train_acces = AverageMeter()
    for i, (emb, target) in enumerate(train_loader):
        emb = emb.to(device)
        with torch.no_grad():
            seen_class_embeddings = attae(
                seen_s2v_labels.to(device).float())[2]
            vis_emb = torch.log(1 + emb)

        vis_emb.retain_grad()
        vis_emb.requires_grad_(True)

        vis_hidden, vis_out, vis_trans_out = visae(vis_emb)
        vis_recons_loss = contractive_loss(vis_out,
                                           vis_emb,
                                           vis_hidden,
                                           criterion2,
                                           device,
                                           gamma=0.0001)

        att_emb = get_text_data(target, s2v_labels.float()).to(device)
        att_hidden, att_out, att_trans_out = attae(att_emb)
        att_recons_loss = criterion2(att_out, att_emb)

        # mmd loss
        loss_mmd = MMDLoss(vis_hidden, att_hidden).to(device)

        # supervised binary prediction loss
        pred_loss = multi_class_hinge_loss(vis_trans_out, att_trans_out,
                                           target, seen_class_embeddings)

        loss = pred_loss + vis_recons_loss + att_recons_loss + loss_mmd

        acc, _ = accuracy(seen_class_embeddings, vis_trans_out.detach(),
                          target.to(device), seen_inds)
        train_losses.update(loss.item(), emb.size(0))
        train_acces.update(acc, emb.size(0))
        if i % 20 == 0:
            print('Epoch-{:<3d} {:3d} batches\t'
                  'loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'accu {acc.val:.3f} ({acc.avg:.3f})\t'.format(
                      epoch,
                      len(train_loader),
                      loss=train_losses,
                      acc=train_acces))
            print('Vis Reconstruction Loss {:.4f}\t'
                  'Att Reconstruction Loss {:.4f}\t'
                  'MMD Loss {:.4f}\t'
                  'Supervised Binary Prediction Loss {:.4f}'.format(
                      vis_recons_loss.item(), att_recons_loss.item(),
                      loss_mmd.item(), pred_loss.item()))

    return train_losses.avg, train_acces.avg
Exemplo n.º 2
0
def zsl_validate(val_loader, visae, attae, epoch):
    with torch.no_grad():
        losses = AverageMeter()
        acces = AverageMeter()
        ce_loss_vals = []
        ce_acc_vals = []
        visae.eval()
        attae.eval()
        class_embeddings = attae(unseen_s2v_labels.to(device).float())[2]
        for i, (emb, target) in enumerate(val_loader):
            emb = emb.to(device)
            vis_emb = torch.log(1 + emb)

            vis_hidden, vis_out, vis_trans_out = visae(vis_emb)
            vis_recons_loss = criterion2(vis_out, vis_emb)

            att_emb = get_text_data(target, s2v_labels.float()).to(device)
            att_hidden, att_out, att_trans_out = attae(att_emb)
            att_recons_loss = criterion2(att_out, att_emb)

            # mmd loss
            loss_mmd = MMDLoss(vis_hidden, att_hidden).to(device)

            # supervised binary prediction loss
            pred_loss = multi_class_hinge_loss(vis_trans_out, att_trans_out,
                                               target, class_embeddings)

            loss = pred_loss + vis_recons_loss + att_recons_loss + loss_mmd
            acc, _ = accuracy(class_embeddings, vis_trans_out,
                              target.to(device), unseen_inds)
            losses.update(loss.item(), emb.size(0))
            acces.update(acc, emb.size(0))

            ce_loss_vals.append(loss.cpu().detach().numpy())
            ce_acc_vals.append(acc)
            if i % 20 == 0:
                print('ZSL Validation Epoch-{:<3d} {:3d}/{:3d} batches \t'
                      'loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'accu {acc.val:.3f} ({acc.avg:.3f})\t'.format(
                          epoch, i, len(val_loader), loss=losses, acc=acces))
                print('Vis Reconstruction Loss {:.4f}\t'
                      'Att Reconstruction Loss {:.4f}\t'
                      'MMD Loss {:.4f}\t'
                      'Supervised Binary Prediction Loss {:.4f}'.format(
                          vis_recons_loss.item(), att_recons_loss.item(),
                          loss_mmd.item(), pred_loss.item()))

    return losses.avg, acces.avg