Exemplo n.º 1
0
def robin_transform_accurate(bc, d_v, d_w, args, wl=1050e-9, dL=6.25e-9):
    # bc: the boundary field to be transform
    # d_v: the derivative of fields in v
    # d_w: the derivative of fields in w

    # first try the 0th order
    print("means: ", torch.mean(bc), torch.mean(d_v), torch.mean(d_w))
    d_v_complex = d_v.squeeze()[0] + 1j * d_v.squeeze()[1]
    d_w_complex = d_w.squeeze()[0] + 1j * d_w.squeeze()[1]
    d_v_fft = torch.fft.fft(d_v_complex)
    d_w_fft = torch.fft.fft(d_w_complex)
    size = d_v_fft.shape[0]
    omega = 2 * np.pi / (wl / dL)
    mod_fre_v = [
        np.sqrt(np.complex((2 * np.pi * k / size)**2 - omega**2))
        for k in list(range(0, int(size / 2))) + list(range(-int(size / 2), 0))
    ]
    mod_fre_w = [
        -np.sqrt(np.complex((2 * np.pi * k / size)**2 - omega**2))
        for k in list(range(0, int(size / 2))) + list(range(-int(size / 2), 0))
    ]
    d_v_modulated_fft = d_v_fft * torch.tensor(mod_fre_v)
    d_v_modulated_ifft = torch.fft.ifft(d_v_modulated_fft)
    d_w_modulated_fft = d_w_fft * torch.tensor(mod_fre_w)
    d_w_modulated_ifft = torch.fft.ifft(d_w_modulated_fft)

    d_v_modulated_ifft_RI = torch.stack(
        [torch.real(d_v_modulated_ifft),
         torch.imag(d_v_modulated_ifft)]).reshape(bc.shape)
    d_w_modulated_ifft_RI = torch.stack(
        [torch.real(d_w_modulated_ifft),
         torch.imag(d_w_modulated_ifft)]).reshape(bc.shape)

    return bc + (d_w_modulated_ifft_RI - d_v_modulated_ifft_RI)
Exemplo n.º 2
0
def get_group_delay(
    raw_data: torch.Tensor,
    sampling_rate_in_hz: int,
    window_length_in_s: float,
    window_shift_in_s: float,
    num_fft_points: int,
    window_type: str,
):
    X_stft_transform = _get_stft(raw_data,
                                 sampling_rate_in_hz,
                                 window_length_in_s,
                                 window_shift_in_s,
                                 num_fft_points,
                                 window_type=window_type)
    Y_stft_transform = _get_stft(
        raw_data,
        sampling_rate_in_hz,
        window_length_in_s,
        window_shift_in_s,
        num_fft_points,
        window_type=window_type,
        data_transformation="group_delay",
    )
    X_stft_transform_real = torch.real(X_stft_transform)
    X_stft_transform_imag = torch.imag(X_stft_transform)
    Y_stft_transform_real = torch.real(Y_stft_transform)
    Y_stft_transform_imag = torch.imag(Y_stft_transform)
    nominator = torch.multiply(
        X_stft_transform_real, Y_stft_transform_real) + torch.multiply(
            X_stft_transform_imag, Y_stft_transform_imag)
    denominator = torch.square(torch.abs(X_stft_transform))
    group_delay = torch.divide(nominator, denominator + 1e-10)
    assert not torch.isnan(
        group_delay).any(), "There are NaN values in group delay"
    return torch.transpose(group_delay, 0, 1)
Exemplo n.º 3
0
    def forward(self, y):
        with torch.no_grad():
            self.mean_eps = torch.mean(self.eps)
            #print('epsilon is {}'.format(self.eps))
        x = self.A.adjoint(y)
        z = self.A(x)
        z_old = z
        u = z.new_zeros(z.shape)

        x.requires_grad = False
        z.requires_grad = False
        z_old.requires_grad = False
        u.requires_grad = False

        self.num_cg = np.zeros((
            self.hparams.num_unrolls,
            self.hparams.num_admm,
        ))

        for i in range(self.hparams.num_unrolls):
            r = self.denoiser(x)

            for j in range(self.hparams.num_admm):

                rhs = self.l2lam * self.A.adjoint(z - u) + r
                fun = lambda xx: self.l2lam * self.A.normal(xx) + xx
                cg_op = ConjGrad(rhs,
                                 fun,
                                 max_iter=self.hparams.cg_max_iter,
                                 eps=self.hparams.cg_eps,
                                 verbose=False)
                x = cg_op.forward(x)
                n_cg = cg_op.num_cg
                self.num_cg[i, j] = n_cg

                Ax_plus_u = self.A(x) + u
                z_old = z
                z = y + opt.l2ball_proj_batch(Ax_plus_u - y, self.eps)
                u = Ax_plus_u - z

                # check ADMM convergence
                with torch.no_grad():
                    Ax = self.A(x)
                    tmp = Ax - z
                    tmp = tmp.contiguous()
                    r_norm = torch.real(opt.zdot_single_batch(tmp)).sqrt()

                    tmp = self.l2lam * self.A.adjoint(z - z_old)
                    tmp = tmp.contiguous()
                    s_norm = torch.real(opt.zdot_single_batch(tmp)).sqrt()

                    if (r_norm + s_norm).max() < 1E-2:
                        if self.debug_level > 0:
                            tqdm.tqdm.write('stopping early, a={}'.format(a))
                        break
                    tmp = y - Ax
                    self.mean_residual_norm = torch.mean(
                        torch.sqrt(torch.real(opt.zdot_single_batch(tmp))))
        return x
Exemplo n.º 4
0
def heightmap_initializer(focal_length,
                          resolution=1248,
                          pixel_pitch=6.4e-6,
                          refractive_idc=1.43,
                          wavelength=530e-9,
                          init_lens='fresnel'):
    """
    Initialize heightmap before training
    :param focal_length: float - distance between phase mask and sensor
    :param resolution: int - size of phase mask
    :param pixel_pitch: float - pixel size of phase mask
    :param refractive_idc: float - refractive index of phase mask
    :param wavelength: float - wavelength of light
    :param init_lens: str - type of lens to initialize
    :return: height map
    """
    if init_lens == 'fresnel' or init_lens == 'plano':
        convex_radius = (refractive_idc -
                         1.) * focal_length  # based on lens maker formula

        N = resolution
        M = resolution
        [x, y] = np.mgrid[-(N // 2):(N + 1) // 2,
                          -(M // 2):(M + 1) // 2].astype(np.float64)

        x = x * pixel_pitch
        y = y * pixel_pitch

        # get lens thickness by paraxial approximations
        heightmap = -(x**2 + y**2) / 2. * (1. / convex_radius)
        if init_lens == 'fresnel':
            phases = utils.heightmap_to_phase(heightmap, wavelength,
                                              refractive_idc)
            fresnel = simple_to_fresnel_lens(phases)
            heightmap = utils.phase_to_heightmap(fresnel, wavelength,
                                                 refractive_idc)

    elif init_lens == 'flat':
        heightmap = torch.ones((resolution, resolution)) * 0.0001
    else:
        heightmap = torch.rand((resolution, resolution)) * pixel_pitch
        gauss_filter = fspecial_gauss(10, 5)

        heightmap = utils.stack_complex(torch.real(heightmap),
                                        torch.imag(heightmap))
        gauss_filter = utils.stack_complex(torch.real(gauss_filter),
                                           torch.imag(gauss_filter))
        heightmap = utils.conv_fft(heightmap, gauss_filter)
        heightmap = heightmap[:, :, 0]

    return torch.Tensor(heightmap)
Exemplo n.º 5
0
def foa_intensity_vectors(complex_specs: torch.Tensor) -> torch.Tensor:
    if not torch.is_complex(complex_specs):
        complex_specs = torch.view_as_complex(complex_specs)

    # complex_specs: [chan, freq, time]
    IVx = torch.real(torch.conj(complex_specs[0]) * complex_specs[3])
    IVy = torch.real(torch.conj(complex_specs[0]) * complex_specs[1])
    IVz = torch.real(torch.conj(complex_specs[0]) * complex_specs[2])

    norm = torch.sqrt(IVx**2 + IVy**2 + IVz**2)
    IVx = IVx / norm
    IVy = IVy / norm
    IVz = IVz / norm

    # apply mel matrix without db ...
    return torch.stack([IVx, IVy, IVz], axis=0)
Exemplo n.º 6
0
def complex_to_channels(image, requires_grad=False):
    """Convert data from complex to channels."""
    image_out = torch.stack([torch.real(image), torch.imag(image)], axis=-1)
    shape_out = torch.cat([torch.shape(image)[:-1], [image.shape[-1] * 2]],
                          axis=0)
    image_out = torch.reshape(image_out, shape_out)
    return image_out
Exemplo n.º 7
0
def getLamdaGaplist(lambdas: torch.Tensor):
    """
    Calculate the gaps between lambda values.
    """
    if torch.is_complex(lambdas):
        lambdas = torch.real(lambdas)
    return lambdas[1:] - lambdas[:-1]
Exemplo n.º 8
0
def sinc_impulse_response(cutoff_frequency, window_size=512, sample_rate=None):
    """Get a sinc impulse response for a set of low-pass cutoff frequencies.

    Args:
        cutoff_frequency: Frequency cutoff for low-pass sinc filter. If the
            sample_rate is given, cutoff_frequency is in Hertz. If sample_rate
            is None, cutoff_frequency is normalized ratio (frequency/nyquist)
            in the range [0, 1.0]. Shape [batch_size, n_time, 1].
        window_size: Size of the Hamming window to apply to the impulse.
        sample_rate: Optionally provide the sample rate.

    Returns:
        impulse_response: A series of impulse responses. Shape
            [batch_size, n_time, (window_size // 2) * 2 + 1].
    """
    if sample_rate is not None:
        cutoff_frequency *= 2 / sample_rate
    half_size = window_size // 2
    full_size = half_size * 2 + 1
    idx = th.arange(-half_size, half_size + 1, dtype=th.float)[None, None, :]

    impulse_response = sinc(cutoff_frequency * idx)
    window = th.hamming_window(full_size).expand_as(impulse_response)
    impulse_response = window * th.real(impulse_response)
    return impulse_response / impulse_response.sum(-1, keepdim=True)
Exemplo n.º 9
0
def test_real(dtype, input_cur):
    backend = pytorch_backend.PyTorchBackend()
    cur = backend.convert_to_tensor(input_cur)
    acual = backend.real(cur)
    expected = torch.real(cur)
    np.testing.assert_allclose(acual, expected)
    cur = backend.convert_to_tensor(np.array([1, 2]))
    np.testing.assert_allclose(backend.real(cur), np.array([1, 2]))
Exemplo n.º 10
0
def reshape_complex_vals_to_adj_channels(arr):
    ''' reshape complex tensor dim [nc,x,y] --> real tensor dim [2*nc,x,y]
        s.t. concat([nc,x,y] real, [nc,x,y] imag), i.e. not alternating real/imag 
        inverse operation of reshape_adj_channels_to_complex_vals() '''

    assert is_complex(arr)  # input should be complex-valued

    return torch.cat([torch.real(arr), torch.imag(arr)])
Exemplo n.º 11
0
def _complex_native_complex(
    h: torch.FloatTensor,
    r: torch.FloatTensor,
    t: torch.FloatTensor,
) -> torch.FloatTensor:
    """Use torch built-ins for computation with complex numbers."""
    h, r, t = [view_complex(x=x) for x in (h, r, t)]
    return torch.real(tensor_product(h, r, torch.conj(t)).sum(dim=-1))
Exemplo n.º 12
0
    def construct_pyramid(self, image, n_levels, n_orientations):
        if image.size() != self.image_size or \
            n_levels != self.n_levels or \
            n_orientations != self.n_orientations:
            # Need to recalculate the filters.
            self.image_size = image.size()
            self.n_levels = n_levels
            self.n_orientations = n_orientations
            self.calculate_filters()
        ft = fftshift(fft2(image))

        curr_level = {}
        pyramid = []
        h0 = self.H0_FILT * ft
        curr_level['h'] = torch.real(ifft2(ifftshift(h0)))

        l0 = self.L0_FILT * ft
        curr_level['l'] = torch.real(ifft2(ifftshift(l0)))

        # apply bandpass filter(B) and downsample iteratively. save pyramid
        _last = l0
        for i in range(self.n_levels):
            curr_level['b'] = []
            for j in range(len(self.B_FILT[i])):
                lb = _last * self.B_FILT[i][j]
                curr_level['b'].append(torch.real(ifft2(ifftshift(lb))))

            # apply lowpass filter(L) to image(Fourier Domain) downsampled.
            l1 = _last * self.L_FILT[i]

            ## Downsampling
            down_size = [l1.size(-2) // 4, l1.size(-1) // 4]

            # extract the central part of DFT
            down_image = l1[:, :, down_size[0]:3 * down_size[0],
                            down_size[1]:3 * down_size[1]] / 4
            #
            _last = down_image.clone()
            pyramid.append(curr_level)
            curr_level = {}

        # lowpass residual
        curr_level['l'] = torch.real(ifft2(ifftshift(_last)))
        pyramid.append(curr_level)
        return pyramid
Exemplo n.º 13
0
def real(input_):
    """Wrapper of `torch.real`.

    Parameters
    ----------
    input_ : DTensor
        Input dense tensor.
    """
    return torch.real(input_._data)
Exemplo n.º 14
0
def get_principal_curvatures(shape_operator):
    """
    Performs an eigen decomposition of the shape operator
    Parameters:
        shape_operator - [M, M] dimensional array
    Returns:
        principal_curvatures - sorted eigenvalues of shape_operator
        principal_directions - sorted eigenvectors of shape_operator
            where principal_directions[:, i] is the vector corresponding to principal_curvatures[i]
    """
    dtype = shape_operator.dtype
    principal_curvatures, principal_directions = torch.linalg.eig(
        shape_operator)
    principal_curvatures = torch.real(principal_curvatures).type(dtype)
    principal_directions = torch.real(principal_directions).type(dtype)
    sort_indices = torch.argsort(principal_curvatures, descending=True)
    return principal_curvatures[
        sort_indices], principal_directions[:, sort_indices]
Exemplo n.º 15
0
    def forward(self, x):

        # DFT along hidden dimension followed by DFT along sequence dimension. Only keep real part of the result
        x_fft = torch.real(torch.fft.fft(torch.fft.fft(x).T).T)
        x = self.LayerNorm1(x + x_fft)

        x_ff = self.feedforward(x)
        x = self.LayerNorm2(x + x_ff)
        return x
Exemplo n.º 16
0
 def _get_fft_basis(self):
     fourier_basis = torch.fft.rfft(torch.eye(self.filter_length))
     cutoff = 1 + self.filter_length // 2
     fourier_basis = torch.cat([
         torch.real(fourier_basis[:, :cutoff]),
         torch.imag(fourier_basis[:, :cutoff])
     ],
                               dim=1)
     return fourier_basis.float()
Exemplo n.º 17
0
    def forward(self, inputs):
        suscp = inputs[0]
        kernel = inputs[1]

        ks = fft.fftn(suscp, dim=[-3, -2, -1])

        ks = ks * kernel
        fm = torch.real(fft.ifftn(ks, dim=[-3, -2, -1]))

        return fm
Exemplo n.º 18
0
def steepest_ascent_direction(grad, norm_type, eps_tot):
    shape = grad.shape
    if norm_type == 'dftinf':
        dftxgrad = torch.fft.fftn(grad, dim=(-2, -1), norm='ortho')
        dftz = dftxgrad.reshape(1, -1)
        dftz = torch.cat((torch.real(dftz), torch.imag(dftz)), dim=0)

        def l2_normalize(delta, eps):
            avoid_zero_div = 1e-15
            norm2 = torch.sum(delta**2, dim=0, keepdim=True)
            norm = torch.sqrt(torch.clamp(norm2, min=avoid_zero_div))
            delta = delta * eps / norm
            return delta

        dftz = l2_normalize(dftz, eps_tot)
        dftz = (dftz[0, :] + 1j * dftz[1, :]).reshape(shape)
        delta = torch.fft.ifftn(dftz, dim=(-2, -1), norm='ortho')
        adv_step = torch.real(delta)
    return adv_step
Exemplo n.º 19
0
def amplitude(r):
    """
    Calculate the amplitude of a complex tensor.

    If the tensor is not complex then calculate square.
    """
    if torch.is_complex(r):
        return torch.real(r * torch.conj(r))
    else:
        return r * r
Exemplo n.º 20
0
def get_inv_spatial_weight(psf_grid):
    #N,3,W,H
    F_psf_grid = torch.fft.rfft(psf_grid, 2)
    F_psf_grid = torch.stack([torch.real(F_psf_grid),
                              torch.imag(F_psf_grid)],
                             dim=-1)
    F_psf_grid_norm = F_psf_grid[..., 0]**2 + F_psf_grid[..., 1]**2
    F_psf_grid_norm = torch.mean(F_psf_grid_norm, dim=(2, 3))
    #F_psf_grid_norm = torch.mean(F_psf_grid_norm, dim=2)
    return F_psf_grid_norm
Exemplo n.º 21
0
def norm_projection(delta, norm_type, eps=1.):
    """Projects to a norm-ball centered at 0.

  Args:
    delta: An array of size dim x num containing vectors to be projected.
    norm_type: A string denoting the type of the norm-ball.
    eps: A float denoting the radius of the norm-ball.

  Returns:
    An array of size dim x num, the projection of delta to the norm-ball.
  """
    shape = delta.shape
    if norm_type == 'l2':
        # Euclidean projection: divide all elements by a constant factor
        avoid_zero_div = 1e-12
        norm2 = np.sum(delta**2, axis=0, keepdims=True)
        norm = np.sqrt(np.maximum(avoid_zero_div, norm2))
        # only decrease the norm, never increase
        delta = delta * np.clip(eps / norm, a_min=None, a_max=1)
    elif norm_type == 'dftinf':
        # transform to DFT, project using known projections, then transform back
        # n x d x h x w
        dftxdelta = torch.fft.fftn(delta, dim=(-2, -1), norm='ortho')
        # L2 projection of each coordinate to the L2-ball in the complex plane
        dftz = dftxdelta.reshape(1, -1)
        dftz = torch.cat((torch.real(dftz), torch.imag(dftz)), dim=0)

        def l2_proj(delta, eps):
            avoid_zero_div = 1e-15
            norm2 = torch.sum(delta**2, dim=0, keepdim=True)
            norm = torch.sqrt(torch.clamp(norm2, min=avoid_zero_div))
            # only decrease the norm, never increase
            delta = delta * torch.clamp(eps / norm, max=1)
            return delta

        dftz = l2_proj(dftz, eps)
        dftz = (dftz[0, :] + 1j * dftz[1, :]).reshape(delta.shape)
        # project back from DFT
        delta = torch.fft.ifftn(dftz, dim=(-2, -1), norm='ortho')
        # Projected vector can have an imaginary part
        delta = torch.real(delta)
    return delta.reshape(shape)
Exemplo n.º 22
0
def rfft(t):
    # Real-to-complex Discrete Fourier Transform
    ver = torch.__version__
    major, minor, ver = ver.split('.')
    ver_int = int(major) * 100 + int(minor)
    if ver_int >= 108:
        ft = torch.fft.fft2(t)
        ft = torch.stack([torch.real(ft), torch.imag(ft)], dim=-1)
    else:
        ft = torch.rfft(t, 2, onesided=False)
    return ft
Exemplo n.º 23
0
 def __init__(self, D, m, b, wG, device, lambda_TV, P=1, alpha=0.5, rho=10):
     self.D = D
     self.m = m
     self.b = b
     self.wG = wG
     self.device = device
     self.lambda_TV = lambda_TV
     self.P = P
     self.alpha = alpha
     self.rho = rho
     self.Dconv = lambda x: torch.real(fft.ifftn(self.D * fft.fftn(x, dim=[0, 1, 2])))
Exemplo n.º 24
0
 def func(tensor):
     n_freq = tensor.size(-2)
     rate = 0.5
     hop_length = 256
     phase_advance = torch.linspace(
         0,
         3.14 * hop_length,
         n_freq,
         dtype=torch.real(tensor).dtype,
         device=tensor.device,
     )[..., None]
     return F.phase_vocoder(tensor, rate, phase_advance)
Exemplo n.º 25
0
    def forward(self, x, FB, FBC, F2B, FBFy, alpha, sf):

        FR = FBFy + torch.fft.fftn(alpha * x, dim=(-2, -1))
        x1 = FB.mul(FR)
        FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
        invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
        invWBR = FBR.div(invW + alpha)
        FCBinvWBR = FBC * invWBR.repeat(1, 1, sf, sf)
        FX = (FR - FCBinvWBR) / alpha
        Xest = torch.real(torch.fft.ifftn(FX, dim=(-2, -1)))

        return Xest
Exemplo n.º 26
0
def irfft(t):
    # Complex-to-real Inverse Discrete Fourier Transform
    ver = torch.__version__
    major, minor, ver = ver.split('.')
    ver_int = int(major) * 100 + int(minor)
    if ver_int >= 108:
        t = torch.complex(t[..., 0], t[..., 1])
        ft = torch.fft.ifft2(t)
        ft = torch.real(ft)
    else:
        ft = torch.irfft(t, 2, onesided=False)
    return ft
Exemplo n.º 27
0
def apply_window_to_impulse_response(impulse_response: th.Tensor,
                                     window_size: int = 0,
                                     causal: bool = False) -> th.Tensor:
    """Apply a window to an impulse response and put in causal form.

    Args:
        impulse_response: A series of impulse responses frames to window, of
            shape [batch_size, n_frames, ir_size].
        window_size: Size of the window to apply in the time domain. If
            window_size is less than 1, it defaults to the impulse_response
            size.
        causal: Impulse responnse input is in causal form (peak in the middle).

    Returns:
        impulse_response: Windowed impulse response in causal form, with last
            dimension cropped to window_size if window_size is greater than 0
            and less than ir_size.
    """
    impulse_response = f32(impulse_response)
    if causal:
        impulse_response = fftshift(impulse_response)

    ir_size = impulse_response.shape[-1]
    if window_size <= 0 or window_size > ir_size:
        window_size = ir_size
    window = th.hann_window(window_size)

    padding = ir_size - window_size
    if padding > 0:
        half_idx = (window_size + 1) // 2
        window = th.cat(
            [window[half_idx:],
             th.zeros(padding), window[:half_idx]], dim=-1)
    else:
        window = fftshift(window)

    window = window.expand_as(impulse_response)
    impulse_response = window * th.real(impulse_response)

    if padding > 0:
        first_half_start = (ir_size - (half_idx - 1)) + 1
        second_half_end = half_idx + 1
        impulse_response = th.cat(
            [
                impulse_response[..., first_half_start:],
                impulse_response[..., :second_half_end],
            ],
            dim=-1,
        )
    else:
        impulse_response = fftshift(impulse_response)
    return impulse_response
Exemplo n.º 28
0
def reg2_proj(usph, imsizer, imrizec, niter=100, alpha=0.05):
    # A smoothness based based projection. Regularization method 2 from
    # "Separate Magnitude and Phase Regularization via Compressed Sensing",  Feng Zhao et al, IEEE TMI, 2012
    device = usph.device

    ims = torch.zeros((imsizer, imrizec, niter), device=device)
    ims[:, :, 0] = usph.clone() + np.pi
    for ix in range(niter - 1):
        ims[:, :, ix + 1] = ims[:, :, ix] - 2 * alpha * torch.real(
            1j * torch.exp(-1j * ims[:, :, ix]) *
            _fdivg(_fgrad(torch.exp(1j * ims[:, :, ix]))))

    return ims[:, :, -1] - np.pi
Exemplo n.º 29
0
 def FT(f_i, x):
     shift = 1
     C_b = torch.fft(f_i, 1)
     N_2 = int(len(f_i) / 2)
     zer = torch.Tensor([0])
     im_shift = torch.Tensor([2 * np.pi * shift * torch.sum(x)])
     F_y = torch.tensor([
         torch.complex(C_b[b][0], C_b[b][1]) * torch.exp(
             torch.complex(
                 zer, torch.Tensor([2 * np.pi * b * (torch.sum(x))])))
         for b in range(-N_2, N_2)
     ])
     f_star = (torch.exp(torch.complex(zer, im_shift)) * torch.sum(F_y))
     return torch.tensor([torch.real(f_star), torch.imag(f_star)])
        def func(tensor):
            is_complex = tensor.is_complex()

            n_freq = tensor.size(-2 if is_complex else -3)
            rate = 0.5
            hop_length = 256
            phase_advance = torch.linspace(
                0,
                3.14 * hop_length,
                n_freq,
                dtype=(torch.real(tensor) if is_complex else tensor).dtype,
                device=tensor.device,
            )[..., None]
            return F.phase_vocoder(tensor, rate, phase_advance)