Esempio n. 1
0
 def fn(a):
     b = a.clone()
     b1 = torch.view_as_complex(b)
     b2 = b1.reshape(b1.numel())
     return b2
Esempio n. 2
0
def table_interp_adjoint(
    data: Tensor,
    omega: Tensor,
    tables: List[Tensor],
    n_shift: Tensor,
    numpoints: Tensor,
    table_oversamp: Tensor,
    offsets: Tensor,
    grid_size: Tensor,
) -> Tensor:
    """Table interpolation adjoint backend.

    This interpolates from an off-grid set of data at coordinates given by
    ``omega`` to on-grid locations.

    Args:
        data: Off-grid data to interpolate from.
        omega: Fourier coordinates to interpolate to (in radians/voxel, -pi to
            pi).
        tables: List of tables for each image dimension.
        n_shift: Size of desired fftshift.
        numpoints: Number of neighbors in each dimension.
        table_oversamp: Size of table in each dimension.
        offsets: A list of offset values for interpolation.
        min_kspace_per_fork: Minimum number of k-space samples to use in each
            process fork.

    Returns:
        ``data`` interpolated to gridded locations.
    """
    dtype = data.dtype
    device = data.device
    int_type = torch.long

    # we fork processes for accumulation, so we need to do a bit of thread management
    # for OMP to make sure we don't oversubscribe (managment not necessary for non-OMP)
    num_threads = torch.get_num_threads()
    factors = torch.arange(1, math.sqrt(num_threads)).flip(0)
    factors = factors[torch.remainder(torch.tensor(num_threads), factors) == 0]
    threads_per_fork = 1
    for factor in factors:
        if factor <= num_threads / (data.shape[0] * data.shape[1]):
            threads_per_fork = int(factor)
            break
    num_forks = num_threads // threads_per_fork

    # calculate output size
    output_prod = int(torch.prod(grid_size))
    output_size = [data.shape[0], data.shape[1]]
    for el in grid_size:
        output_size.append(int(el))

    # convert to normalized freq locs and sort
    tm = omega / (2 * np.pi / grid_size.to(omega).unsqueeze(-1))
    tm, omega, data = sort_data(tm, omega, data, grid_size)

    # compute interpolation centers
    centers = torch.floor(numpoints * table_oversamp / 2).to(dtype=int_type)

    # offset from k-space to first coef loc
    base_offset = 1 + torch.floor(tm - numpoints.unsqueeze(1) / 2.0).to(
        dtype=int_type)

    # initialized flattened image
    image = torch.zeros(
        size=(data.shape[0], data.shape[1], output_prod),
        dtype=dtype,
        device=device,
    )

    # phase for fftshift
    data = (data * imag_exp(
        torch.mv(torch.transpose(omega, 1, 0), n_shift),
        return_complex=True,
    ).conj())

    # necessary for index_add_
    # TODO: change when PyTorch supports complex numbers for index_add_, index_put_
    if not device == torch.device("cpu"):
        image = torch.view_as_real(image)

    # loop over offsets and take advantage of broadcasting
    for offset in offsets:
        coef, arr_ind = calc_coef_and_indices(
            tm=tm,
            base_offset=base_offset,
            offset_increments=offset,
            tables=tables,
            centers=centers,
            table_oversamp=table_oversamp,
            grid_size=grid_size,
            conjcoef=True,
        )

        # we have to fork this multiply ourselves
        tmp = coef * data

        if not device == torch.device("cpu"):
            tmp = torch.view_as_real(tmp)

        if USING_OMP and device == torch.device("cpu"):
            torch.set_num_threads(threads_per_fork)
        # this is a much faster way of doing index accumulation
        fork_and_accum(image, arr_ind, tmp, num_forks)
        if USING_OMP and device == torch.device("cpu"):
            torch.set_num_threads(num_threads)

    if not device == torch.device("cpu"):
        image = torch.view_as_complex(image)

    return image.view(output_size)
Esempio n. 3
0
def get_reg_filter(sz: torch.Tensor, target_sz: torch.Tensor, params):
    """Computes regularization filter in CCOT and ECO."""

    if not params.use_reg_window:
        return params.reg_window_min * torch.ones(1, 1, 1, 1)

    if getattr(params, 'reg_window_square', False):
        target_sz = target_sz.prod().sqrt() * torch.ones(2)

    # Normalization factor
    reg_scale = 0.5 * target_sz

    # Construct grid
    if getattr(params, 'reg_window_centered', True):
        wrg = torch.arange(-int((sz[0] - 1) / 2),
                           int(sz[0] / 2 + 1),
                           dtype=torch.float32).view(1, 1, -1, 1)
        wcg = torch.arange(-int((sz[1] - 1) / 2),
                           int(sz[1] / 2 + 1),
                           dtype=torch.float32).view(1, 1, 1, -1)
    else:
        wrg = torch.cat([
            torch.arange(0, int(sz[0] / 2 + 1), dtype=torch.float32),
            torch.arange(-int((sz[0] - 1) / 2), 0, dtype=torch.float32)
        ]).view(1, 1, -1, 1)
        wcg = torch.cat([
            torch.arange(0, int(sz[1] / 2 + 1), dtype=torch.float32),
            torch.arange(-int((sz[1] - 1) / 2), 0, dtype=torch.float32)
        ]).view(1, 1, 1, -1)

    # Construct regularization window
    reg_window = (params.reg_window_edge - params.reg_window_min) * \
                 (torch.abs(wrg / reg_scale[0]) ** params.reg_window_power +
                  torch.abs(wcg / reg_scale[1]) ** params.reg_window_power) + params.reg_window_min

    # Compute DFT and enforce sparsity
    reg_window_dft = torch.view_as_real(
        torch_fft.rfftn(reg_window, dim=[-2, -1])) / sz.prod()
    reg_window_dft_abs = complex.abs(reg_window_dft)
    reg_window_dft[reg_window_dft_abs < params.reg_sparsity_threshold *
                   reg_window_dft_abs.max(), :] = 0

    # Do the inverse transform to correct for the window minimum

    reg_window_sparse = torch_fft.irfftn(torch.view_as_complex(reg_window_dft),
                                         s=sz.long().tolist(),
                                         dim=[-2, -1])
    reg_window_dft[
        0, 0, 0, 0,
        0] += params.reg_window_min - sz.prod() * reg_window_sparse.min()
    reg_window_dft = complex.real(fourier.rfftshift2(reg_window_dft))

    # Remove zeros
    max_inds, _ = reg_window_dft.nonzero(as_tuple=False).max(dim=0)
    mid_ind = int((reg_window_dft.shape[2] - 1) / 2)
    top = max_inds[-2].item() + 1
    bottom = 2 * mid_ind - max_inds[-2].item()
    right = max_inds[-1].item() + 1
    reg_window_dft = reg_window_dft[..., bottom:top, :right]
    if reg_window_dft.shape[-1] > 1:
        reg_window_dft = torch.cat(
            [reg_window_dft[..., 1:].flip((2, 3)), reg_window_dft], -1)

    return reg_window_dft
Esempio n. 4
0
def propagation_ASM(u_in, feature_size, wavelength, z, linear_conv=True,
                    padtype='zero', return_H=False, precomped_H=None,
                    return_H_exp=False, precomped_H_exp=None,
                    dtype=torch.float32):
    """Propagates the input field using the angular spectrum method

    Inputs
    ------
    u_in: PyTorch Complex tensor (torch.cfloat) of size (num_images, 1, height, width) -- updated with PyTorch 1.7.0
    feature_size: (height, width) of individual holographic features in m
    wavelength: wavelength in m
    z: propagation distance
    linear_conv: if True, pad the input to obtain a linear convolution
    padtype: 'zero' to pad with zeros, 'median' to pad with median of u_in's
        amplitude
    return_H[_exp]: used for precomputing H or H_exp, ends the computation early
        and returns the desired variable
    precomped_H[_exp]: the precomputed value for H or H_exp
    dtype: torch dtype for computation at different precision

    Output
    ------
    tensor of size (num_images, 1, height, width, 2)
    """

    if linear_conv:
        # preprocess with padding for linear conv.
        input_resolution = u_in.size()[-2:]
        conv_size = [i * 2 for i in input_resolution]
        if padtype == 'zero':
            padval = 0
        elif padtype == 'median':
            padval = torch.median(torch.pow((u_in**2).sum(-1), 0.5))
        u_in = utils.pad_image(u_in, conv_size, padval=padval, stacked_complex=False)

    if precomped_H is None and precomped_H_exp is None:
        # resolution of input field, should be: (num_images, num_channels, height, width, 2)
        field_resolution = u_in.size()

        # number of pixels
        num_y, num_x = field_resolution[2], field_resolution[3]

        # sampling inteval size
        dy, dx = feature_size

        # size of the field
        y, x = (dy * float(num_y), dx * float(num_x))

        # frequency coordinates sampling
        fy = np.linspace(-1 / (2 * dy) + 0.5 / (2 * y), 1 / (2 * dy) - 0.5 / (2 * y), num_y)
        fx = np.linspace(-1 / (2 * dx) + 0.5 / (2 * x), 1 / (2 * dx) - 0.5 / (2 * x), num_x)

        # momentum/reciprocal space
        FX, FY = np.meshgrid(fx, fy)

        # transfer function in numpy (omit distance)
        HH = 2 * math.pi * np.sqrt(1 / wavelength**2 - (FX**2 + FY**2))

        # create tensor & upload to device (GPU)
        H_exp = torch.tensor(HH, dtype=dtype).to(u_in.device)

        ###
        # here one may iterate over multiple distances, once H_exp is uploaded on GPU

        # reshape tensor and multiply
        H_exp = torch.reshape(H_exp, (1, 1, *H_exp.size()))

    # handle loading the precomputed H_exp value, or saving it for later runs
    elif precomped_H_exp is not None:
        H_exp = precomped_H_exp

    if precomped_H is None:
        # multiply by distance
        H_exp = torch.mul(H_exp, z)

        # band-limited ASM - Matsushima et al. (2009)
        fy_max = 1 / np.sqrt((2 * z * (1 / y))**2 + 1) / wavelength
        fx_max = 1 / np.sqrt((2 * z * (1 / x))**2 + 1) / wavelength
        H_filter = torch.tensor(((np.abs(FX) < fx_max) & (np.abs(FY) < fy_max)).astype(np.uint8), dtype=dtype)

        # get real/img components
        H_real, H_imag = utils.polar_to_rect(H_filter.to(u_in.device), H_exp)

        H = torch.stack((H_real, H_imag), 4)
        H = utils.ifftshift(H)
        H = torch.view_as_complex(H)
    else:
        H = precomped_H

    # return for use later as precomputed inputs
    if return_H_exp:
        return H_exp
    if return_H:
        return H

    # For who cannot use Pytorch 1.7.0 and its Complex tensors support:
    # # angular spectrum
    # U1 = torch.fft(utils.ifftshift(u_in), 2, True)
    #
    # # convolution of the system
    # U2 = utils.mul_complex(H, U1)
    #
    # # Fourier transform of the convolution to the observation plane
    # u_out = utils.fftshift(torch.ifft(U2, 2, True))

    U1 = torch.fft.fftn(utils.ifftshift(u_in), dim=(-2, -1), norm='ortho')

    U2 = H * U1

    u_out = utils.fftshift(torch.fft.ifftn(U2, dim=(-2, -1), norm='ortho'))

    if linear_conv:
        # return utils.crop_image(u_out, input_resolution) # using stacked version
        return utils.crop_image(u_out, input_resolution, pytorch=True, stacked_complex=False)  # using complex tensor
    else:
        return u_out
Esempio n. 5
0
def kb_table_nufft_adjoint(
    data: Tensor,
    scaling_coef: Tensor,
    im_size: Tensor,
    grid_size: Tensor,
    omega: Tensor,
    tables: List[Tensor],
    n_shift: Tensor,
    numpoints: Tensor,
    table_oversamp: Tensor,
    offsets: Tensor,
    norm: Optional[str] = None,
) -> Tensor:
    """Kaiser-Bessel NUFFT adjoint with table interpolation.

    See :py:class:`~torchkbnufft.KbNufftAdjoint` for an overall description of
    the adjoint NUFFT.

    Args:
        data: Scattered data to be iNUFFT'd to an image.
        scaling_coef: Image-domain coefficients to compensate for interpolation
            errors.
        im_size: Size of image with length being the number of dimensions.
        grid_size: Size of grid to use for interpolation, typically 1.25 to 2
            times ``im_size``.
        omega: k-space trajectory (in radians/voxel).
        tables: Interpolation tables (one table for each dimension).
        n_shift: Size for fftshift, usually ``im_size // 2``.
        numpoints: Number of neighbors to use for interpolation.
        table_oversamp: Table oversampling factor.
        offsets: A list of offsets, looping over all possible combinations of
            `numpoints`.
        norm: Whether to apply normalization with the FFT operation.
            Options are ``"ortho"`` or ``None``.

    Returns:
        ``data`` transformed to an image.
    """
    is_complex = True
    if not data.is_complex():
        if not data.shape[-1] == 2:
            raise ValueError("For real inputs, last dimension must be size 2.")

        is_complex = False
        data = torch.view_as_complex(data)

    image = ifft_and_scale(
        image=kb_table_interp_adjoint(
            data=data,
            omega=omega,
            tables=tables,
            n_shift=n_shift,
            numpoints=numpoints,
            table_oversamp=table_oversamp,
            offsets=offsets,
            grid_size=grid_size,
        ),
        scaling_coef=scaling_coef,
        im_size=im_size,
        grid_size=grid_size,
        norm=norm,
    )

    if is_complex is False:
        image = torch.view_as_real(image)

    return image
Esempio n. 6
0
def _multi_tensor_adagrad(
    params: List[Tensor],
    grads: List[Tensor],
    state_sums: List[Tensor],
    state_steps: List[Tensor],
    *,
    lr: float,
    weight_decay: float,
    lr_decay: float,
    eps: float,
    has_sparse_grad: bool,
    maximize: bool,
):

    # Foreach functions will throw errors if given empty lists
    if len(params) == 0:
        return

    if maximize:
        grads = torch._foreach_neg(grads)

    if has_sparse_grad is None:
        has_sparse_grad = any(grad.is_sparse for grad in grads)

    if has_sparse_grad:
        return _single_tensor_adagrad(
            params,
            grads,
            state_sums,
            state_steps,
            lr=lr,
            weight_decay=weight_decay,
            lr_decay=lr_decay,
            eps=eps,
            has_sparse_grad=has_sparse_grad,
            maximize=False,
        )

    # Update steps
    torch._foreach_add_(state_steps, 1)

    if weight_decay != 0:
        torch._foreach_add_(grads, params, alpha=weight_decay)

    minus_clr = [-lr / (1 + (step - 1) * lr_decay) for step in state_steps]

    grads = [
        torch.view_as_real(x) if torch.is_complex(x) else x for x in grads
    ]
    state_sums = [
        torch.view_as_real(x) if torch.is_complex(x) else x for x in state_sums
    ]
    torch._foreach_addcmul_(state_sums, grads, grads, value=1)
    std = torch._foreach_add(torch._foreach_sqrt(state_sums), eps)
    toAdd = torch._foreach_div(torch._foreach_mul(grads, minus_clr), std)
    toAdd = [
        torch.view_as_complex(x) if torch.is_complex(params[i]) else x
        for i, x in enumerate(toAdd)
    ]
    torch._foreach_add_(params, toAdd)
    state_sums = [
        torch.view_as_complex(x) if torch.is_complex(params[i]) else x
        for i, x in enumerate(state_sums)
    ]
Esempio n. 7
0
 def view_as_complex(x):
     sh = list(x.shape)
     sh[-1] //= 2
     sh += [2]
     x = x.view(sh)
     return torch.view_as_complex(x)
Esempio n. 8
0
def trasmission_2nd_g_alpha_beta_gamma_complex(patched_solved,
                                               x_batch_train,
                                               index,
                                               transmission_func,
                                               args,
                                               diffL,
                                               last_g=None):
    # g = alpha * u + beta * d2u/dt2 + gamma * du/dn

    transmission_func = find_transmission_function(transmission_func)
    row = int(index / args.y_patches)
    col = index % args.y_patches

    patched_solved_complex = torch.view_as_complex(
        patched_solved.permute(0, 2, 3, 1).contiguous())

    g = torch.zeros(patched_solved_complex[index].shape, dtype=torch.cfloat)
    alpha = torch.zeros(patched_solved_complex[index].shape,
                        dtype=torch.cfloat)
    beta = torch.zeros(patched_solved_complex[index].shape, dtype=torch.cfloat)
    gamma = torch.zeros(patched_solved_complex[index].shape,
                        dtype=torch.cfloat)

    if row == 0:  # top
        alpha[0, :] = 1 + 0j
        beta[0, :] = 0 + 0j
        gamma[0, :] = 0 + 0j
        g[0, :] = patched_solved_complex[index, 0, :]
        print("means: u, dudn, d2udt2: ", torch.mean(patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels,:]),\
                                          torch.mean((patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels,:] - patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels+1,:])/diffL), \
                                          torch.mean((patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels,  :-2] + \
                                                      patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels, 2:  ] - \
                                                    2*patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels, 1:-1] )/diffL**2))
    else:
        this_alpha, this_beta, this_gamma = transmission_func(
            x_batch_train[index, 0, 0, :], args)
        alpha[0, :] = this_alpha
        beta[0, :] = this_beta
        gamma[0, :] = this_gamma
        g[0,:] = gamma[0,:]*(patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels,:] - patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels+1,:])/diffL + \
                 alpha[0,:]* patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels,:]
        g[0,1:-1] += beta[0,1:-1]*(patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels,  :-2] + \
                                   patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels, 2:  ] - \
                                 2*patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels, 1:-1] )/diffL**2

    if row == args.x_patches - 1:  # bottom
        alpha[-1, :] = 1 + 0j
        beta[-1, :] = 0 + 0j
        gamma[-1, :] = 0 + 0j
        g[-1, :] = patched_solved_complex[index, -1, :]
    else:
        this_alpha, this_beta, this_gamma = transmission_func(
            x_batch_train[index, 0, -1, :], args)
        alpha[-1, :] = this_alpha
        beta[-1, :] = this_beta
        gamma[-1, :] = this_gamma
        g[-1,:] = gamma[-1,:]*(patched_solved_complex[index+args.y_patches,args.overlap_pixels-1,:] - patched_solved_complex[index+args.y_patches,args.overlap_pixels-2,:])/diffL + \
                  alpha[-1,:]* patched_solved_complex[index+args.y_patches,args.overlap_pixels-1,:]
        g[-1,1:-1] += beta[-1,1:-1]*(patched_solved_complex[index+args.y_patches,args.overlap_pixels-1,  :-2] + \
                                     patched_solved_complex[index+args.y_patches,args.overlap_pixels-1, 2:  ] - \
                                   2*patched_solved_complex[index+args.y_patches,args.overlap_pixels-1, 1:-1] )/diffL**2

    if col == 0:  # left
        alpha[:, 0] = 1 + 0j
        beta[:, 0] = 0 + 0j
        gamma[:, 0] = 0 + 0j
        g[:, 0] = patched_solved_complex[index, :, 0]
    else:
        this_alpha, this_beta, this_gamma = transmission_func(
            x_batch_train[index, 0, :, 0], args)
        alpha[:, 0] = this_alpha
        beta[:, 0] = this_beta
        gamma[:, 0] = this_gamma
        g[:,0] = gamma[:,0]*(patched_solved_complex[index-1,:,args.domain_sizey-args.overlap_pixels] - patched_solved_complex[index-1,:,args.domain_sizey-args.overlap_pixels+1])/diffL + \
                 alpha[:,0]* patched_solved_complex[index-1,:,args.domain_sizey-args.overlap_pixels]
        g[1:-1,0] += beta[1:-1,0]*(patched_solved_complex[index-1, :-2,args.domain_sizey-args.overlap_pixels] + \
                                   patched_solved_complex[index-1,2:  ,args.domain_sizey-args.overlap_pixels] - \
                                 2*patched_solved_complex[index-1,1:-1,args.domain_sizey-args.overlap_pixels])/diffL**2

    if col == args.y_patches - 1:  # right
        alpha[:, -1] = 1 + 0j
        beta[:, -1] = 0 + 0j
        gamma[:, -1] = 0 + 0j
        g[:, -1] = patched_solved_complex[index, :, -1]
    else:
        this_alpha, this_beta, this_gamma = transmission_func(
            x_batch_train[index, 0, :, -1], args)
        alpha[:, -1] = this_alpha
        beta[:, -1] = this_beta
        gamma[:, -1] = this_gamma
        g[:,-1] = gamma[:,-1]*(patched_solved_complex[index+1,:,args.overlap_pixels-1] - patched_solved_complex[index+1,:,args.overlap_pixels-2])/diffL + \
                  alpha[:,-1]* patched_solved_complex[index+1,:,args.overlap_pixels-1]
        g[1:-1,-1] += beta[1:-1,-1]*(patched_solved_complex[index+1, :-2,args.overlap_pixels-1] + \
                                     patched_solved_complex[index+1,2:  ,args.overlap_pixels-1] - \
                                   2*patched_solved_complex[index+1,1:-1,args.overlap_pixels-1])/diffL**2

    if last_g is not None:
        g = (1 - args.relaxation) * last_g + args.relaxation * g

    return g, alpha, beta, gamma
Esempio n. 9
0
def trasmission_pade_g_alpha_beta_gamma_complex(patched_solved,
                                                x_batch_train,
                                                index,
                                                transmission_func,
                                                args,
                                                operator,
                                                last_g=None):
    # g = alpha * u + beta*(du/dn - gamma * operator * u) # gamma should be 1 if the math is correct

    transmission_func = find_transmission_function(transmission_func)
    row = int(index / args.y_patches)
    col = index % args.y_patches

    patched_solved_complex = torch.view_as_complex(
        patched_solved.permute(0, 2, 3, 1).contiguous()).numpy()

    g = np.zeros(patched_solved_complex[index].shape, dtype=np.csingle)
    alpha = np.zeros(patched_solved_complex[index].shape, dtype=np.csingle)
    beta = np.zeros(patched_solved_complex[index].shape, dtype=np.csingle)
    gamma = np.zeros(patched_solved_complex[index].shape, dtype=np.csingle)

    if col == 0:  # left
        alpha[:, 0] = 1 + 0j
        beta[:, 0] = 0 + 0j
        gamma[:, 0] = 0 + 0j
        g[:, 0] = patched_solved_complex[index, :, 0]
    else:
        this_beta, this_gamma = transmission_func(
            x_batch_train[index, 0, :, 0], args)
        alpha[1:-1, 0] = 0 + 0j
        beta[:, 0] = this_beta
        gamma[:, 0] = this_gamma
        g[ :, 0] =               beta[:,0]*(patched_solved_complex[index-1,:,args.domain_sizey-args.overlap_pixels] - patched_solved_complex[index-1,:,args.domain_sizey-args.overlap_pixels+1] + \
                   -gamma[:,0]*operator.dot(patched_solved_complex[index-1,:,args.domain_sizey-args.overlap_pixels]))
        g[0,
          0] = patched_solved_complex[index - 1, 0,
                                      args.domain_sizey - args.overlap_pixels]
        g[-1,
          0] = patched_solved_complex[index - 1, -1,
                                      args.domain_sizey - args.overlap_pixels]

    if col == args.y_patches - 1:  # right
        alpha[:, -1] = 1 + 0j
        beta[:, -1] = 0 + 0j
        gamma[:, -1] = 0 + 0j
        g[:, -1] = patched_solved_complex[index, :, -1]
    else:
        this_beta, this_gamma = transmission_func(
            x_batch_train[index, 0, :, -1], args)
        alpha[1:-1, -1] = 0 + 0j
        beta[:, -1] = this_beta
        gamma[:, -1] = this_gamma
        g[ :,-1] =               beta[:,-1]*(patched_solved_complex[index+1,:,args.overlap_pixels-1] - patched_solved_complex[index+1,:,args.overlap_pixels-2] + \
                   -gamma[:,-1]*operator.dot(patched_solved_complex[index+1,:,args.overlap_pixels-1]))
        g[0, -1] = patched_solved_complex[index + 1, 0,
                                          args.overlap_pixels - 1]
        g[-1, -1] = patched_solved_complex[index + 1, -1,
                                           args.overlap_pixels - 1]

    if row == 0:  # top
        alpha[0, :] = 1 + 0j
        beta[0, :] = 0 + 0j
        gamma[0, :] = 0 + 0j
        g[0, :] = patched_solved_complex[index, 0, :]
    else:
        this_beta, this_gamma = transmission_func(
            x_batch_train[index, 0, 0, :], args)
        alpha[0, 1:-1] = 0 + 0j
        beta[0, :] = this_beta
        gamma[0, :] = this_gamma
        g[ 0, :] =               beta[0,:]*(patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels,:] - patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels+1,:] + \
                   -gamma[0,:]*operator.dot(patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels,:]))
        g[0,
          0] = patched_solved_complex[index - args.y_patches,
                                      args.domain_sizex - args.overlap_pixels,
                                      0]
        g[0,
          -1] = patched_solved_complex[index - args.y_patches,
                                       args.domain_sizex - args.overlap_pixels,
                                       -1]

    if row == args.x_patches - 1:  # bottom
        alpha[-1, :] = 1 + 0j
        beta[-1, :] = 0 + 0j
        gamma[-1, :] = 0 + 0j
        g[-1, :] = patched_solved_complex[index, -1, :]
    else:
        this_beta, this_gamma = transmission_func(
            x_batch_train[index, 0, -1, :], args)
        alpha[-1, 1:-1] = 0 + 0j
        beta[-1, :] = this_beta
        gamma[-1, :] = this_gamma
        g[-1,:] =               beta[-1,:]*(patched_solved_complex[index+args.y_patches,args.overlap_pixels-1,:] - patched_solved_complex[index+args.y_patches,args.overlap_pixels-2,:] + \
                  -gamma[-1,:]*operator.dot(patched_solved_complex[index+args.y_patches,args.overlap_pixels-1,:]))
        g[-1, :] = patched_solved_complex[index + args.y_patches,
                                          args.overlap_pixels - 1, 0]
        g[-1, :] = patched_solved_complex[index + args.y_patches,
                                          args.overlap_pixels - 1, -1]

    if last_g is not None:
        g = (1 - args.relaxation) * last_g + args.relaxation * g

    return g, alpha, beta, gamma
Esempio n. 10
0
    def forward(
        self,
        data: Tensor,
        omega: Tensor,
        interp_mats: Optional[Tuple[Tensor, Tensor]] = None,
        smaps: Optional[Tensor] = None,
        norm: Optional[str] = None,
    ) -> Tensor:
        """Interpolate from scattered data to gridded data and then iFFT.

        Input tensors should be of shape ``(N, C) + klength``, where ``N`` is
        the batch size and ``C`` is the number of sensitivity coils. ``omega``,
        the k-space trajectory, should be of size ``(len(grid_size), klength)``
        or ``(N, len(grid_size), klength)``, where ``klength`` is the length of
        the k-space trajectory.

        Note:

            If the batch dimension is included in ``omega``, the interpolator
            will parallelize over the batch dimension. This is efficient for
            many small trajectories that might occur in dynamic imaging
            settings.

        If your tensors are real, ensure that 2 is the size of the last
        dimension.

        Args:
            data: Data to be gridded and then inverse FFT'd.
            omega: k-space trajectory (in radians/voxel).
            interp_mats: 2-tuple of real, imaginary sparse matrices to use for
                sparse matrix NUFFT interpolation (overrides default table
                interpolation).
            smaps: Sensitivity maps. If input, these will be multiplied before
                the forward NUFFT.
            norm: Whether to apply normalization with the FFT operation.
                Options are ``"ortho"`` or ``None``.

        Returns:
            ``data`` transformed to the image domain.
        """
        if smaps is not None:
            if not smaps.dtype == data.dtype:
                raise TypeError("data dtype does not match smaps dtype.")

        is_complex = True
        if not data.is_complex():
            if not data.shape[-1] == 2:
                raise ValueError(
                    "For real inputs, last dimension must be size 2.")
            if smaps is not None:
                if not smaps.shape[-1] == 2:
                    raise ValueError(
                        "For real inputs, last dimension must be size 2.")

                smaps = torch.view_as_complex(smaps)

            is_complex = False
            data = torch.view_as_complex(data)

        if interp_mats is not None:
            assert isinstance(self.scaling_coef, Tensor)
            assert isinstance(self.im_size, Tensor)
            assert isinstance(self.grid_size, Tensor)

            output = tkbnF.kb_spmat_nufft_adjoint(
                data=data,
                scaling_coef=self.scaling_coef,
                im_size=self.im_size,
                grid_size=self.grid_size,
                interp_mats=interp_mats,
                norm=norm,
            )
        else:
            tables = []
            for i in range(len(self.im_size)):  # type: ignore
                tables.append(getattr(self, f"table_{i}"))

            assert isinstance(self.scaling_coef, Tensor)
            assert isinstance(self.im_size, Tensor)
            assert isinstance(self.grid_size, Tensor)
            assert isinstance(self.n_shift, Tensor)
            assert isinstance(self.numpoints, Tensor)
            assert isinstance(self.table_oversamp, Tensor)
            assert isinstance(self.offsets, Tensor)

            output = tkbnF.kb_table_nufft_adjoint(
                data=data,
                scaling_coef=self.scaling_coef,
                im_size=self.im_size,
                grid_size=self.grid_size,
                omega=omega,
                tables=tables,
                n_shift=self.n_shift,
                numpoints=self.numpoints,
                table_oversamp=self.table_oversamp,
                offsets=self.offsets.to(torch.long),
                norm=norm,
            )

        if smaps is not None:
            output = torch.sum(output * smaps.conj(), dim=1, keepdim=True)

        if not is_complex:
            output = torch.view_as_real(output)

        return output
Esempio n. 11
0
    def forward(
        self,
        image: Tensor,
        kernel: Tensor,
        smaps: Optional[Tensor] = None,
        norm: Optional[str] = "ortho",
    ) -> Tensor:
        """Toeplitz NUFFT forward function.

        Args:
            image: The image to apply the forward/backward Toeplitz-embedded
                NUFFT to.
            kernel: The filter response taking into account Toeplitz embedding.
            norm: Whether to apply normalization with the FFT operation.
                Options are ``"ortho"`` or ``None``.

        Returns:
            ``image`` after applying the Toeplitz forward/backward NUFFT.
        """
        if not kernel.dtype == image.dtype:
            raise TypeError("kernel and image must have same dtype.")

        if smaps is not None:
            if not smaps.dtype == image.dtype:
                raise TypeError("image dtype does not match smaps dtype.")

        is_complex = True
        if not image.is_complex():
            if not image.shape[-1] == 2:
                raise ValueError(
                    "For real inputs, last dimension must be size 2.")
            if not kernel.shape[-1] == 2:
                raise ValueError(
                    "For real inputs, last dimension must be size 2.")
            if smaps is not None:
                if not smaps.shape[-1] == 2:
                    raise ValueError(
                        "For real inputs, last dimension must be size 2.")

                smaps = torch.view_as_complex(smaps)

            is_complex = False
            image = torch.view_as_complex(image)
            kernel = torch.view_as_complex(kernel)

        if len(kernel.shape) > len(image.shape[2:]):
            if kernel.shape[0] == 1:
                kernel = kernel[0]
            elif not kernel.shape[0] == image.shape[0]:
                raise ValueError("If using batch dimension, "
                                 "kernel must have same batch size as image")

        if smaps is None:
            output = tkbnF.fft_filter(image=image, kernel=kernel, norm=norm)
        else:
            output = self.toep_batch_loop(image=image,
                                          smaps=smaps,
                                          kernel=kernel,
                                          norm=norm)

        if not is_complex:
            output = torch.view_as_real(output)

        return output
Esempio n. 12
0
def matnoise(mat, noise='wgn', snr=30, peak='maxv'):
    """add noise to an matrix

    Add noise to an matrix (real or complex)

    Args:
        mat (torch.Tensor): Input tensor, can be real or complex valued
        noise (str, optional): type of noise (default: ``'wgn'``)
        snr (float, optional): Signal-to-noise ratio (default: 30)
        peak (None or float, optional): Peak value in input, for complex data, ``peak=[peakr, peaki]``, if None, auto detected, if ``'maxv'``, use the maximum value as peak value. (default)

    Returns:
        (torch.Tensor): Output tensor.
    """

    if th.is_complex(mat):
        cplxflag = True
        if peak is None:
            # peakr = mat.real.abs().max()
            peakr = mat.real.max()
            peaki = mat.imag.max()
            peakr = 2**nextpow2(peakr) - 1
            peaki = 2**nextpow2(peaki) - 1
        if peak == 'maxv':
            # peakr = mat.real.abs().max()
            peakr = mat.real.max()
            peaki = mat.imag.max()
        else:
            peakr, peaki = peak
        mat = th.view_as_real(mat)
        mat[..., 0] = awgn(mat[..., 0],
                           snr=snr,
                           peak=peakr,
                           pmode='db',
                           measMode='measured')
        mat[..., 1] = awgn(mat[..., 1],
                           snr=snr,
                           peak=peaki,
                           pmode='db',
                           measMode='measured')
    else:
        cplxflag = False
        if peak is None:
            # peakr = mat.real.abs().max()
            peakr = mat.real.max()
            peaki = mat.imag.max()
            peakr = 2**nextpow2(peakr) - 1
            peaki = 2**nextpow2(peaki) - 1
        if peak == 'maxv':
            # peakr = mat.real.abs().max()
            peakr = mat[..., 0].max()
            peaki = mat[..., 1].max()
        else:
            peakr, peaki = peak
        if mat.shape[-1] == 2:
            mat[..., 0] = awgn(mat[..., 0],
                               snr=snr,
                               peak=peakr,
                               pmode='db',
                               measMode='measured')
            mat[..., 1] = awgn(mat[..., 1],
                               snr=snr,
                               peak=peaki,
                               pmode='db',
                               measMode='measured')
        else:
            if peak is None:
                # peak = mat.abs().max()
                peak = mat.max()
                peak = 2**nextpow2(peak) - 1
            elif peak == 'maxv':
                peak = mat.max()
            mat = awgn(mat,
                       snr=snr,
                       peak=peak,
                       pmode='db',
                       measMode='measured')

    if cplxflag:
        mat = th.view_as_complex(mat)

    return mat
Esempio n. 13
0
    x : tensor
        Input tensor with least one axis.
    i0 : int
        Start of original signal before padding.
    i1 : int
        End of original signal before padding.

    Returns
    -------
    x_unpadded : tensor
        The tensor x[..., i0:i1].
    """
    return x[..., i0:i1]


fft = FFT(lambda x: torch.view_as_real(torch.fft.fft(torch.view_as_complex(x))),
    lambda x: torch.view_as_real(torch.fft.ifft(torch.view_as_complex(x))),
    lambda x: torch.fft.ifft(torch.view_as_complex(x)).real,
    type_checks)


backend = namedtuple('backend', ['name', 'modulus_complex', 'subsample_fourier', 'real', 'unpad', 'fft', 'concatenate'])
backend.name = 'torch'
backend.modulus_complex = Modulus()
backend.subsample_fourier = subsample_fourier
backend.real = real
backend.unpad = unpad
backend.cdgmm = cdgmm
backend.pad = pad
backend.pad_1d = pad_1d
backend.fft = fft
Esempio n. 14
0
 def func(z):
     z_ = torch.view_as_complex(z)
     z_select = torch.select(z_, z_.dim() - 1, 0)
     z_select_real = torch.view_as_real(z_select)
     return z_select_real.sum()
Esempio n. 15
0
def torch_complex_from_magphase(mag, phase):
    return torch.view_as_complex(
        torch.stack((mag * torch.cos(phase), mag * torch.sin(phase)), dim=-1))
Esempio n. 16
0
def view_as_complex(data):
    """Named version of `torch.view_as_complex()`"""
    assert_complex(data)
    return torch.view_as_complex(data.rename(None)).refine_names(*data.names[:-1])
Esempio n. 17
0
def torch_complex_from_reim(re, im):
    return torch.view_as_complex(torch.stack([re, im], dim=-1))
Esempio n. 18
0
    def forward(
        self,
        data: Tensor,
        omega: Tensor,
        interp_mats: Optional[Tuple[Tensor, Tensor]] = None,
        grid_size: Optional[Tensor] = None,
    ) -> Tensor:
        """Interpolate from scattered data to gridded data.

        Input tensors should be of shape ``(N, C) + klength``, where ``N`` is
        the batch size and ``C`` is the number of sensitivity coils. ``omega``,
        the k-space trajectory, should be of size ``(len(im_size), klength)``,
        where ``klength`` is the length of the k-space trajectory.

        If your tensors are real-valued, ensure that 2 is the size of the last
        dimension.

        Args:
            data: Data to be gridded.
            omega: k-space trajectory (in radians/voxel).
            interp_mats: 2-tuple of real, imaginary sparse matrices to use for
                sparse matrix KB interpolation (overrides default table
                interpolation).

        Returns:
            ``data`` interpolated to the grid.
        """
        is_complex = True
        if not data.is_complex():
            if not data.shape[-1] == 2:
                raise ValueError(
                    "For real inputs, last dimension must be size 2.")

            is_complex = False
            data = torch.view_as_complex(data)

        if grid_size is None:
            assert isinstance(self.grid_size, Tensor)
            grid_size = self.grid_size
        if interp_mats is not None:
            output = tkbnF.kb_spmat_interp_adjoint(data=data,
                                                   interp_mats=interp_mats,
                                                   grid_size=grid_size)
        else:
            tables = []
            for i in range(len(self.im_size)):  # type: ignore
                tables.append(getattr(self, f"table_{i}"))

            assert isinstance(self.n_shift, Tensor)
            assert isinstance(self.numpoints, Tensor)
            assert isinstance(self.table_oversamp, Tensor)
            assert isinstance(self.offsets, Tensor)

            output = tkbnF.kb_table_interp_adjoint(
                data=data,
                omega=omega,
                tables=tables,
                n_shift=self.n_shift,
                numpoints=self.numpoints,
                table_oversamp=self.table_oversamp,
                offsets=self.offsets,
                grid_size=grid_size,
            )

        if not is_complex:
            output = torch.view_as_real(output)

        return output
Esempio n. 19
0
def irfft(input, signal_ndim, normalized=False, onesided=True, signal_sizes=None):
    if LooseVersion(torch.__version__) < LooseVersion("1.8.0"): 
        return torch.irfft(input, signal_ndim, normalized, onesided, signal_sizes)
    else:
        assert signal_sizes, "Parameter signal_sizes is required"
        if onesided: 
            if normalized:
                if signal_ndim == 1:
                    y = torch.fft.irfft(torch.view_as_complex(input), signal_sizes[-1], -1, "ortho")
                elif signal_ndim == 2:
                    y = torch.fft.irfft2(torch.view_as_complex(input), signal_sizes, (-2, -1), "ortho")
                elif signal_ndim == 3:
                    y = torch.fft.irfftn(torch.view_as_complex(input), signal_sizes, (-3, -2, -1), "ortho")
                else:
                    assert False, "Ortho-normalized irfft() has illegal number of dimensions %s" % (signal_ndim)
            else:
                if signal_ndim == 1:
                    y = torch.fft.irfft(torch.view_as_complex(input), signal_sizes[-1], -1, "backward")
                elif signal_ndim == 2:
                    y = torch.fft.irfft2(torch.view_as_complex(input), signal_sizes, (-2, -1), "backward")
                elif signal_ndim == 3:
                    y = torch.fft.irfftn(torch.view_as_complex(input), signal_sizes, (-3, -2, -1), "backward")
                else:
                    assert False, "Backward-normalized irfft() has illegal number of dimensions %s" % (signal_ndim)
        else: 
            if normalized:
                if signal_ndim == 1:
                    y = torch.fft.irfft(torch.view_as_complex(input), signal_sizes[-1], -1, "ortho")
                elif signal_ndim == 2:
                    y = torch.fft.irfft2(torch.view_as_complex(input), signal_sizes, (-2, -1), "ortho")
                elif signal_ndim == 3:
                    y = torch.fft.irfftn(torch.view_as_complex(input), signal_sizes, (-3, -2, -1), "ortho")
                else:
                    assert False, "Ortho-normalized ifft() has illegal number of dimensions %s" % (signal_ndim)
            else:
                if signal_ndim == 1:
                    y = torch.fft.irfft(torch.view_as_complex(input), signal_sizes[-1], -1, "backward")
                elif signal_ndim == 2:
                    y = torch.fft.irfft2(torch.view_as_complex(input), signal_sizes, (-2, -1), "backward")
                elif signal_ndim == 3:
                    y = torch.fft.irfftn(torch.view_as_complex(input), signal_sizes, (-3, -2, -1), "backward")
                else:
                    assert False, "Backward-normalized ifft() has illegal number of dimensions %s" % (signal_ndim)
        assert not y.is_complex()
        return y.contiguous()
def tocplx(x):
    return torch.view_as_complex(x)
Esempio n. 21
0
    # [batch, 1, 121, 61, 2]
    alphaf = label.to(device=z.device) / (kzzf + lambda0)

    # [batch, 1, 121, 121]
    return torch.irfft(cn.mul(kxzf, alphaf), signal_ndim=2)


##############################################

x = torch.rand((42, 32, 121, 121))

a = torch.rfft(x, signal_ndim=2, onesided=False)
b = fft.fftn(x, dim=[-2, -1])

ca = torch.view_as_complex(a)
print(a.shape)
print(b.shape)
print(torch.allclose(ca, b))

u = ca - b
v = u.abs()
h = torch.histc(v)

import matplotlib.pyplot as plt
plt.hist(v.flatten().numpy(), bins=500, log=True)
plt.show()

exit()

# fft_comparison()
Esempio n. 22
0
    i0 : int
        Start of original signal before padding.
    i1 : int
        End of original signal before padding.

    Returns
    -------
    x_unpadded : tensor
        The tensor x[..., i0:i1].
    """
    return x[..., i0:i1]


if version.parse(torch.__version__) >= version.parse('1.8'):
    fft = FFT(
        lambda x: torch.view_as_real(torch.fft.fft(torch.view_as_complex(x))),
        lambda x: torch.view_as_real(torch.fft.ifft(torch.view_as_complex(x))),
        lambda x: torch.fft.ifft(torch.view_as_complex(x)).real, type_checks)
else:
    fft = FFT(lambda x: torch.fft(x, 1, normalized=False),
              lambda x: torch.ifft(x, 1, normalized=False),
              lambda x: torch.irfft(x, 1, normalized=False, onesided=False),
              type_checks)

backend = namedtuple('backend', [
    'name', 'modulus_complex', 'subsample_fourier', 'real', 'unpad', 'fft',
    'concatenate'
])
backend.name = 'torch'
backend.version = torch.__version__
backend.modulus_complex = Modulus()
Esempio n. 23
0
def complex_ifft(A, *args, **kwargs):
    return torch.view_as_complex(
        torch.ifft(torch.view_as_real(A), *args, **kwargs))
Esempio n. 24
0
def defocus(x, pa=None, pr=None, isfft=True, ftshift=True):
    r"""Defocus image with given phase error

    Defocus image in azimuth by

    .. math::
        Y(k, n_r)=\sum_{n_a=0}^{N_a-1} X(n_a, n_r) \exp \left(j \varphi_{n_a}\right) \exp \left(-j \frac{2 \pi}{N_a} k n_a\right)

    where, :math:`\varphi_{n_a}` is the estimated azimuth phase error of the :math:`n_a`-th azimuth line, :math:`y(k, n_r)` is
    the focused pixel.

    The defocus method in range is the same as azimuth.

    Args:
        x (Tensor): Focused image data :math:`{\mathbf X} \in{\mathbb C}^{N\times N_c\times N_a\times N_r}` or
            :math:`{\mathbf X} \in{\mathbb R}^{N\times N_a\times N_r\times 2}` or
            :math:`{\mathbf X} \in{\mathbb R}^{N\times N_c\times N_a\times N_r\times 2}`.
        pa (Tensor, optional): Defocus parameters in azimuth, phase error in rad unit. (the default is None, not focus)
        pr (Tensor, optional): Defocus parameters in range, phase error in rad unit. (the default is None, not focus)
        isfft (bool, optional): Is need do fft (the default is True)
        ftshift (bool, optional): Is shift zero frequency to center when do fft/ifft/fftfreq (the default is True)

    Returns:
        (Tensor): A tensor of defocused images.

    Raises:
        TypeError: :attr:`x` is complex and should be in complex or real represent formation!
    """

    if type(x) is not th.Tensor:
        x = th.tensor(x)

    if th.is_complex(x):
        # N, Na, Nr = x.size(0), x.size(-2), x.size(-1)
        x = th.view_as_real(x)
        crepflag = True
    elif x.size(-1) == 2:
        # N, Na, Nr = x.size(0), x.size(-3), x.size(-2)
        crepflag = False
    else:
        raise TypeError('x is complex and should be in complex or real represent formation!')

    d = x.dim()
    sizea, sizer = [1] * d, [1] * d

    if pa is not None:
        sizea[0], sizea[-3], sizea[-1] = pa.size(0), pa.size(1), 2
        epa = th.stack((th.cos(pa), th.sin(pa)), dim=-1)
        epa = epa.reshape(sizea)

        if isfft:
            x = ts.fft(x, axis=-3, shift=ftshift)
        x = ts.ebemulcc(x, epa)
        x = ts.ifft(x, axis=-3, shift=ftshift)

    if pr is not None:
        sizer[0], sizer[-2], sizer[-1] = pr.size(0), pr.size(1), 2
        epr = th.stack((th.cos(pr), th.sin(pr)), dim=-1)
        epr = epr.reshape(sizer)

        if isfft:
            x = ts.fft(x, axis=-2, shift=ftshift)
        x = ts.ebemulcc(x, epr)
        x = ts.ifft(x, axis=-2, shift=ftshift)

    if crepflag:
        x = th.view_as_complex(x)

    return x
Esempio n. 25
0
def view_complex_native(x: torch.FloatTensor) -> torch.Tensor:
    """Convert a PyKEEN complex tensor representation into a torch one using :func:`torch.view_as_complex`."""
    return torch.view_as_complex(x.view(*x.shape[:-1], -1, 2))
Esempio n. 26
0
 def test_fn(x):
     return torch_fn(torch.view_as_complex(x), *args)
Esempio n. 27
0
    def meta_tensor(self, t):
        # see expired-storages
        self.check_expired_count += 1
        if self.check_expired_count >= self.check_expired_frequency:
            self.check_for_expired_weak_storages()
            self.check_expired_count = 0

        if self.get_tensor_memo(t) is None:
            with torch.inference_mode(t.is_inference()):
                if t._is_view():
                    # Construct views in two steps: recursively meta-fy their
                    # base, and then create the view off that.  NB: doing it
                    # directly from storage is WRONG because this won't cause
                    # version counters to get shared.
                    assert t._is_view()
                    base = self.meta_tensor(t._base)

                    def is_c_of_r(complex_dtype, real_dtype):
                        return (
                            utils.is_complex_dtype(complex_dtype)
                            and utils.corresponding_real_dtype(complex_dtype)
                            == real_dtype
                        )

                    if base.dtype == t.dtype:
                        pass
                    elif is_c_of_r(base.dtype, t.dtype):
                        base = torch.view_as_real(base)
                    elif is_c_of_r(t.dtype, base.dtype):
                        base = torch.view_as_complex(base)
                    else:
                        # This is not guaranteed to succeed.  If it fails, it
                        # means there is another dtype-converting view function
                        # that hasn't been handled here
                        base = base.view(t.dtype)

                    with torch.enable_grad():
                        r = base.as_strided(t.size(), t.stride(), t.storage_offset())
                else:
                    is_leaf = safe_is_leaf(t)
                    # Fake up some autograd history.
                    if t.requires_grad:
                        r = torch.empty(
                            (0,), dtype=t.dtype, device="meta", requires_grad=True
                        )
                        if not is_leaf:
                            with torch.enable_grad():
                                # The backward function here will be wrong, but
                                # that's OK; our goal is just to get the metadata
                                # looking as close as possible; we're not going to
                                # actually try to backward() on these produced
                                # metas.  TODO: would be safer to install some
                                # sort of unsupported grad_fn here
                                r = r.clone()
                    else:
                        r = torch.empty((0,), dtype=t.dtype, device="meta")
                    # As long as meta storage is not supported, need to prevent
                    # redispatching on set_(Storage, ...) which will choke with
                    # meta storage
                    s = self.meta_storage(t.storage())
                    with no_dispatch():
                        with torch.no_grad():
                            r.set_(s, t.storage_offset(), t.size(), t.stride())

                torch._C._set_conj(r, t.is_conj())
                torch._C._set_neg(r, t.is_neg())
            self.set_tensor_memo(t, r)

        return self.get_tensor_memo(t)
Esempio n. 28
0
 def ifft(input, signal_ndim, normalized=True):
     return torch.view_as_real(
         torch.fft.ifft2(torch.view_as_complex(input),
                         norm="ortho" if normalized else "backward"))
Esempio n. 29
0
def fft(x, n=None, axis=0, norm="backward", shift=False):
    """FFT in torchsar

    FFT in torchsar.

    Parameters
    ----------
    x : {torch array}
        complex representation is supported. Since torch1.7 and above support complex array,
        when :attr:`x` is in real-representation formation(last dimension is 2, real, imag),
        we will change the representation in complex formation, after FFT, it will be change back.
    n : int, optional
        number of fft points (the default is None --> equals to signal dimension)
    axis : int, optional
        axis of fft (the default is 0, which the first dimension)
    norm : {None or str}, optional
        Normalization mode. For the forward transform (fft()), these correspond to:
        - "forward" - normalize by ``1/n``
        - "backward" - no normalization (default)
        - "ortho" - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)
    shift : bool, optional
        shift the zero frequency to center (the default is False)

    Returns
    -------
    y : {torch array}
        fft results torch array with the same type as :attr:`x`

    Raises
    ------
    ValueError
        nfft is small than signal dimension.
    """

    if norm is None:
        norm = 'backward'

    if (x.size(-1) == 2) and (not th.is_complex(x)):
        realflag = True
        x = th.view_as_complex(x)
        if axis < 0:
            axis += 1
    else:
        realflag = False

    d = x.size(axis)
    if n is None:
        n = d
    if d < n:
        x = padfft(x, n, axis, shift)
    elif d > n:
        raise ValueError('nfft is small than signal dimension!')

    if shift:
        y = thfft.fftshift(thfft.fft(thfft.fftshift(x, dim=axis),
                                     n=n,
                                     dim=axis,
                                     norm=norm),
                           dim=axis)
    else:
        y = thfft.fft(x, n=n, dim=axis, norm=norm)

    if realflag:
        y = th.view_as_real(y)

    return y
Esempio n. 30
0
def _spectral_residual_visual_saliency(
        x: torch.Tensor,
        scale: float = 0.25,
        kernel_size: int = 3,
        sigma: float = 3.8,
        gaussian_size: int = 10) -> torch.Tensor:
    r"""Compute Spectral Residual Visual Saliency
    Credits X. Hou and L. Zhang, CVPR 07, 2007
    Reference:
        http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.125.5641&rep=rep1&type=pdf

    Args:
        x: Tensor with shape (N, 1, H, W).
        scale: Resizing factor
        kernel_size: Kernel size of average blur filter
        sigma: Sigma of gaussian filter applied on saliency map
        gaussian_size: Size of gaussian filter applied on saliency map
    Returns:
        saliency_map: Tensor with shape BxHxW

    """
    eps = torch.finfo(x.dtype).eps
    for kernel in kernel_size, gaussian_size:
        if x.size(-1) * scale < kernel or x.size(-2) * scale < kernel:
            raise ValueError(
                f'Kernel size can\'t be greater than actual input size. '
                f'Input size: {x.size()} x {scale}. Kernel size: {kernel}')

    # Downsize image
    in_img = imresize(x, scale=scale)

    # Fourier transform (use complex format [a,b] instead of a + ib
    # because torch<1.8.0 autograd does not support the latter)
    recommended_torch_version = _parse_version('1.8.0')
    torch_version = _parse_version(torch.__version__)
    if len(torch_version) != 0 and torch_version >= recommended_torch_version:
        imagefft = torch.fft.fft2(in_img)
        log_amplitude = torch.log(imagefft.abs() + eps)
        phase = torch.angle(imagefft)
    else:
        imagefft = torch.rfft(in_img, 2, onesided=False)
        # Compute log of absolute value and angle of fourier transform
        log_amplitude = torch.log(imagefft.pow(2).sum(dim=-1).sqrt() + eps)
        phase = torch.atan2(imagefft[..., 1], imagefft[..., 0] + eps)

    # Compute spectral residual using average filtering
    padding = kernel_size // 2
    if padding:
        up_pad = (kernel_size - 1) // 2
        down_pad = padding
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        # replicate padding before average filtering
        spectral_residual = F.pad(log_amplitude,
                                  pad=pad_to_use,
                                  mode='replicate')
    else:
        spectral_residual = log_amplitude

    spectral_residual = log_amplitude - F.avg_pool2d(
        spectral_residual, kernel_size=kernel_size, stride=1)
    # Saliency map
    # representation of complex exp(spectral_residual + j * phase)
    compx = torch.stack((torch.exp(spectral_residual) * torch.cos(phase),
                         torch.exp(spectral_residual) * torch.sin(phase)), -1)

    if len(torch_version) != 0 and torch_version >= recommended_torch_version:
        saliency_map = torch.abs(torch.fft.ifft2(
            torch.view_as_complex(compx)))**2

    else:
        saliency_map = torch.sum(torch.ifft(compx, 2)**2, dim=-1)

    # After effect for SR-SIM
    # Apply gaussian blur
    kernel = gaussian_filter(gaussian_size, sigma)
    if gaussian_size % 2 == 0:  # matlab pads upper and lower borders with 0s for even kernels
        kernel = torch.cat((torch.zeros(1, 1, gaussian_size), kernel), 1)
        kernel = torch.cat((torch.zeros(1, gaussian_size + 1, 1), kernel), 2)
        gaussian_size += 1
    kernel = kernel.view(1, 1, gaussian_size, gaussian_size).to(saliency_map)
    saliency_map = F.conv2d(saliency_map,
                            kernel,
                            padding=(gaussian_size - 1) // 2)

    # normalize between [0, 1]
    min_sal = torch.min(saliency_map[:])
    max_sal = torch.max(saliency_map[:])
    saliency_map = (saliency_map - min_sal) / (max_sal - min_sal + eps)

    # scale to original size
    saliency_map = imresize(saliency_map, sizes=x.size()[-2:])
    return saliency_map