def forward(self, x, input): s = x.size(2) mask = torch.zeros_like(x) for j in range(s): lf = input['num_lf'][j] mask[:, :, j, ...] = T.mask_center(x[:, :, j, ...], lf) return mask
def forward(self, masked_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: # get low frequency line locations and mask them out cent = mask.shape[-2] // 2 left = torch.nonzero(mask.squeeze()[:cent] == 0)[-1] right = torch.nonzero(mask.squeeze()[cent:] == 0)[0] + cent num_low_freqs = right - left pad = (mask.shape[-2] - num_low_freqs + 1) // 2 x = transforms.mask_center(masked_kspace, pad, pad + num_low_freqs) # convert to image space x = fastmri.ifft2c(x) x, b = self.chans_to_batch_dim(x) # estimate sensitivities x = self.norm_unet(x) x = self.batch_chans_to_chan_dim(x, b) x = self.divide_root_sum_of_squares(x) return x
def forward(self, masked_kspace, mask): def get_low_frequency_lines(mask): l = r = mask.shape[-2] // 2 while mask[..., r, :]: r += 1 while mask[..., l, :]: l -= 1 return l + 1, r l, r = get_low_frequency_lines(mask) num_low_freqs = r - l pad = (mask.shape[-2] - num_low_freqs + 1) // 2 x = T.mask_center(masked_kspace, pad, pad + num_low_freqs) x = fastmri.ifft2c(x) x, b = self.chans_to_batch_dim(x) x = self.norm_unet(x) x = self.batch_chans_to_chan_dim(x, b) x = self.divide_root_sum_of_squares(x) return x