コード例 #1
0
ファイル: core.py プロジェクト: tinyRattar/CSMRI_0325
    def train(self,
              epoch=0,
              listTarget=[0, 1],
              forwardStage=1,
              need_evaluate=True):
        if (epoch == 0):
            epoch = self.epoch
        self.optimizer = getOptimizer(self.net.parameters(),
                                      self.config['train']['optimizer'],
                                      self.LR, self.weightDecay)
        msg = "start training: epoch = %d" % (epoch)
        self.record.log(msg)
        print(msg)
        for j in range(1, epoch + 1):
            self.net.train()
            i = 0
            total_loss = 0
            for mode, label, mask in self.trainloader:
                self.optimizer.zero_grad()
                netLabel = Variable(label).type(self.dtype)
                mask_var = Variable(mask).type(self.dtype)
                subF = kspace_subsampling_pytorch(netLabel, mask_var)
                complexFlag = (mode[0] == 'complex')
                netInput = imgFromSubF_pytorch(subF, complexFlag)
                netOutput = self.net(netInput, subF, mask_var)
                loss = self.lossForward(netOutput, netLabel)
                loss.backward()
                total_loss = total_loss + loss.item() * label.shape[0]

                self.optimizer.step()
                i += label.shape[0]
                print('Epoch %05d [%04d/%04d] loss %.8f' %
                      (j + self.ckp_epoch, i, self.trainsetSize, loss.item()),
                      '\r',
                      end='')
            self.record.log_train(j + self.ckp_epoch, total_loss / i)

            if j % self.saveEpoch == 0:
                print('Epoch %05d [%04d/%04d] loss %.8f SAVED' %
                      (j + self.ckp_epoch, i, self.trainsetSize,
                       total_loss / self.trainsetSize))
                if need_evaluate:
                    l, p1, p2, s1, s2 = self.validation()
                    self.record.log_valid(j + self.ckp_epoch, l, p1, p2, s1,
                                          s2)
                    self.record.log(
                        "Evaluate psnr(before|after) =  %.2f|%.2f ssim = %.4f|%.4f"
                        % (p1, p2, s1, s2))
                self.record.write_to_file(self.net.state_dict(), False)
        self.ckp_epoch += epoch
        self.record.write_to_file(self.net.state_dict(), True)
コード例 #2
0
ファイル: core.py プロジェクト: tinyRattar/CSMRI_0325
    def test(self, mode, label, mask):
        y = label.numpy()
        netLabel = Variable(label).type(self.dtype)
        if (self.mode != 'inNetDC'):
            assert False, 'only for inNetDC mode'
        else:
            mask_var = Variable(mask).type(self.dtype)
            subF = kspace_subsampling_pytorch(netLabel, mask_var)
            complexFlag = (mode[0] == 'complex')
            netInput = imgFromSubF_pytorch(subF, complexFlag)

        y = y[0]

        netOutput = self.net(netInput, subF, mask_var)
        loss = self.lossForward(netOutput, netLabel)
        netOutput_np = netOutput.cpu().data.numpy()
        img1 = netOutput_np[0, 0:1].astype('float64')
        if (netOutput_np.shape[1] == 2):
            netOutput_np = abs(netOutput_np[:, 0:1] +
                               netOutput_np[:, 1:2] * 1j)
        img2 = netOutput_np[0].astype('float64')
        y2 = y[0:1]

        img1 = np.clip(img1, 0, 1)
        img2 = np.clip(img2, 0, 1)

        if (self.isFastMRI):
            psnrBefore = psnr(y2, img1, 12)
            psnrAfter = psnr(y2, img2, 12)
        else:
            psnrBefore = psnr(y2, img1)
            psnrAfter = psnr(y2, img2)

        ssimBefore = ssim(y2[0], img1[0])
        ssimAfter = ssim(y2[0], img2[0])

        return {
            "loss": loss.item(),
            "psnr1": psnrBefore,
            "psnr2": psnrAfter,
            "ssim1": ssimBefore,
            "ssim2": ssimAfter,
            "result1": img1,
            "result2": img2,
            'label': y2
        }
コード例 #3
0
ファイル: core.py プロジェクト: Cassie317/CSMRI
    def test(self, mode, name, label, mask, png, CSk):
        y = label.numpy()
        netLabel = Variable(label).type(self.dtype)

        if (self.mode != 'inNetDC'):
            assert False, 'only for inNetDC mode'
        else:
            # netInput = Variable(png).type(self.dtype)
            # mask_var = Variable(mask).type(self.dtype)
            # subF = kspace_subsampling_pytorch(netInput,mask_var)
            # # subF = kspace_subsampling_pytorch(netLabel, mask_var)
            # complexFlag = (mode[0] == 'complex')
            # # netInput = imgFromSubF_pytorch(subF, complexFlag)

            mask_var = Variable(mask).type(self.dtype)
            subF = kspace_subsampling_pytorch(netLabel, mask_var)
            complexFlag = (mode[0] == 'complex')
            netInput = imgFromSubF_pytorch(subF, complexFlag)

        # 保存图片
        x = netInput.cpu().data.numpy()
        x = x[0]
        x[0] = self.standardization(x[0]) * 255
        # y[0] *= 255  # 变换为0-255的灰度值
        x1 = Image.fromarray(x[0])
        x1 = x1.convert('L')  # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’
        x1.save("train_test/input_" + name[0].split('.')[0] + '_y.png')

        # 检测一下netOutput_np的数据类型,保存为.npy
        netOutput = self.net(netInput, subF, mask_var)

        y1 = netOutput.cpu().data.numpy()
        y1 = y1[0]
        y1[0][y1[0] < 0] = 0
        y1[0] = self.standardization(y1[0]) * 255
        # y[0] *= 255  # 变换为0-255的灰度值
        y2 = Image.fromarray(y1[0])
        y2 = y2.convert('L')  # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’
        y2.save("train_test/output_" + name[0].split('.')[0] + '_y.png')

        loss = self.lossForward(netOutput, netLabel)
        netOutput_np = netOutput.cpu().data.numpy()
        img1 = netOutput_np[0, 0:1].astype('float64')
        # im = Image.fromarray((img1[0]*255))
        # scipy.misc.imsave(name, img1[0])
        if (netOutput_np.shape[1] == 2):
            netOutput_np = abs(netOutput_np[:, 0:1] +
                               netOutput_np[:, 1:2] * 1j)
        img2 = netOutput_np[0].astype('float64')
        y = y[0]
        y2 = y[0:1]
        y2 = self.standardization(y2)

        img1 = np.clip(img1, 0, 1)
        img2 = np.clip(img2, 0, 1)

        if (self.isFastMRI):
            psnrBefore = psnr(y2, img1, 12)
            psnrAfter = psnr(y2, img2, 12)
        else:
            psnrBefore = psnr(y2, img1)
            psnrAfter = psnr(y2, img2)

        ssimBefore = ssim(y2[0], img1[0])
        ssimAfter = ssim(y2[0], img2[0])

        return {
            "loss": loss.item(),
            "psnr1": psnrBefore,
            "psnr2": psnrAfter,
            "ssim1": ssimBefore,
            "ssim2": ssimAfter,
            "result1": img1,
            "result2": img2,
            'label': y2
        }
コード例 #4
0
ファイル: core.py プロジェクト: Cassie317/CSMRI
    def train(self,
              epoch=0,
              listTarget=[0, 1],
              forwardStage=1,
              need_evaluate=True):
        if (epoch == 0):
            epoch = self.epoch
        self.optimizer = getOptimizer(self.net.parameters(),
                                      self.config['train']['optimizer'],
                                      self.LR, self.weightDecay)
        msg = "start training: epoch = %d" % (epoch)
        self.record.log(msg)
        print(msg)
        for j in range(1, epoch + 1):
            self.net.train()
            i = 0
            total_loss = 0
            for mode, name, label, mask, png, CSk in self.trainloader:
                # self.optimizer.zero_grad()
                # # label获取满采样图片
                # # png 为降采样图片
                # # CSk为降采样k-data

                netLabel = Variable(label).type(self.dtype)
                mask_var = Variable(mask).type(self.dtype)
                netInput = Variable(png).type(self.dtype)
                subF0 = Variable(CSk).type(self.dtype)
                subF0 = subF0.permute(0, 2, 3, 1)

                complexFlag = (mode[0] == 'complex')
                mask_var = Variable(mask).type(self.dtype)
                # 将满采样图片和mask做下面函数,得到降采样图的k-data
                subF = kspace_subsampling_pytorch(netLabel, mask_var)
                # 得到的subf是和提供的CSK一致的,即mask一致,netInput一致
                # 将降采样图k-data转化为图片
                netInput = imgFromSubF_pytorch(subF, complexFlag)
                # 将降采样图和k-data送入网络
                netOutput = self.net(netInput, subF, mask_var)
                # 满采图和网络输出做loss
                loss = self.lossForward(netOutput, netLabel)
                # -----------------------------------------------------------------------------------------------
                y = netLabel.cpu().data.numpy()
                y = y[0]
                y[0] = self.standardization(y[0]) * 255
                # y[0] *= 255  # 变换为0-255的灰度值
                y1 = Image.fromarray(y[0])
                y1 = y1.convert('L')  # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’
                y1.save("test2/label1_" + name[0].split('.')[0] + '_y.png')

                y = subF.cpu().data.numpy()
                y = y[0]
                y = self.standardization(y[:, :, 0]) * 255
                # y[0] *= 255  # 变换为0-255的灰度值
                y1 = Image.fromarray(y)
                y1 = y1.convert('L')  # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’
                y1.save("test2/subF1_" + name[0].split('.')[0] + '_y.png')

                # subf逆

                x = netInput.cpu().data.numpy()
                x = x[0]
                x[0] = self.standardization(x[0]) * 255
                # y[0] *= 255  # 变换为0-255的灰度值
                x1 = Image.fromarray(x[0])
                x1 = x1.convert('L')  # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’
                x1.save("test2/input1_" + name[0].split('.')[0] + '_y.png')

                y = netOutput.cpu().data.numpy()
                y = y[0]
                y[0][y[0] < 0] = 0
                y[0] = self.standardization(y[0]) * 255
                # y[0] *= 255  # 变换为0-255的灰度值
                y1 = Image.fromarray(y[0])
                y1 = y1.convert('L')  # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’
                y1.save("test2/output1_" + name[0].split('.')[0] + '_y.png')

                # ----------------------------------------------------------------------------------------------

                # netLabel = Variable(label).type(self.dtype)
                # netInput = Variable(png).type(self.dtype)
                # subF = Variable(CSk).type(self.dtype)
                # mask_var = Variable(mask).type(self.dtype)
                # subF = subF.permute(0,2,3,1)
                # complexFlag = (mode[0] == 'complex')
                # netOutput = self.net(netInput, subF, mask_var)
                # loss = self.lossForward(netOutput, netLabel)

                # # ----------------------------------------------------------------------------------------------

                # y = netLabel.cpu().data.numpy()
                # y = y[0]
                # y[0] = self.standardization(y[0])*255
                # # y[0] *= 255  # 变换为0-255的灰度值
                # y1 = Image.fromarray(y[0])
                # y1 = y1.convert('L')  # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’
                # y1.save("test2/label2_"+name[0].split('.')[0]+'_y.png')

                # y = subF0.cpu().data.numpy()
                # y = y[0]
                # y = self.standardization(y[:,:,0])*255
                # # y[0] *= 255  # 变换为0-255的灰度值
                # y1 = Image.fromarray(y)
                # y1 = y1.convert('L')  # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’
                # y1.save("test2/subF2_"+name[0].split('.')[0]+'_y.png')

                # # subf逆

                # x = netInput.cpu().data.numpy()
                # x = x[0]
                # x[0] = self.standardization(x[0])*255
                # # y[0] *= 255  # 变换为0-255的灰度值
                # x1 = Image.fromarray(x[0])
                # x1 = x1.convert('L')  # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’
                # x1.save("test2/input2_"+name[0].split('.')[0]+'_y.png')

                # y = netOutput.cpu().data.numpy()
                # y = y[0]
                # y[0][y[0] < 0] = 0
                # y[0] = self.standardization(y[0])*255
                # # y[0] *= 255  # 变换为0-255的灰度值
                # y1 = Image.fromarray(y[0])
                # y1 = y1.convert('L')  # 这样才能转为灰度图,如果是彩色图则改L为‘RGB’
                # y1.save("test2/output2_"+name[0].split('.')[0]+'_y.png')
                # ---------------------------------------------------------------------------------------------

                loss = self.lossForward(netOutput, netLabel)

                loss.backward()  # 计算梯度
                total_loss = total_loss + loss.item() * label.shape[0]

                self.optimizer.step()  # 反向传播,更新网络参数
                self.optimizer.zero_grad()  # 清空梯度

                i += label.shape[0]
                del mode, label, mask, png, CSk

            print(
                'Epoch %05d [%04d/%04d] loss %.8f' %
                (j + self.ckp_epoch, i, self.trainsetSize, loss.item()), '\r',
                '')

            self.record.log_train(j + self.ckp_epoch, total_loss / i)

            if j % self.saveEpoch == 0:
                print('Epoch %05d [%04d/%04d] loss %.8f SAVED' %
                      (j + self.ckp_epoch, i, self.trainsetSize,
                       total_loss / self.trainsetSize))
                if need_evaluate:
                    l, p1, p2, s1, s2 = self.validation()
                    self.record.log_valid(j + self.ckp_epoch, l, p1, p2, s1,
                                          s2)
                    self.record.log(
                        "Evaluate psnr(before|after) =  %.2f|%.2f ssim = %.4f|%.4f"
                        % (p1, p2, s1, s2))
                self.record.write_to_file(self.net.state_dict(), False)
        self.ckp_epoch += epoch
        self.record.write_to_file(self.net.state_dict(), True)