Beispiel #1
0
def robin_transform_accurate(bc, d_v, d_w, args, wl=1050e-9, dL=6.25e-9):
    # bc: the boundary field to be transform
    # d_v: the derivative of fields in v
    # d_w: the derivative of fields in w

    # first try the 0th order
    print("means: ", torch.mean(bc), torch.mean(d_v), torch.mean(d_w))
    d_v_complex = d_v.squeeze()[0] + 1j * d_v.squeeze()[1]
    d_w_complex = d_w.squeeze()[0] + 1j * d_w.squeeze()[1]
    d_v_fft = torch.fft.fft(d_v_complex)
    d_w_fft = torch.fft.fft(d_w_complex)
    size = d_v_fft.shape[0]
    omega = 2 * np.pi / (wl / dL)
    mod_fre_v = [
        np.sqrt(np.complex((2 * np.pi * k / size)**2 - omega**2))
        for k in list(range(0, int(size / 2))) + list(range(-int(size / 2), 0))
    ]
    mod_fre_w = [
        -np.sqrt(np.complex((2 * np.pi * k / size)**2 - omega**2))
        for k in list(range(0, int(size / 2))) + list(range(-int(size / 2), 0))
    ]
    d_v_modulated_fft = d_v_fft * torch.tensor(mod_fre_v)
    d_v_modulated_ifft = torch.fft.ifft(d_v_modulated_fft)
    d_w_modulated_fft = d_w_fft * torch.tensor(mod_fre_w)
    d_w_modulated_ifft = torch.fft.ifft(d_w_modulated_fft)

    d_v_modulated_ifft_RI = torch.stack(
        [torch.real(d_v_modulated_ifft),
         torch.imag(d_v_modulated_ifft)]).reshape(bc.shape)
    d_w_modulated_ifft_RI = torch.stack(
        [torch.real(d_w_modulated_ifft),
         torch.imag(d_w_modulated_ifft)]).reshape(bc.shape)

    return bc + (d_w_modulated_ifft_RI - d_v_modulated_ifft_RI)
Beispiel #2
0
def get_group_delay(
    raw_data: torch.Tensor,
    sampling_rate_in_hz: int,
    window_length_in_s: float,
    window_shift_in_s: float,
    num_fft_points: int,
    window_type: str,
):
    X_stft_transform = _get_stft(raw_data,
                                 sampling_rate_in_hz,
                                 window_length_in_s,
                                 window_shift_in_s,
                                 num_fft_points,
                                 window_type=window_type)
    Y_stft_transform = _get_stft(
        raw_data,
        sampling_rate_in_hz,
        window_length_in_s,
        window_shift_in_s,
        num_fft_points,
        window_type=window_type,
        data_transformation="group_delay",
    )
    X_stft_transform_real = torch.real(X_stft_transform)
    X_stft_transform_imag = torch.imag(X_stft_transform)
    Y_stft_transform_real = torch.real(Y_stft_transform)
    Y_stft_transform_imag = torch.imag(Y_stft_transform)
    nominator = torch.multiply(
        X_stft_transform_real, Y_stft_transform_real) + torch.multiply(
            X_stft_transform_imag, Y_stft_transform_imag)
    denominator = torch.square(torch.abs(X_stft_transform))
    group_delay = torch.divide(nominator, denominator + 1e-10)
    assert not torch.isnan(
        group_delay).any(), "There are NaN values in group delay"
    return torch.transpose(group_delay, 0, 1)
Beispiel #3
0
def heightmap_initializer(focal_length,
                          resolution=1248,
                          pixel_pitch=6.4e-6,
                          refractive_idc=1.43,
                          wavelength=530e-9,
                          init_lens='fresnel'):
    """
    Initialize heightmap before training
    :param focal_length: float - distance between phase mask and sensor
    :param resolution: int - size of phase mask
    :param pixel_pitch: float - pixel size of phase mask
    :param refractive_idc: float - refractive index of phase mask
    :param wavelength: float - wavelength of light
    :param init_lens: str - type of lens to initialize
    :return: height map
    """
    if init_lens == 'fresnel' or init_lens == 'plano':
        convex_radius = (refractive_idc -
                         1.) * focal_length  # based on lens maker formula

        N = resolution
        M = resolution
        [x, y] = np.mgrid[-(N // 2):(N + 1) // 2,
                          -(M // 2):(M + 1) // 2].astype(np.float64)

        x = x * pixel_pitch
        y = y * pixel_pitch

        # get lens thickness by paraxial approximations
        heightmap = -(x**2 + y**2) / 2. * (1. / convex_radius)
        if init_lens == 'fresnel':
            phases = utils.heightmap_to_phase(heightmap, wavelength,
                                              refractive_idc)
            fresnel = simple_to_fresnel_lens(phases)
            heightmap = utils.phase_to_heightmap(fresnel, wavelength,
                                                 refractive_idc)

    elif init_lens == 'flat':
        heightmap = torch.ones((resolution, resolution)) * 0.0001
    else:
        heightmap = torch.rand((resolution, resolution)) * pixel_pitch
        gauss_filter = fspecial_gauss(10, 5)

        heightmap = utils.stack_complex(torch.real(heightmap),
                                        torch.imag(heightmap))
        gauss_filter = utils.stack_complex(torch.real(gauss_filter),
                                           torch.imag(gauss_filter))
        heightmap = utils.conv_fft(heightmap, gauss_filter)
        heightmap = heightmap[:, :, 0]

    return torch.Tensor(heightmap)
Beispiel #4
0
def complex_to_channels(image, requires_grad=False):
    """Convert data from complex to channels."""
    image_out = torch.stack([torch.real(image), torch.imag(image)], axis=-1)
    shape_out = torch.cat([torch.shape(image)[:-1], [image.shape[-1] * 2]],
                          axis=0)
    image_out = torch.reshape(image_out, shape_out)
    return image_out
Beispiel #5
0
def reshape_complex_vals_to_adj_channels(arr):
    ''' reshape complex tensor dim [nc,x,y] --> real tensor dim [2*nc,x,y]
        s.t. concat([nc,x,y] real, [nc,x,y] imag), i.e. not alternating real/imag 
        inverse operation of reshape_adj_channels_to_complex_vals() '''

    assert is_complex(arr)  # input should be complex-valued

    return torch.cat([torch.real(arr), torch.imag(arr)])
Beispiel #6
0
def test_imag(dtype, input_cur):
    backend = pytorch_backend.PyTorchBackend()
    cur = backend.convert_to_tensor(input_cur)
    acual = backend.imag(cur)
    expected = torch.imag(cur)
    np.testing.assert_allclose(acual, expected)
    cur = backend.convert_to_tensor(np.array([1, 2]))
    np.testing.assert_allclose(backend.imag(cur), np.array([0, 0]))
Beispiel #7
0
def imag(input_):
    """Wrapper of `torch.imag`.

    Parameters
    ----------
    input_ : DTensor
        Input dense tensor.
    """
    return torch.imag(input_._data)
Beispiel #8
0
 def _get_fft_basis(self):
     fourier_basis = torch.fft.rfft(torch.eye(self.filter_length))
     cutoff = 1 + self.filter_length // 2
     fourier_basis = torch.cat([
         torch.real(fourier_basis[:, :cutoff]),
         torch.imag(fourier_basis[:, :cutoff])
     ],
                               dim=1)
     return fourier_basis.float()
Beispiel #9
0
def SLsheardec2D_pytorch(X, shearlets):
    #def SLsheardec2D(X, shearletSystem):
    _, _, n_shearlets = shearlets.shape

    coeffs = torch.zeros(shearlets.shape,
                         dtype=torch.complex128,
                         device=X.device)
    Xfreq = torch.fft.fftshift(torch.fft.fft2(torch.fft.ifftshift(X)))
    for j in range(n_shearlets):
        coeffs[:, :, j] = torch.fft.fftshift(
            torch.fft.ifft2(
                torch.fft.ifftshift(Xfreq * torch.conj(shearlets[:, :, j]))))
    if torch.imag(coeffs).max() > 5e-8:
        print("Warning: magnitude in imaginary part exceeded 5e-08.")
        print("Data is probably not real-valued. Largest magnitude: " +
              str(torch.imag(coeffs).max()))
        print("Imaginary part neglected.")
    return torch.real(coeffs)
Beispiel #10
0
def get_inv_spatial_weight(psf_grid):
    #N,3,W,H
    F_psf_grid = torch.fft.rfft(psf_grid, 2)
    F_psf_grid = torch.stack([torch.real(F_psf_grid),
                              torch.imag(F_psf_grid)],
                             dim=-1)
    F_psf_grid_norm = F_psf_grid[..., 0]**2 + F_psf_grid[..., 1]**2
    F_psf_grid_norm = torch.mean(F_psf_grid_norm, dim=(2, 3))
    #F_psf_grid_norm = torch.mean(F_psf_grid_norm, dim=2)
    return F_psf_grid_norm
Beispiel #11
0
def rfft(t):
    # Real-to-complex Discrete Fourier Transform
    ver = torch.__version__
    major, minor, ver = ver.split('.')
    ver_int = int(major) * 100 + int(minor)
    if ver_int >= 108:
        ft = torch.fft.fft2(t)
        ft = torch.stack([torch.real(ft), torch.imag(ft)], dim=-1)
    else:
        ft = torch.rfft(t, 2, onesided=False)
    return ft
Beispiel #12
0
 def FT(f_i, x):
     shift = 1
     C_b = torch.fft(f_i, 1)
     N_2 = int(len(f_i) / 2)
     zer = torch.Tensor([0])
     im_shift = torch.Tensor([2 * np.pi * shift * torch.sum(x)])
     F_y = torch.tensor([
         torch.complex(C_b[b][0], C_b[b][1]) * torch.exp(
             torch.complex(
                 zer, torch.Tensor([2 * np.pi * b * (torch.sum(x))])))
         for b in range(-N_2, N_2)
     ])
     f_star = (torch.exp(torch.complex(zer, im_shift)) * torch.sum(F_y))
     return torch.tensor([torch.real(f_star), torch.imag(f_star)])
Beispiel #13
0
    def backward(ctx, grad_output):
        # resid = grad_output: (nrhs,nx), traveltime difference
        # b: (nrhs,nx,ny)
        virt, frd = ctx.saved_tensors
        model = ctx.model
        ry = ctx.ry

        if model.device == 'cpu':
            resid = -model.prop.omega.real * grad_output.numpy() / frd
        else:
            resid = -model.prop.omega.real * to_cupy(grad_output) / frd

        b = model.prop.solve_resid(resid, ry)
        grad_input = torch.sum(torch.imag(virt * to_tensor(b)).to(
            torch.float32),
                               dim=0)
        return grad_input, None
Beispiel #14
0
    def training_step(self, batch, batch_idx):
        # Batch contains a set of trajectories
        encodings = self(batch)
        if self.decoder is not None:
            decodings = self.decoder(encodings)
        else:
            decodings = torch.nn.functional.linear(
                encodings, self.encoder.weight.transpose(0, 1))

        reconstruction_loss = self.krecon * F.mse_loss(decodings, batch,
                                                       reduction="mean")

        complex_encodings = torch.view_as_complex(
            encodings.view(encodings.shape[0], encodings.shape[1],
                           -1, 2))

        # Minimize changes in magnitude (A perfect network will not vary the
        # magnitude at all)
        mags = complex_encodings.abs()
        if self.local_delta:
            delta_mags = mags[:, 1:, :] - mags[:, :-1, :]
            constant_mag_loss = self.kmag * delta_mags.square().mean()
        else:
            constant_mag_loss = self.kmag * torch.var(mags, dim=1).mean()

        # Minimize changes in phase velocity (A perfect network will use
        # constant changes in phase)
        phases = torch.atan2(torch.imag(complex_encodings),
                             torch.real(complex_encodings))
        # phases = complex_encodings.angle()   # Not supported by autograd...
        delta_phases = phases[:, 1:, :] - phases[:, :-1, :]
        delta_phases = torch.where(delta_phases <= -math.pi,
                                   delta_phases + (2 * math.pi),
                                   delta_phases)
        delta_phases = torch.where(delta_phases > math.pi,
                                   delta_phases - (2 * math.pi),
                                   delta_phases)
        if self.local_delta:
            d2_phases = delta_phases[:, 1:] - delta_phases[:, :-1]
            linear_phase_loss = self.kphase * d2_phases.square().mean()
        else:
            linear_phase_loss = self.kphase * torch.var(delta_phases, dim=1).mean()

        return (reconstruction_loss
                + constant_mag_loss
                + linear_phase_loss)
def fft_new(z, x, label):
    zf = fft.rfftn(z, dim=[-2, -1])
    xf = fft.rfftn(x, dim=[-2, -1])

    # R[batch, 1, 121, 61]
    kzzf = torch.sum(torch.real(zf)**2 + torch.imag(zf)**2,
                     dim=1,
                     keepdim=True)

    # C[batch, 1, 121, 61]
    t = xf * torch.conj(zf)
    kxzf = torch.sum(t, dim=1, keepdim=True)

    # C[batch, 1, 121, 61, 2]
    alphaf = label.to(device=z.device) / (kzzf + lambda0)

    # R[batch, 1, 121, 121]
    return fft.irfftn(kxzf * alphaf, s=[121, 121], dim=[-2, -1])
Beispiel #16
0
def load_PSF_OTF(filename,
                 vol_size,
                 n_split=20,
                 n_depths=120,
                 downS=1,
                 device="cpu",
                 dark_current=106,
                 calc_max=False,
                 psfIn=None,
                 compute_transpose=False,
                 n_lenslets=29,
                 lenslet_centers_file_out='lenslet_centers_python.txt'):
    # Load PSF
    if psfIn is None:
        psfIn = load_PSF(filename, n_depths)

    if len(lenslet_centers_file_out) > 0:
        find_lenslet_centers(psfIn[0, n_depths // 2, ...].numpy(),
                             n_lenslets=n_lenslets,
                             file_out_name=lenslet_centers_file_out)
    if calc_max:
        psfMaxCoeffs = torch.amax(psfIn, dim=[0, 2, 3])

    psf_shape = torch.tensor(psfIn.shape[2:])
    vol = torch.rand(1,
                     psfIn.shape[1],
                     vol_size[0],
                     vol_size[1],
                     device=device)
    img, OTF = fft_conv_split(vol,
                              psfIn.float().detach().to(device),
                              psf_shape,
                              n_split=n_split,
                              device=device)

    OTF = OTF.detach()

    if compute_transpose:
        OTFt = torch.real(OTF) - 1j * torch.imag(OTF)
        OTF = torch.cat((OTF.unsqueeze(-1), OTFt.unsqueeze(-1)), 4)
    if calc_max:
        return OTF, psf_shape, psfMaxCoeffs
    else:
        return OTF, psf_shape
Beispiel #17
0
def init(
    flim: Tuple, b1lim: Tuple, nf: int, nb: int, b1max: Number,
    device: torch.device, dtype: torch.dtype
) -> Tuple[dict, mobjs.SpinCube, mobjs.Pulse, Tensor]:

    dkw = {'device': device, 'dtype': dtype}

    fov, ofst = tensor([[0., 0., 0.]], **dkw), tensor([[0., 0., 0.]], **dkw)
    imsize = (1,) + (1, nb, nf)

    tmp1 = torch.ones(imsize, **dkw)
    b0map = tmp1 * torch.linspace(*flim, nf, **dkw)
    b1map = tmp1 * torch.linspace(*b1lim, nb, **dkw)[..., None]
    b1map = torch.stack((b1map, torch.zeros(b1map.shape, **dkw)), dim=-1)
    weight = torch.ones(imsize, **dkw)

    cube = mobjs.SpinCube(imsize, fov, ofst=ofst, Δf=b0map, **dkw)

    fn_target = lambda d_, weight_: {'d_': d_, 'weight_': weight_}  # noqa:E731

    d = torch.stack((torch.zeros(imsize, **dkw),
                     torch.zeros(imsize, **dkw),
                     -torch.ones(imsize, **dkw)), axis=-1)

    target = fn_target(cube.extract(d), cube.extract(weight))

    dt = mrphy.dt0.to(**dkw)  # Sec, torch.Tensor
    beta = 5.3

    rf_peak = 0.8 * b1max

    fn_adiabatic = fullpassage

    tp = 1.5e-3  # Sec
    bw = 0.81e3  # Hz

    rf_c = rf_peak * fn_adiabatic(tp, beta, bw, dt.item())[None, None, ...]
    rf_c = rf_c.to(device=dkw['device'])
    rf = torch.cat((torch.real(rf_c), torch.imag(rf_c)), dim=1)
    gr = torch.zeros((1, 3, rf_c.shape[2]), **dkw)

    pulse = mobjs.Pulse(rf, gr, rfmax=b1max, dt=dt, **dkw)

    return target, cube, pulse, b1map
Beispiel #18
0
def steepest_ascent_direction(grad, norm_type, eps_tot):
    shape = grad.shape
    if norm_type == 'dftinf':
        dftxgrad = torch.fft.fftn(grad, dim=(-2, -1), norm='ortho')
        dftz = dftxgrad.reshape(1, -1)
        dftz = torch.cat((torch.real(dftz), torch.imag(dftz)), dim=0)

        def l2_normalize(delta, eps):
            avoid_zero_div = 1e-15
            norm2 = torch.sum(delta**2, dim=0, keepdim=True)
            norm = torch.sqrt(torch.clamp(norm2, min=avoid_zero_div))
            delta = delta * eps / norm
            return delta

        dftz = l2_normalize(dftz, eps_tot)
        dftz = (dftz[0, :] + 1j * dftz[1, :]).reshape(shape)
        delta = torch.fft.ifftn(dftz, dim=(-2, -1), norm='ortho')
        adv_step = torch.real(delta)
    return adv_step
Beispiel #19
0
def norm_projection(delta, norm_type, eps=1.):
    """Projects to a norm-ball centered at 0.

  Args:
    delta: An array of size dim x num containing vectors to be projected.
    norm_type: A string denoting the type of the norm-ball.
    eps: A float denoting the radius of the norm-ball.

  Returns:
    An array of size dim x num, the projection of delta to the norm-ball.
  """
    shape = delta.shape
    if norm_type == 'l2':
        # Euclidean projection: divide all elements by a constant factor
        avoid_zero_div = 1e-12
        norm2 = np.sum(delta**2, axis=0, keepdims=True)
        norm = np.sqrt(np.maximum(avoid_zero_div, norm2))
        # only decrease the norm, never increase
        delta = delta * np.clip(eps / norm, a_min=None, a_max=1)
    elif norm_type == 'dftinf':
        # transform to DFT, project using known projections, then transform back
        # n x d x h x w
        dftxdelta = torch.fft.fftn(delta, dim=(-2, -1), norm='ortho')
        # L2 projection of each coordinate to the L2-ball in the complex plane
        dftz = dftxdelta.reshape(1, -1)
        dftz = torch.cat((torch.real(dftz), torch.imag(dftz)), dim=0)

        def l2_proj(delta, eps):
            avoid_zero_div = 1e-15
            norm2 = torch.sum(delta**2, dim=0, keepdim=True)
            norm = torch.sqrt(torch.clamp(norm2, min=avoid_zero_div))
            # only decrease the norm, never increase
            delta = delta * torch.clamp(eps / norm, max=1)
            return delta

        dftz = l2_proj(dftz, eps)
        dftz = (dftz[0, :] + 1j * dftz[1, :]).reshape(delta.shape)
        # project back from DFT
        delta = torch.fft.ifftn(dftz, dim=(-2, -1), norm='ortho')
        # Projected vector can have an imaginary part
        delta = torch.real(delta)
    return delta.reshape(shape)
Beispiel #20
0
 def multi(x, y):
     a = torch.real(x)
     b = torch.imag(x)
     c = torch.real(y)
     d = torch.imag(y)
     return b * c + a * d
Beispiel #21
0
 def multr(x, y):
     a = torch.real(x)
     b = torch.imag(x)
     c = torch.real(y)
     d = torch.imag(y)
     return a * c - b * d
Beispiel #22
0
    Ym1 = fft(Sm1, axis=0, shift=True)

    f = th.linspace(-Fsr / 2., Fsr / 2., len(Sm2))
    Ym2 = fft(Sm2, axis=0, shift=True)

    plt.figure(1)
    plt.subplot(221)
    plt.plot(t * 1e6, th.real(Sm1))
    plt.plot(t * 1e6, th.abs(Sm1))
    plt.grid()
    plt.legend(['Real part', 'Amplitude'])
    plt.title('Convolution matched filter')
    plt.xlabel(r'Time/$\mu s$')
    plt.ylabel('Amplitude')
    plt.subplot(222)
    plt.plot(t * 1e6, th.imag(Sm1))
    plt.plot(t * 1e6, th.abs(Sm1))
    plt.grid()
    plt.legend(['Imaginary part', 'Amplitude'])
    plt.title('Convolution matched filter')
    plt.xlabel(r'Time/$\mu s$')
    plt.ylabel('Amplitude')
    plt.subplot(223)
    plt.plot(f, th.abs(Ym1))
    plt.grid()
    plt.subplot(224)
    plt.plot(f, th.angle(Ym1))
    plt.grid()

    plt.figure(2)
    plt.subplot(221)
Beispiel #23
0
def main(args, solver):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    ds = DDM_Dataset(args.data_folder, total_sample_number = None)
    torch.manual_seed(42)
    DDM_loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0)

    # df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'])

    model_bs = args.x_patches*args.y_patches
    size_x = args.domain_sizex+(args.x_patches-1)*(args.domain_sizex-args.overlap_pixels)
    size_y = args.domain_sizey+(args.y_patches-1)*(args.domain_sizey-args.overlap_pixels)
    device_losses = np.zeros((args.num_device, args.DDM_iters))

    history_fields=np.zeros((len(DDM_loader), args.DDM_iters+1,model_bs,2,args.domain_sizex,args.domain_sizey))
    x_batch_trains=np.zeros((len(DDM_loader), model_bs,1,args.domain_sizex,args.domain_sizey)) 
    y_batch_trains=np.zeros((len(DDM_loader), model_bs,2,args.domain_sizex,args.domain_sizey)) 
    
    print("shapes:", history_fields.shape, x_batch_trains.shape, y_batch_trains.shape)

    for sample_id, sample_batched in enumerate(DDM_loader):
        if sample_id>=args.num_device:
            break
        if sample_id%20 == 0:
            print("sample_id: ", sample_id, flush=True)

        DDM_img, DDM_Hy= sample_batched['structure'], sample_batched['field']

        # prepare the input batched subdomains to model:
        x_batch_train = [DDM_img[0, 0, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                        args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] for i in range(args.x_patches) for j in range(args.y_patches)]
        x_batch_train = torch.stack(x_batch_train).reshape(model_bs,1,args.domain_sizex,args.domain_sizey)

        yeex_batch_train = [1/2*(DDM_img[0, 0, 0+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                               1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \
                                 DDM_img[0, 0,-1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -2+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                               1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \
                            for i in range(args.x_patches) for j in range(args.y_patches)]
        yeex_batch_train = torch.stack(yeex_batch_train).reshape(model_bs,1,args.domain_sizex-1,args.domain_sizey-2)

        yeey_batch_train = [1/2*(DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                               0+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -1+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)] + \
                                 DDM_img[0, 0, 1+args.starting_x+i*(args.domain_sizex-args.overlap_pixels) : -1+args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                              -1+args.starting_y+j*(args.domain_sizey-args.overlap_pixels) : -2+args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]) \
                            for i in range(args.x_patches) for j in range(args.y_patches)]
        yeey_batch_train = torch.stack(yeey_batch_train).reshape(model_bs,1,args.domain_sizex-2,args.domain_sizey-1)

        y_batch_train = [DDM_Hy[0, :, args.starting_x+i*(args.domain_sizex-args.overlap_pixels):args.starting_x+args.domain_sizex+i*(args.domain_sizex-args.overlap_pixels),\
                                       args.starting_y+j*(args.domain_sizey-args.overlap_pixels):args.starting_y+args.domain_sizey+j*(args.domain_sizey-args.overlap_pixels)]  for i in range(args.x_patches) for j in range(args.y_patches)]
        y_batch_train = torch.stack(y_batch_train).reshape(model_bs,2,args.domain_sizex,args.domain_sizex)

        intep_field, patched_solved = init_four_point_interp(DDM_Hy, prop2, args)

        b, _, n, m = patched_solved.shape
        last_gs = np.zeros((b,n,m), dtype=np.csingle)

        # history_fields=np.zeros((args.DDM_iters+1,b,2,n,m))
        # history_fields[0, :, :, :, :] = patched_solved
        history_fields[sample_id, 0, :, :, :, :] = patched_solved
        x_batch_trains[sample_id] = x_batch_train
        y_batch_trains[sample_id] = y_batch_train
        

        for k in range(args.DDM_iters):
            for idx in range(model_bs):
                # left, right, top, bottom
                ops = [solver.bc_pade_operator(1/2*(x_batch_train[idx,0, :, 0]+x_batch_train[idx,0, :, 1]).numpy()), solver.bc_pade_operator(1/2*(x_batch_train[idx,0, :, -2]+x_batch_train[idx,0, :,-1]).numpy()), solver.bc_pade_operator(1/2*(x_batch_train[idx,0, 1, :]+x_batch_train[idx,0, 0, :]).numpy()), solver.bc_pade_operator(1/2*(x_batch_train[idx,0, -2, :]+x_batch_train[idx,0,-1, :]).numpy())]
                if k==0:
                    g, alpha, beta, gamma, g_mul = trasmission_pade_g_alpha_beta_gamma_complex(patched_solved, x_batch_train, idx, args.transmission_func, args, ops)
                else:
                    g, alpha, beta, gamma, g_mul = trasmission_pade_g_alpha_beta_gamma_complex(patched_solved, x_batch_train, idx, args.transmission_func, args, ops, last_gs[idx])
                last_gs[idx] = g

                A,b = solver.construct_matrices_complex(g, ops, yeex_batch_train[idx,0], yeey_batch_train[idx,0], alpha, beta, gamma, g_mul)
                
                field_vec = torch.tensor(solver.solve(A, b).reshape((args.domain_sizex, args.domain_sizey)), dtype=torch.cfloat)
                field_vec_real = torch.real(field_vec)
                field_vec_imag = torch.imag(field_vec)

                solved = torch.stack([field_vec_real, field_vec_imag], dim=0)
                patched_solved[idx] = solved

            history_fields[sample_id, k+1, :, :, :, :] = patched_solved
            
            # reconstruct the whole field
            intermediate_result = reconstruct(patched_solved, args)
            # print("shapes: ",intermediate_result.shape, DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].shape )
            diff = intermediate_result.contiguous().view(1,-1) - DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].contiguous().view(1,-1)
            loss = torch.mean(torch.abs(diff)) / \
                   torch.mean(torch.abs(DDM_Hy[:, :, args.starting_x:args.starting_x+size_x, args.starting_y:args.starting_y+size_y].contiguous().view(1,-1)))
            print(f"iter {k}, loss {loss}")
            device_losses[sample_id, k] = loss

    np.save(args.output_folder+f"/sx_{args.starting_x}_sy_{args.starting_y}_field_history.npy", history_fields)
    np.save(args.output_folder+f"/sx_{args.starting_x}_sy_{args.starting_y}_eps.npy", x_batch_trains)
    np.save(args.output_folder+f"/sx_{args.starting_x}_sy_{args.starting_y}_Hy_gt.npy", y_batch_trains)

    plt.figure()
    plt.plot(list(range(args.DDM_iters)), device_losses.T)
    plt.legend([f"device_{name}" for name in range(args.num_device)])
    plt.xlabel("iteration")
    plt.yscale('log')
    plt.ylabel("Relative Error")
    plt.savefig(args.output_folder+f"/sx_{args.starting_x}_sy_{args.starting_y}_device_loss.png", dpi=300)    
    np.save(args.output_folder+f"/sx_{args.starting_x}_sy_{args.starting_y}_device_losses.npy", device_losses)
Beispiel #24
0
 def imag(self, ):
     """Return the imaginary part of the vector"""
     self.getNdArray()[:] = torch.imag(self.getNdArray())
     return self
Beispiel #25
0
        segment_noisy_mag = segment[2]
        segment_noisy_pha = segment[3]
        center_noisy_mag = segment_noisy_mag[:, 2]
        center_noisy_pha = segment_noisy_pha[:, 2]
        centers_noisy_mag.append(center_noisy_mag)
        centers_noisy_pha.append(center_noisy_pha)

    clean_mag = np.stack(centers_clean_mag).T
    clean_pha = np.stack(centers_clean_pha).T

    noisy_mag = np.stack(centers_noisy_mag).T
    noisy_pha = np.stack(centers_noisy_pha).T

    clean_mag = np.exp(clean_mag)
    clean_stft = clean_mag * np.exp(1j * clean_pha)
    clean_real, clean_imag = torch.real(torch.from_numpy(clean_stft)), torch.imag(torch.from_numpy(clean_stft))
    clean_stft = torch.stack([clean_real, clean_imag], dim=-1)
    # clean = np.stack([clean_mag, np.expand_dims(clean_pha, axis=0)], -1)
    clean_wav = torch.istft(clean_stft, 400, 160)
    wavwrite('../save_wav/test_clean.wav', 16000, clean_wav.numpy())

    noisy_mag = np.exp(noisy_mag)
    noisy_stft = noisy_mag * np.exp(1j * noisy_pha)
    noisy_real, noisy_imag = torch.real(torch.from_numpy(noisy_stft)), torch.imag(torch.from_numpy(noisy_stft))
    noisy_stft = torch.stack([noisy_real, noisy_imag], dim=-1)
    noisy_wav = torch.istft(noisy_stft, 400, 160)
    wavwrite('../save_wav/test_noisy.wav', 16000, noisy_wav.numpy())

    # timestamps = len(id_files)
    #
    # plt.figure()
Beispiel #26
0
def imag(a: Numeric):
    return torch.imag(a)
Beispiel #27
0
 def imag(self, x):
     dtype = self.dtype(x)
     if dtype.kind == complex:
         return torch.imag(x)
     else:
         return self.zeros(x.shape, DType(float, dtype.precision))
Beispiel #28
0
def hartley(im):
    ft = fft2c(im)
    hart = torch.real(ft) - torch.imag(ft)
    return hart
Beispiel #29
0
def mother_wavelet(x, F, L, S):
    arg = 1j * 2 * np.pi * ((2**S) * torch.matmul(x, F.T) - L)
    wl = (2**(S / 2)) * (torch.exp(2 * arg) - torch.exp(arg)) / arg
    wl = wl / torch.max(torch.abs(wl))
    return torch.cat([torch.real(wl), torch.imag(wl)], axis=-1)
Beispiel #30
0
 def imag(self):
     self.getNdArray()[:] = torch.imag(self.getNdArray())
     return self