def train(self, epoch=0, listTarget=[0, 1], forwardStage=1, need_evaluate=True): if (epoch == 0): epoch = self.epoch self.optimizer = getOptimizer(self.net.parameters(), self.config['train']['optimizer'], self.LR, self.weightDecay) msg = "start training: epoch = %d" % (epoch) self.record.log(msg) print(msg) for j in range(1, epoch + 1): self.net.train() i = 0 total_loss = 0 for mode, label, mask in self.trainloader: self.optimizer.zero_grad() netLabel = Variable(label).type(self.dtype) mask_var = Variable(mask).type(self.dtype) subF = kspace_subsampling_pytorch(netLabel, mask_var) complexFlag = (mode[0] == 'complex') netInput = imgFromSubF_pytorch(subF, complexFlag) netOutput = self.net(netInput, subF, mask_var) loss = self.lossForward(netOutput, netLabel) loss.backward() total_loss = total_loss + loss.item() * label.shape[0] self.optimizer.step() i += label.shape[0] print('Epoch %05d [%04d/%04d] loss %.8f' % (j + self.ckp_epoch, i, self.trainsetSize, loss.item()), '\r', end='') self.record.log_train(j + self.ckp_epoch, total_loss / i) if j % self.saveEpoch == 0: print('Epoch %05d [%04d/%04d] loss %.8f SAVED' % (j + self.ckp_epoch, i, self.trainsetSize, total_loss / self.trainsetSize)) if need_evaluate: l, p1, p2, s1, s2 = self.validation() self.record.log_valid(j + self.ckp_epoch, l, p1, p2, s1, s2) self.record.log( "Evaluate psnr(before|after) = %.2f|%.2f ssim = %.4f|%.4f" % (p1, p2, s1, s2)) self.record.write_to_file(self.net.state_dict(), False) self.ckp_epoch += epoch self.record.write_to_file(self.net.state_dict(), True)
def test(self, mode, label, mask): y = label.numpy() netLabel = Variable(label).type(self.dtype) if (self.mode != 'inNetDC'): assert False, 'only for inNetDC mode' else: mask_var = Variable(mask).type(self.dtype) subF = kspace_subsampling_pytorch(netLabel, mask_var) complexFlag = (mode[0] == 'complex') netInput = imgFromSubF_pytorch(subF, complexFlag) y = y[0] netOutput = self.net(netInput, subF, mask_var) loss = self.lossForward(netOutput, netLabel) netOutput_np = netOutput.cpu().data.numpy() img1 = netOutput_np[0, 0:1].astype('float64') if (netOutput_np.shape[1] == 2): netOutput_np = abs(netOutput_np[:, 0:1] + netOutput_np[:, 1:2] * 1j) img2 = netOutput_np[0].astype('float64') y2 = y[0:1] img1 = np.clip(img1, 0, 1) img2 = np.clip(img2, 0, 1) if (self.isFastMRI): psnrBefore = psnr(y2, img1, 12) psnrAfter = psnr(y2, img2, 12) else: psnrBefore = psnr(y2, img1) psnrAfter = psnr(y2, img2) ssimBefore = ssim(y2[0], img1[0]) ssimAfter = ssim(y2[0], img2[0]) return { "loss": loss.item(), "psnr1": psnrBefore, "psnr2": psnrAfter, "ssim1": ssimBefore, "ssim2": ssimAfter, "result1": img1, "result2": img2, 'label': y2 }
def test(self, mode, name, label, mask, png, CSk): y = label.numpy() netLabel = Variable(label).type(self.dtype) if (self.mode != 'inNetDC'): assert False, 'only for inNetDC mode' else: # netInput = Variable(png).type(self.dtype) # mask_var = Variable(mask).type(self.dtype) # subF = kspace_subsampling_pytorch(netInput,mask_var) # # subF = kspace_subsampling_pytorch(netLabel, mask_var) # complexFlag = (mode[0] == 'complex') # # netInput = imgFromSubF_pytorch(subF, complexFlag) mask_var = Variable(mask).type(self.dtype) subF = kspace_subsampling_pytorch(netLabel, mask_var) complexFlag = (mode[0] == 'complex') netInput = imgFromSubF_pytorch(subF, complexFlag) # 保存图片 x = netInput.cpu().data.numpy() x = x[0] x[0] = self.standardization(x[0]) * 255 # y[0] *= 255 # 变换为0-255的灰度值 x1 = Image.fromarray(x[0]) x1 = x1.convert('L') # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’ x1.save("train_test/input_" + name[0].split('.')[0] + '_y.png') # 检测一下netOutput_np的数据类型,保存为.npy netOutput = self.net(netInput, subF, mask_var) y1 = netOutput.cpu().data.numpy() y1 = y1[0] y1[0][y1[0] < 0] = 0 y1[0] = self.standardization(y1[0]) * 255 # y[0] *= 255 # 变换为0-255的灰度值 y2 = Image.fromarray(y1[0]) y2 = y2.convert('L') # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’ y2.save("train_test/output_" + name[0].split('.')[0] + '_y.png') loss = self.lossForward(netOutput, netLabel) netOutput_np = netOutput.cpu().data.numpy() img1 = netOutput_np[0, 0:1].astype('float64') # im = Image.fromarray((img1[0]*255)) # scipy.misc.imsave(name, img1[0]) if (netOutput_np.shape[1] == 2): netOutput_np = abs(netOutput_np[:, 0:1] + netOutput_np[:, 1:2] * 1j) img2 = netOutput_np[0].astype('float64') y = y[0] y2 = y[0:1] y2 = self.standardization(y2) img1 = np.clip(img1, 0, 1) img2 = np.clip(img2, 0, 1) if (self.isFastMRI): psnrBefore = psnr(y2, img1, 12) psnrAfter = psnr(y2, img2, 12) else: psnrBefore = psnr(y2, img1) psnrAfter = psnr(y2, img2) ssimBefore = ssim(y2[0], img1[0]) ssimAfter = ssim(y2[0], img2[0]) return { "loss": loss.item(), "psnr1": psnrBefore, "psnr2": psnrAfter, "ssim1": ssimBefore, "ssim2": ssimAfter, "result1": img1, "result2": img2, 'label': y2 }
def train(self, epoch=0, listTarget=[0, 1], forwardStage=1, need_evaluate=True): if (epoch == 0): epoch = self.epoch self.optimizer = getOptimizer(self.net.parameters(), self.config['train']['optimizer'], self.LR, self.weightDecay) msg = "start training: epoch = %d" % (epoch) self.record.log(msg) print(msg) for j in range(1, epoch + 1): self.net.train() i = 0 total_loss = 0 for mode, name, label, mask, png, CSk in self.trainloader: # self.optimizer.zero_grad() # # label获取满采样图片 # # png 为降采样图片 # # CSk为降采样k-data netLabel = Variable(label).type(self.dtype) mask_var = Variable(mask).type(self.dtype) netInput = Variable(png).type(self.dtype) subF0 = Variable(CSk).type(self.dtype) subF0 = subF0.permute(0, 2, 3, 1) complexFlag = (mode[0] == 'complex') mask_var = Variable(mask).type(self.dtype) # 将满采样图片和mask做下面函数,得到降采样图的k-data subF = kspace_subsampling_pytorch(netLabel, mask_var) # 得到的subf是和提供的CSK一致的,即mask一致,netInput一致 # 将降采样图k-data转化为图片 netInput = imgFromSubF_pytorch(subF, complexFlag) # 将降采样图和k-data送入网络 netOutput = self.net(netInput, subF, mask_var) # 满采图和网络输出做loss loss = self.lossForward(netOutput, netLabel) # ----------------------------------------------------------------------------------------------- y = netLabel.cpu().data.numpy() y = y[0] y[0] = self.standardization(y[0]) * 255 # y[0] *= 255 # 变换为0-255的灰度值 y1 = Image.fromarray(y[0]) y1 = y1.convert('L') # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’ y1.save("test2/label1_" + name[0].split('.')[0] + '_y.png') y = subF.cpu().data.numpy() y = y[0] y = self.standardization(y[:, :, 0]) * 255 # y[0] *= 255 # 变换为0-255的灰度值 y1 = Image.fromarray(y) y1 = y1.convert('L') # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’ y1.save("test2/subF1_" + name[0].split('.')[0] + '_y.png') # subf逆 x = netInput.cpu().data.numpy() x = x[0] x[0] = self.standardization(x[0]) * 255 # y[0] *= 255 # 变换为0-255的灰度值 x1 = Image.fromarray(x[0]) x1 = x1.convert('L') # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’ x1.save("test2/input1_" + name[0].split('.')[0] + '_y.png') y = netOutput.cpu().data.numpy() y = y[0] y[0][y[0] < 0] = 0 y[0] = self.standardization(y[0]) * 255 # y[0] *= 255 # 变换为0-255的灰度值 y1 = Image.fromarray(y[0]) y1 = y1.convert('L') # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’ y1.save("test2/output1_" + name[0].split('.')[0] + '_y.png') # ---------------------------------------------------------------------------------------------- # netLabel = Variable(label).type(self.dtype) # netInput = Variable(png).type(self.dtype) # subF = Variable(CSk).type(self.dtype) # mask_var = Variable(mask).type(self.dtype) # subF = subF.permute(0,2,3,1) # complexFlag = (mode[0] == 'complex') # netOutput = self.net(netInput, subF, mask_var) # loss = self.lossForward(netOutput, netLabel) # # ---------------------------------------------------------------------------------------------- # y = netLabel.cpu().data.numpy() # y = y[0] # y[0] = self.standardization(y[0])*255 # # y[0] *= 255 # 变换为0-255的灰度值 # y1 = Image.fromarray(y[0]) # y1 = y1.convert('L') # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’ # y1.save("test2/label2_"+name[0].split('.')[0]+'_y.png') # y = subF0.cpu().data.numpy() # y = y[0] # y = self.standardization(y[:,:,0])*255 # # y[0] *= 255 # 变换为0-255的灰度值 # y1 = Image.fromarray(y) # y1 = y1.convert('L') # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’ # y1.save("test2/subF2_"+name[0].split('.')[0]+'_y.png') # # subf逆 # x = netInput.cpu().data.numpy() # x = x[0] # x[0] = self.standardization(x[0])*255 # # y[0] *= 255 # 变换为0-255的灰度值 # x1 = Image.fromarray(x[0]) # x1 = x1.convert('L') # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’ # x1.save("test2/input2_"+name[0].split('.')[0]+'_y.png') # y = netOutput.cpu().data.numpy() # y = y[0] # y[0][y[0] < 0] = 0 # y[0] = self.standardization(y[0])*255 # # y[0] *= 255 # 变换为0-255的灰度值 # y1 = Image.fromarray(y[0]) # y1 = y1.convert('L') # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’ # y1.save("test2/output2_"+name[0].split('.')[0]+'_y.png') # --------------------------------------------------------------------------------------------- loss = self.lossForward(netOutput, netLabel) loss.backward() # 计算梯度 total_loss = total_loss + loss.item() * label.shape[0] self.optimizer.step() # 反向传播,更新网络参数 self.optimizer.zero_grad() # 清空梯度 i += label.shape[0] del mode, label, mask, png, CSk print( 'Epoch %05d [%04d/%04d] loss %.8f' % (j + self.ckp_epoch, i, self.trainsetSize, loss.item()), '\r', '') self.record.log_train(j + self.ckp_epoch, total_loss / i) if j % self.saveEpoch == 0: print('Epoch %05d [%04d/%04d] loss %.8f SAVED' % (j + self.ckp_epoch, i, self.trainsetSize, total_loss / self.trainsetSize)) if need_evaluate: l, p1, p2, s1, s2 = self.validation() self.record.log_valid(j + self.ckp_epoch, l, p1, p2, s1, s2) self.record.log( "Evaluate psnr(before|after) = %.2f|%.2f ssim = %.4f|%.4f" % (p1, p2, s1, s2)) self.record.write_to_file(self.net.state_dict(), False) self.ckp_epoch += epoch self.record.write_to_file(self.net.state_dict(), True)