Ejemplo n.º 1
0
def test(model, testloader, use_gpu):
    accs = AverageMeter()
    test_accuracies = []
    model.eval()
    with torch.no_grad():
        for batch_idx, (images_train, labels_train, images_test, labels_test) in enumerate(testloader):
            if use_gpu:
                images_train = images_train.cuda()
                images_test = images_test.cuda()
            end = time.time()
            batch_size, num_train_examples, channels, height, width = images_train.size()
            num_test_examples = images_test.size(1)
            labels_train_1hot = one_hot(labels_train).cuda()
            labels_test_1hot = one_hot(labels_test).cuda()
            cls_scores = model(images_train, images_test,
                               labels_train_1hot, labels_test_1hot)
            cls_scores = cls_scores.view(batch_size * num_test_examples, -1)
            labels_test = labels_test.view(batch_size * num_test_examples)
            _, preds = torch.max(cls_scores.detach().cpu(), 1)
            acc = (torch.sum(preds == labels_test.detach().cpu()
                             ).float()) / labels_test.size(0)
            accs.update(acc.item(), labels_test.size(0))
            gt = (preds == labels_test.detach().cpu()).float()
            gt = gt.view(batch_size, num_test_examples).numpy()  # [b, n]
            acc = np.sum(gt, 1) / num_test_examples
            acc = np.reshape(acc, (batch_size))
            test_accuracies.append(acc)
    accuracy = accs.avg
    test_accuracies = np.array(test_accuracies)
    test_accuracies = np.reshape(test_accuracies, -1)
    stds = np.std(test_accuracies, 0)
    ci95 = 1.96 * stds / np.sqrt(args.epoch_size)
    print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy, ci95))
    return accuracy
def test_vis(model, testloader, use_gpu):
    accs = AverageMeter()
    test_accuracies = []
    model.eval()

    with torch.no_grad():
        for batch_idx , (images_train, labels_train, images_test, labels_test) in enumerate(testloader):
            if use_gpu:
                images_train = images_train.cuda()
                images_test = images_test.cuda()
            #print(images_test.shape, 'located in train.py at 177' )
            #print(images_test.shape[0],'located in train.py at 177')
            #exit(0)
            std=np.expand_dims(np.array([0.229, 0.224, 0.225]),axis=1)
            std=np.expand_dims(std,axis=2) 
            mean=np.expand_dims(np.array([0.485, 0.456, 0.406]),axis=1)             
            mean=np.expand_dims(mean,axis=2)  
            #print(std.shape,mean.shape)
            #exit(0)
            #for i in range(images_test.shape[0]):
                #for j in range(images_test.shape[1]):
                    #images_temp=images_test[i,j,:,:].cpu().numpy()
                    #print(images_temp.shape)
                    
                    #images_temp=images_temp*std+mean
                    #images_ori=images_temp.transpose((1,2,0))
                    #print(images_ori.shape)
                    #print(images_ori.max(0).max(0).max(0),images_ori.min(0).min(0).min(0))
                    #exit(0)
                    #images_ori=np.uint8(images_ori*255)
                    #cv2.imwrite('./result/vis_images/images_ori.jpg',images_ori)
                    #exit(0)
            end = time.time()

            batch_size, num_train_examples, channels, height, width = images_train.size()
            num_test_examples = images_test.size(1)

            labels_train_1hot = one_hot(labels_train).cuda()
            labels_test_1hot = one_hot(labels_test).cuda()

            cls_scores,a1,a2 = model(images_train, images_test, labels_train_1hot, labels_test_1hot,True)
            #a1.
            #print(a1.shape,a2.shape,'located in train.py at 209',(a1-1).max(),(a1-1).min())#[4,5,75,6,6]
            #print(type(a1.max(3)))
            #exit(0)
            max_a1=a1.max(3)[0].max(3)[0].unsqueeze(3).unsqueeze(3)
            min_a1=a1.min(3)[0].min(3)[0].unsqueeze(3).unsqueeze(3)
            max_a2=a2.max(3)[0].max(3)[0].unsqueeze(3).unsqueeze(3)
            min_a2=a2.min(3)[0].min(3)[0].unsqueeze(3).unsqueeze(3)
            #print(min_a1.shape,min_a1[0,0,0],max_a1[0,0,0])
            #exit(0)
            #print(std.shape,mean.shape)
            #exit(0)
            scale_a1=torch.div((a1-min_a1),(max_a1-min_a1))
            scale_a2=torch.div((a2-min_a2),(max_a2-min_a2)) 
            #print(images_train.shape[1],images_test.shape[1],'located in train.py at 224')
            #exit(0)
            #print(scale_a1[0,0,1],scale_a2[0,0,1])
            #exit(0)
            result_surpport_imgs=np.zeros((84*5+8*4,84*4+8*3,3)).astype(dtype=np.uint8)
            #print(labels_test[0])
            #exit(0)
            #result_test_imgs=np.zeros((84+3)*20,(84+3)*75,3)
        
            cls_scores = cls_scores.view(batch_size * num_test_examples, -1)
            labels_test = labels_test.view(batch_size * num_test_examples)

            _, preds = torch.max(cls_scores.detach().cpu(), 1)
            #print(labels_test.numpy()[:75])            
            #print(preds.numpy()[:75])
            #exit(0)
            acc = (torch.sum(preds == labels_test.detach().cpu()).float()) / labels_test.size(0)
            accs.update(acc.item(), labels_test.size(0))
            #print(images_train.shape,images_test.shape)
            #print(scale_a1.shape)
            #exit(0)
            if only_test:
                for i in range(images_test.shape[0]):
                    for k in range(images_test.shape[1]):
                        for j in range(images_train.shape[1]):
                            images_temp_test=images_test[i,k,:,:].cpu().numpy()
                            images_temp_train=images_train[i,j,:,:].cpu().numpy()                        
                            #print(images_temp.shape)
                            index_support=labels_train[i,j]
                            index_test= labels_test[i*num_test_examples+k]                           
                            #print(label_gt,label_pred)
                            #exit(0)
                            images_temp_test=images_temp_test*std+mean
                            images_ori_test=images_temp_test.transpose((1,2,0))[:,:,::-1]
                        
                            images_temp_train=images_temp_train*std+mean
                            images_ori_train=images_temp_train.transpose((1,2,0))[:,:,::-1]                    
                            #print(images_ori.shape)
                            #print(images_ori.max(0).max(0).max(0),images_ori.min(0).min(0).min(0))
                            #exit(0)
                            hot_a1=cv2.resize(np.uint8(scale_a1[i,index_support,k].cpu().numpy()*255),(84,84))
                            hot_a2=cv2.resize(np.uint8(scale_a2[i,index_support,k].cpu().numpy()*255),(84,84))
                            heatmap_a1 = cv2.applyColorMap(hot_a1, cv2.COLORMAP_JET)
                            heatmap_a2 = cv2.applyColorMap(hot_a2, cv2.COLORMAP_JET)                        
                            #print(heatmap_a1.shape)
                        
                            #exit(0)
                            images_ori_test=np.uint8(images_ori_test*255)
                            images_ori_train=np.uint8(images_ori_train*255)
                            vis_test=images_ori_test*0.7+heatmap_a2*0.3
                            #hot_a1=scale_a1[i,k,j]
                            #hot_a2=scale_a2[i,k,j]  
                            vis_train=images_ori_train*0.7+heatmap_a1*0.3                        
                            #cv2.imwrite('./result/vis_images/images_ori_test.jpg',images_ori_test)
                            #cv2.imwrite('./result/vis_images/images_test.jpg',vis_test)
                            #cv2.imwrite('./result/vis_images/images_ori_train.jpg',images_ori_train)
                            #cv2.imwrite('./result/vis_images/images_train.jpg',vis_train)  
                            result_surpport_imgs[84*index_support+8*index_support:84*(index_support+1)+8*index_support,:84,:]=images_ori_test
                            result_surpport_imgs[84*index_support+8*index_support:84*(index_support+1)+8*index_support,84+8:84+84+8,:]=images_ori_train  
                            result_surpport_imgs[84*index_support+8*index_support:84*(index_support+1)+8*index_support,84*2+8*2:84*3+8*2,:]=vis_test 
                            result_surpport_imgs[84*index_support+8*index_support:84*(index_support+1)+8*index_support,84*3+8*3:84*4+8*3,:]=vis_train
                        label_gt=int(labels_test.numpy()[k])
                        label_pred=int(preds.numpy()[k])    
                        cv2.imwrite('./result/vis_images/vis'+'_'+str(batch_idx)+'_'+str(i)+'_'+str(k)+'_'+str(label_gt)+'_'+str(label_pred)+'.jpg',result_surpport_imgs)                            
                #exit(0)
            if not True:                                
                if batch_idx>12:
                    break
            gt = (preds == labels_test.detach().cpu()).float()
            gt = gt.view(batch_size, num_test_examples).numpy() #[b, n]
            acc = np.sum(gt, 1) / num_test_examples
            acc = np.reshape(acc, (batch_size))
            test_accuracies.append(acc)
    #exit(0)
    accuracy = accs.avg
    test_accuracies = np.array(test_accuracies)
    test_accuracies = np.reshape(test_accuracies, -1)
    stds = np.std(test_accuracies, 0)
    ci95 = 1.96 * stds / np.sqrt(args.epoch_size)
    print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy, ci95))
    #exit(0)
    return accuracy
def test_ori_5(model,best_path, testloader, use_gpu,topK=28):
    accs = AverageMeter()
    test_accuracies = []
    final_accs = AverageMeter()
    final_test_accuracies = [] 
    params = torch.load(best_path)
    model.load_state_dict(params['state_dict'], strict=True)            
    model.eval()

    with torch.no_grad():
        #for batch_idx , (images_train, labels_train,Xt_img_ori,Xt_img_gray, images_test, labels_test) in enumerate(testloader):
        for batch_idx , (images_train, images_train2,images_train3,images_train4,images_train5,labels_train, images_test, labels_test) in enumerate(testloader):   
            shape_test=images_train.shape[0]
            images_train1=images_train.reshape(shape_test,-1,1,3,84,84)
            images_train2=images_train2.reshape(shape_test,-1,1,3,84,84)
            images_train3=images_train3.reshape(shape_test,-1,1,3,84,84)
            images_train4=images_train4.reshape(shape_test,-1,1,3,84,84) 
            images_train5=images_train5.reshape(shape_test,-1,1,3,84,84) 

            labels_train_5 = labels_train.reshape(shape_test,-1,1)#[:,:,0]

            labels_train_5 = labels_train_5.repeat(1,1,5) 
            labels_train = labels_train_5.reshape(shape_test,-1)             
            images_train_5=torch.cat((images_train1, images_train2,images_train3,images_train4,images_train5), 2)   
            images_train=images_train_5.reshape(shape_test,-1,3,84,84)           
            if use_gpu:
                images_train = images_train.cuda()
                #images_train_5 = images_train_5.cuda()
                images_test = images_test.cuda()

            end = time.time()
            #print(images_train.shape,labels_train.shape)
            #exit()
            batch_size, num_train_examples, channels, height, width = images_train.size()
            num_test_examples = images_test.size(1)

            labels_train_1hot = one_hot(labels_train).cuda()
            labels_test_1hot = one_hot(labels_test).cuda()
            #print(images_train.shape,images_test.shape)
            cls_scores ,cls_scores_final= model(images_train, images_test, labels_train_1hot, labels_test_1hot,topK)
            #print(cls_scores.shape,cls_scores_final.shape)
            #exit(0)
            cls_scores = cls_scores.view(batch_size * num_test_examples, -1)
            cls_scores_final = cls_scores_final.view(batch_size * num_test_examples, -1)            
            labels_test = labels_test.view(batch_size * num_test_examples)

            _, preds = torch.max(cls_scores.detach().cpu(), 1)
            _, preds_final = torch.max(cls_scores_final.detach().cpu(), 1)            
            acc = (torch.sum(preds == labels_test.detach().cpu()).float()) / labels_test.size(0)
            accs.update(acc.item(), labels_test.size(0))

            acc_final = (torch.sum(preds_final == labels_test.detach().cpu()).float()) / labels_test.size(0)
            final_accs.update(acc_final.item(), labels_test.size(0))
            
            gt = (preds == labels_test.detach().cpu()).float()
            gt = gt.view(batch_size, num_test_examples).numpy() #[b, n]
            
            gt_final = (preds_final == labels_test.detach().cpu()).float()
            gt_final = gt_final.view(batch_size, num_test_examples).numpy() #[b, n]
            
            acc = np.sum(gt, 1) / num_test_examples
            acc = np.reshape(acc, (batch_size))
            test_accuracies.append(acc)

            acc_final = np.sum(gt_final, 1) / num_test_examples
            acc_final = np.reshape(acc_final, (batch_size))
            final_test_accuracies.append(acc_final)
            
    accuracy = accs.avg
    test_accuracies = np.array(test_accuracies)
    test_accuracies = np.reshape(test_accuracies, -1)
    stds = np.std(test_accuracies, 0)
    ci95 = 1.96 * stds / np.sqrt(args.epoch_size)
    
    accuracy_final = final_accs.avg
    test_accuracies_final = np.array(final_test_accuracies)
    test_accuracies_final = np.reshape(test_accuracies_final, -1)
    stds_final = np.std(test_accuracies_final, 0)
    ci95_final = 1.96 * stds_final / np.sqrt(args.epoch_size)    
    print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy, ci95))
    print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy_final, ci95_final))
    return accuracy        
def train(epoch, model, criterion,loss_fn, optimizer, trainloader, learning_rate, use_gpu):
    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    std=np.expand_dims(np.array([0.229, 0.224, 0.225]),axis=1)
    std=np.expand_dims(std,axis=2) 
    mean=np.expand_dims(np.array([0.485, 0.456, 0.406]),axis=1)             
    mean=np.expand_dims(mean,axis=2) 
    model.train()
    #model_edge.eval()
    #model_tradclass.eval()
    end = time.time()
    #print('llllllllllllll','located in train_with_inpaint_final.py at 264')
    #exit(0)
    #for batch_idx, (images_train, labels_train,tpids,Xt_img_ori,Xt_img_gray,images_test, labels_test, pids) in enumerate(trainloader):
    for batch_idx, (images_train,images_train1,images_train2,images_train3,images_train4,images_train5,images_train6,images_train7,images_train8, labels_train,tpids, images_test,images_test1,images_test2,images_test3,images_test4, labels_test, pids) in enumerate(trainloader):    
        data_time.update(time.time() - end)
        #print(Xt_img_ori.shape,Xt_img_gray.shape,images_train.shape,'lll')
        edges=[]
        if only_CSEI:
            augment_k=4
        else:
            augment_k=8
        tpids_4 = tpids.reshape(4,-1,1)#[:,:,0]

        tpids_4 = tpids_4.repeat(1,1,augment_k).reshape(4,-1)  

        K_shot=images_train.shape[1]/5
        images_train1=images_train1.reshape(4,-1,1,3,84,84)
        images_train2=images_train2.reshape(4,-1,1,3,84,84)
        images_train3=images_train3.reshape(4,-1,1,3,84,84)
        images_train4=images_train4.reshape(4,-1,1,3,84,84) 
        images_train5=images_train5.reshape(4,-1,1,3,84,84)
        images_train6=images_train6.reshape(4,-1,1,3,84,84)
        images_train7=images_train7.reshape(4,-1,1,3,84,84)
        images_train8=images_train8.reshape(4,-1,1,3,84,84)         
        #print(images_test.shape)
        #exit(0)
        #images_test1=images_test1.reshape(4,30,1,3,84,84)
        #images_test2=images_test2.reshape(4,30,1,3,84,84)
        #images_test3=images_test3.reshape(4,30,1,3,84,84)
        #images_test4=images_test4.reshape(4,30,1,3,84,84)
       
        #if cuda  memory enough use follow code
        if only_CSEI:
            images_train_4=torch.cat((images_train1, images_train2,images_train3,images_train4), 2)
        else:
            images_train_4=torch.cat((images_train1, images_train2,images_train3,images_train4,images_train5,images_train6,images_train7,images_train8), 2)   
        #if cuda  memory not enough use follow this code
        #images_train_4=torch.cat((images_train1, images_train2,images_train3), 2)
        #images_train_fuse=   torch.cat((images_train.reshape(4,-1,1,3,84,84), images_train1, images_train2,images_train3), 2)     
        #images_test=images_test.reshape(4,30,1,3,84,84)        
        #images_test_4=torch.cat((images_test,images_test1, images_test2,images_test3, images_test4), 2)
        #images_test_4=torch.cat((images_test,images_test3, images_test4), 2)        
        #images_test=images_test_4.reshape(4,-1,3,84,84)       
        labels_train_4 = labels_train.reshape(4,-1,1)#[:,:,0]

        labels_train_4 = labels_train_4.repeat(1,1,augment_k)
        labels_test_4=labels_train_4[:,:,:augment_k]
        labels_train_4 = labels_train_4.reshape(4,-1)  
        labels_test_4=labels_test_4.reshape(4,-1)       
       
        if use_gpu:
            images_train, labels_train,images_train_4 = images_train.cuda(), labels_train.cuda(),images_train_4.cuda()
            #images_train_fuse=images_train_fuse.cuda()
            
            images_test, labels_test = images_test.cuda(), labels_test.cuda()
            pids = pids.cuda()
            labels_train_4=labels_train_4.cuda()
            labels_test_4=labels_test_4.cuda()
            tpids_4 = tpids_4.cuda()
            tpids=tpids.cuda()
        pids_con=torch.cat((pids, tpids_4), 1)
        labels_test_4=torch.cat((labels_test, labels_test_4), 1)
        #tpids

        batch_size, num_train_examples, channels, height, width = images_train.size()
        num_test_examples = images_test.size(1)
        
        labels_train_1hot = one_hot(labels_train).cuda()
    
        train_pid=torch.matmul(labels_train_1hot.transpose(1, 2),tpids.unsqueeze(2).float()).squeeze()
        train_pid=(train_pid/K_shot).long()

        
        #exit()
        labels_train_1hot_4 = one_hot(labels_train_4).cuda()        
        #labels_train = labels_train.view(batch_size * num_train_examples)   
        #print( labels_train)
        #exit(0)        
        labels_test_1hot = one_hot(labels_test).cuda()
        labels_test_1hot_4 = one_hot(labels_test_4).cuda()
 
        #support set
        switch=np.random.uniform(0,1)
        if switch>-1:
            images_train=images_train.reshape(4,-1,3,84,84)
        else:
            images_train=images_train1.cuda().reshape(4,-1,3,84,84) 
            #images_train1            
        images_train_4=images_train_4.reshape(4,-1,3,84,84)
        #inpaint_tensor=torch.from_numpy(inpaint_img_np).cuda().reshape(4,20,3,84,84).float()        
        images_test=torch.cat((images_test, images_train_4), 1).reshape(4,-1,3,84,84)#images_train
        

        
        
 



        

        ytest, cls_scores,features,params_classifier,spatial = model(images_train, images_test, labels_train_1hot, labels_test_1hot_4)#ytest is all class classification
                                                                                                #cls_scores is N-way classifation

        loss1 = criterion(ytest, pids_con.view(-1))                  
        loss2 = criterion(cls_scores, labels_test_4.view(-1))        

        if epoch>900: 
            loss3 = loss_fn(params_classifier,ytest, features,pids_con)        
            loss = loss1 + 0.5 * loss2+loss3
        else:
            loss= loss1 + 0.5 * loss2#+0.5*loss_contrast
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.update(loss.item(), pids.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

    print('Epoch{0} '
          'lr: {1} '
          'Time:{batch_time.sum:.1f}s '
          'Data:{data_time.sum:.1f}s '
          'Loss:{loss.avg:.4f} '.format(
           epoch+1, learning_rate, batch_time=batch_time, 
           data_time=data_time, loss=losses))
Ejemplo n.º 5
0
def train(epoch, model, criterion, optimizer, trainloader, learning_rate,
          use_gpu):
    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    model.train()
    end = time.time()
    for batch_idx, (images_train, labels_train, images_test, labels_test,
                    pids) in enumerate(trainloader):
        data_time.update(time.time() - end)
        if use_gpu:
            images_train, labels_train = images_train.cuda(
            ), labels_train.cuda()
            images_test, labels_test = images_test.cuda(), labels_test.cuda()
            pids = pids.cuda()
        batch_size, num_train_examples, channels, height, width = images_train.size(
        )
        num_test_examples = images_test.size(1)
        labels_train_1hot = one_hot(labels_train).cuda()
        labels_test_1hot = one_hot(labels_test).cuda()
        ytest = model(images_train, images_test, labels_train_1hot,
                      labels_test_1hot)
        loss = criterion(ytest, pids.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.update(loss.item(), pids.size(0))
        batch_time.update(time.time() - end)
        end = time.time()
    print('Epoch{0} '
          'lr: {1} '
          'Time:{batch_time.sum:.1f}s '
          'Data:{data_time.sum:.1f}s '
          'Loss:{loss.avg:.4f} '.format(epoch + 1,
                                        learning_rate,
                                        batch_time=batch_time,
                                        data_time=data_time,
                                        loss=losses))
Ejemplo n.º 6
0
def test(model_edge, model, model_tradclass,weight_softmax, testloader, use_gpu):
    accs = AverageMeter()
    test_accuracies = []
    std=np.expand_dims(np.array([0.229, 0.224, 0.225]),axis=1)
    std=np.expand_dims(std,axis=2) 
    mean=np.expand_dims(np.array([0.485, 0.456, 0.406]),axis=1)             
    mean=np.expand_dims(mean,axis=2)     
    model.eval()
    model_tradclass.eval()
    with torch.no_grad():
        for batch_idx , (images_train, labels_train,Xt_img_ori,Xt_img_gray, images_test, labels_test) in enumerate(testloader):
            if use_gpu:
                images_train = images_train.cuda()
                images_test = images_test.cuda()

            end = time.time()
            #print(images_train.shape,images_test.shape)
            #exit(0)
            batch_size, num_train_examples, channels, height, width = images_train.size()
            num_test_examples = images_test.size(1)
            labels_train_4 = labels_train.reshape(4,5,1)#[:,:,0]

            labels_train_4 = labels_train_4.repeat(1,1,5).reshape(4,-1) 
            labels_train_4=labels_train_4.cuda()           
            labels_train_1hot = one_hot(labels_train).cuda()
            labels_test_1hot = one_hot(labels_test).cuda()
            labels_train_1hot_4 = one_hot(labels_train_4).cuda()            
            ytest,feature= model_tradclass(images_train, images_train, labels_train_1hot, labels_test_1hot)
        #print(ytest.shape)
        #exit(0)
            images_train=images_train.reshape(4,5,1,3,84,84)
            feature_cpu=feature.detach().cpu().numpy()
            probs, idx = ytest.detach().sort(1, True)
            probs = probs.cpu().numpy()
            idx = idx.cpu().numpy() 
        #print(pids)
        #print(idx[:,0,0,0])
        #print(idx.shape)
        #exit(0)
        #print(feature.shape)
        #exit(0)
            masks=[]
        #output_cam=[]
            for i in range(feature.shape[0]):
                CAMs=returnCAM(feature_cpu[i], weight_softmax, [idx[i,:4,0,0]],masks)
                masks=CAMs
        #print(len(masks),masks[0].shape)
            masks_tensor = torch.stack(masks, dim=0)
            Xt_masks = masks_tensor.reshape(4,5,4,1,84,84)#[:,:,0]
            Xt_img_ori_repeat=Xt_img_ori.reshape(4,5,1,3,84,84)

            Xt_img_ori_repeat = Xt_img_ori_repeat.repeat(1,1,4,1,1,1)    
            Xt_img_gray_repeat=Xt_img_gray.reshape(4,5,1,1,84,84)

            Xt_img_gray_repeat = Xt_img_gray_repeat.repeat(1,1,4,1,1,1)          
            #print(Xt_img_ori.shape,Xt_masks.shape)
        #exit(0)
            edges=[]
            mask_numpy=np.uint8(Xt_masks.numpy()*255)
        #print(mask_numpy.shape,Xt_img_gray_numpy.shape)
            Xt_img_gray_numpy=np.uint8(Xt_img_gray.numpy()*255)
        #print(Xt_img_gray_numpy.shape)
            for i in range(4):
                for j in range(5):
                    for k in range(4):
                        edge_PIL=Image.fromarray(load_edge(Xt_img_gray_numpy[i,j,0], mask_numpy[i,j,k,0]))
                        edges.append(Funljj.to_tensor(edge_PIL).float())        
            edges = torch.stack(edges, dim=0) 
            edge_sh=edges#.reshape(4,5,1,84,84)
        #exit(0)        
        #model_edge.test(Xt_img_ori,edge_sh,Xt_img_gray,Xt_masks)
            inpaint_img=model_edge.test(Xt_img_ori_repeat.reshape(80,3,84,84),edge_sh,Xt_img_gray_repeat.reshape(80,1,84,84),masks_tensor)
            inpaint_img_np=inpaint_img.detach().cpu().numpy()
            inpaint_img_np=(inpaint_img_np-mean)/std
            inpaint_tensor=torch.from_numpy(inpaint_img_np).cuda().reshape(4,5,4,3,84,84).float()
            images_train=torch.cat((images_train, inpaint_tensor), 2).reshape(4,25,3,84,84)
            cls_scores = model(images_train, images_test, labels_train_1hot_4, labels_test_1hot)
            cls_scores = cls_scores.view(batch_size * num_test_examples, -1)
            labels_test = labels_test.view(batch_size * num_test_examples)

            _, preds = torch.max(cls_scores.detach().cpu(), 1)
            acc = (torch.sum(preds == labels_test.detach().cpu()).float()) / labels_test.size(0)
            accs.update(acc.item(), labels_test.size(0))

            gt = (preds == labels_test.detach().cpu()).float()
            gt = gt.view(batch_size, num_test_examples).numpy() #[b, n]
            acc = np.sum(gt, 1) / num_test_examples
            acc = np.reshape(acc, (batch_size))
            test_accuracies.append(acc)

    accuracy = accs.avg
    test_accuracies = np.array(test_accuracies)
    test_accuracies = np.reshape(test_accuracies, -1)
    stds = np.std(test_accuracies, 0)
    ci95 = 1.96 * stds / np.sqrt(args.epoch_size)
    print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy, ci95))

    return accuracy
Ejemplo n.º 7
0
def train(epoch,model_edge, model, model_tradclass,weight_softmax, criterion, optimizer, trainloader, learning_rate, use_gpu):
    
    if not os.path.isdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_1"):
        os.mkdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_1")
        os.mkdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_2")
        os.mkdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_3")
        os.mkdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_4") 
        os.mkdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_5")
        os.mkdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_6")
        os.mkdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_7")
        os.mkdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_8")  
        os.mkdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_9")
        os.mkdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_10")
        os.mkdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_11")
        os.mkdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_12") 
        os.mkdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_13")
        os.mkdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_14")
        os.mkdir("/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train_15")        
    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    std=np.expand_dims(np.array([0.229, 0.224, 0.225]),axis=1)
    std=np.expand_dims(std,axis=2) 
    mean=np.expand_dims(np.array([0.485, 0.456, 0.406]),axis=1)             
    mean=np.expand_dims(mean,axis=2) 
    model.eval()
    #model_edge.eval()
    model_tradclass.eval()
    end = time.time()
    #print('llllllllllllll','located in train_with_inpaint_final.py at 264')
    #exit(0)
    for root, dirs, _ in os.walk('/data4/lijunjie/tiered-imagenet-tools-master/tiered_imagenet/train'):
        #for f in files:
            #print(os.path.join(root, f))

        for d in dirs:
            path=os.path.join(root, d)
            path_1=path.replace('train','train_1')
            path_2=path.replace('train','train_2')
            path_3=path.replace('train','train_3')
            path_4=path.replace('train','train_4')
            path_5=path.replace('train','train_5')
            path_6=path.replace('train','train_6')
            path_7=path.replace('train','train_7')
            path_8=path.replace('train','train_8')
            path_9=path.replace('train','train_9')
            path_10=path.replace('train','train_10')
            path_11=path.replace('train','train_11')
            path_12=path.replace('train','train_12')
            path_13=path.replace('train','train_13')
            path_14=path.replace('train','train_14')
            path_15=path.replace('train','train_15')
            #path_8=path.replace('train','train_8')            
            if not os.path.isdir(path_1):            
                os.mkdir(path_1)
                os.mkdir(path_2)
                os.mkdir(path_3)
                os.mkdir(path_4)
                os.mkdir(path_5)
                os.mkdir(path_6)
                os.mkdir(path_7)
                os.mkdir(path_8)
                os.mkdir(path_9)
                os.mkdir(path_10)
                os.mkdir(path_11)
                os.mkdir(path_12)  
                os.mkdir(path_13)
                os.mkdir(path_14)
                os.mkdir(path_15)
                #os.mkdir(path_12)                 
            files = os.listdir(path) 
            #images=[]
            #imgs_gray=[]
            #Xt_img_ori=[]
            Paths=[]
            Paths.append(path_1)
            Paths.append(path_2)
            Paths.append(path_3)
            Paths.append(path_4)    
            Paths.append(path_5)
            Paths.append(path_6)
            Paths.append(path_7)
            Paths.append(path_8)    
            Paths.append(path_9)
            Paths.append(path_10)
            Paths.append(path_11)
            Paths.append(path_12)    
            Paths.append(path_13)
            Paths.append(path_14)
            Paths.append(path_15)            
            for file in files:
                images=[]
                imgs_gray=[]
                Xt_img_ori=[]            
                img_ori = read_image(os.path.join(path, file))
                #print(file)
                #exit(0)
                masked_img=np.array(img_ori)#*(1-mask_3)+mask_3*255
                masked_img=Image.fromarray(masked_img)
                masked_img_tensor=Funljj.to_tensor(masked_img).float()           
                Xt_img_ori.append(masked_img_tensor)
                img = transform_test(img_ori)
                img_gray = rgb2gray(np.array(img_ori))
                img_gray=Image.fromarray(img_gray)
                img_gray_tensor=Funljj.to_tensor(img_gray).float()            
                imgs_gray.append(img_gray_tensor)                
                images.append(img)
                images = torch.stack(images, dim=0)
                imgs_gray = torch.stack(imgs_gray, dim=0) 
                Xt_img_ori = torch.stack(Xt_img_ori, dim=0)
                if use_gpu:
                    images_train = images.cuda()
                    imgs_gray = imgs_gray.cuda()
                    Xt_img_ori = Xt_img_ori.cuda()
                    
                with torch.no_grad():
                    ytest,feature= model_tradclass(images_train.reshape(1,1,3,84,84), images_train.reshape(1,1,3,84,84),images_train.reshape(1,1,3,84,84), images_train.reshape(1,1,3,84,84))               
                feature_cpu=feature.detach().cpu().numpy()
                probs, idx = ytest.detach().sort(1, True)
                probs = probs.cpu().numpy()
                idx = idx.cpu().numpy() 
        #print(pids)
        #print(idx[:,0,0,0])
        #print(idx.shape)
        #exit(0)
        #print(feature.shape)
        #exit(0)
                masks=[]
                edges=[]
        #output_cam=[]
                for i in range(feature.shape[0]):
                    CAMs=returnCAM(feature_cpu[i], weight_softmax, [idx[i,:15,0,0]],masks)
                    #for j in range(4):
                        #print(CAMs[j].shape,CAMs[j].max(),CAMs[j].min(),CAMs[j].sum())
                    #exit(0)
                    masks=CAMs
        #print(len(masks),masks[0].shape)
                masks_tensor = torch.stack(masks, dim=0) 
                Xt_masks = masks_tensor.reshape(1,1,15,1,84,84)#[:,:,0]
                Xt_img_ori_repeat=Xt_img_ori.reshape(1,1,1,3,84,84)

                Xt_img_ori_repeat = Xt_img_ori_repeat.repeat(1,1,15,1,1,1)    
                Xt_img_gray_repeat=imgs_gray.reshape(1,1,1,1,84,84)

                Xt_img_gray_repeat = Xt_img_gray_repeat.repeat(1,1,15,1,1,1)          
        #print(Xt_img_ori.shape,Xt_masks.shape)
        #exit(0)
                mask_numpy=np.uint8(Xt_masks.numpy()*255)
                print(mask_numpy.shape)
                #exit(0)
                Xt_img_gray_numpy=np.uint8(imgs_gray.cpu().numpy()*255).reshape(1,1,1,84,84)
        #print(Xt_img_gray_numpy.shape)
                for i in range(1):
                    for j in range(1):
                        for k in range(15):
                            edge_PIL=Image.fromarray(load_edge(Xt_img_gray_numpy[i,j,0], mask_numpy[i,j,k,0]))
                            print(mask_numpy[i,j,k,0].sum()/255,'llll')
                            #exit(0)
                            edges.append(Funljj.to_tensor(edge_PIL).float())        
                edges = torch.stack(edges, dim=0) 
                edge_sh=edges#.reshape(4,5,1,84,84)
                #print(edge_sh.shape,Xt_img_gray_repeat.shape,masks_tensor.shape)
                #exit(0)                
        #exit(0)        
        #model_edge.test(Xt_img_ori,edge_sh,Xt_img_gray,Xt_masks)
                with torch.no_grad():
                    inpaint_img=model_edge.test(Xt_img_ori_repeat.reshape(15,3,84,84),edge_sh,Xt_img_gray_repeat.reshape(15,1,84,84),masks_tensor)
                inpaint_img_np=inpaint_img.detach().cpu().numpy()
                Xt_img_ori_np=Xt_img_ori_repeat.detach().cpu().numpy()                
                #print(inpaint_img_np.shape)
                #exit(0)
                for id in range(15):
                    images_temp_train1=inpaint_img_np[id,:,:]
                    Xt_img_ori_repeat1=Xt_img_ori_np.reshape(15,3,84,84)[id,:,:]
                    print(Xt_img_ori_repeat1.shape)
            #images_temp_train=images_temp_train1*std+mean
                    images_ori_train=images_temp_train1.transpose((1,2,0))[:,:,::-1]
                    Xt_img_ori_repeat1=Xt_img_ori_repeat1.transpose((1,2,0))[:,:,::-1]
                    images_ori_train=np.uint8(images_ori_train*255)  
                    Xt_img_ori_repeat1=np.uint8(Xt_img_ori_repeat1*255)                    
                    cv2.imwrite(Paths[id]+'/'+file, images_ori_train)     
                    #cv2.imwrite('./result/inpaint_img/'+str(i)+'_'+str(id)+'_ori.jpg', Xt_img_ori_repeat1)                      
    exit(0)                    
                #exit(0)                
            #print(path)
            #print(path_1)
            #print(path_2)
            #print(path_3)
            #print(path_4)            
                #exit(0)            
    for batch_idx, (images_train, labels_train,tpids,Xt_img_ori,Xt_img_gray,images_test, labels_test, pids) in enumerate(trainloader):
    
    #for batch_idx, (images_train, labels_train, images_test, labels_test, pids) in enumerate(trainloader):    
        data_time.update(time.time() - end)
        #print(Xt_img_ori.shape,Xt_img_gray.shape,images_train.shape,'lll')
        edges=[]
        if use_gpu:
            images_train = images_train.cuda()

        batch_size, num_train_examples, channels, height, width = images_train.size()
        num_test_examples = images_test.size(1)
        
        labels_train_1hot = one_hot(labels_train).cuda()
        labels_train_1hot_4 = one_hot(labels_train_4).cuda()        
        #labels_train = labels_train.view(batch_size * num_train_examples)   
        #print( labels_train)
        #exit(0)        
        labels_test_1hot = one_hot(labels_test).cuda()
        labels_test_1hot_4 = one_hot(labels_test_4).cuda()
        #print(labels_test_1hot_4.shape,labels_test_1hot.shape)        
        #labels_test_1hot_4 = torch.cat((labels_test_1hot , labels_test_1hot_4), 1)
        #print(labels_test_1hot.shape,labels_test_1hot_4.shape)
        #exit(0)
        with torch.no_grad():
            ytest,feature= model_tradclass(images_train, images_train, labels_train_1hot, labels_test_1hot)
        #print(ytest.shape)
        #exit(0)
        images_train=images_train.reshape(4,5,1,3,84,84)
        #images_test=images_test.reshape(4,30,1,3,84,84)        
        feature_cpu=feature.detach().cpu().numpy()
        probs, idx = ytest.detach().sort(1, True)
        probs = probs.cpu().numpy()
        idx = idx.cpu().numpy() 
        #print(pids)
        #print(idx[:,0,0,0])
        #print(idx.shape)
        #exit(0)
        #print(feature.shape)
        #exit(0)
        masks=[]
        #output_cam=[]
        for i in range(feature.shape[0]):
            CAMs=returnCAM(feature_cpu[i], weight_softmax, [idx[i,:4,0,0]],masks)
            masks=CAMs
        #print(len(masks),masks[0].shape)
        masks_tensor = torch.stack(masks, dim=0)
        Xt_masks = masks_tensor.reshape(1,1,4,1,84,84)#[:,:,0]
        Xt_img_ori_repeat=Xt_img_ori.reshape(1,1,1,3,84,84)

        Xt_img_ori_repeat = Xt_img_ori_repeat.repeat(1,1,4,1,1,1)    
        Xt_img_gray_repeat=Xt_img_gray.reshape(1,1,1,1,84,84)

        Xt_img_gray_repeat = Xt_img_gray_repeat.repeat(1,1,4,1,1,1)          
        #print(Xt_img_ori.shape,Xt_masks.shape)
        #exit(0)
        mask_numpy=np.uint8(Xt_masks.numpy()*255)
        #print(mask_numpy.shape,Xt_img_gray_numpy.shape)
        Xt_img_gray_numpy=np.uint8(Xt_img_gray.numpy()*255)
        #print(Xt_img_gray_numpy.shape)
        for i in range(1):
            for j in range(1):
                for k in range(4):
                    edge_PIL=Image.fromarray(load_edge(Xt_img_gray_numpy[i,j,0], mask_numpy[i,j,k,0]))
                    edges.append(Funljj.to_tensor(edge_PIL).float())        
        edges = torch.stack(edges, dim=0) 
        edge_sh=edges#.reshape(4,5,1,84,84)
        #exit(0)        
        #model_edge.test(Xt_img_ori,edge_sh,Xt_img_gray,Xt_masks)
        with torch.no_grad():
            inpaint_img=model_edge.test(Xt_img_ori_repeat.reshape(4,3,84,84),edge_sh,Xt_img_gray_repeat.reshape(4,1,84,84),masks_tensor)
        inpaint_img_np=inpaint_img.detach().cpu().numpy()
        for i in range(4):
            images_temp_train1=inpaint_img_np[i,:,:].cpu().numpy()
            #images_temp_train=images_temp_train1*std+mean
            images_ori_train=images_temp_train1.transpose((1,2,0))[:,:,::-1]
            images_ori_train=np.uint8(images_ori_train*255)                 
            cv2.imwrite('./result/inpaint_img/'+str(i)+'_'+str(j)+'_'+str(labels_train_ex[i,j])+'.jpg', images_ori_train)            
        exit(0)
        inpaint_img_np=(inpaint_img_np-mean)/std
        #support set
        inpaint_tensor=torch.from_numpy(inpaint_img_np).cuda().reshape(4,5,4,3,84,84).float()        
Ejemplo n.º 8
0
def train(epoch, model_edge, model, model_tradclass, weight_softmax, criterion,
          optimizer, trainloader, learning_rate, use_gpu):

    if not os.path.isdir(
            "/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_1"):
        os.mkdir(
            "/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_1")
        os.mkdir(
            "/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_2")
        os.mkdir(
            "/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_3")
        os.mkdir(
            "/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_4")
        os.mkdir(
            "/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_5")
        os.mkdir(
            "/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_6")
        os.mkdir(
            "/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_7")
        os.mkdir(
            "/data4/lijunjie/mini-imagenet-tools/processed_images_84/train_8")
    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    std = np.expand_dims(np.array([0.229, 0.224, 0.225]), axis=1)
    std = np.expand_dims(std, axis=2)
    mean = np.expand_dims(np.array([0.485, 0.456, 0.406]), axis=1)
    mean = np.expand_dims(mean, axis=2)
    model.eval()
    #model_edge.eval()
    model_tradclass.eval()
    end = time.time()
    #print('llllllllllllll','located in train_with_inpaint_final.py at 264')
    #exit(0)
    for root, dirs, _ in os.walk(
            '/data4/lijunjie/mini-imagenet-tools/processed_images_84/train'):
        #for f in files:
        #print(os.path.join(root, f))

        for d in dirs:
            path = os.path.join(root, d)
            path_1 = path.replace('train', 'train_1')
            path_2 = path.replace('train', 'train_2')
            path_3 = path.replace('train', 'train_3')
            path_4 = path.replace('train', 'train_4')
            path_5 = path.replace('train', 'train_5')
            path_6 = path.replace('train', 'train_6')
            path_7 = path.replace('train', 'train_7')
            path_8 = path.replace('train', 'train_8')
            if not os.path.isdir(path_1):
                os.mkdir(path_1)
                os.mkdir(path_2)
                os.mkdir(path_3)
                os.mkdir(path_4)
                os.mkdir(path_5)
                os.mkdir(path_6)
                os.mkdir(path_7)
                os.mkdir(path_8)
            files = os.listdir(path)
            #images=[]
            #imgs_gray=[]
            #Xt_img_ori=[]
            Paths = []
            Paths.append(path_1)
            Paths.append(path_2)
            Paths.append(path_3)
            Paths.append(path_4)
            Paths.append(path_5)
            Paths.append(path_6)
            Paths.append(path_7)
            Paths.append(path_8)
            for file in files:
                images = []
                imgs_gray = []
                Xt_img_ori = []
                img_ori = read_image(os.path.join(path, file))
                #print(file)
                #exit(0)
                masked_img = np.array(img_ori)  #*(1-mask_3)+mask_3*255
                masked_img = Image.fromarray(masked_img)
                masked_img_tensor = Funljj.to_tensor(masked_img).float()
                Xt_img_ori.append(masked_img_tensor)
                img = transform_test(img_ori)
                img_gray = rgb2gray(np.array(img_ori))
                img_gray = Image.fromarray(img_gray)
                img_gray_tensor = Funljj.to_tensor(img_gray).float()
                imgs_gray.append(img_gray_tensor)
                images.append(img)
                images = torch.stack(images, dim=0)
                imgs_gray = torch.stack(imgs_gray, dim=0)
                Xt_img_ori = torch.stack(Xt_img_ori, dim=0)
                if use_gpu:
                    images_train = images.cuda()
                    imgs_gray = imgs_gray.cuda()
                    Xt_img_ori = Xt_img_ori.cuda()

                with torch.no_grad():
                    ytest, feature = model_tradclass(
                        images_train.reshape(1, 1, 3, 84, 84),
                        images_train.reshape(1, 1, 3, 84, 84),
                        images_train.reshape(1, 1, 3, 84, 84),
                        images_train.reshape(1, 1, 3, 84, 84))
                feature_cpu = feature.detach().cpu().numpy()
                probs, idx = ytest.detach().sort(1, True)
                probs = probs.cpu().numpy()
                idx = idx.cpu().numpy()
                #print(pids)
                #print(idx[:,0,0,0])
                #print(idx.shape)
                #exit(0)
                #print(feature.shape)
                #exit(0)
                masks = []
                edges = []
                #output_cam=[]
                for i in range(feature.shape[0]):
                    CAMs = returnCAM(feature_cpu[i], weight_softmax,
                                     [idx[i, :8, 0, 0]], masks)
                    #for j in range(4):
                    #print(CAMs[j].shape,CAMs[j].max(),CAMs[j].min(),CAMs[j].sum())
                    #exit(0)
                    masks = CAMs
        #print(len(masks),masks[0].shape)
                masks_tensor = torch.stack(masks, dim=0)
                Xt_masks = masks_tensor.reshape(1, 1, 8, 1, 84, 84)  #[:,:,0]
                Xt_img_ori_repeat = Xt_img_ori.reshape(1, 1, 1, 3, 84, 84)

                Xt_img_ori_repeat = Xt_img_ori_repeat.repeat(1, 1, 8, 1, 1, 1)
                Xt_img_gray_repeat = imgs_gray.reshape(1, 1, 1, 1, 84, 84)

                Xt_img_gray_repeat = Xt_img_gray_repeat.repeat(
                    1, 1, 7, 1, 1, 1)
                #print(Xt_img_ori.shape,Xt_masks.shape)
                #exit(0)
                mask_numpy = np.uint8(Xt_masks.numpy() * 255)
                print(mask_numpy.shape)
                #exit(0)
                Xt_img_gray_numpy = np.uint8(imgs_gray.cpu().numpy() *
                                             255).reshape(1, 1, 1, 84, 84)
                #print(Xt_img_gray_numpy.shape)
                for i in range(1):
                    for j in range(1):
                        for k in range(7):
                            edge_PIL = Image.fromarray(
                                load_edge(Xt_img_gray_numpy[i, j, 0],
                                          mask_numpy[i, j, k, 0]))
                            print(mask_numpy[i, j, k, 0].sum() / 255, 'llll')
                            #exit(0)
                            edges.append(Funljj.to_tensor(edge_PIL).float())
                edges = torch.stack(edges, dim=0)
                edge_sh = edges  #.reshape(4,5,1,84,84)
                #print(edge_sh.shape,Xt_img_gray_repeat.shape,masks_tensor.shape)
                #exit(0)
                #exit(0)
                #model_edge.test(Xt_img_ori,edge_sh,Xt_img_gray,Xt_masks)
                with torch.no_grad():
                    inpaint_img = model_edge.test(
                        Xt_img_ori_repeat.reshape(8, 3, 84, 84), edge_sh,
                        Xt_img_gray_repeat.reshape(8, 1, 84, 84), masks_tensor)
                inpaint_img_np = inpaint_img.detach().cpu().numpy()
                Xt_img_ori_np = Xt_img_ori_repeat.detach().cpu().numpy()
                #print(inpaint_img_np.shape)
                #exit(0)
                for id in range(8):
                    images_temp_train1 = inpaint_img_np[id, :, :]
                    Xt_img_ori_repeat1 = Xt_img_ori_np.reshape(8, 3, 84,
                                                               84)[id, :, :]
                    print(Xt_img_ori_repeat1.shape)
                    #images_temp_train=images_temp_train1*std+mean
                    images_ori_train = images_temp_train1.transpose(
                        (1, 2, 0))[:, :, ::-1]
                    Xt_img_ori_repeat1 = Xt_img_ori_repeat1.transpose(
                        (1, 2, 0))[:, :, ::-1]
                    images_ori_train = np.uint8(images_ori_train * 255)
                    Xt_img_ori_repeat1 = np.uint8(Xt_img_ori_repeat1 * 255)
                    cv2.imwrite(Paths[id] + '/' + file, images_ori_train)
Ejemplo n.º 9
0
def test_trans(model, testloader, candicate_num, use_gpu):
    accs = AverageMeter()
    test_accuracies = []
    model.eval()

    with torch.no_grad():
        for batch_idx, (images_train, labels_train, images_test, labels_test,
                        _) in enumerate(testloader):
            if use_gpu:
                images_train = images_train.cuda()
                images_test = images_test.cuda()
                labels_train = labels_train.cuda()
            # print(images_train.shape)  #[8, 25, 3, 84, 84]
            # print(labels_train.shape)  #8, 25
            # print(images_test.shape)  #[8, 75, 3, 84, 84]
            # print(labels_test.shape)   #8, 75
            end = time.time()
            for transductive_iter in range(2):
                batch_size, num_train_examples, channels, height, width = images_train.size(
                )
                num_test_examples = images_test.size(1)

                labels_test_1hot = one_hot(labels_test).cuda()
                labels_train_1hot = one_hot(labels_train).cuda()
                cls_scores = model(images_train, images_test,
                                   labels_train_1hot, labels_test_1hot)
                cls_scores = cls_scores.view(batch_size * num_test_examples,
                                             -1)
                labels_test = labels_test.view(batch_size * num_test_examples)

                probs, preds = torch.max(cls_scores, 1)

                preds = preds.view(batch_size, num_test_examples)
                probs = probs.view(batch_size, num_test_examples)
                top_k, top_k_id = torch.topk(probs,
                                             candicate_num *
                                             (transductive_iter + 1),
                                             largest=True,
                                             sorted=True)  #  (bs, K)
                candicate_img_index = top_k_id.unsqueeze(-1).unsqueeze(
                    -1).unsqueeze(-1).expand(
                        batch_size, candicate_num * (transductive_iter + 1),
                        *images_test.size()[2:])
                candicate_img = torch.gather(images_test,
                                             index=candicate_img_index,
                                             dim=1)

                candicate_label = torch.gather(preds, dim=1, index=top_k_id)
                images_train = torch.cat((images_train, candicate_img), dim=1)
                labels_train = torch.cat((labels_train, candicate_label),
                                         dim=1)

            # print(top_k_id.shape, top_k.shape)
            batch_size, num_train_examples, channels, height, width = images_train.size(
            )
            num_test_examples = images_test.size(1)

            labels_train_1hot = one_hot(labels_train).cuda()
            labels_test_1hot = one_hot(labels_test).cuda()

            cls_scores = model(images_train, images_test, labels_train_1hot,
                               labels_test_1hot)
            cls_scores = cls_scores.view(batch_size * num_test_examples, -1)
            labels_test = labels_test.view(batch_size * num_test_examples)

            _, preds = torch.max(cls_scores.detach().cpu(), 1)
            acc = (torch.sum(preds == labels_test.detach().cpu()).float()
                   ) / labels_test.size(0)
            accs.update(acc.item(), labels_test.size(0))

            gt = (preds == labels_test.detach().cpu()).float()
            gt = gt.view(batch_size, num_test_examples).numpy()  # [b, n]
            acc = np.sum(gt, 1) / num_test_examples
            acc = np.reshape(acc, (batch_size))
            test_accuracies.append(acc)

    accuracy = accs.avg
    test_accuracies = np.array(test_accuracies)
    test_accuracies = np.reshape(test_accuracies, -1)
    stds = np.std(test_accuracies, 0)
    ci95 = 1.96 * stds / np.sqrt(args.epoch_size)
    print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy, ci95))

    return accuracy