def forward(self, data_shot, data_query):
        proto = self.encoder(data_shot)
        if self.args.hyperbolic:
            proto = self.e2p(proto)

            if self.training:
                proto = proto.reshape(self.args.shot, self.args.way, -1)
            else:
                proto = proto.reshape(self.args.shot, self.args.validation_way,
                                      -1)

            proto = poincare_mean(proto, dim=0, c=self.e2p.c)
            data_query = self.e2p(self.encoder(data_query))
            logits = -dist_matrix(data_query, proto,
                                  c=self.e2p.c) / self.args.temperature

        else:
            if self.training:
                proto = proto.reshape(self.args.shot, self.args.way,
                                      -1).mean(dim=0)
            else:
                proto = proto.reshape(self.args.shot, self.args.validation_way,
                                      -1).mean(dim=0)

            logits = euclidean_metric(self.encoder(data_query),
                                      proto) / self.args.temperature
        return logits
def fold(model, args, meta_support, true_labels):
    p = args.shot * args.train_way
    q = args.query * args.train_way
    for j in range(1, args.folds):
        next_proto = model(meta_support[p * i:p * (i + 1)])
        meta_logits = euclidean_metric(next_proto, meta_proto)
        soft_labels = (F.softmax(meta_logits, dim=1) +
                       args.lam * s_onehot) / (1 + lam)
        meta_proto = torch.mm(soft_labels.permute((1, 0)), next_proto)
def inter_fold(model, args, meta_support):
    p = args.shot * args.train_way
    q = args.query * args.train_way
    meta_protos = []
    features = model(meta_support)
    for i in range(1, args.folds):
        proto = torch.cat([features[0:i], features[i + 1:-1]])
        proto.reshape(args.shot - 1, args.train_way, -1).mean(dim=0)
        current_ex = features[i]
        meta_logits = euclidean_metric(current_ex, proto)
        soft_labels = (F.softmax(meta_logits, dim=1)
                       )  # + args.lam * true_labels) / (1 + args.lam)
        meta_proto = torch.mm(soft_labels.permute((1, 0)), current_ex)
        meta_protos.append(meta_proto)
    return torch.tensor(meta_protos)
Example #4
0
    def forward(self, data) :

        shot = self.args.shot
        if self.training :
            ways = self.args.train_way
        else :
            ways = self.args.test_way

        nk = shot * ways
        x, data_query = data[:nk], data[nk:]

        x = self.base_model(x)
        if self.reshaper is not None :
            proto = self.reshaper(x)
        else :
            proto = x
        proto = proto.view(shot, ways, -1).mean(0)

        if self.args.base_model.startswith('resnet') :
            c, w, h = 64, 6, 6
        else:
            c, w, h = 64, 5, 5
        concentrated = self.concentrator(x)
        concentrated = concentrated.view(shot, ways, c, w, h).mean(0)
        stacked = concentrated.view(-1, w, h).unsqueeze(0)
        mask =  self.projector(stacked)
        mask = F.softmax(mask, dim = 1)
        mask = mask.view(1, -1)
        proto = torch.mul(proto, mask)


        # query = self.reshaper(self.base_model(data_query))
        query = self.base_model(data_query)
        if self.reshaper is not None :
            query = self.reshaper(query)
        query = query.view(query.size(0), -1)
        query = torch.mul(query, mask)

        # print(mask.shape, proto.shape, query.shape)

        logits = euclidean_metric(query, proto)

        return logits
Example #5
0
    def forward(self, data):
        shot = self.args.shot
        if self.training :
            ways = self.args.train_way
        else :
            ways = self.args.test_way

        nk = shot * ways
        x, data_query = data[:nk], data[nk:]

        x = self.base_model(x)
        if self.reshaper is not None :
            x = self.reshaper(x)
        proto = x.view(shot, ways, -1).mean(0)

        query = self.base_model(data_query)
        if self.reshaper is not None :
            query = self.reshaper(query)
        query = query.view(query.size(0), -1)

        logits = euclidean_metric(query, proto)

        return logits
Example #6
0
def main(args):
    device = torch.device(args.device)
    ensure_path(args.save_path)

    data = Data(args.dataset, args.n_batches, args.train_way, args.test_way, args.shot, args.query)
    train_loader = data.train_loader
    val_loader = data.valid_loader

    model = Convnet(x_dim=2).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    def save_model(name):
        torch.save(model.state_dict(), osp.join(args.save_path, name + '.pth'))
    
    trlog = dict(
        args=vars(args),
        train_loss=[],
        val_loss=[],
        train_acc=[],
        val_acc=[],
        max_acc=0.0,
    )

    timer = Timer()

    for epoch in range(1, args.max_epoch + 1):
        lr_scheduler.step()

        model.train()

        tl = Averager()
        ta = Averager()

        for i, batch in enumerate(train_loader, 1):
            data, _ = [_.to(device) for _ in batch]
            data = data.reshape(-1, 2, 105, 105)
            p = args.shot * args.train_way
            embedded = model(data)
            embedded_shot, embedded_query = embedded[:p], embedded[p:]

            proto = embedded_shot.reshape(args.shot, args.train_way, -1).mean(dim=0)

            label = torch.arange(args.train_way).repeat(args.query).to(device)

            logits = euclidean_metric(embedded_query, proto)
            loss = F.cross_entropy(logits, label)
            acc = count_acc(logits, label)
            print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'
                  .format(epoch, i, len(train_loader), loss.item(), acc))

            tl.add(loss.item())
            ta.add(acc)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        tl = tl.item()
        ta = ta.item()

        model.eval()

        vl = Averager()
        va = Averager()

        for i, batch in enumerate(val_loader, 1):
            data, _ = [_.cuda() for _ in batch]
            data = data.reshape(-1, 2, 105, 105)
            p = args.shot * args.test_way
            data_shot, data_query = data[:p], data[p:]

            proto = model(data_shot)
            proto = proto.reshape(args.shot, args.test_way, -1).mean(dim=0)

            label = torch.arange(args.test_way).repeat(args.query).to(device)

            logits = euclidean_metric(model(data_query), proto)
            loss = F.cross_entropy(logits, label)
            acc = count_acc(logits, label)

            vl.add(loss.item())
            va.add(acc)

        vl = vl.item()
        va = va.item()
        print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va))

        if va > trlog['max_acc']:
            trlog['max_acc'] = va
            save_model('max-acc')

        trlog['train_loss'].append(tl)
        trlog['train_acc'].append(ta)
        trlog['val_loss'].append(vl)
        trlog['val_acc'].append(va)

        torch.save(trlog, osp.join(args.save_path, 'trlog'))

        save_model('epoch-last')

        if epoch % args.save_epoch == 0:
            save_model('epoch-{}'.format(epoch))

        print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch)))
            p = args.shot * args.train_way
            data_spt, data_query = data[:p], data[p:]
            img_mask_spt, img_mask_qry = img_mask[:p], img_mask[p:]
            img_mask_mo_spt, img_mask_mo_qry = img_mask_mo[:p], img_mask_mo[p:]

            #query set include img and img_mask
            # merge_qry = torch.cat([data_query,img_mask_qry],0)

            label = torch.arange(args.train_way).repeat(args.query)
            label = label.type(torch.cuda.LongTensor)
            #proto = img proto + img_mask proto
            proto = model(img_mask_spt)
            proto2 = model(img_mask_mo_spt)
            proto = (proto + proto2) / 2
            proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0)
            logits = euclidean_metric(model(img_mask_qry), proto)
            loss = F.cross_entropy(logits, label)
            acc = count_acc(logits, label)
            print('epoch {}, train image {}/{}, loss={:.4f} acc={:.4f}'.format(
                epoch, i, len(train_loader), loss.item(), acc))

            tl.add(loss.item())
            ta.add(acc)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            proto = None
            logits = None
            loss = None
Example #8
0
    base_net.train()
    for i, batch in enumerate(train_loader, 1):
        inputs_ori, inputs_aug = batch[0], batch[1]
        # [100, 3, 80, 80], [100, 3, 80, 80]
        inputs_aug = inputs_aug[:n_shot] # [25, 3, 80, 80]
        inputs = torch.cat([inputs_ori, inputs_aug], 0).cuda()
        # [125, 3, 80, 80]

        (fea_ori, fea_aug) = torch.split(base_net(inputs), n_epi, 0)
        # [100, 640], [25, 640]
        fea_shot, fea_query = fea_ori[:n_shot], fea_ori[n_shot:]
        # [25, 640], [75, 640]

        proto = fea_shot.reshape(args.shot, args.way, -1).mean(dim=0) # [5, 640]
        proto_aug = fea_aug.reshape(args.shot, args.way, -1).mean(dim=0) # [5, 640]
        logits = euclidean_metric(fea_query, proto) # [75, 5]
        logits_aug = euclidean_metric(fea_query, proto_aug) # [75, 5]

        # fsl loss
        fsl_loss = ce_loss(logits, label)
        fsl_acc = count_acc(logits, label)

        # align
        if args.lambda_align > 0:
            probs1 = F.softmax(logits.detach(), 1)
            probs2 = F.softmax(logits_aug.detach(), 1)
            log_probs1 = F.log_softmax(logits / args.T, 1)
            log_probs2 = F.log_softmax(logits_aug / args.T, 1)
            align_loss = args.T * (kl_loss(log_probs2, probs1) + kl_loss(log_probs1, probs2))
            total_loss = fsl_loss + args.lambda_align * align_loss
        else:
Example #9
0
    ave_acc = Averager()

    for i, batch in enumerate(loader, 1):
        data,img_mask,img_mask_mo, _ = [_.cuda() for _ in batch]
        k = args.way * args.shot
        data_spt= data[:k]
        data_query = img_mask_mo[k:]

        data_spt, data_query = data[:k], data[k:]
        img_mask_spt, img_mask_qry = img_mask[:k], img_mask[k:]
        img_mask_mo_spt, img_mask_mo_qry = img_mask_mo[:k],img_mask_mo[k:]
        
        
        proto = model(img_mask_spt)
        proto2 = model(img_mask_mo_spt)
        proto = (proto+proto2)/2
        proto = proto.reshape(args.shot, args.way, -1).mean(dim=0)
        p = proto

        logits = euclidean_metric(model(img_mask_mo_qry), p)

        label = torch.arange(args.way).repeat(args.query)
        label = label.type(torch.cuda.LongTensor)

        acc = count_acc(logits, label)
        ave_acc.add(acc)
        print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100))
        
        proto = None; p = None; logits = None
    
Example #10
0
def do_train_pass(train_batches_shot,
                  train_batches_query,
                  train_batches_labels,
                  shot,
                  way,
                  query,
                  expressions,
                  train,
                  test,
                  id_to_token=None,
                  id_to_tag=None,
                  tag_to_id=None,
                  test_cls=None):
    model, optimizer = expressions
    llog, alog = Averager(), Averager()

    for i, (batch_shot, batch_query, batch_labels) in enumerate(
            zip(train_batches_shot, train_batches_query, train_batches_labels),
            1):

        flog = Aggregate_F()
        flog_old = Aggregate_F()
        data_token_shot = [x for _, _, x, _, _, _, _ in batch_shot]
        data_sentence_shot = [
            sent[sent_id] for sent, _, _, _, _, sent_id, _ in batch_shot
        ]
        data_sentence_labels_shot = [
            label[sent_id] for _, label, _, _, _, sent_id, _ in batch_shot
        ]
        data_sentence_bert_shot = [
            bert_emb[sent_id]
            for _, _, _, bert_emb, _, sent_id, _ in batch_shot
        ]

        data_sentence_labels_shot = [[
            int(batch_labels[token]) for token in sent
        ] for sent in data_sentence_labels_shot]

        #print(data_sentence_labels_shot)

        (data_sentence_shot, data_sentence_labels_shot,
         data_sentence_bert_shot,
         sentence_shot_lens) = pad_query_sentences(data_sentence_shot,
                                                   data_sentence_labels_shot,
                                                   data_sentence_bert_shot,
                                                   MAX_SENT_LEN,
                                                   PAD_CLS=way + 1)

        #data_sentence_labels_shot = [[int(batch_labels[token]) for token in sent] for sent in data_sentence_labels_shot]
        #print(data_sentence_labels_shot)
        #exit()

        proto = model(data_sentence_shot,
                      data_token_shot,
                      data_sentence_bert_shot,
                      sentence_shot_lens,
                      shot=True)

        sorted_batch_labels = sorted(batch_labels.items(),
                                     key=lambda kv: (kv[1], kv[0]))
        ###print(sorted_batch_labels)

        #start with the zero!!
        zero_indices = np.argwhere(data_sentence_labels_shot == 0)
        ###print(zero_indices)
        ###print(zero_indices.size())
        old_proto = proto[zero_indices[0], zero_indices[1]]
        ###print(proto.size())
        ###print(old_proto.size())
        new_proto = model.return_attn()(old_proto)
        #print(model.return_attn())
        #print(new_proto.size())
        weights = F.softmax(new_proto, dim=0)
        #print(weights)
        #print(weights.size())
        ###print(weights.size())
        new_proto = (weights * old_proto).sum(dim=0, keepdim=True)
        #print(new_proto)
        ###print(new_proto.size())
        # exit()
        #print(sorted_batch_labels)

        #Aggregate all the values of the same label and then take mean!!
        for (key, val) in sorted_batch_labels:
            if val != 0:
                val_indices = np.argwhere(data_sentence_labels_shot == val)
                #print(val)
                #print(val_indices.size())
                new_proto = torch.cat([
                    new_proto, proto[val_indices[0], val_indices[1]].mean(
                        dim=0, keepdim=True)
                ],
                                      dim=0)

        ###print(new_proto.size())
        ###exit()

        ###proto = proto.reshape(shot, way-1, -1).mean(dim=0)

        ###dim_size = proto.size()[1]

        ###proto = torch.cat([torch.zeros(1, dim_size).to(device), proto])

        data_token_query = [x for _, _, x, _, _, _, _ in batch_query]
        data_sentence_query = [
            sent[sent_id] for sent, _, _, _, _, sent_id, _ in batch_query
        ]
        data_sentence_labels_query = [
            label[sent_id] for _, label, _, _, _, sent_id, _ in batch_query
        ]
        data_sentence_bert_query = [
            bert_emb[sent_id]
            for _, _, _, bert_emb, _, sent_id, _ in batch_query
        ]
        '''zero_indices = np.argwhere(np.array(batch_labels) == 0)        
        nonzero_indices_part = np.argwhere(np.array(batch_labels) > 0)
        nonzero_indices = []

        for i in range(int(len(zero_indices)/ float(query))):
            nonzero_indices += [ind[0] for ind in nonzero_indices_part]'''

        #zero_indices = np.argwhere(np.array(batch_labels) == 0)
        #batch_labels = [batch_labels[ind] for ind in nonzero_indices] + [batch_labels[ind[0]] for ind in zero_indices]
        #data_token_query = [data_token_query[ind] for ind in nonzero_indices] + [data_token_query[ind[0]] for ind in zero_indices]
        #data_sentence_query = [data_sentence_query[ind] for ind in nonzero_indices] + [data_sentence_query[ind[0]] for ind in zero_indices]
        ##batch_labels = [label-1 for label in batch_labels]
        '''count6 = 0
        count7 = 0
        for sent_label in data_sentence_labels_query:
            for token in sent_label:
                if token == 6 :
                    count6+=1
                if token == 7:
                    count7+=1
        print("Count of 6\t"+str(count6)+"\tCount of 7\t"+str(count7))'''

        data_sentence_labels_query = [[
            int(batch_labels[token]) for token in sent
        ] for sent in data_sentence_labels_query]
        #print(batch_labels)

        #print(np.argwhere(data_sentence_labels_query == np.array(batch_labels).any()))

        #labels = torch.LongTensor(np.array(batch_labels)).to(device)

        (data_sentence_query, labels, data_sentence_bert_query,
         sentence_query_lens) = pad_query_sentences(data_sentence_query,
                                                    data_sentence_labels_query,
                                                    data_sentence_bert_query,
                                                    MAX_SENT_LEN,
                                                    PAD_CLS=way + 1)

        query_matrix = model(data_sentence_query, data_token_query,
                             data_sentence_bert_query, sentence_query_lens)

        for vec, index, sentence in zip(query_matrix, data_token_query,
                                        data_sentence_query):
            if vec.sum() == 0.:
                print(index[0])
                for token in sentence:
                    token = token.to('cpu').item()
                    if token == "__PAD__":
                        continue
                    print(token)
                    #print(id_to_token)
                    print(id_to_token[int(token)])
                print("Finally-------------------------")
                print(id_to_token[sentence[index[0]].to('cpu').item()])
                print("---")

        logits = euclidean_metric(query_matrix, new_proto)
        #print("Training_logits\t")
        #print(logits)
        #print(logits.size())

        logits[:, :, 0] = model.return_0class()
        ###print(logits)

        softmax_scores = F.softmax(logits, dim=2)
        #print(softmax_scores)

        #labels = torch.LongTensor(np.array(data_sentence_labels_query)).to(device)
        logits_t = logits.transpose(2, 1)
        #print(logits_t.size())
        #print(labels.size())
        #exit()

        loss_function = torch.nn.CrossEntropyLoss(ignore_index=way + 1)
        loss = loss_function(logits_t, labels)
        #loss = F.cross_entropy(logits_t, labels)
        llog.add(loss.item())

        correct, total_preds, total_gold, confidence = count_F(
            softmax_scores,
            labels,
            batch_labels.values(),
            train=True,
            PAD_CLS=way + 1,
            id_to_tag=id_to_tag)

        flog.add(correct, total_preds, total_gold)
        item1, item2, item3 = flog.item()
        print("Correct")
        print(item1)
        print("Predicted")
        print(item2)
        print("Gold")
        print(item3)

        f_score1 = flog.f_score()
        print(f_score1)

        print("Token-level accuracy----")
        correct, total_preds, total_gold, confidence = count_F_old(
            softmax_scores,
            labels,
            batch_labels.values(),
            train=True,
            PAD_CLS=way + 1)

        flog_old.add(correct, total_preds, total_gold)
        item1, item2, item3 = flog_old.item()
        print("Correct")
        print(item1)
        print("Predicted")
        print(item2)
        print("Gold")
        print(item3)

        f_score1_old = flog_old.f_score()
        print(f_score1_old)

        #exit()

        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return llog, flog
Example #11
0
                                            -1).mean(dim=0)
            proto = model(data_shot)
            proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0)
            #meta_logits = euclidean_metric(proto, meta_proto)
            #lam = .01
            lam = .5
            lam_labs = 1
            #soft_labels = ((F.sigmoid(meta_logits)) + lam*s_onehot) / (1 + lam)
            #soft_labels = F.softmax(meta_logits, dim=1) #* lam + s_onehot) / (1 + lam)
            #soft_labels = soft_labels / soft_labels.sum(dim=0)
            #proto = torch.mm(soft_labels.permute((1, 0)), proto)

            # proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0)
            #label = torch.arange(args.train_way).repeat(args.query)
            #label = label.type(torch.cuda.LongTensor)
            logits = euclidean_metric(model(udata), proto)
            logits2 = euclidean_metric(model(udata), meta_proto)
            #loss = F.binary_cross_entropy_with_logits(logits, q_onehot)
            #loss = lam_labs*F.cross_entropy(euclidean_metric(model(data_query), proto), label) + lam_labs*F.cross_entropy(logits2, label) + lam * F.kl_div(F.log_softmax(logits, dim=1), F.softmax(logits2, dim=1))
            loss = F.kl_div(F.log_softmax(logits, dim=1),
                            F.softmax(logits2, dim=1))
            acc = count_acc(logits, label)
            print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'.format(
                epoch, i, len(ss_loader), loss.item(), acc))
            tl.add(loss.item())
            ta.add(acc)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            proto = None
            logits = None
    path = 'videos/trump.mp4'
    cap = cv2.VideoCapture(path)
    cv2.namedWindow('demo_video_haar')

    while (cap.isOpened()):

        ret, frame = cap.read()
        frame = cv2.resize(frame, (640, 360))

        target = None
        target = realtime.real_time_detect_haar(frame)

        if (target != None):
            data_query = data[k:-1]
            data_query = torch.cat([data_query, target])
            logits = euclidean_metric(model(data_query), p)  # [20, 20]

            max_dis_idx = torch.argmax(logits, dim=1)[-1]
            max_dis = logits[-1][max_dis_idx]

            sum_dis = logits[-1].sum()
            other_avr_dis = (sum_dis - logits[-1][max_dis_idx]) / (args.way -
                                                                   1)
            # print('max distance: {}, other distance average: {}'.format(
            # max_dis, other_avr_dis))
            if (max_dis - other_avr_dis > 20):
                # if(not isUnknown):
                text = samplers.ground_truth[max_dis_idx]
            else:
                text = 'Unknown'
            cv2.putText(frame, text, (10, 340), cv2.FONT_HERSHEY_SIMPLEX, 1,
Example #13
0
    def log(out_str):
        print(out_str)
        logfile.write(out_str+'\n')
        logfile.flush()

    model_cnn.eval()
    for epoch in range(1, args.max_epoch + 1):

        for i, batch in enumerate(val_loader, 1):
            data, lab = [_.cuda() for _ in batch]

            data_shot = data[:, 3:, :]
            proto = model_cnn(data_shot)
            global_set=torch.cat([global_base[0],global_novel[0]])
            logits = euclidean_metric(proto, global_set)
            loss = F.cross_entropy(logits, lab)
            acc = count_acc(logits, lab)

            vl.add(loss.item())
            va.add(acc)
            proto = None; logits = None; loss = None

        vl = vl.item()
        va = va.item()
        log('both epoch {}, val, loss={:.4f} acc={:.4f}'.format(i, vl, va))

        vl = Averager()
        va = Averager()

        for i, batch in enumerate(val_loader2, 1):
        lr_scheduler.step()

        model.train()

        tl = Averager()
        ta = Averager()
        lam = .01
        for i, batch in enumerate(train_loader, 1):
            data, _ = [_.cuda() for _ in batch]
            p = args.shot * args.train_way
            q = args.query * args.train_way
            support, data_query = data[:p], data[p:]
            features = model(support)
            proto = features.reshape(args.shot, args.train_way, -1).mean(dim=0)

            meta_logits = euclidean_metric(features, proto)
            soft_labels = (F.sigmoid(meta_logits, dim=1) +
                           lam * s_onehot) / (1 + lam)
            #soft_labels = soft_labels / soft_labels.sum(dim=0)
            meta_proto = torch.mm(soft_labels.permute((1, 0)), proto)
            proto.retain_grad()

            # soft_labels = (F.softmax(meta_logits, dim=1))  # * lam + s_onehot) / (1 + lam)
            # soft_labels = soft_labels / soft_labels.sum(dim=0)
            # proto = torch.mm(soft_labels.permute((1, 0)), proto)
            # proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0)

            label = torch.arange(args.train_way).repeat(args.query)
            label = label.type(torch.cuda.LongTensor)
            logits = euclidean_metric(model(data_query), meta_proto)
            loss = F.cross_entropy(logits, label)
Example #15
0
            test_shot, test_query = test_data[:p], test_data[p:]
            for i, batch in enumerate(train_loader, 1):
                data, _ = [_.cuda() for _ in batch]
                p = args.shot * args.train_way
                train_shot, train_query = data[:p], data[p:]
                proto = model(train_shot, vars=None, bn_training=True)
                unsup_logits = classifier(proto)

                soft_labels = F.softmax(unsup_logits, dim=1)
                soft_labels = soft_labels / soft_labels.sum(dim=0)
                proto = torch.mm(soft_labels.permute((1, 0)), proto)

                label = torch.arange(args.train_way).repeat(args.query)
                label = label.type(torch.cuda.LongTensor)

                logits = euclidean_metric(model(train_query), proto)
                loss = F.cross_entropy(logits, label)
                grad = torch.autograd.grad(loss, model.parameters())

                fast_weights = list(map(lambda p: p[1] - update_lr * p[0], zip(grad, model.parameters())))

            acc = count_acc(logits, label)
            print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'
                  .format(epoch, i, len(train_loader), loss.item(), acc))

            tl.add(loss.item())
            ta.add(acc)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
Example #16
0
def do_pass(batches,
            counters,
            shot,
            way,
            query,
            expressions,
            train,
            test,
            id_to_token=None,
            id_to_tag=None,
            test_cls=None):
    model, optimizer = expressions
    llog, alog = Averager(), Averager()

    if test:
        output_file = open("./output.txt" + str(test_cls), 'w')

    for i, (batch, counter) in enumerate(zip(batches, counters), 1):
        #print("Batch number\t"+str(i))
        data_token = [x for _, x, _, _ in batch]
        data_sentence = [sent for sent, _, _, _ in batch]
        data_label = [label for _, _, label, _ in batch]
        p = shot * way
        #print(len(data_token))
        #print(p)
        #print(shot)
        #print(way)
        data_token_shot, data_token_query = data_token[:p], data_token[p:]
        data_sentence_shot, data_sentence_query = data_sentence[:
                                                                p], data_sentence[
                                                                    p:]
        counter_token, counter_query = counter[:p], counter[p:]

        (data_sentence_shot,
         sentence_shot_lens), (data_sentence_query,
                               query_shot_lens) = pad_sentences(
                                   data_sentence_shot,
                                   MAX_SENT_LEN), pad_sentences(
                                       data_sentence_query, MAX_SENT_LEN)

        proto = model(data_sentence_shot, data_token_shot, sentence_shot_lens)
        proto = proto.reshape(shot, way, -1).mean(dim=0)

        ####label = torch.arange(way).repeat(query)
        if not train:
            #print(len(data_token))
            #print(p)
            #print(way)
            query = int((len(data_token) - p) / way)
            #print(query)
            #exit()

        label = torch.arange(way).repeat(query)
        label = label.type(torch.LongTensor).to(device)

        logits = euclidean_metric(
            model(data_sentence_query, data_token_query, query_shot_lens),
            proto)

        #print(list(model.parameters()))
        #print(model.return_0class())

        #print(logits.size())
        logits[:, 0] = model.return_0class()
        #print(logits.size())
        #print(label.size())
        #print(len(counter_query))
        #print(counter_query)
        #print("---")

        loss = F.cross_entropy(logits, label)
        acc = count_acc(logits, label, counter_query)

        llog.add(loss.item())
        alog.add(acc)

        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if test:
            #print the outputs to a file
            save_dev_output(output_file, logits, label, data_label,
                            data_sentence_query, data_token_query,
                            query_shot_lens, id_to_token, id_to_tag)

    if test:
        output_file.close()
    return llog, alog
Example #17
0
            data, udata = batch  #[_.cuda() for _ in batch]
            data = data.cuda()
            p = args.shot * args.train_way
            m = args.meta_size * args.train_way
            q = args.query * args.train_way
            data_shot, data_shot2, data_query = data[:p], data[p:2 *
                                                               p], data[2 *
                                                                        p:2 *
                                                                        p + q]
            proto = model(data_shot)
            proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0)
            proto2 = model(data_shot2)
            proto2 = proto2.reshape(args.shot, args.train_way, -1).mean(dim=0)
            label = torch.arange(args.train_way).repeat(args.query)
            label = label.type(torch.cuda.LongTensor)
            logits = euclidean_metric(model(data_query), proto)
            #loss = F.binary_cross_entropy_with_logits(logits, q_onehot)

            #load unsupervised
            q = args.uquery * args.train_way
            udata_query = udata
            udata_query = udata_query.cuda()
            ulogits1 = euclidean_metric(model(udata_query), proto)
            ulogits2 = euclidean_metric(model(udata_query), proto2)
            agree = torch.argmax(ulogits1, dim=1) == torch.argmax(ulogits2,
                                                                  dim=1)
            #loss_const = F.kl_div(F.log_softmax(ulogits1, dim=1), F.softmax(ulogits2, dim=1))
            feats = model(data_query)
            loss = F.cross_entropy(
                euclidean_metric(feats, proto), label
            ) + F.cross_entropy(
Example #18
0
 fea_all = torch.cat((features,features_aug1,features_aug2,features_aug3),dim = 0)
 fea_all2 = torch.cat((features,features_aug2,features_aug3,features_aug1),dim = 0)
 fea_all = fea_all.transpose(0,1)
 fea_all2 = fea_all2.transpose(0,1)
 fea_all = attn_net(fea_all,fea_all,fea_all)
 fea_all2 = attn_net(fea_all2,fea_all2,fea_all2)
 fea = fea_all.reshape(n_all,-1)
 fea2 = fea_all2.reshape(n_all,-1)
 if args.mix == 1:
     fea_shot, fea_query = fea[:n_shot], fea2[n_shot:]
 else:
     fea_shot, fea_query = fea[:n_shot], fea[n_shot:]
 # fea_shot: [25, 640]
 # fea_query: [75, 640]
 proto = fea_shot.reshape(args.shot, args.way, -1).mean(dim = 0) # [5, 640]
 logits = euclidean_metric(fea_query, proto)/args.temperature #[75, 5]
 fsl_loss = ce_loss(logits,label_fsl_s)
 acc = count_acc(logits, label_fsl_s)
 #con_loss
 con_loss = 0
 if args.lambda_con > 0:
     similarity_f = nn.CosineSimilarity()
     if args.proj == 1:
         fea = proj_net(fea)
     fea_shot, fea_query = fea[:n_shot], fea[n_shot:]
     fea_query2 = fea2[n_shot:]
     proto = fea_shot.reshape(args.shot, args.way, -1).mean(dim = 0) # [5, 640]
     ind = torch.arange(args.query)
     for index in range(args.way):
         p = proto[index].unsqueeze(0).repeat(args.way*args.query,1)#[75, 640]
         s = similarity_f(p,fea_query)/args.T#[75]
Example #19
0
                weight_arr = weight_arr / np.sum(weight_arr)
                proto_novel_f = (torch.from_numpy(weight_arr.reshape(
                    -1, 1, 1)).type(torch.float).cuda() *
                                 proto_novel_f).sum(dim=0)
                proto_base = proto_base.mean(dim=0)
                proto_final = torch.cat([proto_base, proto_novel_f], 0)
            else:
                proto_final = proto.reshape(args.shot, args.train_way,
                                            -1).mean(dim=0)

            label = torch.arange(args.train_way).repeat(args.query)
            label = label.type(torch.cuda.LongTensor)
            global_new, proto_new = model_reg(support_set=torch.cat(
                [global_base[0], global_novel[0]]),
                                              query_set=proto_final)
            logits2 = euclidean_metric(proto_new, global_new)
            loss2 = F.cross_entropy(logits2, train_gt)

            similarity = F.softmax(logits2)
            feature = torch.matmul(
                similarity, torch.cat([global_base[0], global_novel[0]]))
            logits = euclidean_metric(model_cnn(data_query), feature)
            loss1 = F.cross_entropy(logits, label)

            acc1 = count_acc(logits, label)
            acc2 = count_acc(similarity, train_gt)

            tl1.add(loss1.item())
            tl2.add(loss2.item())
            ta1.add(acc1)
            ta2.add(acc2)
Example #20
0
def do_test_pass(test_batches_shot,
                 test_batches_query,
                 test_batches_labels,
                 shot,
                 way,
                 query,
                 expressions,
                 train,
                 test,
                 id_to_token=None,
                 id_to_tag=None,
                 test_cls=None):
    model, optimizer = expressions
    llog, alog = Averager(), Averager()

    if test:
        output_file = open("./output.txt" + str(test_cls), 'w')

    for i, (batch_shot, batch_query, batch_labels) in enumerate(
            zip(test_batches_shot, test_batches_query, test_batches_labels),
            1):

        #print(batch_labels.keys())

        #batch_query = batch_shot

        data_token_shot = [x for _, _, x, _, _, _, _ in batch_shot]
        data_sentence_shot = [
            sent[sent_ind] for sent, _, _, _, _, sent_ind, _ in batch_shot
        ]
        data_sentence_labels_shot = [
            label[sent_ind] for _, label, _, _, _, sent_ind, _ in batch_shot
        ]
        data_sentence_bert_shot = [
            bert_emb[sent_ind]
            for _, _, _, bert_emb, _, sent_ind, _ in batch_shot
        ]

        #(data_sentence_shot, sentence_shot_lens)= pad_sentences(data_sentence_shot, MAX_SENT_LEN)

        data_sentence_labels_shot = [[
            int(batch_labels[token]) for token in sent
        ] for sent in data_sentence_labels_shot]

        (data_sentence_shot, data_sentence_labels_shot,
         data_sentence_bert_shot,
         sentence_shot_lens) = pad_query_sentences(data_sentence_shot,
                                                   data_sentence_labels_shot,
                                                   data_sentence_bert_shot,
                                                   MAX_SENT_LEN,
                                                   PAD_CLS=way + 1)

        sorted_batch_labels = sorted(batch_labels.items(),
                                     key=lambda kv: (kv[1], kv[0]))

        zero_indices = np.argwhere(data_sentence_labels_shot == 0)

        proto = model(data_sentence_shot,
                      data_token_shot,
                      data_sentence_bert_shot,
                      sentence_shot_lens,
                      shot=True)

        old_proto = proto[zero_indices[0], zero_indices[1]]

        new_proto = model.return_attn()(old_proto)
        weights = F.softmax(new_proto, dim=0)
        new_proto = (weights * old_proto).sum(dim=0, keepdim=True)

        for (key, val) in sorted_batch_labels:
            if val != 0:
                val_indices = np.argwhere(data_sentence_labels_shot == val)
                #print(val)
                #print(val_indices.size())
                new_proto = torch.cat([
                    new_proto, proto[val_indices[0], val_indices[1]].mean(
                        dim=0, keepdim=True)
                ],
                                      dim=0)

        #proto = model(data_sentence_shot, data_token_shot, sentence_shot_lens, shot=True)
        #proto = proto.reshape(shot, way-1, -1).mean(dim=0)

        #dim_size = proto.size()[1]

        #proto = torch.cat([torch.zeros(1, dim_size).to(device), proto])

        #batch_query = batch_shot

        data_token_query = [x for _, _, x, _, _, _, _ in batch_query]
        data_sentence_query = [
            sent[ind] for sent, _, _, _, _, ind, _ in batch_query
        ]
        data_sentence_labels_query = [
            label[ind] for _, label, _, _, _, ind, _ in batch_query
        ]
        data_sentence_bert_query = [
            bert_emb[ind] for _, _, _, bert_emb, _, ind, _ in batch_query
        ]

        data_sentence_labels_query = [[
            int(batch_labels[token]) for token in sent
        ] for sent in data_sentence_labels_query]

        #print(len(data_sentence_query))
        #print(len(data_sentence_labels_query))

        #batch_index = 0
        block_size = 1000

        flog = Aggregate_F()
        flog_old = Aggregate_F()
        #confidence = []

        for batch_index in np.arange(0, len(data_sentence_labels_query),
                                     block_size):
            #print(batch_index)
            #if(batch_index > 0):
            #    continue
            ##exit()
            mini_data_token_query = data_token_query[batch_index:batch_index +
                                                     block_size]
            mini_data_sentence_query = data_sentence_query[
                batch_index:batch_index + block_size]
            mini_data_sentences_labels_query = data_sentence_labels_query[
                batch_index:batch_index + block_size]
            mini_data_sentence_bert_query = data_sentence_bert_query[
                batch_index:batch_index + block_size]

            #print(mini_data_token_query.size())
            #print(mini_sentence_query.size())
            #print(mini_data_sentences)

            (mini_data_sentence_query, labels, mini_data_sentence_bert_query,
             mini_sentence_query_lens) = pad_query_sentences(
                 mini_data_sentence_query,
                 mini_data_sentences_labels_query,
                 mini_data_sentence_bert_query,
                 MAX_SENT_LEN,
                 PAD_CLS=way + 1)
            '''class_indices = np.argwhere(np.array(batch_labels) > 0)
            new_batch_labels = np.array(batch_labels)
            new_batch_labels[class_indices] = 1
                                                
        
            old_mini_labels = batch_labels[batch_index: batch_index+block_size]'''

            #print(len(labels))
            #print(old_mini_labels)
            '''class_indices = np.argwhere(np.array(old_mini_labels) > 0)
            mini_labels = np.array(old_mini_labels)
            mini_labels[class_indices] = 1'''
            #print(mini_labels)
            #print(set(batch_labels))
            #print("--")
            #mini_labels = torch.LongTensor(labels).to(device)
            #print(labels)
            mini_labels = labels

            logits = euclidean_metric(
                model(mini_data_sentence_query, mini_data_token_query,
                      mini_data_sentence_bert_query, mini_sentence_query_lens),
                new_proto)
            ###print("Test_logits\t")
            ###print(logits)

            logits[:, :, 0] = model.return_0class()
            ###print(logits)

            softmax_scores = F.softmax(logits, dim=2)
            #print(softmax_scores)

            ##loss = F.cross_entropy(logits, labels)
            ##acc = count_acc(logits, mini_labels)
            ##llog.add(loss.item())
            ##alog.add(acc)
            if test:
                #print the outputs to a file
                save_dev_span_output(output_file, softmax_scores,
                                     np.array(mini_labels), batch_labels,
                                     mini_data_sentence_query,
                                     mini_data_token_query,
                                     mini_sentence_query_lens, id_to_token,
                                     id_to_tag, test_cls)
            correct, total_preds, total_gold, mini_confidence = count_F(
                softmax_scores,
                mini_labels,
                batch_labels.values(),
                train=False,
                id_to_tag=id_to_tag,
                PAD_CLS=way + 1)

            #confidence += mini_confidence

            flog.add(correct, total_preds, total_gold)
            item1, item2, item3 = flog.item()
            #print(item1)
            #print(item2)
            #print(item3)

            correct, total_preds, total_gold, mini_confidence = count_F_old(
                softmax_scores,
                mini_labels,
                batch_labels.values(),
                train=False,
                PAD_CLS=way + 1)

            #confidence += mini_confidence

            flog_old.add(correct, total_preds, total_gold)
            #item1, item2, item3 = flog.item()

        f_score1 = flog.f_score()
        print(f_score1)
        item1, item2, item3 = flog.item()
        print("Correct")
        print(item1)
        print("Predicted")
        print(item2)
        print("Gold")
        print(item3)

        print("Token level accuracy-----")

        f_score1 = flog_old.f_score()
        print(f_score1)
        item1, item2, item3 = flog_old.item()
        print("Correct")
        print(item1)
        print("Predicted")
        print(item2)
        print("Gold")
        print(item3)

        #confidence.sort(key=lambda pair: pair[0])
        #confidence.sort(key=lambda pair: pair[1])
        #print(confidence)

        #exit()
    if test:
        output_file.close()
    return flog_old, flog
    s_onehot = s_onehot.scatter_(1, s_label.unsqueeze(dim=1), 1).cuda()

    for i, batch in enumerate(loader, 1):
        data, _ = [_.cuda() for _ in batch]
        k = args.way * args.shot
        data_shot, meta_support, data_query = data[:k], data[k:2*k], data[2*k:]

        #p = inter_fold(model, args, data_shot)

        x = model(data_shot)
        x = x.reshape(args.shot, args.way, -1).mean(dim=0)
        p = x

        lam = 0.01
        proto = model(meta_support)
        meta_logits = euclidean_metric(proto, p)
        soft_labels = (F.sigmoid(meta_logits, dim=1) + lam * s_onehot) / (1 + lam)
        #soft_labels_norm2 = soft_labels / soft_labels.sum(dim=0)
        proto = torch.mm(soft_labels.permute((1, 0)), proto)

        logits = euclidean_metric(model(data_query), proto)

        label = torch.arange(args.way).repeat(args.query)
        label = label.type(torch.cuda.LongTensor)

        acc = count_acc(logits, label)
        ave_acc.add(acc)
        print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100))

        x = None;
        p = None;
Example #22
0
    print('No implementation!')
    exit()
saved_models = torch.load(args.load)
base_net.load_state_dict(saved_models['base_net'])
base_net.eval()

n_shot = args.way * args.shot  # 25
label = torch.arange(args.way).repeat(args.query)  #75
label = label.type(torch.cuda.LongTensor)
test_accuracies = []
with torch.no_grad():
    for i, batch in enumerate(test_loader, 1):
        inputs = batch[0].cuda()  # [100, 3, 80, 80]
        features = base_net(inputs)  # [100, 640]

        fea_shot, fea_query = features[:n_shot], features[n_shot:]
        # [25, 640], [75, 640]
        proto = fea_shot.reshape(args.shot, args.way,
                                 -1).mean(dim=0)  # [5, 640]
        logits = euclidean_metric(fea_query, proto)
        acc = count_acc(logits, label)
        test_accuracies.append(acc)

        if i % 50 == 0:
            avg = np.mean(np.array(test_accuracies))
            std = np.std(np.array(test_accuracies))
            ci95 = 1.96 * std / np.sqrt(i + 1)
            log_str = 'batch {}: Accuracy: {:.4f} +- {:.4f} % ({:.4f} %)'.format(
                i, avg, ci95, acc)
            log(log_file_path, log_str)