예제 #1
0
def test_takagi_factorization_real_diagonal(rank, dtype):
    real_a = torch.diag_embed(torch.rand((10, rank), dtype=dtype).real * 10)
    a = torch.complex(real_a, torch.zeros_like(real_a))

    eigenvalues, s = sm.takagi_eig(a)

    assert_sds_equals_a(s, eigenvalues, a)
    # real part of eigenvectors is made of vectors with one 1 and all zeros
    real_part = torch.sum(torch.abs(s.real), dim=-1)
    np.testing.assert_allclose(torch.ones_like(real_part),
                               real_part,
                               atol=1e-5,
                               rtol=1e-5)
    # imaginary part of eigenvectors is all zeros
    np.testing.assert_allclose(torch.zeros(1),
                               torch.sum(s.imag),
                               atol=1e-5,
                               rtol=1e-5)
예제 #2
0
파일: utils.py 프로젝트: huerd/torch-dreams
def fft_to_rgb(height, width, image_parameter, device = 'cuda'):
    """convert image param to NCHW 

    WARNING: torch v1.7.0 works differently from torch v1.8.0 on fft. 
    Hence you might find some weird workarounds in this function.

    Latest docs: https://pytorch.org/docs/stable/fft.html

    Also refer:
        https://github.com/pytorch/pytorch/issues/49637

    Args:
        height (int): height of image
        width (int): width of image 
        image_parameter (auto_image_param): instance of class auto_image_param()

    Returns:
        torch.tensor: NCHW tensor

    size log:
        before: 
            torch.Size([1, 3, height, width//2, 2]) 
            OR 
            torch.Size([1, 3, height, width+1//2, 2])

        after: 
            torch.Size([1, 3, height, width])

    """
    scale = get_fft_scale(height, width, device = device)

    t = scale * image_parameter.to(device)

    if torch.__version__[:3] == "1.7":
        t = torch.irfft(t, 2, normalized=True, signal_sizes=(height,width))
    elif  torch.__version__[:3] == '1.8':
        """
        hacky workaround to fix issues for the new torch.fft on torch 1.8.0

        """
        t = torch.complex(t[..., 0], t[..., 1])
        t = torch.fft.irfftn(t, s = (3, height, width), dim = (1,2,3), norm = 'ortho')

    return t
예제 #3
0
    def forward(self, tf_rep):
        """forward.

        Args:
            tf_rep (torch.Tensor): 4D tensor (multi-channel complex STFT of mixture)
                        of shape [B, T, C, F] batch, frames, microphones, frequencies.

        Returns:
            out (torch.Tensor): complex 3D tensor monaural STFT of the targets
                shape is [B, T, F] batch, frames, frequencies.

        """
        # B, T, C, F
        tf_rep = tf_rep.permute(0, 2, 3, 1)
        bsz, mics, _, frames = tf_rep.shape
        assert mics == self.mic_channels

        inp_feats = torch.cat((tf_rep.real, tf_rep.imag), 1)
        inp_feats = inp_feats.transpose(-1, -2)
        inp_feats = inp_feats.reshape(bsz, self.mic_channels * 2, frames,
                                      self.in_channels)

        enc_out = []
        buffer = inp_feats
        for enc_layer in self.encoder:
            buffer = enc_layer(buffer)
            enc_out.append(buffer)

        assert buffer.shape[-1] == 1
        tcn_out = self.tcn(buffer.squeeze(-1)).unsqueeze(-1)

        buffer = tcn_out
        for indx, dec_layer in enumerate(self.decoder):
            c_input = torch.cat((buffer, enc_out[-(indx + 1)]), 1)
            buffer = dec_layer(c_input)

        buffer = buffer.reshape(bsz, 2, self.n_spk, -1, self.in_channels)

        if is_torch_1_9_plus:
            out = torch.complex(buffer[:, 0], buffer[:, 1])
        else:
            out = ComplexTensor(buffer[:, 0], buffer[:, 1])
        # bsz, complex_chans, frames or bsz, spk, complex_chans, frames
        return out  # bsz, spk, time, freq -> bsz, time, spk, freq
예제 #4
0
def fft_to_rgb(height, width, image_parameter, device='cuda'):
    """convert image param to NCHW 

    WARNING: torch v1.7.0 works differently from torch v1.8.0 on fft. 
    torch-dreams supports ONLY 1.8.x 

    Latest docs: https://pytorch.org/docs/stable/fft.html

    Also refer:
        https://github.com/pytorch/pytorch/issues/49637

    Args:
        height (int): height of image
        width (int): width of image 
        image_parameter (auto_image_param): auto_image_param.param

    Returns:
        torch.tensor: NCHW tensor

    """
    scale = get_fft_scale(height, width,
                          device=device).to(image_parameter.device)
    # print(scale.shape, image_parameter.shape)
    if width % 2 == 1:
        image_parameter = image_parameter.reshape(1, 3, height,
                                                  (width + 1) // 2, 2)
    else:
        image_parameter = image_parameter.reshape(1, 3, height, width // 2, 2)

    image_parameter = torch.complex(image_parameter[..., 0],
                                    image_parameter[..., 1])
    t = scale * image_parameter

    version = torch.__version__.split('.')[:2]
    main_version = int(version[0])
    sub_version = int(version[1])

    if main_version >= 1 and sub_version >= 8:  ## if torch.__version__ is greater than 1.8
        t = torch.fft.irfft2(t, s=(height, width), norm='ortho')
    else:
        raise PytorchVersionError(version=torch.__version__)

    return t
예제 #5
0
def test_dc_crn_separator_output():
    real = torch.rand(2, 10, 17)
    imag = torch.rand(2, 10, 17)
    x = torch.complex(real, imag) if is_torch_1_9_plus else ComplexTensor(
        real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    for num_spk in range(1, 3):
        model = DC_CRNSeparator(
            input_dim=17,
            num_spk=num_spk,
            input_channels=[2, 2, 4],
        )
        model.eval()
        specs, _, others = model(x, x_lens)
        assert isinstance(specs, list)
        assert isinstance(others, dict)
        for n in range(num_spk):
            assert "mask_spk{}".format(n + 1) in others
            assert specs[n].shape == others["mask_spk{}".format(n + 1)].shape
예제 #6
0
def complex(real, imag):
    """Return a 'complex' tensor
        - If `fft` module is present, returns a propert complex tensor
        - Otherwise, stack the real and imaginary compoenents along the last
        dimension.

    Parameters
    ----------
    real : tensor
    imag : tensor

    Returns
    -------
    complex : tensor

    """
    if _torch_has_complex:
        return torch.complex(real, imag)
    else:
        return torch.stack([real, imag], -1)
예제 #7
0
def test_dc_crn_separator_multich_input(
    num_spk,
    input_channels,
    enc_kernel_size,
    enc_padding,
    enc_last_kernel_size,
    enc_last_stride,
    enc_last_padding,
    skip_last_kernel_size,
    skip_last_stride,
    skip_last_padding,
):
    model = DC_CRNSeparator(
        input_dim=33,
        num_spk=num_spk,
        input_channels=input_channels,
        enc_hid_channels=2,
        enc_kernel_size=enc_kernel_size,
        enc_padding=enc_padding,
        enc_last_kernel_size=enc_last_kernel_size,
        enc_last_stride=enc_last_stride,
        enc_last_padding=enc_last_padding,
        enc_layers=3,
        skip_last_kernel_size=skip_last_kernel_size,
        skip_last_stride=skip_last_stride,
        skip_last_padding=skip_last_padding,
    )
    model.train()

    real = torch.rand(2, 10, input_channels[0] // 2, 33)
    imag = torch.rand(2, 10, input_channels[0] // 2, 33)
    x = torch.complex(real, imag) if is_torch_1_9_plus else ComplexTensor(
        real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert is_complex(masked[0])
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()
예제 #8
0
def proj_forward(sinogram):
    alpha = torch.linspace(0, 180, 361) * np.pi / 180 + np.pi / 2
    a = torch.linspace(0, 183, 184)
    b = torch.linspace(-183, -1, 183)
    f = torch.cat((a, b)) / sinogram.shape[0]
    f = torch.unsqueeze(f, dim=1)
    #ramp filter
    fourier_filter = 2 * torch.abs(f)
    fourier_filter_ = fourier_filter.expand(367, 361).unsqueeze(-1)
    fourier_filter_ = torch.transpose(fourier_filter_, 0, 1)
    fourier_filter_ = torch.cat((fourier_filter_, fourier_filter_), -1)

    #     projection = torch.rfft(sinogram, 2, onesided=False).double() * fourier_filter_.double()
    #     proj_ifft = torch.irfft(projection, 2, onesided=False).float()
    output_fft_new = torch.fft.fft2(sinogram, dim=(-2, -1))
    projection = torch.stack((output_fft_new.real, output_fft_new.imag),
                             -1).double() * fourier_filter_.double()
    proj_ifft = torch.fft.ifft2(torch.complex(projection[..., 0],
                                              projection[..., 1]),
                                dim=(-2, -1)).float()

    proj_ifft = proj_ifft.contiguous()

    fbp_host = np.zeros((256, 256))
    fbp_dev = cuda.to_device(fbp_host)
    alpha_dev = cuda.to_device(alpha)
    proj_ifft_dev = cuda.to_device(proj_ifft)
    alpha_dev = cuda.to_device(alpha)
    TPB = 16
    threadperblock = (TPB, TPB)
    blockpergrid_x = int(math.ceil(fbp_dev.shape[0] / threadperblock[0]))
    blockpergrid_y = int(math.ceil(fbp_dev.shape[1] / threadperblock[1]))
    blockpergrid = (blockpergrid_x, blockpergrid_y)

    func[blockpergrid, threadperblock](alpha_dev, proj_ifft_dev, fbp_dev)
    cuda.synchronize()

    im = fbp_dev.copy_to_host()
    im = im / 0.06

    return im
예제 #9
0
    def check_children_attributes(self):
        if self.stft is None:
            try:
                n_fft = self.audio_to_melspec_precessor.n_fft
                hop_length = self.audio_to_melspec_precessor.hop_length
                win_length = self.audio_to_melspec_precessor.win_length
                window = self.audio_to_melspec_precessor.window.to(self.device)
            except AttributeError as e:
                raise AttributeError(
                    f"{self} could not find a valid audio_to_melspec_precessor. GlowVocoder requires child class "
                    "to have audio_to_melspec_precessor defined to obtain stft parameters. "
                    "audio_to_melspec_precessor requires n_fft, hop_length, win_length, window, and nfilt to be "
                    "defined."
                ) from e

            def yet_another_patch(audio, n_fft, hop_length, win_length, window):
                spec = torch.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window)
                if spec.dtype in [torch.cfloat, torch.cdouble]:
                    spec = torch.view_as_real(spec)
                return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0])

            self.stft = lambda x: yet_another_patch(
                x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window,
            )
            self.istft = lambda x, y: torch.istft(
                torch.complex(x * torch.cos(y), x * torch.sin(y)),
                n_fft=n_fft,
                hop_length=hop_length,
                win_length=win_length,
                window=window,
            )

        if self.n_mel is None:
            try:
                self.n_mel = self.audio_to_melspec_precessor.nfilt
            except AttributeError as e:
                raise AttributeError(
                    f"{self} could not find a valid audio_to_melspec_precessor. GlowVocoder requires child class to "
                    "have audio_to_melspec_precessor defined to obtain stft parameters. audio_to_melspec_precessor "
                    "requires nfilt to be defined."
                ) from e
예제 #10
0
def FDA_source_to_target(src_img, trg_img, beta=1e-2):
    # exchange magnitude
    # input: src_img, trg_img

    # get fft of both source and target
    fft_src = torch.fft.fftn(src_img.clone(), dim=(2, 3)) # check if fft2 is enough
    fft_trg = torch.fft.fftn(trg_img.clone(), dim=(2, 3))

    assert fft_src.dtype == torch.complex64, fft_src.dtype
    assert fft_trg.dtype == torch.complex64, fft_src.dtype
    assert fft_src.shape == (BATCH_SIZE_CAMVID, 3, 720, 960), fft_src.shape
    assert fft_trg.shape == (BATCH_SIZE_CAMVID, 3, 720, 960), fft_trg.shape

    # extract amplitude and phase of both ffts
    amp_src, pha_src = fft_src.abs(), fft_src.angle()
    amp_trg, pha_trg = fft_trg.abs(), fft_trg.angle()

    assert amp_src.dtype == torch.float32, f"assertion failure {amp_src.dtype}"
    assert amp_trg.dtype == torch.float32, f"assertion failure {amp_src.dtype}"
    assert amp_src.shape == (BATCH_SIZE_CAMVID, 3, 720, 960), f"assertion failure {amp_src.shape}"
    assert amp_trg.shape == (BATCH_SIZE_CAMVID, 3, 720, 960), f"assertion failure {amp_trg.shape}"

    # replace the low frequency amplitude part of source with that from target
    amp_src_ = low_freq_mutate(amp_src.clone(), amp_trg.clone(), beta=beta)

    assert amp_src_.dtype == torch.float32, f"assertion failure {amp_src_.dtype}"
    assert amp_src_.shape == (BATCH_SIZE_CAMVID, 3, 720, 960), f"assertion failure {amp_src_.shape}"

    # recompose fft of source
    fft_src_real = torch.cos(pha_src.clone()) * amp_src_.clone()
    fft_src_imag = torch.sin(pha_src.clone()) * amp_src_.clone()
    fft_src_ = torch.complex(fft_src_real, fft_src_imag)
    assert fft_src_.shape == (BATCH_SIZE_CAMVID, 3, 720, 960), f"assertion failure {fft_src_.shape}"
  
    # get the recomposed image: source content, target style
    _, _, imgH, imgW = src_img.size()
    src_in_trg = torch.fft.ifftn(fft_src_, dim=(2, 3))
    assert src_in_trg.shape == (BATCH_SIZE_CAMVID, 3, 720, 960), f"assertion failure {src_in_trg.shape}"

    return src_in_trg
예제 #11
0
def double_phase_amplitude_coding(target_phase,
                                  target_amp,
                                  prop_dist,
                                  wavelength,
                                  feature_size,
                                  prop_model='ASM',
                                  propagator=None,
                                  dtype=torch.float32,
                                  precomputed_H=None):
    """
    Use a single propagation and converts amplitude and phase to double phase coding

    Input
    -----
    :param target_phase: The phase at the target image plane
    :param target_amp: A tensor, (B,C,H,W), the amplitude at the target image plane.
    :param prop_dist: propagation distance, in m.
    :param wavelength: wavelength, in m.
    :param feature_size: The SLM pixel pitch, in meters.
    :param prop_model: The light propagation model to use for prop from target plane to slm plane
    :param propagator: propagation_ASM
    :param dtype: torch datatype for computation at different precision.
    :param precomputed_H: pre-computed kernel - to make it faster over multiple iteration/images - calculate it once

    Output
    ------
    :return: a tensor, the optimized phase pattern at the SLM plane, in the shape of (1,1,H,W)
    """
    real, imag = utils.polar_to_rect(target_amp, target_phase)
    target_field = torch.complex(real, imag)

    slm_field = utils.propagate_field(target_field, propagator, prop_dist,
                                      wavelength, feature_size, prop_model,
                                      dtype, precomputed_H)

    slm_phase = double_phase(slm_field, three_pi=False, mean_adjust=True)

    return slm_phase
    def forward(self, x):
        batch = x.shape[0]

        if self.spatial_scale_factor is not None:
            orig_size = x.shape[-2:]
            x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)

        r_size = x.size()
        # (batch, c, h, w/2+1, 2)
        fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
        ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
        ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
        ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()  # (batch, c, 2, h, w/2+1)
        ffted = ffted.view((batch, -1,) + ffted.size()[3:])

        if self.spectral_pos_encoding:
            height, width = ffted.shape[-2:]
            coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
            coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
            ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)

        if self.use_se:
            ffted = self.se(ffted)

        ffted = self.conv_layer(ffted)  # (batch, c*2, h, w/2+1)
        ffted = self.relu(self.bn(ffted))

        ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
            0, 1, 3, 4, 2).contiguous()  # (batch,c, t, h, w/2+1, 2)
        ffted = torch.complex(ffted[..., 0], ffted[..., 1])

        ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
        output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)

        if self.spatial_scale_factor is not None:
            output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)

        return output
예제 #13
0
def combine_zernike_basis(coeffs, basis, return_phase=False):
    """
    Multiplies the Zernike coefficients and basis functions while preserving
    dimensions

    :param coeffs: torch tensor with coeffs, see propagation_ASM_zernike
    :param basis: the output of compute_zernike_basis, must be same length as coeffs
    :param return_phase:
    :return: A Complex64 tensor that combines coeffs and basis.
    """

    if len(coeffs.shape) < 3:
        coeffs = torch.reshape(coeffs, (coeffs.shape[0], 1, 1))

    # combine zernike basis and coefficients
    zernike = (coeffs * basis).sum(0, keepdim=True)

    # shape to [1, len(coeffs), H, W]
    zernike = zernike.unsqueeze(0)

    # convert to Pytorch Complex tensor
    real, imag = utils.polar_to_rect(torch.ones_like(zernike), zernike)
    return torch.complex(real, imag)
예제 #14
0
    def transform(self, inputs, return_type='complex'):
        """Take input data (audio) to STFT domain.

        Args:
            inputs (tensor): Tensor of floats, with shape (num_batch, num_samples)
            return_type (str, optional): return (mag, phase) when `magphase`,
            return (real, imag) when `realimag` and complex(real, imag) when `complex`.
            Defaults to 'complex'.

        Returns:
            tuple: (mag, phase) when `magphase`, return (real, imag) when
            `realimag`. Defaults to 'complex', each elements with shape 
            [num_batch, num_frequencies, num_frames]
        """
        assert return_type in ['magphase', 'realimag', 'complex']
        if inputs.dim() == 2:
            inputs = th.unsqueeze(inputs, 1)
        self.num_samples = inputs.size(-1)
        if self.pad_center:
            inputs = F.pad(inputs, (self.pad_amount, self.pad_amount),
                           mode='reflect')
        enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop)
        outputs = th.transpose(enframe_inputs, 1, 2)
        outputs = F.linear(outputs, self.fft_k)
        outputs = th.transpose(outputs, 1, 2)
        dim = self.fft_len // 2 + 1
        real = outputs[:, :dim, :]
        imag = outputs[:, dim:, :]
        if return_type == 'realimag':
            return real, imag
        elif return_type == 'complex':
            assert support_clp_op
            return th.complex(real, imag)
        else:
            mags = th.sqrt(real**2 + imag**2)
            phase = th.atan2(imag, real)
            return mags, phase
예제 #15
0
def upper_half_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.UpperHalf]

    # x is the result of projecting ex
    ex = torch.randn(*shape, dtype=torch.complex128)
    ex = geoopt.linalg.batch_linalg.sym(ex)
    x = ex.clone()
    x.imag = geoopt.manifolds.siegel.csym_math.positive_conjugate_projection(x.imag)

    # ev is in the tangent space
    ev = torch.randn(*shape, dtype=torch.complex128) / 10
    ev = geoopt.linalg.batch_linalg.sym(ev)

    # v is the result of projecting ev at x
    real_ev, imag_ev = ev.real, ev.imag
    real_v = x.imag @ real_ev @ x.imag
    imag_v = x.imag @ imag_ev @ x.imag
    v = torch.complex(real_v, imag_v)

    manifold = geoopt.UpperHalf()
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
예제 #16
0
def fft_image(shape, sd=None, decay_power=1):
    """An image paramaterization using 2D Fourier coefficients."""
    sd = sd or 0.01
    # batch, h, w, ch = shape # tf style: [N, H, W, C]
    batch, ch, h, w = shape  # torch style: [N, C, H, W]

    freqs = rfft2d_freqs(h, w)
    init_val_size = (2, batch, ch) + freqs.shape

    init_val = np.random.normal(size=init_val_size,
                                scale=sd).astype(np.float32)
    # spectrum_real_imag_t = tf.Variable(init_val)
    spectrum_real_imag_t = Variable(torch.from_numpy(init_val))

    # spectrum_t = tf.complex(spectrum_real_imag_t[0], spectrum_real_imag_t[1])
    spectrum_t = torch.complex(spectrum_real_imag_t[0],
                               spectrum_real_imag_t[1])

    # Scale the spectrum. First normalize energy, then scale by the square-root
    # of the number of pixels to get a unitary transformation.
    # This allows to use similar leanring rates to pixel-wise optimisation.
    scale = 1.0 / np.maximum(freqs, 1.0 / max(w, h))**decay_power
    scale *= np.sqrt(w * h)
    print(scale.shape, spectrum_t.shape)
    scaled_spectrum_t = torch.from_numpy(scale) * spectrum_t

    # convert complex scaled spectrum to shape (h, w, ch) image tensor
    # needs to transpose because irfft2d returns channels first
    # image_t = tf.transpose(tf.spectral.irfft2d(scaled_spectrum_t), (0, 2, 3, 1))
    image_t = torch.fft.irfft(scaled_spectrum_t)  # shape: [N, C, H, W]

    # in case of odd spatial input dimensions we need to crop
    # image_t = image_t[:batch, :h, :w, :ch] # tf style
    image_t = image_t[:batch, :ch, :h, :w]  # torch style
    image_t = image_t / 4.0  # TODO: is that a magic constant?
    return image_t
예제 #17
0
def vector_to_Hermitian(vec, use_builtin_complex=False):
    """Construct a Hermitian matrix from a vector of N**2 independent
    real-valued elements.

    Args:
        vec (torch.Tensor): (..., N ** 2)
        use_builtin_complex (bool): Whether to use builtin complex support
    Returns:
        mat (torch.complex64/ComplexTensor): (..., N, N)
    """  # noqa: H405, D205, D400
    N = int(np.sqrt(vec.shape[-1]))
    mat = torch.zeros(size=vec.shape[:-1] + (N, N, 2), device=vec.device)

    # real component
    triu = np.triu_indices(N, 0)
    triu2 = np.triu_indices(N, 1)  # above main diagonal
    tril = (triu2[1], triu2[0])  # below main diagonal; for symmetry
    mat[(..., ) + triu +
        (np.zeros(triu[0].shape[0]), )] = vec[..., :triu[0].shape[0]]
    start = triu[0].shape[0]
    mat[(..., ) + tril +
        (np.zeros(tril[0].shape[0]), )] = mat[(..., ) + triu2 +
                                              (np.zeros(triu2[0].shape[0]), )]

    # imaginary component
    mat[(..., ) + triu2 +
        (np.ones(triu2[0].shape[0]), )] = vec[...,
                                              start:start + triu2[0].shape[0]]
    mat[(..., ) + tril +
        (np.ones(tril[0].shape[0]), )] = -mat[(..., ) + triu2 +
                                              (np.ones(triu2[0].shape[0]), )]

    if is_torch_1_9_plus and use_builtin_complex:
        return torch.complex(mat[..., 0], mat[..., 1])
    else:
        return ComplexTensor(mat[..., 0], mat[..., 1])
예제 #18
0
def test_gev_phase_correction():
    mat = ComplexTensor(torch.rand(2, 3, 4), torch.rand(2, 3, 4))
    mat_th = torch.complex(mat.real, mat.imag)
    norm = gev_phase_correction(mat)
    norm_th = gev_phase_correction(mat_th)
    assert np.allclose(norm.numpy(), norm_th.numpy())
예제 #19
0
파일: utils.py 프로젝트: cxdcxd/pykeen
def view_complex(x: torch.FloatTensor) -> torch.Tensor:
    """Convert a PyKEEN complex tensor representation into a torch one."""
    real, imag = split_complex(x=x)
    return torch.complex(real=real, imag=imag)
    def forward(self, phases, skip_lut=False, skip_tm=False):

        # Section 5.1.3. Modeling Phase Nonlinearity
        if self.process_phase is not None and not skip_lut:
            if self.latent_code is not None:
                # support mini-batch
                processed_phase = self.process_phase(phases, self.latent_code.repeat(phases.shape[0], 1, 1, 1))
            else:
                processed_phase = self.process_phase(phases)
        else:
            processed_phase = phases

        # Section 5.1.1. Create Source Amplitude (DC + gaussians)
        if self.source_amp is not None:
            source_amp = self.source_amp(processed_phase)
        else:
            source_amp = torch.ones_like(processed_phase)

        # convert phase to real and imaginary
        real, imag = utils.polar_to_rect(source_amp, processed_phase)
        processed_complex = torch.complex(real, imag)

        # Section 5.1.2. precompute the zernike basis only once
        if self.zernike is None and self.coeffs is not None:
            self.zernike = compute_zernike_basis(self.coeffs.size()[0],
                                                 phases.size()[-2:], wo_piston=True)
            self.zernike = self.zernike.to(self.dev).detach()
            self.zernike.requires_grad = False

        if self.zernike_fourier is None and self.coeffs_fourier is not None:
            self.zernike_fourier = compute_zernike_basis(self.coeffs_fourier.size()[0],
                                                         [i * 2 for i in phases.size()[-2:]],
                                                         wo_piston=True)
            self.zernike_fourier = self.zernike_fourier.to(self.dev).detach()
            self.zernike_fourier.requires_grad = False

        if not self.training and self.zernike_eval is None and self.coeffs is not None:
            # sum the phases
            self.zernike_eval = combine_zernike_basis(self.coeffs, self.zernike)
            self.zernike_eval = self.zernike_eval.to(self.coeffs.device).detach()
            self.zernike_eval.requires_grad = False

        if not self.training and self.zernike_eval_fourier is None and self.coeffs_fourier is not None:
            # sum the phases
            self.zernike_eval_fourier = combine_zernike_basis(self.coeffs_fourier, self.zernike_fourier)
            self.zernike_eval_fourier = utils.ifftshift(self.zernike_eval_fourier)
            self.zernike_eval_fourier = self.zernike_eval_fourier.to(self.coeffs_fourier.device).detach()
            self.zernike_eval_fourier.requires_grad = False

        # precompute the kernel only once
        if self.learn_dist and self.training:
            self.precompute_H_exp(processed_complex)
        else:
            self.precompute_H(processed_complex)

        # Section 5.1.2. apply zernike and propagate
        if self.training:
            if self.coeffs_fourier is None:
                output_complex = self.prop_zernike(processed_complex,
                                                   self.feature_size,
                                                   self.wavelength,
                                                   self.distance,
                                                   coeffs=self.coeffs,
                                                   zernike=self.zernike,
                                                   precomped_H=self.precomped_H,
                                                   precomped_H_exp=self.precomped_H_exp,
                                                   linear_conv=self.linear_conv)
            else:
                output_complex = self.prop_zernike_fourier(processed_complex,
                                                           self.feature_size,
                                                           self.wavelength,
                                                           self.distance,
                                                           coeffs=self.coeffs_fourier,
                                                           zernike=self.zernike_fourier,
                                                           precomped_H=self.precomped_H,
                                                           precomped_H_exp=self.precomped_H_exp,
                                                           linear_conv=self.linear_conv)

        else:
            if self.coeffs is not None:
                # in primal domain
                processed_zernike = self.zernike_eval * processed_complex
            else:
                processed_zernike = processed_complex

            if self.coeffs_fourier is not None:
                # in fourier domain
                precomped_H = self.zernike_eval_fourier * self.precomped_H
            else:
                precomped_H = self.precomped_H

            output_complex = self.prop(processed_zernike,
                                       self.feature_size,
                                       self.wavelength,
                                       self.distance,
                                       precomped_H=precomped_H,
                                       linear_conv=self.linear_conv)

        # Section 5.1.1. Content-independent field at target plane
        if self.target_constant_amp is not None:
            real, imag = utils.polar_to_rect(self.target_constant_amp, self.target_constant_phase)
            target_field = torch.complex(real, imag)
            output_complex = output_complex + target_field

        # Section 5.1.4. Content-dependent Undiffracted light
        if self.content_dependent_field is not None:
            if self.latent_coords is not None:
                cdf = self.content_dependent_field(phases, self.latent_coords.repeat(phases.shape[0], 1, 1, 1))
            else:
                cdf = self.content_dependent_field(phases)
            real, imag = utils.polar_to_rect(cdf[..., 0], cdf[..., 1])
            cdf_rect = torch.complex(real, imag)
            output_complex = output_complex + cdf_rect

        amp = output_complex.abs()
        _, phase = utils.rect_to_polar(output_complex.real, output_complex.imag)

        if self.blur is not None:
            amp = self.blur(amp)

        real, imag = utils.polar_to_rect(amp, phase)

        return torch.complex(real, imag)
예제 #21
0
    def __call__(self, inputs):
        """
        Args:
            inputs: shape: [..., T], T is #samples
            num_samples, list or tensor of #samples

        >>> mixture = torch.rand((2, 6, 203))
        >>> torch_stft = STFT(512, 20, window_length=40,\
                              complex_representation='concat')
        >>> torch_stft_out = torch_stft(mixture)
        >>> torch_stft_out.shape
        torch.Size([2, 6, 12, 514])
        >>> from paderbox.transform import stft
        >>> stft_out = stft(mixture.numpy(), 512, 20, window_length=40)
        >>> stft_signal = np.concatenate(\
                [np.real(stft_out), np.imag(stft_out)], axis=-1)
        >>> np.testing.assert_allclose(torch_stft_out, stft_signal, atol=1e-5)
        >>> mixture = torch.rand((2, 6, 203))
        >>> torch_stft = STFT(512, 20, window_length=40,\
                              complex_representation='complex')
        >>> torch_stft_out = torch_stft(mixture)
        >>> torch_stft_out.shape
        torch.Size([2, 6, 12, 257])
        >>> from paderbox.transform import stft
        >>> stft_out = stft(mixture.numpy(), 512, 20, window_length=40)
        >>> np.testing.assert_allclose(torch_stft_out.numpy(), stft_out, atol=1e-5)

        """
        org_shape = inputs.shape
        stride = self.shift
        length = self.window_length
        x = inputs.view((-1, org_shape[-1]))
        # Pad with zeros to have enough samples for the window function to fade.
        assert self.fading in [None, True, False, 'full', 'half'], self.fading
        if self.fading not in [False, None]:
            if self.fading == 'half':
                pad_width= (
                    (self.window_length - self.shift) // 2,
                    ceil((self.window_length - self.shift) / 2)
                )
            else:
                pad_width = self.window_length - self.shift
                pad_width = (pad_width, pad_width)
            x = F.pad(x, pad_width, mode='constant')

        if self.pad:
            if x.shape[-1] < length:
                pad_size = length - x.shape[-1]
                x = F.pad(x, (0, pad_size))
            elif stride != 1 and (x.shape[-1] + stride - length) % stride != 0:
                pad_size = stride - ((x.shape[-1] + stride - length) % stride)
                x = F.pad(x, (0, pad_size))

        x = torch.unsqueeze(x, 1) # [..., 1, T]
        weights = self.stft_kernel.to(x)
        encoded = F.conv1d(x, weight=weights, stride=stride)

        encoded = encoded.view(*org_shape[:-1], *encoded.shape[-2:])
        encoded = rearrange(encoded, '... feat frames -> ... frames feat')
        encoded = torch.chunk(encoded, 2, dim=-1)
        if self.complex_representation == 'stacked':
            encoded = torch.stack(encoded, dim=-1)
        elif self.complex_representation == 'concat':
            encoded = torch.cat(encoded, dim=-1)
        elif self.complex_representation == 'complex':
            encoded = torch.complex(*encoded)
        else:
            raise ValueError(
                f'Please choose one of the predefined output_types'
                f'{self.possible_out_types} not {self.complex_representation}'
            )
        return encoded
예제 #22
0
# torch.quantize_per_tensor
reveal_type(
    torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10,
                              torch.quint8))  # E: torch.tensor.Tensor

# torch.quantize_per_channel
x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
quant = torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]),
                                   torch.tensor([10, 0]), 0, torch.quint8)
reveal_type(x)  # E: torch.tensor.Tensor

# torch.dequantize
reveal_type(torch.dequantize(x))  # E: torch.tensor.Tensor

# torch.complex
real = torch.tensor([1, 2], dtype=torch.float32)
imag = torch.tensor([3, 4], dtype=torch.float32)
reveal_type(torch.complex(real, imag))  # E: torch.tensor.Tensor

# torch.polar
abs = torch.tensor([1, 2], dtype=torch.float64)
pi = torch.acos(torch.zeros(1)).item() * 2
angle = torch.tensor([pi / 2, 5 * pi / 4], dtype=torch.float64)
reveal_type(torch.polar(abs, angle))  # E: torch.tensor.Tensor

# torch.heaviside
inp = torch.tensor([-1.5, 0, 2.0])
values = torch.tensor([0.5])
reveal_type(torch.heaviside(inp, values))  # E: torch.tensor.Tensor
예제 #23
0
 def forward(self, input):
     return torch.complex(self.activation(input.real),
                          self.activation(input.imag))
예제 #24
0
# torch.quantize_per_tensor
reveal_type(
    torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10,
                              torch.quint8))  # E: {Tensor}

# torch.quantize_per_channel
x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
quant = torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]),
                                   torch.tensor([10, 0]), 0, torch.quint8)
reveal_type(x)  # E: {Tensor}

# torch.dequantize
reveal_type(torch.dequantize(x))  # E: {Tensor}

# torch.complex
real = torch.tensor([1, 2], dtype=torch.float32)
imag = torch.tensor([3, 4], dtype=torch.float32)
reveal_type(torch.complex(real, imag))  # E: {Tensor}

# torch.polar
abs = torch.tensor([1, 2], dtype=torch.float64)
pi = torch.acos(torch.zeros(1)).item() * 2
angle = torch.tensor([pi / 2, 5 * pi / 4], dtype=torch.float64)
reveal_type(torch.polar(abs, angle))  # E: {Tensor}

# torch.heaviside
inp = torch.tensor([-1.5, 0, 2.0])
values = torch.tensor([0.5])
reveal_type(torch.heaviside(inp, values))  # E: {Tensor}
예제 #25
0
 def polar_to_rect(self, amp, phase):
     """from neural holo"""
     self.u = torch.complex(amp * torch.cos(phase),
                            amp * torch.sin(phase)).broadcast_to(self.shape)
예제 #26
0
    def forward(self, target_amp):
        # compute some initial phase, convert to real+imag representation
        if self.initial_phase is not None:
            init_phase = self.initial_phase(target_amp)
            real, imag = utils.polar_to_rect(target_amp, init_phase)
            target_complex = torch.complex(real, imag)
        else:
            init_phase = torch.zeros_like(target_amp)
            # no need to convert, zero phase implies amplitude = real part
            target_complex = torch.complex(target_amp, init_phase)

        # subtract the additional target field
        if self.target_field is not None:
            target_complex_diff = target_complex - self.target_field
        else:
            target_complex_diff = target_complex

        # precompute the propagation kernel only once
        if self.precomped_H is None:
            self.precomped_H = self.prop(target_complex_diff,
                                         self.feature_size,
                                         self.wavelength,
                                         self.distance,
                                         return_H=True,
                                         linear_conv=self.linear_conv)
            self.precomped_H = self.precomped_H.to(self.dev).detach()
            self.precomped_H.requires_grad = False

        if self.precomped_H_zernike is None:
            if self.zernike is None and self.zernike_coeffs is not None:
                self.zernike_basis = compute_zernike_basis(self.zernike_coeffs.size()[0],
                                                           [i * 2 for i in target_amp.size()[-2:]], wo_piston=True)
                self.zernike_basis = self.zernike_basis.to(self.dev).detach()
                self.zernike = combine_zernike_basis(self.zernike_coeffs, self.zernike_basis)
                self.zernike = utils.ifftshift(self.zernike)
                self.zernike = self.zernike.to(self.dev).detach()
                self.zernike.requires_grad = False
                self.precomped_H_zernike = self.zernike * self.precomped_H
                self.precomped_H_zernike = self.precomped_H_zernike.to(self.dev).detach()
                self.precomped_H_zernike.requires_grad = False
            else:
                self.precomped_H_zernike = self.precomped_H

        # precompute the source amplitude, only once
        if self.source_amp is None and self.source_amplitude is not None:
            self.source_amp = self.source_amplitude(target_amp)
            self.source_amp = self.source_amp.to(self.dev).detach()
            self.source_amp.requires_grad = False

        # implement the basic propagation to the SLM plane
        slm_naive = self.prop(target_complex_diff, self.feature_size,
                              self.wavelength, self.distance,
                              precomped_H=self.precomped_H_zernike,
                              linear_conv=self.linear_conv)

        # switch to amplitude+phase and apply source amplitude adjustment
        amp, ang = utils.rect_to_polar(slm_naive.real, slm_naive.imag)
        # amp, ang = slm_naive.abs(), slm_naive.angle()  # PyTorch 1.7.0 Complex tensor doesn't support
                                                         # the gradient of angle() currently.

        if self.source_amp is not None and self.manual_aberr_corr:
            amp = amp / self.source_amp

        if self.final_phase_only is None:
            return amp, double_phase(amp, ang, three_pi=False)
        else:
            # note the change to usual complex number stacking!
            # We're making this the channel dim via cat instead of stack
            if (self.zernike is None and self.source_amp is None
                    or self.manual_aberr_corr):
                if self.latent_codes is not None:
                    slm_amp_phase = torch.cat((amp, ang, self.latent_codes.repeat(amp.shape[0], 1, 1, 1)), -3)
                else:
                    slm_amp_phase = torch.cat((amp, ang), -3)
            elif self.zernike is None:
                slm_amp_phase = torch.cat((amp, ang, self.source_amp), -3)
            elif self.source_amp is None:
                slm_amp_phase = torch.cat((amp, ang, self.zernike), -3)
            else:
                slm_amp_phase = torch.cat((amp, ang, self.zernike,
                                           self.source_amp), -3)
            return amp, self.final_phase_only(slm_amp_phase)
예제 #27
0
def fourier_interpolate_torch(ain,
                              shapeout,
                              norm="conserve_val",
                              N=None,
                              qspace_in=False,
                              qspace_out=False):
    """
    Fourier interpolation of array ain to shape shapeout.

    If shapeout is smaller than ain.shape then Fourier downsampling is
    performed

    Parameters
    ----------
    ain : (...,Nn,..,Ny,Nx) torch.tensor
        Input array
    shapeout : (n,) array_like
        Shape of output array
    norm : str, optional  {'conserve_val','conserve_norm','conserve_L1'}
        Normalization of output. If 'conserve_val' then array values are preserved
        if 'conserve_norm' L2 norm is conserved under interpolation and if
        'conserve_L1' L1 norm is conserved under interpolation
    N : int, optional
        Number of (trailing) dimensions to Fourier interpolate. By default take
        the length of shapeout
    qspace_in : bool, optional
        If True expect a Fourier space input, otherwise (default) expect a
        real space input
    qspace_out : bool, optional
        If True return a Fourier space output, otherwise (default) return in
        real space
    """
    dtype = ain.dtype

    if N is None:
        N = len(shapeout)

    inputComplex = iscomplex(ain)

    # Get input dimensions
    shapein = ain.size()[-N:]

    # axes to Fourier transform
    axes = np.arange(-N, 0).tolist()

    # Now transfer over Fourier coefficients from input to output array
    if inputComplex:
        ain_ = ain
    else:
        ain_ = torch.complex(
            ain, torch.zeros(ain.shape, dtype=ain.dtype, device=ain.device))

    if not qspace_in:
        ain_ = torch.fft.fftn(ain_, dim=axes)

    aout = torch.fft.ifftshift(crop_torch(torch.fft.fftshift(ain_, dim=axes),
                                          shapeout),
                               dim=axes)

    # Fourier transform result with appropriate normalization
    if norm == "conserve_val":
        aout *= np.prod(shapeout) / np.prod(np.shape(ain)[-N:])
    elif norm == "conserve_norm":
        aout *= np.sqrt(np.prod(shapeout) / np.prod(np.shape(ain)[-N:]))

    if not qspace_out:
        aout = torch.fft.ifftn(aout, dim=axes)

    # Return correct array data type
    if inputComplex:
        return aout
    return torch.real(aout)
    # w = fourier.ifft_shift(w)
    # x = torch.fft.ifftn(w, dim=(-2, -1)).real + 0.5 

    # torchvision.utils.save_image(x, "logs/new_fft.png")

    basis_list = list()
    begin = int(-np.floor(grid_size / 2))
    end = int(np.ceil(grid_size / 2))
    for i_h in range(begin, end):
        for i_w in range(begin, end):
            height, width = image_size, image_size
            h_center_index = height // 2
            w_center_index = width // 2

            w = torch.zeros((3, height, width), dtype=torch.cfloat)
            w[:, h_center_index + i_h, w_center_index + i_h] = torch.complex(torch.tensor([1.0]), torch.tensor([1.0]))
            w[:, h_center_index - i_h, w_center_index - i_w] = torch.complex(torch.tensor([1.0]), torch.tensor([1.0]))

            # w = fourier.ifft_shift(w)

            w_np = w.real.numpy()
            w_np = np.fft.ifftshift(w_np)
            # fourier_basis = torch.fft.ifftn(w, dim=(-2, -1))
            fourier_basis = torch.from_numpy(np.fft.ifft2(w_np).real).float()

            # fourier_basis = torch.fft.ifftn(w, dim=(-2, -1))
            # fourier_basis = (fourier_basis.real + fourier_basis.imag) / 2.0
            fourier_basis[0, :, :] /= fourier_basis[0].norm()
            fourier_basis[1, :, :] /= fourier_basis[1].norm()
            fourier_basis[2, :, :] /= fourier_basis[2].norm()
예제 #29
0
    def __init__(self):
        super().__init__()

        self.register_buffer("complex_buffer",
                             torch.complex(torch.rand(10), torch.rand(10)),
                             False)
예제 #30
0
                                wavelength=wavelength,
                                blur=blur).to(device)
    model_prop.load_state_dict(torch.load(opt.model_path, map_location=device))

    # Here, we crop model parameters to match the Holonet resolution, which is slightly different from 1080p.
    # parameters for CITL model
    zernike_coeffs = model_prop.coeffs_fourier
    source_amplitude = model_prop.source_amp
    latent_codes = model_prop.latent_code
    latent_codes = utils.crop_image(latent_codes, target_shape=image_res, pytorch=True, stacked_complex=False)

    # content independent target field (See Section 5.1.1.)
    u_t_amp = utils.crop_image(model_prop.target_constant_amp, target_shape=image_res, stacked_complex=False)
    u_t_phase = utils.crop_image(model_prop.target_constant_phase, target_shape=image_res, stacked_complex=False)
    real, imag = utils.polar_to_rect(u_t_amp, u_t_phase)
    u_t = torch.complex(real, imag)

    # match the shape of forward model parameters (1072, 1920)

    # If you make it nn.Parameter, the Holonet will include these parameters in the torch.save files
    model_prop.latent_code = nn.Parameter(latent_codes.detach(), requires_grad=False)
    model_prop.target_constant_amp = nn.Parameter(u_t_amp.detach(), requires_grad=False)
    model_prop.target_constant_phase = nn.Parameter(u_t_phase.detach(), requires_grad=False)

    # But as these parameters are already in the CITL-calibrated models,
    # you can load these from those models avoiding duplicated saves.

model_prop.eval()  # ensure freezing propagation model

# create new phase generator object
if opt.purely_unet: