Exemplo n.º 1
0
def interpolate_dft(a: torch.Tensor, interp_fs) -> torch.Tensor:

    if isinstance(interp_fs, torch.Tensor):
        return complex.mult(a, interp_fs)
    if isinstance(interp_fs, (tuple, list)):
        return complex.mult(complex.mult(a, interp_fs[0]), interp_fs[1])
    raise ValueError('"interp_fs" must be tensor or tuple of tensors.')
Exemplo n.º 2
0
def get_interp_fourier(sz: torch.Tensor,
                       method='ideal',
                       bicubic_param=0.5,
                       centering=True,
                       windowing=False,
                       device='cpu'):

    ky, kx = fourier.get_frequency_coord(sz)

    if method == 'ideal':
        interp_y = torch.ones(ky.shape) / sz[0]
        interp_x = torch.ones(kx.shape) / sz[1]
    elif method == 'bicubic':
        interp_y = cubic_spline_fourier(ky / sz[0], bicubic_param) / sz[0]
        interp_x = cubic_spline_fourier(kx / sz[1], bicubic_param) / sz[1]
    else:
        raise ValueError('Unknown method.')

    if centering:
        interp_y = complex.mult(interp_y,
                                complex.exp_imag((-math.pi / sz[0]) * ky))
        interp_x = complex.mult(interp_x,
                                complex.exp_imag((-math.pi / sz[1]) * kx))

    if windowing:
        raise NotImplementedError

    return interp_y.to(device), interp_x.to(device)
Exemplo n.º 3
0
    def A(self, hf: TensorList):
        # Classify
        sh = complex.mtimes(self.training_samples, hf.permute(2,3,1,0,4)) # (h, w, num_samp, num_filt, 2)
        sh = complex.mult(self.sample_weights.view(1,1,-1,1), sh)

        # Multiply with transpose
        hf_out = complex.mtimes(sh.permute(0,1,3,2,4), self.training_samples, conj_b=True).permute(2,3,0,1,4)

        # Add regularization
        for hfe, hfe_out, reg_filter in zip(hf, hf_out, self.reg_filter):
            reg_pad1 = min(reg_filter.shape[-2] - 1, hfe.shape[-3] - 1)
            reg_pad2 = min(reg_filter.shape[-1] - 1, 2*hfe.shape[-2]- 2)

            # Add part needed for convolution
            if reg_pad2 > 0:
                hfe_conv = torch.cat([complex.conj(hfe[...,1:reg_pad2+1,:].flip((2,3))), hfe], -2)
            else:
                hfe_conv = hfe.clone()

            # Shift data to batch dimension
            hfe_conv = hfe_conv.permute(0,1,4,2,3).reshape(-1, 1, hfe_conv.shape[-3], hfe_conv.shape[-2])

            # Do first convolution
            hfe_conv = F.conv2d(hfe_conv, reg_filter, padding=(reg_pad1, reg_pad2))

            # Do second convolution
            remove_size = min(reg_pad2, hfe.shape[-2]-1)
            hfe_conv = F.conv2d(hfe_conv[...,remove_size:], reg_filter)

            # Reshape back and add
            hfe_out += hfe_conv.reshape(hfe.shape[0], hfe.shape[1], 2, hfe.shape[2], hfe.shape[3]).permute(0,1,3,4,2)
        return hf_out
Exemplo n.º 4
0
def shift_fs(a: torch.Tensor, shift: torch.Tensor):
    """Shift a sample a in the Fourier domain.
    Params:
        a : The fourier coefficiens of the sample.
        shift : The shift to be performed normalized to the range [-pi, pi]."""

    if a.dim() != 5:
        raise ValueError(
            'a must be the Fourier coefficients, a 5-dimensional tensor.')

    if shift[0] == 0 and shift[1] == 0:
        return a

    ky, kx = get_frequency_coord((a.shape[2], 2 * a.shape[3] - 1),
                                 device=a.device)

    return complex.mult(
        complex.mult(a, complex.exp_imag(shift[0].item() * ky)),
        complex.exp_imag(shift[1].item() * kx))
Exemplo n.º 5
0
    def __call__(self, x: TensorList):
        """
        Compute residuals
        :param x: [filters, projection_matrices]
        :return: [data_terms, filter_regularizations, proj_mat_regularizations]
        """
        hf = x[:len(x) // 2]
        P = x[len(x) // 2:]

        compressed_samples = complex.mtimes(self.training_samples, P)
        residuals = complex.mtimes(compressed_samples, hf.permute(
            2, 3, 1, 0, 4))  # (h, w, num_samp, num_filt, 2)
        residuals = residuals - self.yf

        if self.sample_weights_sqrt is not None:
            residuals = complex.mult(
                self.sample_weights_sqrt.view(1, 1, -1, 1), residuals)

        # Add spatial regularization
        for hfe, reg_filter in zip(hf, self.reg_filter):
            reg_pad1 = min(reg_filter.shape[-2] - 1, hfe.shape[-3] - 1)
            reg_pad2 = min(reg_filter.shape[-1] - 1, hfe.shape[-2] - 1)

            # Add part needed for convolution
            if reg_pad2 > 0:
                hfe_left_padd = complex.conj(
                    hfe[..., 1:reg_pad2 + 1, :].clone().detach().flip((2, 3)))
                hfe_conv = torch.cat([hfe_left_padd, hfe], -2)
            else:
                hfe_conv = hfe.clone()

            # Shift data to batch dimension
            hfe_conv = hfe_conv.permute(0, 1, 4, 2,
                                        3).reshape(-1, 1, hfe_conv.shape[-3],
                                                   hfe_conv.shape[-2])

            # Do first convolution
            hfe_conv = F.conv2d(hfe_conv,
                                reg_filter,
                                padding=(reg_pad1, reg_pad2))

            residuals.append(hfe_conv)

        # Add regularization for projection matrix
        residuals.extend(math.sqrt(self.params.projection_reg) * P)

        return residuals
Exemplo n.º 6
0
 def apply_filter(self, sample_xf: TensorList) -> torch.Tensor:
     return complex.mult(self.filter, sample_xf).sum(1, keepdim=True)
Exemplo n.º 7
0
 def apply_filters(self, sample_xf: TensorList) -> torch.Tensor:
     return TensorList([
         complex.mult(f, sample_xf).sum(1, keepdim=True)
         for f in self.filters
     ])