def test_vsnet(shape, cfg, center_fractions, accelerations): """ Test VSNet with different parameters Args: shape: shape of the input cfg: configuration of the model center_fractions: center fractions accelerations: accelerations Returns: None """ mask_func = RandomMaskFunc(center_fractions, accelerations) x = create_input(shape) outputs, masks = [], [] for i in range(x.shape[0]): output, mask, _ = transforms.apply_mask(x[i : i + 1], mask_func, seed=123) outputs.append(output) masks.append(mask) output = torch.cat(outputs) mask = torch.cat(masks) cfg = OmegaConf.create(cfg) cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True)) vsnet = VSNet(cfg) with torch.no_grad(): y = vsnet.forward(output, output, mask, output, target=torch.abs(torch.view_as_complex(output))) if y.shape[1:] != x.shape[2:4]: raise AssertionError
def test_centered_fft2_forward_normalization(shape): """ Test centered 2D Fast Fourier Transform with forward normalization. Args: shape: shape of the input Returns: None """ shape = shape + [2] x = create_input(shape) out_torch = fft2(x, centered=True, normalization="forward", spatial_dims=[-2, -1]).numpy() out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] input_numpy = tensor_to_complex_np(x) input_numpy = np.fft.ifftshift(input_numpy, (-2, -1)) out_numpy = np.fft.fft2(input_numpy, norm="forward") out_numpy = np.fft.fftshift(out_numpy, (-2, -1)) if not np.allclose(out_torch, out_numpy): raise AssertionError
def test_complex_abs(shape): """ Test complex absolute value. Args: shape: shape of the input Returns: None """ shape = shape + [2] x = create_input(shape) out_torch = complex_abs(x).numpy() input_numpy = tensor_to_complex_np(x) out_numpy = np.abs(input_numpy) if not np.allclose(out_torch, out_numpy): raise AssertionError
def test_non_centered_fft2(shape): """ Test non-centered 2D Fast Fourier Transform. Args: shape: shape of the input Returns: None """ shape = shape + [2] x = create_input(shape) out_torch = fft2(x, centered=False, normalization="ortho", spatial_dims=[-2, -1]).numpy() out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] input_numpy = tensor_to_complex_np(x) out_numpy = np.fft.fft2(input_numpy, norm="ortho") if not np.allclose(out_torch, out_numpy): raise AssertionError