def test(): model_stage1.eval(); model_select.eval(); global bestloss test_loss=0 for sample in val_data.get_loader(): if (use_gpu): sample['image']=sample['image'].to(device) sample['label']=sample['label'].to(device) sample['size'][0]=sample['size'][0].to(device); sample['size'][1]=sample['size'][1].to(device); stage1_label=model_stage1(sample['image']) theta=model_select(stage1_label,sample['size']) theta_label = torch.zeros((sample['image'].size()[0],6,2,3),device=device,requires_grad=False); #[batch_size,6,2,3] W=1024.0; H=1024.0; ''' cens = torch.floor(calc_centroid_old(sample['label'])) #[batch_size,9,2] for i in range(sample['image'].size()[0]): for j in range(9): cens[i,j,0]=cens[i,j,0]*(sample['size'][0][i]-1.0)/(128.0-1.0) cens[i,j,1]=cens[i,j,1]*(sample['size'][1][i]-1.0)/(128.0-1.0) points = torch.floor(torch.cat([cens[:, 1:6],cens[:, 6:9].mean(dim=1, keepdim=True)],dim=1)) #[batch_size,6,2] ''' points=torch.floor(calc_centroid(sample['label_org'])) for i in range(6): theta_label[:,i,0,0]=(81.0-1.0)/(W-1.0); theta_label[:,i,1,1]=(81.0-1.0)/(H-1.0); theta_label[:,i,0,2]=-1+2*points[:,i,0]/(W-1.0); theta_label[:,i,1,2]=-1+2*points[:,i,1]/(H-1.0); loss=fun.smooth_l1_loss(theta, theta_label); test_loss+=fun.smooth_l1_loss(theta, theta_label).data; test_loss/=len(val_data.get_loader().dataset) print('\nTest set: {} Cases,Average loss: {:.8f}\n'.format( len(test_data.get_loader().dataset),test_loss)) loss_list.append(test_loss.data.cpu().numpy()); if (test_loss<bestloss): bestloss=test_loss print("Best data Stage1 Updata\n"); torch.save(model_stage1,"./preBestNet_stage1") torch.save(model_select,"./preBestNet_select")
def train(epoch): model_stage2.train() ''' part1_time=0; part2_time=0; part3_time=0; prev_time=time.time(); ''' unloader = transforms.ToPILImage() losstmp = 0 lossstep = 0 k = 0 for batch_idx, sample in enumerate(train_data.get_loader()): ''' now_time=time.time(); part3_time+=now_time-prev_time; prev_time=now_time; ''' if (use_gpu): sample['image'] = sample['image'].to(device) sample['label'] = sample['label'].to(device) sample['image_org'] = sample['image_org'].to(device) sample['label_org'] = sample['label_org'].to(device) sample['size'][0] = sample['size'][0].to(device) sample['size'][1] = sample['size'][1].to(device) ''' for i in range(batch_size): image=sample['image_org'][i].cpu().clone(); image=transforms.ToPILImage()(image).convert('RGB') plt.imshow(image); plt.show(block=True); ''' optimizer_stage2.zero_grad() theta_label = torch.zeros((sample['image'].size()[0], 6, 2, 3), device=device, requires_grad=False) #[batch_size,6,2,3] W = 1024.0 H = 1024.0 points = torch.floor(calc_centroid(sample['label_org'])) for i in range(6): theta_label[:, i, 0, 0] = (81.0 - 1.0) / (W - 1.0) theta_label[:, i, 1, 1] = (81.0 - 1.0) / (H - 1.0) theta_label[:, i, 0, 2] = -1 + 2 * points[:, i, 0] / (W - 1.0) theta_label[:, i, 1, 2] = -1 + 2 * points[:, i, 1] / (H - 1.0) stage2_label = model_stage2(sample['image_org'], theta_label) parts2 = [] parts_label2 = [] loss = [] for i in range(6): affine_stage2 = F.affine_grid( theta_label[:, i], (sample['image'].size()[0], 3, 81, 81), align_corners=True) parts2.append( F.grid_sample(sample['image_org'], affine_stage2, align_corners=True)) affine_stage2 = F.affine_grid( theta_label[:, i], (sample['image'].size()[0], label_channel[i], 81, 81), align_corners=True) parts_label2.append( F.grid_sample(sample['label_org'][:, label_list[i]], affine_stage2, align_corners=True)) parts_label2[i][:, 0] += 0.00001 parts_label2[i] = parts_label2[i].detach() ''' for j in range(sample['image'].size()[0]): if (not os.path.exists("./data/trainimg_output/"+train_data.get_namelist()[(k+j)%2000])): os.mkdir("./data/trainimg_output/"+train_data.get_namelist()[(k+j)%2000]); image3=transforms.ToPILImage()(sample['image_org'][j].cpu().clone()).convert('RGB') image3.save("./data/trainimg_output/"+train_data.get_namelist()[(k+j)%2000]+'/'+str((k+j)//2000)+'_orgimage'+'.jpg',quality=100); image3=transforms.ToPILImage()(parts2[i][j].cpu().clone()).convert('RGB') image3.save("./data/trainimg_output/"+train_data.get_namelist()[(k+j)%2000]+'/'+str((k+j)//2000)+'lbl0'+str(i)+'_thetalabel'+'.jpg',quality=100); image3=unloader(np.uint8(parts_label2[i][j][1].cpu().detach().numpy())) image3.save("./data/trainimg_output/"+train_data.get_namelist()[(k+j)%2000]+'/'+str((k+j)//2000)+'lbl0'+str(i)+'_label'+'_thetalabel'+'.jpg',quality=100); ''' loss_tmp = fun.cross_entropy( stage2_label[i], parts_label2[i].argmax(dim=1, keepdim=False)) loss.append(loss_tmp) k += sample['image'].size()[0] if (batch_idx % 100 == 0): print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(sample['image']), len(train_data.get_loader().dataset), 100. * batch_idx / len(train_data.get_loader()), torch.sum(torch.stack(loss)))) ''' now_time=time.time(); part1_time+=now_time-prev_time; prev_time=now_time; ''' loss = torch.stack(loss) loss.backward(torch.ones(6, device=device, requires_grad=False)) losstmp += torch.sum(loss).item() lossstep += sample['image'].size()[0] optimizer_stage2.step() '''
def printoutput(): model_stage2 = torch.load("./preBestNet_stage2", map_location="cpu") if (use_gpu): model_stage2 = model_stage2.to(device) model_stage2.eval() global bestloss, bestf1 test_loss = 0 hists = [] k = 0 for sample in test_data.get_loader(): if (use_gpu): sample['image'] = sample['image'].to(device) sample['label'] = sample['label'].to(device) sample['image_org'] = sample['image_org'].to(device) sample['label_org'] = sample['label_org'].to(device) sample['size'][0] = sample['size'][0].to(device) sample['size'][1] = sample['size'][1].to(device) theta_label = torch.zeros((sample['image'].size()[0], 6, 2, 3), device=device, requires_grad=False) #[batch_size,6,2,3] W = 1024.0 H = 1024.0 points = torch.floor(calc_centroid(sample['label_org'])) for i in range(6): theta_label[:, i, 0, 0] = (81.0 - 1.0) / (W - 1.0) theta_label[:, i, 1, 1] = (81.0 - 1.0) / (H - 1.0) theta_label[:, i, 0, 2] = -1 + 2 * points[:, i, 0] / (W - 1.0) theta_label[:, i, 1, 2] = -1 + 2 * points[:, i, 1] / (H - 1.0) stage2_label = model_stage2(sample['image_org'], theta_label) #parts=[]; parts_label = [] for i in range(6): #affine_stage2=F.affine_grid(theta[:,i],(sample['image'].size()[0],3,81,81),align_corners=True); #parts.append(F.grid_sample(sample['image'],affine_stage2),align_corners=True); affine_stage2 = F.affine_grid( theta_label[:, i], (sample['image'].size()[0], label_channel[i], 81, 81), align_corners=True) parts_label.append( F.grid_sample(sample['label_org'][:, label_list[i]], affine_stage2, align_corners=True)) parts_label[i][:, 0] += 0.00001 test_loss += fun.cross_entropy( stage2_label[i], parts_label[i].argmax(dim=1, keepdim=False)).data for j in range(sample['image'].shape[0]): path = pre_stage2_output_path + '/' + test_data.get_namelist()[ k + j] if not os.path.exists(path): os.makedirs(path) image = TF.to_pil_image( parts_label[i][j][1].unsqueeze(0).cpu()) image.save(path + '/' + test_data.get_namelist()[k + j] + 'lbl0' + str(i) + '_label.jpg', quality=100) image = torch.softmax(stage2_label[i][j].cpu(), dim=0).argmax(dim=0, keepdim=True) image = torch.zeros(label_channel[i], 81, 81).scatter_(0, image, 1) image = TF.to_pil_image(image[1].unsqueeze(0), mode="L") image.save(path + '/' + test_data.get_namelist()[k + j] + 'lbl0' + str(i) + '_train.jpg', quality=100) output_2 = torch.softmax(stage2_label[i], dim=1).argmax(dim=1, keepdim=False) output_2 = output_2.cpu().clone() target_2 = parts_label[i].argmax(dim=1, keepdim=False) target_2 = target_2.cpu().clone() hist = np.bincount(9 * target_2.reshape([-1]) + output_2.reshape([-1]), minlength=81).reshape(9, 9) hists.append(hist) k += sample['image'].shape[0] hists_sum = np.sum(np.stack(hists, axis=0), axis=0) tp = 0 tpfn = 0 tpfp = 0 f1score = 0.0 for i in range(9): for j in range(9): print(hists_sum[i][j], end=' ') print() for i in range(1, 9): tp += hists_sum[i][i].sum() tpfn += hists_sum[i, :].sum() tpfp += hists_sum[:, i].sum() f1score = 2 * tp / (tpfn + tpfp) test_loss /= len(test_data.get_loader().dataset) print('\nPrintoutput Average loss: {:.4f}\n'.format(test_loss)) print("STN-iCNN stage2 tp=", tp) print("STN-iCNN stage2 tpfp=", tpfp) print("STN-iCNN tstage2 pfn=", tpfn) print('\nPrintoutputF1 Score: {:.4f}\n'.format(f1score)) print("printoutput Finish")
def test(epoch): model_stage2.eval() global bestloss, bestf1 test_loss = 0 hists = [] for sample in val_data.get_loader(): if (use_gpu): sample['image'] = sample['image'].to(device) sample['label'] = sample['label'].to(device) sample['image_org'] = sample['image_org'].to(device) sample['label_org'] = sample['label_org'].to(device) sample['size'][0] = sample['size'][0].to(device) sample['size'][1] = sample['size'][1].to(device) theta_label = torch.zeros((sample['image'].size()[0], 6, 2, 3), device=device, requires_grad=False) #[batch_size,6,2,3] W = 1024.0 H = 1024.0 points = torch.floor(calc_centroid(sample['label_org'])) for i in range(6): theta_label[:, i, 0, 0] = (81.0 - 1.0) / (W - 1.0) theta_label[:, i, 1, 1] = (81.0 - 1.0) / (H - 1.0) theta_label[:, i, 0, 2] = -1 + 2 * points[:, i, 0] / (W - 1.0) theta_label[:, i, 1, 2] = -1 + 2 * points[:, i, 1] / (H - 1.0) stage2_label = model_stage2(sample['image_org'], theta_label) #parts=[]; parts_label = [] for i in range(6): #affine_stage2=F.affine_grid(theta[:,i],(sample['image'].size()[0],3,81,81),align_corners=True); #parts.append(F.grid_sample(sample['image'],affine_stage2),align_corners=True); affine_stage2 = F.affine_grid( theta_label[:, i], (sample['image'].size()[0], label_channel[i], 81, 81), align_corners=True) parts_label.append( F.grid_sample(sample['label_org'][:, label_list[i]], affine_stage2, align_corners=True)) parts_label[i][:, 0] += 0.00001 test_loss += fun.cross_entropy( stage2_label[i], parts_label[i].argmax(dim=1, keepdim=False)).data output_2 = torch.softmax(stage2_label[i], dim=1).argmax(dim=1, keepdim=False) output_2 = output_2.cpu().clone() target_2 = parts_label[i].argmax(dim=1, keepdim=False) target_2 = target_2.cpu().clone() hist = np.bincount(9 * target_2.reshape([-1]) + output_2.reshape([-1]), minlength=81).reshape(9, 9) hists.append(hist) hists_sum = np.sum(np.stack(hists, axis=0), axis=0) for i in range(9): for j in range(9): print(hists_sum[i][j], end=' ') print() print() tp = 0 tpfn = 0 tpfp = 0 f1score = 0.0 for i in range(1, 9): tp += hists_sum[i][i].sum() tpfn += hists_sum[i, :].sum() tpfp += hists_sum[:, i].sum() f1score = 2 * tp / (tpfn + tpfp) test_loss /= len(test_data.get_loader().dataset) print('\nTest set: {} Cases,Average loss: {:.4f}\n'.format( len(test_data.get_loader().dataset), test_loss)) print("STN-iCNN tp=", tp) print("STN-iCNN tpfp=", tpfp) print("STN-iCNN tpfn=", tpfn) print('\nTest set: {} Cases,F1 Score: {:.4f}\n'.format( len(test_data.get_loader().dataset), f1score)) loss_list.append(test_loss.data.cpu().numpy()) f1_list.append(f1score) if (UseF1): if (f1score > bestf1): bestf1 = f1score print("Best data Updata\n") torch.save(model_stage2, "./preBestNet_stage2") else: if (test_loss < bestloss): bestloss = test_loss print("Best data Updata\n") torch.save(model_stage2, "./preBestNet_stage2")
def printoutput(in_data): unloader = transforms.ToPILImage() k = 0 hists = [] global label_list for sample in in_data.get_loader(): if (use_gpu): sample['image_org'] = sample['image_org'].to(device) sample['label_org'] = sample['label_org'].to(device) N = sample['image_org'].shape[0] theta_label = torch.zeros((N, 6, 2, 3), device=device, requires_grad=False) #[batch_size,6,2,3] W = 1024.0 H = 1024.0 cens = torch.floor(calc_centroid( sample['label_org'])) #[batch_size,9,2] points = torch.floor( torch.cat([cens[:, 1:6], cens[:, 6:9].mean(dim=1, keepdim=True)], dim=1)) #[batch_size,6,2] for i in range(6): theta_label[:, i, 0, 0] = (81.0 - 1.0) / (W - 1.0) theta_label[:, i, 1, 1] = (81.0 - 1.0) / (H - 1.0) theta_label[:, i, 0, 2] = -1 + 2 * points[:, i, 0] / (W - 1.0) theta_label[:, i, 1, 2] = -1 + 2 * points[:, i, 1] / (H - 1.0) parts = [] parts_label = [] for i in range(6): affine_stage2 = F.affine_grid(theta_label[:, i], (N, 3, 81, 81), align_corners=True) parts.append( F.grid_sample(sample['image_org'], affine_stage2, align_corners=True)) affine_stage2 = F.affine_grid(theta_label[:, i], (N, label_channel[i], 81, 81), align_corners=True) parts_label.append( F.grid_sample(sample['label_org'][:, label_list[i]], affine_stage2, align_corners=True)) parts_label[i][:, 0] += 0.00001 for j in range(sample['image_org'].shape[0]): parts_label_tmp = parts_label[i][j].argmax(dim=0, keepdim=True) parts_label[i][j] = torch.zeros(label_channel[i], 81, 81).to(device).scatter_( 0, parts_label_tmp, 255) for j in range(N): path = "./data/facial_parts/" + in_data.get_namelist()[k + j] if (not os.path.exists(path)): os.mkdir(path) image3 = transforms.ToPILImage()( parts[i][j].cpu().clone()).convert('RGB') image3.save(path + '/' + 'lbl0' + str(i) + '_img' + '.jpg', quality=100) for l in range(label_channel[i]): image3 = unloader( np.uint8(parts_label[i][j][l].cpu().detach().numpy())) image3.save(path + '/' + 'lbl0' + str(i) + '_label0' + str(l) + '.jpg', quality=100) k += N if (k % 200 == 0): print(k) print("Printoutput Finish!")
def train(epoch): model_stage1.train(); model_select.train(); ''' part1_time=0; part2_time=0; part3_time=0; prev_time=time.time(); ''' k=0; for batch_idx,sample in enumerate(train_data.get_loader()): ''' now_time=time.time(); part3_time+=now_time-prev_time; prev_time=now_time; ''' if (use_gpu): sample['image']=sample['image'].to(device) sample['label']=sample['label'].to(device) sample['size'][0]=sample['size'][0].to(device); sample['size'][1]=sample['size'][1].to(device); optimizer_stage1.zero_grad(); optimizer_select.zero_grad(); stage1_label=model_stage1(sample['image']) theta=model_select(stage1_label,sample['size']) theta_label = torch.zeros((sample['image'].size()[0],6,2,3),device=device,requires_grad=False); #[batch_size,6,2,3] W=1024.0; H=1024.0; ''' cens = torch.floor(calc_centroid_old(sample['label'])) #[batch_size,9,2] for i in range(sample['image'].size()[0]): for j in range(9): cens[i,j,0]=cens[i,j,0]*(sample['size'][0][i]-1.0)/(128.0-1.0) cens[i,j,1]=cens[i,j,1]*(sample['size'][1][i]-1.0)/(128.0-1.0) points = torch.floor(torch.cat([cens[:, 1:6],cens[:, 6:9].mean(dim=1, keepdim=True)],dim=1)) #[batch_size,6,2] ''' ''' points2 = torch.floor(calc_centroid(sample['label_org'])) #[batch_size,9,2] print("cens resize:"); print(points); print("cens org:"); print(points2); print("delta"); print(points.cpu()-points2); input("wait"); ''' points=torch.floor(calc_centroid(sample['label_org'])) for i in range(6): theta_label[:,i,0,0]=(81.0-1.0)/(W-1.0); theta_label[:,i,1,1]=(81.0-1.0)/(H-1.0); theta_label[:,i,0,2]=-1+2*points[:,i,0]/(W-1.0); theta_label[:,i,1,2]=-1+2*points[:,i,1]/(H-1.0); if (torch.min(theta_label)<-1 or torch.max(theta_label)>1): print("F**K"); print(k); ''' for i in range(sample['image'].shape[0]): if (not os.path.exists("./data/select_pre/"+train_data.get_namelist()[(k+i)%2000])): os.mkdir("./data/select_pre/"+train_data.get_namelist()[(k+i)%2000]); image=sample['image_org'][i].cpu().clone(); image=transforms.ToPILImage()(image).convert('RGB') plt.imshow(image); plt.show(block=True); image.save('./data/select_pre/'+train_data.get_namelist()[(k+i)%2000]+'/'+str((k+i)//2000)+'_img'+'.jpg',quality=100); for j in range(6): affine_stage2=F.affine_grid(theta_label[i][j].unsqueeze(0),(1,1,81,81),align_corners=True); image=F.grid_sample(sample['label_org'][i][label_list[j][1]].unsqueeze(0).unsqueeze(0).to(device),affine_stage2,align_corners=True); image=image.squeeze(0).cpu(); image=transforms.ToPILImage()(image); image.save('./data/select_pre/'+train_data.get_namelist()[(k+i)%2000]+'/'+str((k+i)//2000)+'_'+str(j)+'_thetalabel'+'.jpg',quality=100); image=sample['label_org'][i][label_list[j][1]] image=transforms.ToPILImage()(image); image.save('./data/select_pre/'+train_data.get_namelist()[(k+i)%2000]+'/'+str((k+i)//2000)+'_'+str(j)+'_orglabel'+'.jpg',quality=100); #plt.imshow(image); #plt.show(block=True); ''' k+=sample['image'].shape[0]; loss=fun.smooth_l1_loss(theta, theta_label); ''' now_time=time.time(); part1_time+=now_time-prev_time; prev_time=now_time; ''' loss.backward() optimizer_select.step(); optimizer_stage1.step(); ''' now_time=time.time(); part2_time+=now_time-prev_time; prev_time=now_time; ''' if (batch_idx%250==0): print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(sample['image']), len(train_data.get_loader().dataset), 100. * batch_idx / len(train_data.get_loader()),loss)) '''
def printoutput(): CheckBest=True; if (CheckBest): model_stage1=torch.load("./preBestnet_stage1",map_location="cpu") model_select=torch.load("./preBestnet_select",map_location="cpu") else: model_stage1=torch.load("./preNetdata_stage1",map_location="cpu") model_select=torch.load("./preNetdata_select",map_location="cpu") model_select.select.change_device(device); if (use_gpu): model_stage1=model_stage1.to(device) model_select=model_select.to(device) unloader = transforms.ToPILImage() k=0; loss=0; for sample in test_data.get_loader(): if (use_gpu): sample['image']=sample['image'].to(device) sample['label']=sample['label'].to(device) sample['size'][0]=sample['size'][0].to(device); sample['size'][1]=sample['size'][1].to(device); sample['label_org']=sample['label_org'].to(device) stage1_label=model_stage1(sample['image']) theta=model_select(stage1_label,sample['size']) theta_label = torch.zeros((sample['image'].size()[0],6,2,3),device=device,requires_grad=False); #[batch_size,6,2,3] W=1024.0; H=1024.0; ''' cens = torch.floor(calc_centroid_old(sample['label'])) #[batch_size,9,2] for i in range(sample['image'].size()[0]): for j in range(9): cens[i,j,0]=cens[i,j,0]*(sample['size'][0][i]-1.0)/(128.0-1.0) cens[i,j,1]=cens[i,j,1]*(sample['size'][1][i]-1.0)/(128.0-1.0) points = torch.floor(torch.cat([cens[:, 1:6],cens[:, 6:9].mean(dim=1, keepdim=True)],dim=1)) #[batch_size,6,2] points2 = torch.floor(calc_centroid(sample['label_org'])) #[batch_size,9,2] theta_label2 = torch.zeros((sample['image'].size()[0],6,2,3),device=device,requires_grad=False); #[batch_size,6,2,3] for i in range(6): theta_label2[:,i,0,0]=(81.0-1.0)/(W-1.0); theta_label2[:,i,1,1]=(81.0-1.0)/(H-1.0); theta_label2[:,i,0,2]=-1+2*points2[:,i,0]/(W-1.0); theta_label2[:,i,1,2]=-1+2*points2[:,i,1]/(H-1.0); if (abs(torch.max(points.cpu()-points2.cpu()))>20 or abs(torch.min(points.cpu()-points2.cpu()))>20): print("points resize:"); print(points); print("points org:"); print(points2); print("delta"); print(points.cpu()-points2.cpu()); for i in range(sample['image'].size()[0]): print(test_data.get_namelist()[k+i]); input("wait"); ''' points=torch.floor(calc_centroid(sample['label_org'])) for i in range(6): theta_label[:,i,0,0]=(81.0-1.0)/(W-1.0); theta_label[:,i,1,1]=(81.0-1.0)/(H-1.0); theta_label[:,i,0,2]=-1+2*points[:,i,0]/(W-1.0); theta_label[:,i,1,2]=-1+2*points[:,i,1]/(H-1.0); loss+=fun.smooth_l1_loss(theta, theta_label).detach().data; ''' f=open("1.txt",mode='w'); print(sample['label'].size()) for k in range(9): for i in range(sample['label'][0][k].size()[0]): for j in range(sample['label'][0][k].size()[1]): print(float(sample['label'][0][k][i][j].data),end=' ',file=f); print(file=f); f.close(); for i in range(batch_size): for j in range(9): image=sample['label'][i][j].cpu().clone(); image=transforms.ToPILImage()(image).convert('L') plt.imshow(image); plt.show(block=True); input("check") ''' ''' print(theta); print(theta_label); input("check") ''' output=[]; for i in range(sample['image'].size()[0]): ''' if (test_data.get_namelist()[k]=="13601661_1"): print(cens[i]); print(points[i]); print(theta_label[i]); input("wait") ''' path=pre_output_path+'/'+test_data.get_namelist()[k]; if not os.path.exists(path): os.makedirs(path); image=sample['image'][i].cpu().clone(); image =unloader(image) image.save(path+'/'+test_data.get_namelist()[k]+'_img.jpg',quality=100); image=sample['image_org'][i].cpu().clone(); image2 =unloader(image) image2.save(path+'/'+test_data.get_namelist()[k]+'_img_org.jpg',quality=100); output2=stage1_label[i].cpu().clone(); output2=torch.softmax(output2, dim=0).argmax(dim=0, keepdim=False) output2=output2.unsqueeze(0); output2=torch.zeros(9,128,128).scatter_(0, output2, 255); output.append(output2); for j in range(9): image3=unloader(np.uint8(output[i][j].numpy())) #image3=transforms.ToPILImage()(output[i][j].cpu().detach()).convert('L') image3.save(path+'/'+'stage1_'+test_data.get_namelist()[k]+'lbl0'+str(j)+'.jpg',quality=100); image3=transforms.ToPILImage()(sample['label_org'][i][j].cpu().detach()).convert('L') image3.save(path+'/'+test_data.get_namelist()[k]+'lbl0'+str(j)+'_label'+'.jpg',quality=100); image=image.to(device).unsqueeze(0); for j in range(6): affine_stage2=F.affine_grid(theta[i,j].unsqueeze(0),(1,3,81,81),align_corners=True); image3=F.grid_sample(sample['image_org'][i].to(device).unsqueeze(0),affine_stage2,align_corners=True); image3=unloader(image3[0].cpu()); image3.save(path+'/'+'select_'+test_data.get_namelist()[k]+'lbl0'+str(j)+'.jpg',quality=100); affine_stage2=F.affine_grid(theta_label[i,j].unsqueeze(0),(1,3,81,81),align_corners=True); image3=F.grid_sample(sample['image_org'][i].to(device).unsqueeze(0),affine_stage2,align_corners=True); image3=unloader(image3[0].cpu()); image3.save(path+'/'+'select_'+test_data.get_namelist()[k]+'lbl0'+str(j)+'_thetalabel'+'.jpg',quality=100); ''' affine_stage2=F.affine_grid(theta_label2[i,j].unsqueeze(0),(1,3,81,81),align_corners=True); image3=F.grid_sample(sample['image_org'][i].to(device).unsqueeze(0),affine_stage2,align_corners=True); image3=unloader(image3[0].cpu()); image3.save(path+'/'+'select_'+test_data.get_namelist()[k]+'lbl0'+str(j)+'_thetalabel2'+'.jpg',quality=100); ''' affine_stage2=F.affine_grid(theta[i,j].unsqueeze(0),(1,1,81,81),align_corners=True); image3=F.grid_sample(sample['label_org'][i][label_list[j][1]].unsqueeze(0).unsqueeze(0),affine_stage2,align_corners=True); image3=transforms.ToPILImage()(image3[0].squeeze(0).cpu().detach()).convert('L'); image3.save(path+'/'+'select_'+test_data.get_namelist()[k]+'lbl0'+str(j)+'_label'+'.jpg',quality=100); affine_stage2=F.affine_grid(theta_label[i,j].unsqueeze(0),(1,1,81,81),align_corners=True); image3=F.grid_sample(sample['label_org'][i][label_list[j][1]].unsqueeze(0).unsqueeze(0),affine_stage2,align_corners=True); image3=transforms.ToPILImage()(image3[0].squeeze(0).cpu().detach()).convert('L'); image3.save(path+'/'+'select_'+test_data.get_namelist()[k]+'lbl0'+str(j)+'_label'+'_thetalabel'+'.jpg',quality=100); k+=1 if (k>=test_data.get_len()):break if (k>=test_data.get_len()):break loss/=len(val_data.get_loader().dataset); print('\nTest set: {} Cases,Average loss: {:.8f}\n'.format(len(test_data.get_loader().dataset),loss)) print("printoutput Finish");