Пример #1
0
def test_modulus(shape):
    shape = shape + [2]
    data = create_input(shape)
    out_torch = transforms.modulus(data).numpy()
    input_numpy = tensor_to_complex_numpy(data)
    out_numpy = np.abs(input_numpy)
    assert np.allclose(out_torch, out_numpy)
Пример #2
0
def test_conjugate(shape):
    data = np.arange(np.product(shape)).reshape(
        shape) + 1j * (np.arange(np.product(shape)).reshape(shape) + 1)
    torch_tensor = transforms.to_tensor(data)
    torch_tensor = add_names(torch_tensor, named=True)

    out_torch = tensor_to_complex_numpy(transforms.conjugate(torch_tensor))
    out_numpy = np.conjugate(data)
    assert np.allclose(out_torch, out_numpy)
Пример #3
0
def test_root_sum_of_squares_complex(shape, dims):
    shape = shape + [
        2,
    ]
    data = create_input(shape, named=True)  # noqa
    out_torch = transforms.root_sum_of_squares(data, dims).numpy()
    input_numpy = tensor_to_complex_numpy(data)
    out_numpy = np.sqrt(
        np.sum(np.abs(input_numpy)**2, dims if not dims == "coils" else 0))
    assert np.allclose(out_torch, out_numpy)
Пример #4
0
def test_complex_multiplication(shape):
    data_0 = np.arange(np.product(shape)).reshape(
        shape) + 1j * (np.arange(np.product(shape)).reshape(shape) + 1)
    data_1 = data_0 + 0.5 + 1j
    torch_tensor_0 = transforms.to_tensor(data_0)
    torch_tensor_1 = transforms.to_tensor(data_1)

    torch_tensor_0 = add_names(torch_tensor_0, named=True)
    torch_tensor_1 = add_names(torch_tensor_1, named=True)

    out_torch = tensor_to_complex_numpy(
        transforms.complex_multiplication(torch_tensor_0, torch_tensor_1))
    out_numpy = data_0 * data_1
    assert np.allclose(out_torch, out_numpy)
Пример #5
0
def test_ifft2(shape, named):
    shape = shape + [2]
    data = create_input(shape, named=named)
    if named:
        dim = ("height", "width")
    else:
        dim = (-2, -1)
    out_torch = transforms.ifft2(data, dim=dim).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    input_numpy = tensor_to_complex_numpy(data)
    input_numpy = np.fft.ifftshift(input_numpy, (-2, -1))
    out_numpy = np.fft.ifft2(input_numpy, norm="ortho")
    out_numpy = np.fft.fftshift(out_numpy, (-2, -1))
    assert np.allclose(out_torch, out_numpy)
Пример #6
0
def test_fft2(shape, named):
    shape = shape + [2]
    data = create_input(shape, named=named)
    if named:
        dim = ("height", "width")
    else:
        dim = (-3, -2)

    out_torch = transforms.fft2(data, dim=dim).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    data_numpy = tensor_to_complex_numpy(data)
    data_numpy = np.fft.ifftshift(data_numpy, (-2, -1))
    out_numpy = np.fft.fft2(data_numpy, norm="ortho")
    out_numpy = np.fft.fftshift(out_numpy, (-2, -1))
    z = out_torch - out_numpy
    print(z.real.max(), z.real.min(), z.imag.max(), z.imag.min())
    assert np.allclose(out_torch, out_numpy)
Пример #7
0
import numpy as np
import torch

from direct.data.transforms import tensor_to_complex_numpy
from direct.nn.rim.mri_models import MRILogLikelihood
from direct.data import transforms

input_image = create_input([1, 4, 4, 2]).rename("batch", "height", "width",
                                                "complex")
sensitivity_map = create_input([1, 15, 4, 4, 2]) * 0.1
masked_kspace = create_input([1, 15, 4, 4, 2]) + 0.33
sampling_mask = torch.from_numpy(
    np.random.binomial(size=(1, 1, 4, 4, 1), n=1,
                       p=0.5)).refine_names(*sensitivity_map.names)

input_image_numpy = tensor_to_complex_numpy(input_image)
sensitivity_map_numpy = tensor_to_complex_numpy(sensitivity_map)
masked_kspace_numpy = tensor_to_complex_numpy(masked_kspace)
sampling_mask_numpy = sampling_mask.numpy()[..., 0]

# Torch
input_image = input_image.align_to("batch", "height", "width", "complex")
sensitivity_map = sensitivity_map.align_to("batch", "coil", "height", "width",
                                           "complex")
masked_kspace = masked_kspace.align_to("batch", "coil", "height", "width",
                                       "complex")

mul = transforms.complex_multiplication(sensitivity_map,
                                        input_image.align_as(sensitivity_map))

mul_names = mul.names