コード例 #1
0
  def metrics_evaluation(self):
    # 3D metrics including mse, cs and psnr
    g_fake_unNorm = self.ct_unGaussian(self.G_fake)
    g_real_unNorm = self.ct_unGaussian(self.G_real)

    self.metrics_Mse = Metrics.Mean_Squared_Error(g_fake_unNorm, g_real_unNorm)
    self.metrics_CosineSimilarity = Metrics.Cosine_Similarity(g_fake_unNorm, g_real_unNorm)
    self.metrics_PSNR = Metrics.Peak_Signal_to_Noise_Rate(g_fake_unNorm, g_real_unNorm, PIXEL_MAX=1.0)
コード例 #2
0
  def forward(self):
    # output is [B D H W]
    self.G_fake = self.netG(self.G_input)
    # visual object should be [B 1 D H W]
    self.G_fake_D = torch.unsqueeze(self.G_fake, 1)
    if not self.training:
      if self.opt.CT_MEAN_STD[0] == 0:
        self.G_fake = torch.clamp(self.G_fake, 0, 1)
      elif self.opt.CT_MEAN_STD[0] == 0.5:
        self.G_fake = torch.clamp(self.G_fake, -1, 1)
      else:
        raise NotImplementedError()
    # input of Discriminator is [B 1 D H W]
    self.G_real_D = torch.unsqueeze(self.G_real, 1)
    if self.conditional_D:
      self.G_condition_D = self.G_input.unsqueeze(1).expand_as(self.G_real_D)
    # map
    self.G_Map_real_F = self.transition(self.output_map(self.ct_unGaussian(self.G_real_D), 2).squeeze(1))
    self.G_Map_fake_F = self.transition(self.output_map(self.ct_unGaussian(self.G_fake_D), 2).squeeze(1))
    self.G_Map_real_S = self.transition(self.output_map(self.ct_unGaussian(self.G_real_D), 4).squeeze(1))
    self.G_Map_fake_S = self.transition(self.output_map(self.ct_unGaussian(self.G_fake_D), 4).squeeze(1))

    if self.training:
      for i in self.multi_view:
        out_map = self.output_map(self.ct_unGaussian(self.G_real_D), i + 1).squeeze(1)
        out_map = self.ct_Gaussian(self.transition(out_map))
        setattr(self, 'G_Map_{}_real'.format(i), out_map)

        out_map = self.output_map(self.ct_unGaussian(self.G_fake_D), i + 1).squeeze(1)
        out_map = self.ct_Gaussian(self.transition(out_map))
        setattr(self, 'G_Map_{}_fake'.format(i), out_map)

    # metrics
    g_fake_unNorm = self.ct_unGaussian(self.G_fake)
    g_real_unNorm = self.ct_unGaussian(self.G_real)

    self.metrics_Mse = Metrics.Mean_Squared_Error(g_fake_unNorm, g_real_unNorm)
    self.metrics_CosineSimilarity = Metrics.Cosine_Similarity(g_fake_unNorm, g_real_unNorm)
    self.metrics_PSNR = Metrics.Peak_Signal_to_Noise_Rate(g_fake_unNorm, g_real_unNorm, PIXEL_MAX=1.0)