Esempio n. 1
0
def test():
    model_stage1.eval();     
    model_select.eval();         
    global bestloss
    test_loss=0    
    for sample in val_data.get_loader():
        if (use_gpu):
            sample['image']=sample['image'].to(device)                                            
            sample['label']=sample['label'].to(device)   
            sample['size'][0]=sample['size'][0].to(device);
            sample['size'][1]=sample['size'][1].to(device);
        stage1_label=model_stage1(sample['image'])              
        theta=model_select(stage1_label,sample['size'])      
        
        theta_label = torch.zeros((sample['image'].size()[0],6,2,3),device=device,requires_grad=False); #[batch_size,6,2,3]    
        W=1024.0;
        H=1024.0;
        '''        
        cens = torch.floor(calc_centroid_old(sample['label'])) #[batch_size,9,2]     
        for i in range(sample['image'].size()[0]):    
            for j in range(9):
                cens[i,j,0]=cens[i,j,0]*(sample['size'][0][i]-1.0)/(128.0-1.0)
                cens[i,j,1]=cens[i,j,1]*(sample['size'][1][i]-1.0)/(128.0-1.0)        
        points = torch.floor(torch.cat([cens[:, 1:6],cens[:, 6:9].mean(dim=1, keepdim=True)],dim=1)) #[batch_size,6,2]
        '''
        points=torch.floor(calc_centroid(sample['label_org']))  
        for i in range(6):
            theta_label[:,i,0,0]=(81.0-1.0)/(W-1.0);
            theta_label[:,i,1,1]=(81.0-1.0)/(H-1.0);
            theta_label[:,i,0,2]=-1+2*points[:,i,0]/(W-1.0);
            theta_label[:,i,1,2]=-1+2*points[:,i,1]/(H-1.0); 
            
        loss=fun.smooth_l1_loss(theta, theta_label); 
        test_loss+=fun.smooth_l1_loss(theta, theta_label).data;        
    test_loss/=len(val_data.get_loader().dataset)    
    print('\nTest set: {} Cases,Average loss: {:.8f}\n'.format(
        len(test_data.get_loader().dataset),test_loss))        
    loss_list.append(test_loss.data.cpu().numpy());
    if (test_loss<bestloss):
        bestloss=test_loss
        print("Best data Stage1 Updata\n");
        torch.save(model_stage1,"./preBestNet_stage1")           
        torch.save(model_select,"./preBestNet_select")                          
Esempio n. 2
0
def train(epoch):
    model_stage2.train()
    '''
    part1_time=0;    
    part2_time=0;    
    part3_time=0;         
    prev_time=time.time();             
    '''
    unloader = transforms.ToPILImage()
    losstmp = 0
    lossstep = 0
    k = 0

    for batch_idx, sample in enumerate(train_data.get_loader()):
        '''
        now_time=time.time();
        part3_time+=now_time-prev_time;        
        prev_time=now_time;
        '''
        if (use_gpu):
            sample['image'] = sample['image'].to(device)
            sample['label'] = sample['label'].to(device)
            sample['image_org'] = sample['image_org'].to(device)
            sample['label_org'] = sample['label_org'].to(device)
            sample['size'][0] = sample['size'][0].to(device)
            sample['size'][1] = sample['size'][1].to(device)
        '''
        for i in range(batch_size):
            image=sample['image_org'][i].cpu().clone();                                 
            image=transforms.ToPILImage()(image).convert('RGB')
            plt.imshow(image);
            plt.show(block=True);            
        '''
        optimizer_stage2.zero_grad()

        theta_label = torch.zeros((sample['image'].size()[0], 6, 2, 3),
                                  device=device,
                                  requires_grad=False)
        #[batch_size,6,2,3]
        W = 1024.0
        H = 1024.0
        points = torch.floor(calc_centroid(sample['label_org']))
        for i in range(6):
            theta_label[:, i, 0, 0] = (81.0 - 1.0) / (W - 1.0)
            theta_label[:, i, 1, 1] = (81.0 - 1.0) / (H - 1.0)
            theta_label[:, i, 0, 2] = -1 + 2 * points[:, i, 0] / (W - 1.0)
            theta_label[:, i, 1, 2] = -1 + 2 * points[:, i, 1] / (H - 1.0)

        stage2_label = model_stage2(sample['image_org'], theta_label)

        parts2 = []
        parts_label2 = []
        loss = []

        for i in range(6):
            affine_stage2 = F.affine_grid(
                theta_label[:, i], (sample['image'].size()[0], 3, 81, 81),
                align_corners=True)
            parts2.append(
                F.grid_sample(sample['image_org'],
                              affine_stage2,
                              align_corners=True))
            affine_stage2 = F.affine_grid(
                theta_label[:, i],
                (sample['image'].size()[0], label_channel[i], 81, 81),
                align_corners=True)
            parts_label2.append(
                F.grid_sample(sample['label_org'][:, label_list[i]],
                              affine_stage2,
                              align_corners=True))
            parts_label2[i][:, 0] += 0.00001
            parts_label2[i] = parts_label2[i].detach()
            '''
            for j in range(sample['image'].size()[0]):
                if (not os.path.exists("./data/trainimg_output/"+train_data.get_namelist()[(k+j)%2000])):
                    os.mkdir("./data/trainimg_output/"+train_data.get_namelist()[(k+j)%2000]);                
                image3=transforms.ToPILImage()(sample['image_org'][j].cpu().clone()).convert('RGB')                   
                image3.save("./data/trainimg_output/"+train_data.get_namelist()[(k+j)%2000]+'/'+str((k+j)//2000)+'_orgimage'+'.jpg',quality=100);    
                     
                image3=transforms.ToPILImage()(parts2[i][j].cpu().clone()).convert('RGB')         
                image3.save("./data/trainimg_output/"+train_data.get_namelist()[(k+j)%2000]+'/'+str((k+j)//2000)+'lbl0'+str(i)+'_thetalabel'+'.jpg',quality=100);
                image3=unloader(np.uint8(parts_label2[i][j][1].cpu().detach().numpy()))                                       
                image3.save("./data/trainimg_output/"+train_data.get_namelist()[(k+j)%2000]+'/'+str((k+j)//2000)+'lbl0'+str(i)+'_label'+'_thetalabel'+'.jpg',quality=100); 
            '''

            loss_tmp = fun.cross_entropy(
                stage2_label[i], parts_label2[i].argmax(dim=1, keepdim=False))
            loss.append(loss_tmp)
        k += sample['image'].size()[0]
        if (batch_idx % 100 == 0):
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(sample['image']),
                len(train_data.get_loader().dataset),
                100. * batch_idx / len(train_data.get_loader()),
                torch.sum(torch.stack(loss))))
        '''    
        now_time=time.time();
        part1_time+=now_time-prev_time;        
        prev_time=now_time;        
        '''

        loss = torch.stack(loss)
        loss.backward(torch.ones(6, device=device, requires_grad=False))
        losstmp += torch.sum(loss).item()
        lossstep += sample['image'].size()[0]

        optimizer_stage2.step()
        '''
Esempio n. 3
0
def printoutput():
    model_stage2 = torch.load("./preBestNet_stage2", map_location="cpu")
    if (use_gpu):
        model_stage2 = model_stage2.to(device)
    model_stage2.eval()
    global bestloss, bestf1
    test_loss = 0
    hists = []
    k = 0
    for sample in test_data.get_loader():
        if (use_gpu):
            sample['image'] = sample['image'].to(device)
            sample['label'] = sample['label'].to(device)
            sample['image_org'] = sample['image_org'].to(device)
            sample['label_org'] = sample['label_org'].to(device)
            sample['size'][0] = sample['size'][0].to(device)
            sample['size'][1] = sample['size'][1].to(device)

        theta_label = torch.zeros((sample['image'].size()[0], 6, 2, 3),
                                  device=device,
                                  requires_grad=False)
        #[batch_size,6,2,3]
        W = 1024.0
        H = 1024.0
        points = torch.floor(calc_centroid(sample['label_org']))
        for i in range(6):
            theta_label[:, i, 0, 0] = (81.0 - 1.0) / (W - 1.0)
            theta_label[:, i, 1, 1] = (81.0 - 1.0) / (H - 1.0)
            theta_label[:, i, 0, 2] = -1 + 2 * points[:, i, 0] / (W - 1.0)
            theta_label[:, i, 1, 2] = -1 + 2 * points[:, i, 1] / (H - 1.0)

        stage2_label = model_stage2(sample['image_org'], theta_label)

        #parts=[];
        parts_label = []
        for i in range(6):
            #affine_stage2=F.affine_grid(theta[:,i],(sample['image'].size()[0],3,81,81),align_corners=True);
            #parts.append(F.grid_sample(sample['image'],affine_stage2),align_corners=True);
            affine_stage2 = F.affine_grid(
                theta_label[:, i],
                (sample['image'].size()[0], label_channel[i], 81, 81),
                align_corners=True)
            parts_label.append(
                F.grid_sample(sample['label_org'][:, label_list[i]],
                              affine_stage2,
                              align_corners=True))

            parts_label[i][:, 0] += 0.00001
            test_loss += fun.cross_entropy(
                stage2_label[i], parts_label[i].argmax(dim=1,
                                                       keepdim=False)).data
            for j in range(sample['image'].shape[0]):
                path = pre_stage2_output_path + '/' + test_data.get_namelist()[
                    k + j]
                if not os.path.exists(path):
                    os.makedirs(path)
                image = TF.to_pil_image(
                    parts_label[i][j][1].unsqueeze(0).cpu())
                image.save(path + '/' + test_data.get_namelist()[k + j] +
                           'lbl0' + str(i) + '_label.jpg',
                           quality=100)
                image = torch.softmax(stage2_label[i][j].cpu(),
                                      dim=0).argmax(dim=0, keepdim=True)
                image = torch.zeros(label_channel[i], 81,
                                    81).scatter_(0, image, 1)
                image = TF.to_pil_image(image[1].unsqueeze(0), mode="L")
                image.save(path + '/' + test_data.get_namelist()[k + j] +
                           'lbl0' + str(i) + '_train.jpg',
                           quality=100)

            output_2 = torch.softmax(stage2_label[i],
                                     dim=1).argmax(dim=1, keepdim=False)
            output_2 = output_2.cpu().clone()
            target_2 = parts_label[i].argmax(dim=1, keepdim=False)
            target_2 = target_2.cpu().clone()
            hist = np.bincount(9 * target_2.reshape([-1]) +
                               output_2.reshape([-1]),
                               minlength=81).reshape(9, 9)
            hists.append(hist)
        k += sample['image'].shape[0]

    hists_sum = np.sum(np.stack(hists, axis=0), axis=0)
    tp = 0
    tpfn = 0
    tpfp = 0
    f1score = 0.0
    for i in range(9):
        for j in range(9):
            print(hists_sum[i][j], end=' ')
        print()
    for i in range(1, 9):
        tp += hists_sum[i][i].sum()
        tpfn += hists_sum[i, :].sum()
        tpfp += hists_sum[:, i].sum()
    f1score = 2 * tp / (tpfn + tpfp)
    test_loss /= len(test_data.get_loader().dataset)
    print('\nPrintoutput Average loss: {:.4f}\n'.format(test_loss))
    print("STN-iCNN stage2 tp=", tp)
    print("STN-iCNN stage2 tpfp=", tpfp)
    print("STN-iCNN tstage2 pfn=", tpfn)
    print('\nPrintoutputF1 Score: {:.4f}\n'.format(f1score))
    print("printoutput Finish")
Esempio n. 4
0
def test(epoch):
    model_stage2.eval()
    global bestloss, bestf1
    test_loss = 0
    hists = []
    for sample in val_data.get_loader():
        if (use_gpu):
            sample['image'] = sample['image'].to(device)
            sample['label'] = sample['label'].to(device)
            sample['image_org'] = sample['image_org'].to(device)
            sample['label_org'] = sample['label_org'].to(device)
            sample['size'][0] = sample['size'][0].to(device)
            sample['size'][1] = sample['size'][1].to(device)

        theta_label = torch.zeros((sample['image'].size()[0], 6, 2, 3),
                                  device=device,
                                  requires_grad=False)
        #[batch_size,6,2,3]
        W = 1024.0
        H = 1024.0
        points = torch.floor(calc_centroid(sample['label_org']))
        for i in range(6):
            theta_label[:, i, 0, 0] = (81.0 - 1.0) / (W - 1.0)
            theta_label[:, i, 1, 1] = (81.0 - 1.0) / (H - 1.0)
            theta_label[:, i, 0, 2] = -1 + 2 * points[:, i, 0] / (W - 1.0)
            theta_label[:, i, 1, 2] = -1 + 2 * points[:, i, 1] / (H - 1.0)

        stage2_label = model_stage2(sample['image_org'], theta_label)

        #parts=[];
        parts_label = []
        for i in range(6):
            #affine_stage2=F.affine_grid(theta[:,i],(sample['image'].size()[0],3,81,81),align_corners=True);
            #parts.append(F.grid_sample(sample['image'],affine_stage2),align_corners=True);
            affine_stage2 = F.affine_grid(
                theta_label[:, i],
                (sample['image'].size()[0], label_channel[i], 81, 81),
                align_corners=True)
            parts_label.append(
                F.grid_sample(sample['label_org'][:, label_list[i]],
                              affine_stage2,
                              align_corners=True))

            parts_label[i][:, 0] += 0.00001
            test_loss += fun.cross_entropy(
                stage2_label[i], parts_label[i].argmax(dim=1,
                                                       keepdim=False)).data

            output_2 = torch.softmax(stage2_label[i],
                                     dim=1).argmax(dim=1, keepdim=False)
            output_2 = output_2.cpu().clone()
            target_2 = parts_label[i].argmax(dim=1, keepdim=False)
            target_2 = target_2.cpu().clone()
            hist = np.bincount(9 * target_2.reshape([-1]) +
                               output_2.reshape([-1]),
                               minlength=81).reshape(9, 9)
            hists.append(hist)
    hists_sum = np.sum(np.stack(hists, axis=0), axis=0)
    for i in range(9):
        for j in range(9):
            print(hists_sum[i][j], end=' ')
        print()
    print()
    tp = 0
    tpfn = 0
    tpfp = 0
    f1score = 0.0
    for i in range(1, 9):
        tp += hists_sum[i][i].sum()
        tpfn += hists_sum[i, :].sum()
        tpfp += hists_sum[:, i].sum()

    f1score = 2 * tp / (tpfn + tpfp)
    test_loss /= len(test_data.get_loader().dataset)
    print('\nTest set: {} Cases,Average loss: {:.4f}\n'.format(
        len(test_data.get_loader().dataset), test_loss))
    print("STN-iCNN tp=", tp)
    print("STN-iCNN tpfp=", tpfp)
    print("STN-iCNN tpfn=", tpfn)
    print('\nTest set: {} Cases,F1 Score: {:.4f}\n'.format(
        len(test_data.get_loader().dataset), f1score))

    loss_list.append(test_loss.data.cpu().numpy())
    f1_list.append(f1score)
    if (UseF1):
        if (f1score > bestf1):
            bestf1 = f1score
            print("Best data Updata\n")
            torch.save(model_stage2, "./preBestNet_stage2")
    else:
        if (test_loss < bestloss):
            bestloss = test_loss
            print("Best data Updata\n")
            torch.save(model_stage2, "./preBestNet_stage2")
Esempio n. 5
0
def printoutput(in_data):
    unloader = transforms.ToPILImage()
    k = 0
    hists = []
    global label_list
    for sample in in_data.get_loader():
        if (use_gpu):
            sample['image_org'] = sample['image_org'].to(device)
            sample['label_org'] = sample['label_org'].to(device)
        N = sample['image_org'].shape[0]
        theta_label = torch.zeros((N, 6, 2, 3),
                                  device=device,
                                  requires_grad=False)
        #[batch_size,6,2,3]
        W = 1024.0
        H = 1024.0
        cens = torch.floor(calc_centroid(
            sample['label_org']))  #[batch_size,9,2]
        points = torch.floor(
            torch.cat([cens[:, 1:6], cens[:, 6:9].mean(dim=1, keepdim=True)],
                      dim=1))  #[batch_size,6,2]
        for i in range(6):
            theta_label[:, i, 0, 0] = (81.0 - 1.0) / (W - 1.0)
            theta_label[:, i, 1, 1] = (81.0 - 1.0) / (H - 1.0)
            theta_label[:, i, 0, 2] = -1 + 2 * points[:, i, 0] / (W - 1.0)
            theta_label[:, i, 1, 2] = -1 + 2 * points[:, i, 1] / (H - 1.0)
        parts = []
        parts_label = []
        for i in range(6):
            affine_stage2 = F.affine_grid(theta_label[:, i], (N, 3, 81, 81),
                                          align_corners=True)
            parts.append(
                F.grid_sample(sample['image_org'],
                              affine_stage2,
                              align_corners=True))
            affine_stage2 = F.affine_grid(theta_label[:, i],
                                          (N, label_channel[i], 81, 81),
                                          align_corners=True)
            parts_label.append(
                F.grid_sample(sample['label_org'][:, label_list[i]],
                              affine_stage2,
                              align_corners=True))
            parts_label[i][:, 0] += 0.00001
            for j in range(sample['image_org'].shape[0]):
                parts_label_tmp = parts_label[i][j].argmax(dim=0, keepdim=True)
                parts_label[i][j] = torch.zeros(label_channel[i], 81,
                                                81).to(device).scatter_(
                                                    0, parts_label_tmp, 255)
            for j in range(N):
                path = "./data/facial_parts/" + in_data.get_namelist()[k + j]
                if (not os.path.exists(path)):
                    os.mkdir(path)
                image3 = transforms.ToPILImage()(
                    parts[i][j].cpu().clone()).convert('RGB')
                image3.save(path + '/' + 'lbl0' + str(i) + '_img' + '.jpg',
                            quality=100)
                for l in range(label_channel[i]):
                    image3 = unloader(
                        np.uint8(parts_label[i][j][l].cpu().detach().numpy()))
                    image3.save(path + '/' + 'lbl0' + str(i) + '_label0' +
                                str(l) + '.jpg',
                                quality=100)
        k += N
        if (k % 200 == 0): print(k)
    print("Printoutput Finish!")
Esempio n. 6
0
def train(epoch):   
    model_stage1.train();     
    model_select.train();   
    '''
    part1_time=0;    
    part2_time=0;    
    part3_time=0;         
    prev_time=time.time();  
    '''
    k=0;
    for batch_idx,sample in enumerate(train_data.get_loader()):           
        '''
        now_time=time.time();
        part3_time+=now_time-prev_time;        
        prev_time=now_time;
        '''
        if (use_gpu):
            sample['image']=sample['image'].to(device)            
            sample['label']=sample['label'].to(device)               
            sample['size'][0]=sample['size'][0].to(device);
            sample['size'][1]=sample['size'][1].to(device);
            
        optimizer_stage1.zero_grad();                           
        optimizer_select.zero_grad();                
        
        stage1_label=model_stage1(sample['image'])             
        theta=model_select(stage1_label,sample['size'])

        theta_label = torch.zeros((sample['image'].size()[0],6,2,3),device=device,requires_grad=False); #[batch_size,6,2,3]    
        W=1024.0;
        H=1024.0;
        '''
        cens = torch.floor(calc_centroid_old(sample['label'])) #[batch_size,9,2]           
        for i in range(sample['image'].size()[0]):    
            for j in range(9):
                cens[i,j,0]=cens[i,j,0]*(sample['size'][0][i]-1.0)/(128.0-1.0)
                cens[i,j,1]=cens[i,j,1]*(sample['size'][1][i]-1.0)/(128.0-1.0)        
        points = torch.floor(torch.cat([cens[:, 1:6],cens[:, 6:9].mean(dim=1, keepdim=True)],dim=1)) #[batch_size,6,2]
        '''
        '''
        points2 = torch.floor(calc_centroid(sample['label_org'])) #[batch_size,9,2]
        print("cens resize:");
        print(points);
        print("cens org:");   
        print(points2);
        print("delta");
        print(points.cpu()-points2);
        input("wait");
        '''
        points=torch.floor(calc_centroid(sample['label_org']))  
        for i in range(6):
            theta_label[:,i,0,0]=(81.0-1.0)/(W-1.0);
            theta_label[:,i,1,1]=(81.0-1.0)/(H-1.0);
            theta_label[:,i,0,2]=-1+2*points[:,i,0]/(W-1.0);
            theta_label[:,i,1,2]=-1+2*points[:,i,1]/(H-1.0); 
        if (torch.min(theta_label)<-1 or torch.max(theta_label)>1):
            print("F**K");
            print(k);

        '''
        for i in range(sample['image'].shape[0]):            
            if (not os.path.exists("./data/select_pre/"+train_data.get_namelist()[(k+i)%2000])):
                os.mkdir("./data/select_pre/"+train_data.get_namelist()[(k+i)%2000]);
            image=sample['image_org'][i].cpu().clone();                                 
            image=transforms.ToPILImage()(image).convert('RGB')
            plt.imshow(image);        
            plt.show(block=True);                
            image.save('./data/select_pre/'+train_data.get_namelist()[(k+i)%2000]+'/'+str((k+i)//2000)+'_img'+'.jpg',quality=100);
            for j in range(6):                            
                affine_stage2=F.affine_grid(theta_label[i][j].unsqueeze(0),(1,1,81,81),align_corners=True);    
                image=F.grid_sample(sample['label_org'][i][label_list[j][1]].unsqueeze(0).unsqueeze(0).to(device),affine_stage2,align_corners=True);
                image=image.squeeze(0).cpu();                                                
                image=transforms.ToPILImage()(image);
                image.save('./data/select_pre/'+train_data.get_namelist()[(k+i)%2000]+'/'+str((k+i)//2000)+'_'+str(j)+'_thetalabel'+'.jpg',quality=100);
                image=sample['label_org'][i][label_list[j][1]]
                image=transforms.ToPILImage()(image);
                image.save('./data/select_pre/'+train_data.get_namelist()[(k+i)%2000]+'/'+str((k+i)//2000)+'_'+str(j)+'_orglabel'+'.jpg',quality=100);
                #plt.imshow(image);        
                #plt.show(block=True);      
        '''
        
        k+=sample['image'].shape[0];
        loss=fun.smooth_l1_loss(theta, theta_label); 
        '''
        now_time=time.time();
        part1_time+=now_time-prev_time;        
        prev_time=now_time;  
        '''
        loss.backward()
                
        optimizer_select.step();
        optimizer_stage1.step();          
        '''
        now_time=time.time();
        part2_time+=now_time-prev_time;        
        prev_time=now_time;
        '''
        if (batch_idx%250==0):
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(sample['image']), len(train_data.get_loader().dataset),
                100. * batch_idx / len(train_data.get_loader()),loss))
            '''
Esempio n. 7
0
def printoutput():
    CheckBest=True;
    if (CheckBest):
        model_stage1=torch.load("./preBestnet_stage1",map_location="cpu")
        model_select=torch.load("./preBestnet_select",map_location="cpu")
    else:
        model_stage1=torch.load("./preNetdata_stage1",map_location="cpu")
        model_select=torch.load("./preNetdata_select",map_location="cpu")
    model_select.select.change_device(device);
    if (use_gpu):
        model_stage1=model_stage1.to(device)
        model_select=model_select.to(device)
    unloader = transforms.ToPILImage()
    k=0;   
    loss=0;
    for sample in test_data.get_loader():
        if (use_gpu):
            sample['image']=sample['image'].to(device)            
            sample['label']=sample['label'].to(device)            
            sample['size'][0]=sample['size'][0].to(device);
            sample['size'][1]=sample['size'][1].to(device);
            sample['label_org']=sample['label_org'].to(device)    

        stage1_label=model_stage1(sample['image'])
        theta=model_select(stage1_label,sample['size'])                
        theta_label = torch.zeros((sample['image'].size()[0],6,2,3),device=device,requires_grad=False); #[batch_size,6,2,3]            
        W=1024.0;
        H=1024.0;
        '''
        cens = torch.floor(calc_centroid_old(sample['label'])) #[batch_size,9,2]           
        for i in range(sample['image'].size()[0]):    
            for j in range(9):
                cens[i,j,0]=cens[i,j,0]*(sample['size'][0][i]-1.0)/(128.0-1.0)
                cens[i,j,1]=cens[i,j,1]*(sample['size'][1][i]-1.0)/(128.0-1.0)        
        points = torch.floor(torch.cat([cens[:, 1:6],cens[:, 6:9].mean(dim=1, keepdim=True)],dim=1)) #[batch_size,6,2]                    
        points2 = torch.floor(calc_centroid(sample['label_org'])) #[batch_size,9,2]
        theta_label2 = torch.zeros((sample['image'].size()[0],6,2,3),device=device,requires_grad=False); #[batch_size,6,2,3]    
        for i in range(6):
            theta_label2[:,i,0,0]=(81.0-1.0)/(W-1.0);
            theta_label2[:,i,1,1]=(81.0-1.0)/(H-1.0);
            theta_label2[:,i,0,2]=-1+2*points2[:,i,0]/(W-1.0);
            theta_label2[:,i,1,2]=-1+2*points2[:,i,1]/(H-1.0);           
        if (abs(torch.max(points.cpu()-points2.cpu()))>20 or abs(torch.min(points.cpu()-points2.cpu()))>20):
            print("points resize:");
            print(points);
            print("points org:");   
            print(points2);
            print("delta");
            print(points.cpu()-points2.cpu());
            for i in range(sample['image'].size()[0]):
                print(test_data.get_namelist()[k+i]);
            input("wait");
        '''        
        points=torch.floor(calc_centroid(sample['label_org']))    
        for i in range(6):
            theta_label[:,i,0,0]=(81.0-1.0)/(W-1.0);
            theta_label[:,i,1,1]=(81.0-1.0)/(H-1.0);
            theta_label[:,i,0,2]=-1+2*points[:,i,0]/(W-1.0);
            theta_label[:,i,1,2]=-1+2*points[:,i,1]/(H-1.0);               
        
        loss+=fun.smooth_l1_loss(theta, theta_label).detach().data; 
        
        '''        
        f=open("1.txt",mode='w');
        print(sample['label'].size())
        for k in range(9):
            for i in range(sample['label'][0][k].size()[0]):
                for j in range(sample['label'][0][k].size()[1]):
                    print(float(sample['label'][0][k][i][j].data),end=' ',file=f);
                print(file=f);
        f.close();
        for i in range(batch_size):
            for j in range(9):
                image=sample['label'][i][j].cpu().clone();                                 
                image=transforms.ToPILImage()(image).convert('L')
                plt.imshow(image);
                plt.show(block=True);   
        input("check")
        '''
        
        '''
        print(theta);
        print(theta_label);
        input("check")
        '''
        output=[];
        for i in range(sample['image'].size()[0]):      
            '''
            if (test_data.get_namelist()[k]=="13601661_1"):
                print(cens[i]);
                print(points[i]);
                print(theta_label[i]);
                input("wait")
            '''
            path=pre_output_path+'/'+test_data.get_namelist()[k];   
            if not os.path.exists(path):
                os.makedirs(path);                
            image=sample['image'][i].cpu().clone();                
            image =unloader(image)
            image.save(path+'/'+test_data.get_namelist()[k]+'_img.jpg',quality=100);      
                      
            image=sample['image_org'][i].cpu().clone();               
            image2 =unloader(image)
            image2.save(path+'/'+test_data.get_namelist()[k]+'_img_org.jpg',quality=100);                            
        
            output2=stage1_label[i].cpu().clone();
            output2=torch.softmax(output2, dim=0).argmax(dim=0, keepdim=False) 
            output2=output2.unsqueeze(0);            
            output2=torch.zeros(9,128,128).scatter_(0, output2, 255); 
                
            output.append(output2);             
            for j in range(9):                                
                image3=unloader(np.uint8(output[i][j].numpy()))
                #image3=transforms.ToPILImage()(output[i][j].cpu().detach()).convert('L')
                image3.save(path+'/'+'stage1_'+test_data.get_namelist()[k]+'lbl0'+str(j)+'.jpg',quality=100);         
                image3=transforms.ToPILImage()(sample['label_org'][i][j].cpu().detach()).convert('L')
                image3.save(path+'/'+test_data.get_namelist()[k]+'lbl0'+str(j)+'_label'+'.jpg',quality=100);  
            image=image.to(device).unsqueeze(0);            
            for j in range(6):   
                affine_stage2=F.affine_grid(theta[i,j].unsqueeze(0),(1,3,81,81),align_corners=True);                                
                image3=F.grid_sample(sample['image_org'][i].to(device).unsqueeze(0),affine_stage2,align_corners=True);   
                image3=unloader(image3[0].cpu());
                image3.save(path+'/'+'select_'+test_data.get_namelist()[k]+'lbl0'+str(j)+'.jpg',quality=100);                   
                affine_stage2=F.affine_grid(theta_label[i,j].unsqueeze(0),(1,3,81,81),align_corners=True);                                
                image3=F.grid_sample(sample['image_org'][i].to(device).unsqueeze(0),affine_stage2,align_corners=True);   
                image3=unloader(image3[0].cpu());
                image3.save(path+'/'+'select_'+test_data.get_namelist()[k]+'lbl0'+str(j)+'_thetalabel'+'.jpg',quality=100);   
                '''
                affine_stage2=F.affine_grid(theta_label2[i,j].unsqueeze(0),(1,3,81,81),align_corners=True);                                
                image3=F.grid_sample(sample['image_org'][i].to(device).unsqueeze(0),affine_stage2,align_corners=True);   
                image3=unloader(image3[0].cpu());
                image3.save(path+'/'+'select_'+test_data.get_namelist()[k]+'lbl0'+str(j)+'_thetalabel2'+'.jpg',quality=100);   
                '''
                affine_stage2=F.affine_grid(theta[i,j].unsqueeze(0),(1,1,81,81),align_corners=True);    
                image3=F.grid_sample(sample['label_org'][i][label_list[j][1]].unsqueeze(0).unsqueeze(0),affine_stage2,align_corners=True);                   
                image3=transforms.ToPILImage()(image3[0].squeeze(0).cpu().detach()).convert('L');
                image3.save(path+'/'+'select_'+test_data.get_namelist()[k]+'lbl0'+str(j)+'_label'+'.jpg',quality=100);   
                affine_stage2=F.affine_grid(theta_label[i,j].unsqueeze(0),(1,1,81,81),align_corners=True);           
                image3=F.grid_sample(sample['label_org'][i][label_list[j][1]].unsqueeze(0).unsqueeze(0),affine_stage2,align_corners=True);   
                image3=transforms.ToPILImage()(image3[0].squeeze(0).cpu().detach()).convert('L');
                image3.save(path+'/'+'select_'+test_data.get_namelist()[k]+'lbl0'+str(j)+'_label'+'_thetalabel'+'.jpg',quality=100);   
            k+=1
            if (k>=test_data.get_len()):break        
            
        if (k>=test_data.get_len()):break   
    loss/=len(val_data.get_loader().dataset);
    print('\nTest set: {} Cases,Average loss: {:.8f}\n'.format(len(test_data.get_loader().dataset),loss))
    print("printoutput Finish");