示例#1
0
    def log_val_images(self, pred_img, fbp_fc, y_fc, y_real, amp_min, amp_max):
        dft_fbp = convert2DFT(x=fbp_fc,
                              amp_min=amp_min,
                              amp_max=amp_max,
                              dst_flatten_order=self.dst_flatten_order,
                              img_shape=self.hparams.img_shape)
        dft_target = convert2DFT(x=y_fc,
                                 amp_min=amp_min,
                                 amp_max=amp_max,
                                 dst_flatten_order=self.dst_flatten_order,
                                 img_shape=self.hparams.img_shape)

        for i in range(min(3, len(pred_img))):

            if self.bin_factor == 1:
                fbp_img = torch.roll(
                    torch.fft.irfftn(self.mask * dft_fbp[i],
                                     s=2 * (self.hparams.img_shape, )),
                    2 * (self.hparams.img_shape // 2, ), (0, 1))
                y_img = y_real[i]
            else:
                fbp_img = torch.roll(
                    torch.fft.irfftn(self.mask * dft_fbp[i],
                                     s=2 * (self.hparams.img_shape, )),
                    2 * (self.hparams.img_shape // 2, ), (0, 1))
                y_img = torch.roll(
                    torch.fft.irfftn(self.mask * dft_target[i],
                                     s=2 * (self.hparams.img_shape, )),
                    2 * (self.hparams.img_shape // 2, ), (0, 1))

            fbp_img = torch.clamp(
                (fbp_img - fbp_img.min()) / (fbp_img.max() - fbp_img.min()), 0,
                1)
            pred_img_ = pred_img[i]
            pred_img_ = torch.clamp((pred_img_ - pred_img_.min()) /
                                    (pred_img_.max() - pred_img_.min()), 0, 1)
            y_img = torch.clamp(
                (y_img - y_img.min()) / (y_img.max() - y_img.min()), 0, 1)

            self.trainer.logger.experiment.add_image(
                'inputs/img_{}'.format(i),
                fbp_img.unsqueeze(0),
                global_step=self.trainer.global_step)
            self.trainer.logger.experiment.add_image(
                'predcitions/img_{}'.format(i),
                pred_img_.unsqueeze(0),
                global_step=self.trainer.global_step)
            self.trainer.logger.experiment.add_image(
                'targets/img_{}'.format(i),
                y_img.unsqueeze(0),
                global_step=self.trainer.global_step)
示例#2
0
 def convert2img(self, fc, mag_min, mag_max):
     dft = convert2DFT(x=fc,
                       amp_min=mag_min,
                       amp_max=mag_max,
                       dst_flatten_order=self.dst_flatten_order,
                       img_shape=self.hparams.img_shape)
     return torch.fft.irfftn(dft,
                             s=2 * (self.hparams.img_shape, ),
                             dim=[1, 2])
示例#3
0
    def forward(self, x, fbp, mag_min, mag_max, dst_flatten_coords, img_shape, attenuation):
        dft_hat = convert2DFT(fbp, amp_min=mag_min, amp_max=mag_max, dst_flatten_order=dst_flatten_coords,
                              img_shape=img_shape)
        dft_hat *= attenuation
        img_hat = torch.roll(torch.fft.irfftn(dft_hat, dim=[1, 2], s=2 * (img_shape,)),
                             2 * (img_shape // 2,), (1, 2)).unsqueeze(1)
        img_post = self.conv_block(img_hat)
        img_post += img_hat

        return fbp, img_post[:, 0]
示例#4
0
    def _gt_bin_mse(self, y_fc, y_real, amp_min, amp_max):
        dft_y = convert2DFT(x=y_fc,
                            amp_min=amp_min,
                            amp_max=amp_max,
                            dst_flatten_order=self.dst_flatten_order,
                            img_shape=self.hparams.img_shape)
        y_hat = torch.roll(
            torch.fft.irfftn(dft_y,
                             dim=[1, 2],
                             s=2 * (self.hparams.img_shape, )),
            2 * (self.hparams.img_shape // 2, ), (1, 2))

        return F.mse_loss(y_hat, y_real)
示例#5
0
    def _real_loss(self, pred_img, target_fc, amp_min, amp_max):
        dft_target = convert2DFT(x=target_fc,
                                 amp_min=amp_min,
                                 amp_max=amp_max,
                                 dst_flatten_order=self.dst_flatten_order,
                                 img_shape=self.hparams.img_shape)
        if self.bin_factor > 1:
            dft_target *= self.mask

        y_target = torch.roll(
            torch.fft.irfftn(dft_target,
                             dim=[1, 2],
                             s=2 * (self.hparams.img_shape, )),
            2 * (self.hparams.img_shape // 2, ), (1, 2))
        return F.mse_loss(pred_img, y_target)
示例#6
0
    def forward(self, x, fbp, mag_min, mag_max, dst_flatten_coords, img_shape, attenuation):
        fbp = self.fbp_fourier_coefficient_embedding(fbp)
        fbp = self.pos_embedding_target(fbp)
        y_hat = self.encoder(fbp)
        y_amp = self.predictor_amp(y_hat)
        y_phase = F.tanh(self.predictor_phase(y_hat))
        y_hat = torch.cat([y_amp, y_phase], dim=-1)

        dft_hat = convert2DFT(y_hat, amp_min=mag_min, amp_max=mag_max, dst_flatten_order=dst_flatten_coords,
                              img_shape=img_shape)
        dft_hat *= attenuation
        img_hat = torch.roll(torch.fft.irfftn(dft_hat, dim=[1, 2], s=2 * (img_shape,)),
                             2 * (img_shape // 2,), (1, 2)).unsqueeze(1)
        img_post = self.conv_block(img_hat)
        img_post += img_hat

        return y_hat, img_post[:, 0]
示例#7
0
    def get_imgs(self, x, fbp, y, amp_min, amp_max):
        self.eval()

        self.bin_factor = 1
        self.register_buffer(
            'mask',
            psf_rfft(self.bin_factor,
                     pixel_res=self.hparams.img_shape).to(self.device))

        x_fc_, fbp_fc_, y_fc_ = self._bin_data(x, fbp, y)

        pred_fc, pred_img = self.trec.forward(
            x_fc_,
            fbp_fc_,
            amp_min=amp_min,
            amp_max=amp_max,
            dst_flatten_coords=self.dst_flatten_order,
            img_shape=self.hparams.img_shape,
            attenuation=self.mask)

        tmp = denormalize_FC(pred_fc, amp_min=amp_min, amp_max=amp_max)
        pred_fc_ = torch.ones(x.shape[0],
                              self.hparams.img_shape *
                              (self.hparams.img_shape // 2 + 1),
                              dtype=x.dtype,
                              device=x.device)
        pred_fc_[:, :tmp.shape[1]] = tmp

        dft_pred_fc = convert2DFT(x=pred_fc,
                                  amp_min=amp_min,
                                  amp_max=amp_max,
                                  dst_flatten_order=self.dst_flatten_order,
                                  img_shape=self.hparams.img_shape)
        img_pred_before_conv = torch.roll(
            torch.fft.irfftn(dft_pred_fc,
                             dim=[1, 2],
                             s=2 * (self.hparams.img_shape, )),
            2 * (self.hparams.img_shape // 2, ), (1, 2))

        return pred_img, img_pred_before_conv