Exemplo n.º 1
0
def test_construct_kernels(tmp_path_factory):
    matplotlib = pytest.importorskip("matplotlib")
    matplotlib.use("agg")
    plt = pytest.importorskip("matplotlib.pyplot")

    plt.figure()
    WIDTH = 5
    OVERSAMP = 101
    ll = kernels.uspace(WIDTH, OVERSAMP)
    sel = ll <= (WIDTH + 2) // 2
    plt.axvline(0.5, 0, 1, ls="--", c="k")
    plt.axvline(-0.5, 0, 1, ls="--", c="k")
    plt.plot(ll[sel] * OVERSAMP / 2 / np.pi,
             10 * np.log10(
                 np.abs(
                     np.fft.fftshift(
                         np.fft.fft(
                             kernels.kbsinc(
                                 WIDTH, oversample=OVERSAMP, order=0)[sel])))),
             label="kbsinc order 0")
    plt.plot(
        ll[sel] * OVERSAMP / 2 / np.pi,
        10 * np.log10(
            np.abs(
                np.fft.fftshift(
                    np.fft.fft(
                        kernels.kbsinc(WIDTH, oversample=OVERSAMP,
                                       order=15)[sel])))),
        label="kbsinc order 15")
    plt.plot(ll[sel] * OVERSAMP / 2 / np.pi,
             10 * np.log10(
                 np.abs(
                     np.fft.fftshift(
                         np.fft.fft(
                             kernels.hanningsinc(WIDTH,
                                                 oversample=OVERSAMP)[sel])))),
             label="hanning sinc")
    plt.plot(ll[sel] * OVERSAMP / 2 / np.pi,
             10 * np.log10(
                 np.abs(
                     np.fft.fftshift(
                         np.fft.fft(
                             kernels.sinc(WIDTH, oversample=OVERSAMP)[sel])))),
             label="sinc")
    plt.xlim(-10, 10)
    plt.legend()
    plt.ylabel("Response [dB]")
    plt.xlabel("FoV")
    plt.grid(True)
    plt.savefig(tmp_path_factory.mktemp("plots") / "aakernels.png")
Exemplo n.º 2
0
def test_detaper(tmp_path_factory):
    W = 5
    OS = 3
    K1D = kernels.kbsinc(W, oversample=OS)
    K2D = np.outer(K1D, K1D)
    detaper = kernels.compute_detaper(128, K2D, W, OS)
    detaperdft = kernels.compute_detaper_dft(128, K2D, W, OS)
    detaperdftsep = kernels.compute_detaper_dft_seperable(128, K1D, W, OS)

    try:
        import matplotlib
    except ImportError:
        pass
    else:
        matplotlib.use("agg")
        from matplotlib import pyplot as plt
        plt.figure()
        plt.subplot(131)
        plt.title("FFT detaper")
        plt.imshow(detaper)
        plt.colorbar()
        plt.subplot(132)
        plt.title("DFT detaper")
        plt.imshow(detaperdft)
        plt.colorbar()
        plt.subplot(133)
        plt.title("ABS error")
        plt.imshow(np.abs(detaper - detaperdft))
        plt.colorbar()
        plt.savefig(tmp_path_factory.mktemp("detaper") / "detaper.png")

    assert (np.percentile(np.abs(detaper - detaperdft), 99.0) < 1.0e-14)
    assert (np.max(np.abs(detaperdft - detaperdftsep)) < 1.0e-14)
Exemplo n.º 3
0
def test_degrid_dft_packed_dask_dft_check():
    da = pytest.importorskip("dask.array")

    # construct kernel
    W = 5
    OS = 3
    kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS),
                               W,
                               oversample=OS)
    nrow = 100
    nrow_chunk = nrow // 8
    uvw = np.column_stack(
        (5000.0 * np.cos(np.linspace(0, 2 * np.pi, nrow)),
         5000.0 * np.sin(np.linspace(0, 2 * np.pi, nrow)), np.zeros(nrow)))

    pxacrossbeam = 10
    nchan = 16
    frequency = np.linspace(1.0e9, 1.4e9, nchan)
    wavelength = lightspeed / frequency

    cell = np.rad2deg(
        wavelength[0] /
        (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) *
         pxacrossbeam))
    npix = 512
    mod = np.zeros((1, npix, npix), dtype=np.complex64)
    mod[0, npix // 2 - 5, npix // 2 - 5] = 1.0

    ftmod = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(
        mod[0, :, :]))).reshape((1, 1, npix, npix))
    chanmap = np.zeros(nchan, dtype=np.int64)
    dec, ra = np.meshgrid(
        np.arange(-npix // 2, npix // 2) * np.deg2rad(cell),
        np.arange(-npix // 2, npix // 2) * np.deg2rad(cell))
    radec = np.column_stack((ra.flatten(), dec.flatten()))
    vis_dft = im_to_vis(mod[0, :, :].reshape(1, 1, npix * npix).T.copy(), uvw,
                        radec, frequency)

    vis_degrid = dwrap.degridder(
        da.from_array(uvw, chunks=(nrow_chunk, 3)),
        da.from_array(ftmod, chunks=(1, 1, npix, npix)),
        da.from_array(wavelength, chunks=(nchan, )),
        da.from_array(chanmap, chunks=(nchan, )),
        cell * 3600.0,
        da.from_array(np.array([[0, np.pi / 4.0]]), chunks=(1, 2)),
        (0, np.pi / 4.0),
        kern,
        W,
        OS,
        "None",  # no faceting
        "None",  # no faceting
        "XXYY_FROM_I",
        "conv_1d_axisymmetric_packed_gather")

    vis_degrid = vis_degrid.compute()

    assert np.percentile(
        np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), 99.0) < 0.05
    assert np.percentile(
        np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), 99.0) < 0.05
Exemplo n.º 4
0
def test_degrid_dft_packed_dask():
    da = pytest.importorskip("dask.array")

    # construct kernel
    W = 5
    OS = 3
    kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS),
                               W,
                               oversample=OS)
    nrow = int(5e4)
    nrow_chunk = nrow // 32
    uvw = np.column_stack(
        (5000.0 * np.cos(np.linspace(0, 2 * np.pi, nrow)),
         5000.0 * np.sin(np.linspace(0, 2 * np.pi, nrow)), np.zeros(nrow)))

    pxacrossbeam = 10
    nchan = 1024
    frequency = np.linspace(1.0e9, 1.4e9, nchan)
    wavelength = lightspeed / frequency

    cell = np.rad2deg(
        wavelength[0] /
        (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) *
         pxacrossbeam))
    npix = 512
    mod = np.zeros((1, npix, npix), dtype=np.complex64)
    mod[0, npix // 2 - 5, npix // 2 - 5] = 1.0

    ftmod = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(
        mod[0, :, :]))).reshape((1, 1, npix, npix))
    chanmap = np.zeros(nchan, dtype=np.int64)

    with clock("DASK degridding") as tictoc:
        vis_degrid = dwrap.degridder(
            da.from_array(uvw, chunks=(nrow_chunk, 3)),
            da.from_array(ftmod, chunks=(1, 1, npix, npix)),
            da.from_array(wavelength, chunks=(nchan, )),
            da.from_array(chanmap, chunks=(nchan, )),
            cell * 3600.0,
            da.from_array(np.array([[0, np.pi / 4.0]]), chunks=(1, 2)),
            (0, np.pi / 4.0),
            kern,
            W,
            OS,
            "None",  # no faceting
            "None",  # no faceting
            "XXYY_FROM_I",
            "conv_1d_axisymmetric_packed_gather")

        vis_degrid = vis_degrid.compute()

    print(tictoc)
Exemplo n.º 5
0
def test_facetcodepath():
    # construct kernel
    W = 5
    OS = 3
    kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS),
                               W,
                               oversample=OS)

    # offset 0
    uvw = np.array([[0, 0, 0]])
    vis = np.array([[[1.0 + 0j, 1.0 + 0j]]])
    gridder.gridder(uvw, vis, np.array([1.0]), np.array([0]), 64, 30, (0, 0),
                    (0, 0), kern, W, OS, "rotate", "phase_rotate",
                    "I_FROM_XXYY", "conv_1d_axisymmetric_packed_scatter")
Exemplo n.º 6
0
def test_degrid_dft_packed_nondask():
    # construct kernel
    W = 5
    OS = 3
    kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS),
                               W,
                               oversample=OS)
    nrow = int(5e4)
    uvw = np.column_stack(
        (5000.0 * np.cos(np.linspace(0, 2 * np.pi, nrow)),
         5000.0 * np.sin(np.linspace(0, 2 * np.pi, nrow)), np.zeros(nrow)))

    pxacrossbeam = 10
    nchan = 1024
    frequency = np.linspace(1.0e9, 1.4e9, nchan)
    wavelength = lightspeed / frequency

    cell = np.rad2deg(
        wavelength[0] /
        (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) *
         pxacrossbeam))
    npix = 512
    mod = np.zeros((1, npix, npix), dtype=np.complex64)
    mod[0, npix // 2 - 5, npix // 2 - 5] = 1.0

    ftmod = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(
        mod[0, :, :]))).reshape((1, npix, npix))
    chanmap = np.zeros(nchan, dtype=np.int64)

    with clock("Non-DASK degridding") as tictoc:
        degridder.degridder(
            uvw,
            ftmod,
            wavelength,
            chanmap,
            cell * 3600.0,
            (0, np.pi / 4.0),
            (0, np.pi / 4.0),
            kern,
            W,
            OS,
            "None",  # no faceting
            "None",  # no faceting
            "XXYY_FROM_I",
            "conv_1d_axisymmetric_packed_gather")

    print(tictoc)
Exemplo n.º 7
0
def test_wcorrection_faceting_forward(tmp_path_factory):
    # construct kernel
    W = 5
    OS = 9
    kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS), W, OS)
    nrow = 5000
    np.random.seed(0)
    # simulate some ficticious baselines rotated by an hour angle
    uvw = np.zeros((nrow, 3), dtype=np.float64)
    blpos = np.random.uniform(26, 10000, size=(25, 3))
    ntime = int(nrow / 25.0)
    d0 = np.pi / 4.0
    for n in range(25):
        for ih0, h0 in enumerate(
                np.linspace(np.deg2rad(-20), np.deg2rad(20), ntime)):
            s = np.sin
            c = np.cos
            R = np.array([[s(h0), c(h0), 0],
                          [-s(d0) * c(h0),
                           s(d0) * s(h0), c(d0)],
                          [c(d0) * c(h0), -c(d0) * s(h0),
                           s(d0)]])
            uvw[n * ntime + ih0, :] = np.dot(R, blpos[n, :].T)

    pxacrossbeam = 5
    frequency = np.array([1.4e9])
    wavelength = lightspeed / frequency

    cell = np.rad2deg(
        wavelength[0] /
        (max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) *
         pxacrossbeam))
    npixfacet = 100
    mod = np.ones((1, 1, 1), dtype=np.complex64)
    deltaradec = np.array([[20 * np.deg2rad(cell), 20 * np.deg2rad(cell)]])
    lm = radec_to_lmn(deltaradec + np.array([[0, d0]]),
                      phase_centre=np.array([0, d0]))

    vis_dft = im_to_vis(mod, uvw, lm[:, 0:2],
                        frequency).repeat(2).reshape(nrow, 1, 2)
    chanmap = np.array([0])
    ftmod = np.ones((1, npixfacet, npixfacet),
                    dtype=np.complex64)  # point source at centre of facet
    vis_degrid = degridder.degridder(
        uvw,
        ftmod,
        wavelength,
        chanmap,
        cell * 3600.0,
        (deltaradec + np.array([[0, d0]]))[0, :],
        (0, d0),
        kern,
        W,
        OS,
        "rotate",  # no faceting
        "phase_rotate",  # no faceting
        "XXYY_FROM_I",
        "conv_1d_axisymmetric_packed_gather")

    try:
        import matplotlib
    except ImportError:
        pass
    else:
        matplotlib.use("agg")
        from matplotlib import pyplot as plt
        plot_dir = tmp_path_factory.mktemp("wcorrection_forward")

        plt.figure()
        plt.plot(vis_degrid[:, 0, 0].real,
                 label=r"$\Re(\mathtt{degrid facet})$")
        plt.plot(vis_dft[:, 0, 0].real, label=r"$\Re(\mathtt{dft})$")
        plt.plot(np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real),
                 label="Error")
        plt.legend()
        plt.xlabel("sample")
        plt.ylabel("Real of predicted")
        plt.savefig(plot_dir / "facet_degrid_vs_dft_re_packed.png")
        plt.figure()
        plt.plot(vis_degrid[:, 0, 0].imag,
                 label=r"$\Im(\mathtt{degrid facet})$")
        plt.plot(vis_dft[:, 0, 0].imag, label=r"$\Im(\mathtt{dft})$")
        plt.plot(np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag),
                 label="Error")
        plt.legend()
        plt.xlabel("sample")
        plt.ylabel("Imag of predicted")
        plt.savefig(plot_dir / "facet_degrid_vs_dft_im_packed.png")

    assert np.percentile(
        np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), 99.0) < 0.05
    assert np.percentile(
        np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), 99.0) < 0.05
Exemplo n.º 8
0
def test_wcorrection_faceting_backward(tmp_path_factory):
    # construct kernel
    W = 5
    OS = 9
    kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS), W, OS)
    nrow = 5000
    np.random.seed(0)
    # simulate some ficticious baselines rotated by an hour angle
    uvw = np.zeros((nrow, 3), dtype=np.float64)
    blpos = np.random.uniform(26, 10000, size=(25, 3))
    ntime = int(nrow / 25.0)
    d0 = np.pi / 4.0
    for n in range(25):
        for ih0, h0 in enumerate(
                np.linspace(np.deg2rad(-20), np.deg2rad(20), ntime)):
            s = np.sin
            c = np.cos
            R = np.array([[s(h0), c(h0), 0],
                          [-s(d0) * c(h0),
                           s(d0) * s(h0), c(d0)],
                          [c(d0) * c(h0), -c(d0) * s(h0),
                           s(d0)]])
            uvw[n * ntime + ih0, :] = np.dot(R, blpos[n, :].T)

    pxacrossbeam = 5
    frequency = np.array([1.4e9])
    wavelength = lightspeed / frequency

    cell = np.rad2deg(
        wavelength[0] /
        (max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) *
         pxacrossbeam))
    npix = 2048
    npixfacet = 100
    fftpad = 1.1
    mod = np.ones((1, 1, 1), dtype=np.complex64)
    deltaradec = np.array([[600 * np.deg2rad(cell), 600 * np.deg2rad(cell)]])
    lm = radec_to_lmn(deltaradec + np.array([[0, d0]]),
                      phase_centre=np.array([0, d0]))

    vis_dft = im_to_vis(mod, uvw, lm[:, 0:2],
                        frequency).repeat(2).reshape(nrow, 1, 2)
    chanmap = np.array([0])

    detaper = kernels.compute_detaper_dft_seperable(
        int(npix * fftpad), kernels.unpack_kernel(kern, W, OS), W, OS)
    vis_grid_nofacet = gridder.gridder(
        uvw,
        vis_dft,
        wavelength,
        chanmap,
        int(npix * fftpad),
        cell * 3600.0,
        (0, d0),
        (0, d0),
        kern,
        W,
        OS,
        "None",  # no faceting
        "None",  # no faceting
        "I_FROM_XXYY",
        "conv_1d_axisymmetric_packed_scatter",
        do_normalize=True)
    ftvis = (np.fft.fftshift(
        np.fft.ifft2(np.fft.ifftshift(vis_grid_nofacet[0, :, :]))).reshape(
            (1, int(npix * fftpad), int(npix * fftpad)))).real / detaper * int(
                npix * fftpad)**2
    ftvis = ftvis[:,
                  int(npix * fftpad) // 2 - npix // 2:int(npix * fftpad) // 2 -
                  npix // 2 + npix,
                  int(npix * fftpad) // 2 - npix // 2:int(npix * fftpad) // 2 -
                  npix // 2 + npix]

    detaper_facet = kernels.compute_detaper_dft_seperable(
        int(npixfacet * fftpad), kernels.unpack_kernel(kern, W, OS), W, OS)
    vis_grid_facet = gridder.gridder(uvw,
                                     vis_dft,
                                     wavelength,
                                     chanmap,
                                     int(npixfacet * fftpad),
                                     cell * 3600.0,
                                     (deltaradec + np.array([[0, d0]]))[0, :],
                                     (0, d0),
                                     kern,
                                     W,
                                     OS,
                                     "rotate",
                                     "phase_rotate",
                                     "I_FROM_XXYY",
                                     "conv_1d_axisymmetric_packed_scatter",
                                     do_normalize=True)
    ftvisfacet = (np.fft.fftshift(
        np.fft.ifft2(np.fft.ifftshift(vis_grid_facet[0, :, :]))).reshape(
            (1, int(npixfacet * fftpad), int(
                npixfacet * fftpad)))).real / detaper_facet * int(
                    npixfacet * fftpad)**2
    ftvisfacet = ftvisfacet[:,
                            int(npixfacet * fftpad) // 2 -
                            npixfacet // 2:int(npixfacet * fftpad) // 2 -
                            npixfacet // 2 + npixfacet,
                            int(npixfacet * fftpad) // 2 -
                            npixfacet // 2:int(npixfacet * fftpad) // 2 -
                            npixfacet // 2 + npixfacet]

    try:
        import matplotlib
    except ImportError:
        pass
    else:
        matplotlib.use("agg")
        from matplotlib import pyplot as plt
        plot_dir = tmp_path_factory.mktemp("wcorrection_backward")

        plt.figure()
        plt.subplot(121)
        plt.imshow(ftvis[0, 1624 - 50:1624 + 50, 1447 - 50:1447 + 50])
        plt.colorbar()
        plt.title("Offset FFT (peak={0:.1f})".format(np.max(ftvis)))
        plt.subplot(122)
        plt.imshow(ftvisfacet[0, :, :])
        plt.colorbar()
        plt.title("Faceted FFT (peak={0:.1f})".format(np.max(ftvisfacet)))
        plt.savefig(plot_dir / "facet_imaging.png")

    assert (np.abs(np.max(ftvisfacet[0, :, :]) - 1.0) < 1.0e-6)
Exemplo n.º 9
0
def test_grid_dft_packed(tmp_path_factory):
    # construct kernel
    W = 7
    OS = 1009
    kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS), W, OS)
    nrow = 5000
    np.random.seed(0)
    uvw = np.random.normal(scale=6000, size=(nrow, 3))
    uvw[:, 2] = 0.0  # ignore widefield effects for now

    pxacrossbeam = 10
    frequency = np.array([30.0e9])
    wavelength = lightspeed / frequency

    cell = np.rad2deg(
        wavelength[0] /
        (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) *
         pxacrossbeam))
    npix = 256
    fftpad = 1.25
    mod = np.zeros((1, npix, npix), dtype=np.complex64)
    for n in [int(n) for n in np.linspace(npix // 8, 2 * npix // 5, 5)]:
        mod[0, npix // 2 + n, npix // 2 + n] = 1.0
        mod[0, npix // 2 + n, npix // 2 - n] = 1.0
        mod[0, npix // 2 - n, npix // 2 - n] = 1.0
        mod[0, npix // 2 - n, npix // 2 + n] = 1.0
        mod[0, npix // 2, npix // 2 + n] = 1.0
        mod[0, npix // 2, npix // 2 - n] = 1.0
        mod[0, npix // 2 - n, npix // 2] = 1.0
        mod[0, npix // 2 + n, npix // 2] = 1.0

    dec, ra = np.meshgrid(
        np.arange(-npix // 2, npix // 2) * np.deg2rad(cell),
        np.arange(-npix // 2, npix // 2) * np.deg2rad(cell))
    radec = np.column_stack((ra.flatten(), dec.flatten()))

    vis_dft = im_to_vis(mod[0, :, :].reshape(1, 1, npix * npix).T.copy(), uvw,
                        radec, frequency).repeat(2).reshape(nrow, 1, 2)
    chanmap = np.array([0])
    detaper = kernels.compute_detaper_dft_seperable(
        int(npix * fftpad), kernels.unpack_kernel(kern, W, OS), W, OS)
    vis_grid = gridder.gridder(
        uvw,
        vis_dft,
        wavelength,
        chanmap,
        int(npix * fftpad),
        cell * 3600.0,
        (0, np.pi / 4.0),
        (0, np.pi / 4.0),
        kern,
        W,
        OS,
        "None",  # no faceting
        "None",  # no faceting
        "I_FROM_XXYY",
        "conv_1d_axisymmetric_packed_scatter",
        do_normalize=True)

    ftvis = (np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(
        vis_grid[0, :, :]))).reshape((1, int(npix * fftpad), int(
            npix * fftpad)))).real / detaper * int(npix * fftpad)**2
    ftvis = ftvis[:,
                  int(npix * fftpad) // 2 - npix // 2:int(npix * fftpad) // 2 -
                  npix // 2 + npix,
                  int(npix * fftpad) // 2 - npix // 2:int(npix * fftpad) // 2 -
                  npix // 2 + npix]
    dftvis = vis_to_im(vis_dft, uvw, radec, frequency,
                       np.zeros(vis_dft.shape,
                                dtype=np.bool)).T.copy().reshape(
                                    2, 1, npix, npix) / nrow

    try:
        import matplotlib
    except ImportError:
        pass
    else:
        matplotlib.use("agg")
        from matplotlib import pyplot as plt
        plt.figure()
        plt.subplot(131)
        plt.title("FFT")
        plt.imshow(ftvis[0, :, :])
        plt.colorbar()
        plt.subplot(132)
        plt.title("DFT")
        plt.imshow(dftvis[0, 0, :, :])
        plt.colorbar()
        plt.subplot(133)
        plt.title("ABS diff")
        plt.imshow(np.abs(ftvis[0, :, :] - dftvis[0, 0, :, :]))
        plt.colorbar()
        plt.savefig(
            tmp_path_factory.mktemp("grid_dft_packed") /
            "grid_diff_dft_packed.png")

    assert (np.percentile(np.abs(ftvis[0, :, :] - dftvis[0, 0, :, :]), 95.0) <
            0.15)
Exemplo n.º 10
0
def test_degrid_dft_packed(tmp_path_factory):
    # construct kernel
    W = 5
    OS = 3
    kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS),
                               W,
                               oversample=OS)
    uvw = np.column_stack(
        (5000.0 * np.cos(np.linspace(0, 2 * np.pi, 1000)),
         5000.0 * np.sin(np.linspace(0, 2 * np.pi, 1000)), np.zeros(1000)))

    pxacrossbeam = 10
    frequency = np.array([1.4e9])
    wavelength = lightspeed / frequency

    cell = np.rad2deg(
        wavelength[0] /
        (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) *
         pxacrossbeam))
    npix = 512
    mod = np.zeros((1, npix, npix), dtype=np.complex64)
    mod[0, npix // 2 - 5, npix // 2 - 5] = 1.0

    ftmod = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(
        mod[0, :, :]))).reshape((1, npix, npix))
    chanmap = np.array([0])
    vis_degrid = degridder.degridder(
        uvw,
        ftmod,
        wavelength,
        chanmap,
        cell * 3600.0,
        (0, np.pi / 4.0),
        (0, np.pi / 4.0),
        kern,
        W,
        OS,
        "None",  # no faceting
        "None",  # no faceting
        "XXYY_FROM_I",
        "conv_1d_axisymmetric_packed_gather")

    dec, ra = np.meshgrid(
        np.arange(-npix // 2, npix // 2) * np.deg2rad(cell),
        np.arange(-npix // 2, npix // 2) * np.deg2rad(cell))
    radec = np.column_stack((ra.flatten(), dec.flatten()))

    vis_dft = im_to_vis(mod[0, :, :].reshape(1, 1, npix * npix).T.copy(), uvw,
                        radec, frequency)

    try:
        import matplotlib
    except ImportError:
        pass
    else:
        matplotlib.use("agg")
        from matplotlib import pyplot as plt
        plt.figure()
        plt.plot(vis_degrid[:, 0, 0].real, label=r"$\Re(\mathtt{degrid})$")
        plt.plot(vis_dft[:, 0, 0].real, label=r"$\Re(\mathtt{dft})$")
        plt.plot(np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real),
                 label="Error")
        plt.legend()
        plt.xlabel("sample")
        plt.ylabel("Real of predicted")
        plt.savefig(
            os.path.join(os.environ.get("TMPDIR", "/tmp"),
                         "degrid_vs_dft_re_packed.png"))
        plt.figure()
        plt.plot(vis_degrid[:, 0, 0].imag, label=r"$\Im(\mathtt{degrid})$")
        plt.plot(vis_dft[:, 0, 0].imag, label=r"$\Im(\mathtt{dft})$")
        plt.plot(np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag),
                 label="Error")
        plt.legend()
        plt.xlabel("sample")
        plt.ylabel("Imag of predicted")
        plt.savefig(
            tmp_path_factory.mktemp("degrid_dft_packed") /
            "degrid_vs_dft_im_packed.png")

    assert np.percentile(
        np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), 99.0) < 0.05
    assert np.percentile(
        np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), 99.0) < 0.05
Exemplo n.º 11
0
def test_gridder_dask():
    da = pytest.importorskip("dask.array")

    with clock("DASK gridding") as tictoc:
        # construct kernel
        W = 5
        OS = 9
        kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS), W, OS)
        nrow = int(1e6)
        np.random.seed(0)
        # simulate some ficticious baselines rotated by an hour angle
        row_chunks = nrow // 10
        uvw = np.zeros((nrow, 3), dtype=np.float64)
        blpos = np.random.uniform(26, 10000, size=(25, 3))
        ntime = int(nrow / 25.0)
        d0 = np.pi / 4.0
        for n in range(25):
            for ih0, h0 in enumerate(
                    np.linspace(np.deg2rad(-20), np.deg2rad(20), ntime)):
                s = np.sin
                c = np.cos
                R = np.array([[s(h0), c(h0), 0],
                              [-s(d0) * c(h0),
                               s(d0) * s(h0),
                               c(d0)], [c(d0) * c(h0), -c(d0) * s(h0),
                                        s(d0)]])
                uvw[n * ntime + ih0, :] = np.dot(R, blpos[n, :].T)
        uvw = da.from_array(uvw, chunks=(row_chunks, 3))
        pxacrossbeam = 5
        nchan = 128
        frequency = da.from_array(np.linspace(1.0e9, 1.4e9, nchan),
                                  chunks=(nchan, ))
        wavelength = lightspeed / frequency
        cell = da.rad2deg(wavelength[0] / (max(da.max(da.absolute(
            uvw[:, 0])), da.max(da.absolute(uvw[:, 1]))) * pxacrossbeam))
        npixfacet = 100
        fftpad = 1.1

        image_centres = da.from_array(np.array([[0, d0]]), chunks=(1, 2))
        chanmap = da.from_array(np.zeros(nchan, dtype=np.int64),
                                chunks=(nchan, ))
        detaper_facet = kernels.compute_detaper_dft_seperable(
            int(npixfacet * fftpad), kernels.unpack_kernel(kern, W, OS), W, OS)
        vis_dft = da.ones(shape=(nrow, nchan, 2),
                          chunks=(row_chunks, nchan, 2),
                          dtype=np.complex64)
        vis_grid_facet = dwrap.gridder(uvw,
                                       vis_dft,
                                       wavelength,
                                       chanmap,
                                       int(npixfacet * fftpad),
                                       cell * 3600.0,
                                       image_centres, (0, d0),
                                       kern,
                                       W,
                                       OS,
                                       "None",
                                       "None",
                                       "I_FROM_XXYY",
                                       "conv_1d_axisymmetric_packed_scatter",
                                       do_normalize=True)

        vis_grid_facet = vis_grid_facet.compute()

        ftvisfacet = (np.fft.fftshift(
            np.fft.ifft2(np.fft.ifftshift(vis_grid_facet[0, :, :]))).reshape(
                (1, int(npixfacet * fftpad), int(
                    npixfacet * fftpad)))).real / detaper_facet * int(
                        npixfacet * fftpad)**2
        ftvisfacet = ftvisfacet[:,
                                int(npixfacet * fftpad) // 2 -
                                npixfacet // 2:int(npixfacet * fftpad) // 2 -
                                npixfacet // 2 + npixfacet,
                                int(npixfacet * fftpad) // 2 -
                                npixfacet // 2:int(npixfacet * fftpad) // 2 -
                                npixfacet // 2 + npixfacet]
    print(tictoc)
    assert (np.abs(np.max(ftvisfacet[0, :, :]) - 1.0) < 1.0e-6)