def log_val_images(self, pred_img, fbp_fc, y_fc, y_real, amp_min, amp_max): dft_fbp = convert2DFT(x=fbp_fc, amp_min=amp_min, amp_max=amp_max, dst_flatten_order=self.dst_flatten_order, img_shape=self.hparams.img_shape) dft_target = convert2DFT(x=y_fc, amp_min=amp_min, amp_max=amp_max, dst_flatten_order=self.dst_flatten_order, img_shape=self.hparams.img_shape) for i in range(min(3, len(pred_img))): if self.bin_factor == 1: fbp_img = torch.roll( torch.fft.irfftn(self.mask * dft_fbp[i], s=2 * (self.hparams.img_shape, )), 2 * (self.hparams.img_shape // 2, ), (0, 1)) y_img = y_real[i] else: fbp_img = torch.roll( torch.fft.irfftn(self.mask * dft_fbp[i], s=2 * (self.hparams.img_shape, )), 2 * (self.hparams.img_shape // 2, ), (0, 1)) y_img = torch.roll( torch.fft.irfftn(self.mask * dft_target[i], s=2 * (self.hparams.img_shape, )), 2 * (self.hparams.img_shape // 2, ), (0, 1)) fbp_img = torch.clamp( (fbp_img - fbp_img.min()) / (fbp_img.max() - fbp_img.min()), 0, 1) pred_img_ = pred_img[i] pred_img_ = torch.clamp((pred_img_ - pred_img_.min()) / (pred_img_.max() - pred_img_.min()), 0, 1) y_img = torch.clamp( (y_img - y_img.min()) / (y_img.max() - y_img.min()), 0, 1) self.trainer.logger.experiment.add_image( 'inputs/img_{}'.format(i), fbp_img.unsqueeze(0), global_step=self.trainer.global_step) self.trainer.logger.experiment.add_image( 'predcitions/img_{}'.format(i), pred_img_.unsqueeze(0), global_step=self.trainer.global_step) self.trainer.logger.experiment.add_image( 'targets/img_{}'.format(i), y_img.unsqueeze(0), global_step=self.trainer.global_step)
def convert2img(self, fc, mag_min, mag_max): dft = convert2DFT(x=fc, amp_min=mag_min, amp_max=mag_max, dst_flatten_order=self.dst_flatten_order, img_shape=self.hparams.img_shape) return torch.fft.irfftn(dft, s=2 * (self.hparams.img_shape, ), dim=[1, 2])
def forward(self, x, fbp, mag_min, mag_max, dst_flatten_coords, img_shape, attenuation): dft_hat = convert2DFT(fbp, amp_min=mag_min, amp_max=mag_max, dst_flatten_order=dst_flatten_coords, img_shape=img_shape) dft_hat *= attenuation img_hat = torch.roll(torch.fft.irfftn(dft_hat, dim=[1, 2], s=2 * (img_shape,)), 2 * (img_shape // 2,), (1, 2)).unsqueeze(1) img_post = self.conv_block(img_hat) img_post += img_hat return fbp, img_post[:, 0]
def _gt_bin_mse(self, y_fc, y_real, amp_min, amp_max): dft_y = convert2DFT(x=y_fc, amp_min=amp_min, amp_max=amp_max, dst_flatten_order=self.dst_flatten_order, img_shape=self.hparams.img_shape) y_hat = torch.roll( torch.fft.irfftn(dft_y, dim=[1, 2], s=2 * (self.hparams.img_shape, )), 2 * (self.hparams.img_shape // 2, ), (1, 2)) return F.mse_loss(y_hat, y_real)
def _real_loss(self, pred_img, target_fc, amp_min, amp_max): dft_target = convert2DFT(x=target_fc, amp_min=amp_min, amp_max=amp_max, dst_flatten_order=self.dst_flatten_order, img_shape=self.hparams.img_shape) if self.bin_factor > 1: dft_target *= self.mask y_target = torch.roll( torch.fft.irfftn(dft_target, dim=[1, 2], s=2 * (self.hparams.img_shape, )), 2 * (self.hparams.img_shape // 2, ), (1, 2)) return F.mse_loss(pred_img, y_target)
def forward(self, x, fbp, mag_min, mag_max, dst_flatten_coords, img_shape, attenuation): fbp = self.fbp_fourier_coefficient_embedding(fbp) fbp = self.pos_embedding_target(fbp) y_hat = self.encoder(fbp) y_amp = self.predictor_amp(y_hat) y_phase = F.tanh(self.predictor_phase(y_hat)) y_hat = torch.cat([y_amp, y_phase], dim=-1) dft_hat = convert2DFT(y_hat, amp_min=mag_min, amp_max=mag_max, dst_flatten_order=dst_flatten_coords, img_shape=img_shape) dft_hat *= attenuation img_hat = torch.roll(torch.fft.irfftn(dft_hat, dim=[1, 2], s=2 * (img_shape,)), 2 * (img_shape // 2,), (1, 2)).unsqueeze(1) img_post = self.conv_block(img_hat) img_post += img_hat return y_hat, img_post[:, 0]
def get_imgs(self, x, fbp, y, amp_min, amp_max): self.eval() self.bin_factor = 1 self.register_buffer( 'mask', psf_rfft(self.bin_factor, pixel_res=self.hparams.img_shape).to(self.device)) x_fc_, fbp_fc_, y_fc_ = self._bin_data(x, fbp, y) pred_fc, pred_img = self.trec.forward( x_fc_, fbp_fc_, amp_min=amp_min, amp_max=amp_max, dst_flatten_coords=self.dst_flatten_order, img_shape=self.hparams.img_shape, attenuation=self.mask) tmp = denormalize_FC(pred_fc, amp_min=amp_min, amp_max=amp_max) pred_fc_ = torch.ones(x.shape[0], self.hparams.img_shape * (self.hparams.img_shape // 2 + 1), dtype=x.dtype, device=x.device) pred_fc_[:, :tmp.shape[1]] = tmp dft_pred_fc = convert2DFT(x=pred_fc, amp_min=amp_min, amp_max=amp_max, dst_flatten_order=self.dst_flatten_order, img_shape=self.hparams.img_shape) img_pred_before_conv = torch.roll( torch.fft.irfftn(dft_pred_fc, dim=[1, 2], s=2 * (self.hparams.img_shape, )), 2 * (self.hparams.img_shape // 2, ), (1, 2)) return pred_img, img_pred_before_conv