def metrics_evaluation(self): # 3D metrics including mse, cs and psnr g_fake_unNorm = self.ct_unGaussian(self.G_fake) g_real_unNorm = self.ct_unGaussian(self.G_real) self.metrics_Mse = Metrics.Mean_Squared_Error(g_fake_unNorm, g_real_unNorm) self.metrics_CosineSimilarity = Metrics.Cosine_Similarity(g_fake_unNorm, g_real_unNorm) self.metrics_PSNR = Metrics.Peak_Signal_to_Noise_Rate(g_fake_unNorm, g_real_unNorm, PIXEL_MAX=1.0)
def forward(self): # output is [B D H W] self.G_fake = self.netG(self.G_input) # visual object should be [B 1 D H W] self.G_fake_D = torch.unsqueeze(self.G_fake, 1) if not self.training: if self.opt.CT_MEAN_STD[0] == 0: self.G_fake = torch.clamp(self.G_fake, 0, 1) elif self.opt.CT_MEAN_STD[0] == 0.5: self.G_fake = torch.clamp(self.G_fake, -1, 1) else: raise NotImplementedError() # input of Discriminator is [B 1 D H W] self.G_real_D = torch.unsqueeze(self.G_real, 1) if self.conditional_D: self.G_condition_D = self.G_input.unsqueeze(1).expand_as(self.G_real_D) # map self.G_Map_real_F = self.transition(self.output_map(self.ct_unGaussian(self.G_real_D), 2).squeeze(1)) self.G_Map_fake_F = self.transition(self.output_map(self.ct_unGaussian(self.G_fake_D), 2).squeeze(1)) self.G_Map_real_S = self.transition(self.output_map(self.ct_unGaussian(self.G_real_D), 4).squeeze(1)) self.G_Map_fake_S = self.transition(self.output_map(self.ct_unGaussian(self.G_fake_D), 4).squeeze(1)) if self.training: for i in self.multi_view: out_map = self.output_map(self.ct_unGaussian(self.G_real_D), i + 1).squeeze(1) out_map = self.ct_Gaussian(self.transition(out_map)) setattr(self, 'G_Map_{}_real'.format(i), out_map) out_map = self.output_map(self.ct_unGaussian(self.G_fake_D), i + 1).squeeze(1) out_map = self.ct_Gaussian(self.transition(out_map)) setattr(self, 'G_Map_{}_fake'.format(i), out_map) # metrics g_fake_unNorm = self.ct_unGaussian(self.G_fake) g_real_unNorm = self.ct_unGaussian(self.G_real) self.metrics_Mse = Metrics.Mean_Squared_Error(g_fake_unNorm, g_real_unNorm) self.metrics_CosineSimilarity = Metrics.Cosine_Similarity(g_fake_unNorm, g_real_unNorm) self.metrics_PSNR = Metrics.Peak_Signal_to_Noise_Rate(g_fake_unNorm, g_real_unNorm, PIXEL_MAX=1.0)