Exemplo n.º 1
0
def test_vsnet(shape, cfg, center_fractions, accelerations):
    """
    Test VSNet with different parameters

    Args:
        shape: shape of the input
        cfg: configuration of the model
        center_fractions: center fractions
        accelerations: accelerations

    Returns:
        None
    """
    mask_func = RandomMaskFunc(center_fractions, accelerations)
    x = create_input(shape)

    outputs, masks = [], []
    for i in range(x.shape[0]):
        output, mask, _ = transforms.apply_mask(x[i : i + 1], mask_func, seed=123)
        outputs.append(output)
        masks.append(mask)

    output = torch.cat(outputs)
    mask = torch.cat(masks)

    cfg = OmegaConf.create(cfg)
    cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True))

    vsnet = VSNet(cfg)

    with torch.no_grad():
        y = vsnet.forward(output, output, mask, output, target=torch.abs(torch.view_as_complex(output)))

    if y.shape[1:] != x.shape[2:4]:
        raise AssertionError
Exemplo n.º 2
0
def test_centered_fft2_forward_normalization(shape):
    """
    Test centered 2D Fast Fourier Transform with forward normalization.

    Args:
        shape: shape of the input

    Returns:
        None
    """
    shape = shape + [2]
    x = create_input(shape)
    out_torch = fft2(x,
                     centered=True,
                     normalization="forward",
                     spatial_dims=[-2, -1]).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    input_numpy = tensor_to_complex_np(x)
    input_numpy = np.fft.ifftshift(input_numpy, (-2, -1))
    out_numpy = np.fft.fft2(input_numpy, norm="forward")
    out_numpy = np.fft.fftshift(out_numpy, (-2, -1))

    if not np.allclose(out_torch, out_numpy):
        raise AssertionError
Exemplo n.º 3
0
def test_complex_abs(shape):
    """
    Test complex absolute value.

    Args:
        shape: shape of the input

    Returns:
        None
    """
    shape = shape + [2]
    x = create_input(shape)
    out_torch = complex_abs(x).numpy()
    input_numpy = tensor_to_complex_np(x)
    out_numpy = np.abs(input_numpy)

    if not np.allclose(out_torch, out_numpy):
        raise AssertionError
Exemplo n.º 4
0
def test_non_centered_fft2(shape):
    """
    Test non-centered 2D Fast Fourier Transform.

    Args:
        shape: shape of the input

    Returns:
        None
    """
    shape = shape + [2]
    x = create_input(shape)
    out_torch = fft2(x,
                     centered=False,
                     normalization="ortho",
                     spatial_dims=[-2, -1]).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    input_numpy = tensor_to_complex_np(x)
    out_numpy = np.fft.fft2(input_numpy, norm="ortho")

    if not np.allclose(out_torch, out_numpy):
        raise AssertionError