Пример #1
0
def angle(input_, deg=False):
    """Wrapper of `torch.angle`.

    Parameters
    ----------
    input_ : DTensor
        Input dense tensor.
    deg : bool, optional
        If true, result is in degree format. Otherwise, return in radians. By
        default False
    """
    if deg:
        ret = torch.angle(input_) * 180 / math.pi
    else:
        ret = torch.angle(input_)
    return ret
Пример #2
0
def phase_comp(psi_comp, uwrap=False, dens=None):
    """Compute the phase (angle) of a single complex wavefunction component.

    Parameters
    ----------
    psi_comp : NumPy :obj:`array` or PyTorch :obj:`Tensor`
        A single wavefunction component.

    Returns
    -------
    angle : NumPy :obj:`array` or PyTorch :obj:`Tensor`
        The phase (angle) of the component's wavefunction.

    """
    if isinstance(psi_comp, np.ndarray):
        ang = np.angle(psi_comp)
        if uwrap:
            ang = rest.unwrap_phase(ang)
    elif isinstance(psi_comp, torch.Tensor):
        ang = torch.angle(psi_comp)
        if uwrap:
            raise NotImplementedError("Unwrapping the complex phase is not "
                                      "implemented for PyTorch tensors.")
    if dens is not None:
        ang[dens < (dens.max() * 1e-6)] = 0
    return ang
Пример #3
0
def stft_to_phase_magn(
        complex_values: th.Tensor,
        nb_vec: int = constant.N_VEC) -> Tuple[th.Tensor, th.Tensor]:
    magn = th.abs(complex_values)
    phase = th.angle(complex_values)

    magn = bark_magn_scale(magn, unscale=False)

    phase = unwrap(phase)

    phase = phase[:, 1:] - phase[:, :-1]
    magn = magn[:, 1:]

    max_magn = magn.max()
    min_magn = magn.min()
    max_phase = phase.max()
    min_phase = phase.min()

    magn = (magn - min_magn) / (max_magn - min_magn)
    phase = (phase - min_phase) / (max_phase - min_phase)

    magn, phase = magn * 2. - 1., phase * 2. - 1.

    magn = magn[:, magn.size()[1] % nb_vec:]
    phase = phase[:, phase.size()[1] % nb_vec:]
    magn = th.stack(magn.split(nb_vec, dim=1), dim=0)
    phase = th.stack(phase.split(nb_vec, dim=1), dim=0)

    return magn, phase
Пример #4
0
    def forward(self, input, angle):
        
        # padding
        mag, ph, real, image= self.stft.transform(input.reshape(-1, input.size()[-1]))
        pad = Variable(torch.zeros(mag.size()[0],mag.size()[1], 1)).type(input.type())
        mag = torch.cat([mag, pad], -1)
        ph = torch.cat([ph, pad], -1)
        output, rest = self.pad_signal(input)
        enc_output = self.encoder(output[:, :1])  # B, N, L
        mag = mag.view(enc_output.size(0), self.n_mic, -1, enc_output.size(-1))
        ph = ph.view(enc_output.size(0), self.n_mic, -1, enc_output.size(-1))
        LPS = 10 * torch.log10(mag ** 2 + 10e-20)

        complex = (mag * torch.exp(ph * 1j))
        IPD_list = []
        for m in self.pairs:
            com_u1 = complex[:, m[0]]
            com_u2 = complex[:, m[1]]
            IPD = torch.angle(com_u1 * torch.conj(com_u2))
            IPD /= (self.frequency_vector + 1.0)[:, None]
            IPD = IPD % torch.pi
            IPD = IPD.unsqueeze(dim=1)
            IPD_list.append(IPD)
        IPD = torch.cat(IPD_list, dim=1)
        steering_vector = self.__get_steering_vector(angle, self.pairs)
        steering_vector = steering_vector.unsqueeze(dim=-1)
        AF = steering_vector * IPD
        AF = AF/AF.sum(dim=1, keepdims=True).real
        w = self.w.unsqueeze(dim=0).expand(AF.size()[0], -1, -1, -1)
        dpr = torch.zeros((AF.size(0), self.n_grid, AF.size(-2), AF.size(-1)), dtype=torch.complex128)
        print(w.size())
        print(complex.size())
        exit()
        for i in range(36):
            for j in range(602):
                for h in range(97):
                    dpr[:, i, h, j] = (w[:, :, i, h] * complex[:, :, h, j]).sum(dim=1)
        dpr = (dpr * torch.conj(dpr))/ torch.sum(dpr * torch.conj(dpr), dim=1, keepdim=True)
        print(dpr.size())
        print(AF.size())

        feature_list = [enc_output.unsqueeze(dim=1), AF, dpr, torch.cos(IPD)]
        fusion = torch.cat(feature_list, dim=1).float()

        batch_size = output.size(0)
        fusion = fusion.view(batch_size, -1, fusion.size()[-1])
        
        # waveform encoder


        masks = torch.sigmoid(self.TCN(fusion)).view(batch_size, self.num_spk, self.enc_dim, -1)  # B, C, N, L
        masked_output = enc_output.unsqueeze(1) * masks  # B, C, N, L
        
        # waveform decoder
        output = self.decoder(masked_output.view(batch_size*self.num_spk, self.enc_dim, -1))  # B*C, 1, L
        output = output[:,:,self.stride:-(rest+self.stride)].contiguous()  # B*C, 1, L
        output = output.view(batch_size, self.num_spk, -1)  # B, C, T

        return output
Пример #5
0
def entropy_loss(ent_out, ent_gt):
    if ent_out.is_complex:
        loss = torch.square(ent_gt - ent_out)
        return torch.abs(loss) + torch.angle(loss)
        # mag = torch.square(torch.abs(ent_gt) - torch.abs(ent_out))
        # ph = torch.square(torch.angle(ent_gt) - torch.angle(ent_out))
        # return mag+ph
    else:
        return torch.square(ent_gt - ent_out)
    def forward(self, x):
        # encoding path
        temp = torch.stft(x, n_fft=128, return_complex=True)
        x_abs = torch.abs(temp)
        x_ang = torch.angle(temp)
        x_abs = torch.unsqueeze(x_abs, dim=1)

        # pdb.set_trace() #Get the right size!
        # x_abs = x_abs.reshape( x_abs.shape[0],1, 65, 126 )

        x1 = self.Conv1(x_abs)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        # pdb.set_trace()
        d5 = torch.cat((x4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)

        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)
        d1 = torch.squeeze(d1)

        S_complex = d1 * torch.exp(1j * x_ang)
        synthesis = torch.istft(
            S_complex,
            n_fft=128)  # The istft is not the same as that of librosa!

        # out = self.Fine_tune( synthesis )
        out = synthesis

        return d1, out
Пример #7
0
    def test_angle(self):
        a = ht.array([1.0, 1.0j, 1 + 1j, -2 + 2j, 3 - 3j])
        angle = ht.angle(a)
        res = torch.angle(a.larray)

        self.assertIs(angle.device, self.device)
        self.assertIs(angle.dtype, ht.float)
        self.assertEqual(angle.shape, (5, ))
        self.assertTrue(torch.equal(angle.larray, res))

        a = ht.array([1.0, 1.0j, 1 + 1j, -2 + 2j, 3 - 3j], split=0)
        angle = ht.angle(a)
        res = torch.angle(a.larray)

        self.assertIs(angle.device, self.device)
        self.assertIs(angle.dtype, ht.float)
        self.assertEqual(angle.shape, (5, ))
        self.assertTrue(torch.equal(angle.larray, res))

        a = ht.array([[1.0, 1.0j], [1 + 1j, -2 + 2j], [3 - 3j, -4 - 4j]],
                     split=1)
        angle = ht.angle(a, deg=True)
        res = ht.array(
            [[0.0, 90.0], [45.0, 135.0], [-45.0, -135.0]],
            dtype=ht.float32,
            device=self.device,
            split=1,
        )

        self.assertIs(angle.device, self.device)
        self.assertIs(angle.dtype, ht.float32)
        self.assertEqual(angle.shape, (3, 2))
        self.assertTrue(ht.equal(angle, res))

        # Not complex
        a = ht.ones((4, 4), split=1)
        angle = ht.angle(a)
        res = ht.zeros((4, 4), split=1)

        self.assertIs(angle.device, self.device)
        self.assertIs(angle.dtype, ht.float32)
        self.assertEqual(angle.shape, (4, 4))
        self.assertTrue(ht.equal(angle, res))
Пример #8
0
def sumofsq(image_in, keep_dims=False, axis=-1, name="sumofsq", type="mag"):
    """Compute square root of sum of squares."""
    if type == "mag":
        image_out = torch.square(torch.abs(image_in))
    else:
        image_out = torch.square(torch.angle(image_in))
    image_out = torch.sum(image_out, keep_dims=keep_dims, axis=axis)
    image_out = torch.sqrt(image_out)

    return image_out
Пример #9
0
def gcc_features(complex_specs: torch.Tensor, n_mels: int) -> torch.Tensor:
    if not torch.is_complex(complex_specs):
        complex_specs = torch.view_as_complex(complex_specs)

    # based on the codes from DCASE2020 SELDnet cls_feature_class.py
    # complex_specs: [chan, freq, time]
    n_chan = complex_specs.size(0)
    gcc_chan = n_chan * (n_chan - 1) // 2

    gcc_feat = []
    for m in range(n_chan):
        for n in range(m + 1, n_chan):
            R = torch.conj(complex_specs[m]) * complex_specs[n]
            cc = torch.fft.irfft(torch.exp(1.j * torch.angle(R)), dim=0)
            cc = torch.cat([cc[-n_mels // 2:], cc[:(n_mels + 1) // 2]], axis=0)
            gcc_feat.append(cc)

    return torch.stack(gcc_feat, axis=0)
Пример #10
0
def get_phase_stft_magnitude(
    raw_data: torch.Tensor,
    sampling_rate_in_hz: int,
    window_length_in_s: float,
    window_shift_in_s: float,
    num_fft_points: int,
    window_type: str,
) -> torch.Tensor:
    stft = _get_stft(raw_data,
                     sampling_rate_in_hz,
                     window_length_in_s,
                     window_shift_in_s,
                     num_fft_points,
                     window_type=window_type)
    abs_stft = torch.abs(stft)
    phase = torch.angle(stft)
    stft_phase = torch.cat([phase, abs_stft], dim=1)
    return torch.transpose(stft_phase, 0, 1)
Пример #11
0
    def stft(self, x):
        """ Perform STFT.
        Args:
            x (Tensor): Input signal tensor (B, T).

        Returns:
            Tensor: x_mag, x_phs
                Magnitude and phase spectra (B, fft_size // 2 + 1, frames).
        """
        x_stft = torch.stft(x,
                            self.fft_size,
                            self.hop_size,
                            self.win_length,
                            self.window,
                            return_complex=True)
        x_mag = torch.abs(x_stft)
        x_phs = torch.angle(x_stft)
        return x_mag, x_phs
Пример #12
0
    def update_variables(self):
        self.Psi_Qp[:] = 0
        self.Psi_Qp_left_sb[:] = 0
        self.Psi_Qp_right_sb[:] = 0

        eps = 1e-3
        single_sideband_reconstruction(
            self.G,
            self.Qx1d,
            self.Qy1d,
            self.Kx,
            self.Ky,
            self.C,
            np.deg2rad(self.rotation_deg),
            self.meta.alpha_rad,
            self.Psi_Qp,
            self.Psi_Qp_left_sb,
            self.Psi_Qp_right_sb,
            eps,
            self.meta.wavelength,
        )

        self.Psi_Rp[:] = fft.ifft2(self.Psi_Qp, norm="ortho")
        self.Psi_Rp_left_sb[:] = fft.ifft2(self.Psi_Qp_left_sb, norm="ortho")
        self.Psi_Rp_right_sb[:] = fft.ifft2(self.Psi_Qp_right_sb, norm="ortho")

        self.Gamma = disk_overlap_function(self.Qx_max1d, self.Qy_max1d,
                                           self.Kx, self.Ky, self.C,
                                           np.deg2rad(self.rotation_deg),
                                           self.meta.alpha_rad,
                                           self.meta.wavelength)

        Psi = self.probe_gen(
            th.tensor(self.C.get()).cuda(),
            th.tensor(self.A).cuda())
        self.phases = th.angle(
            th.fft.fftshift(
                self.probe_gen(
                    th.tensor(self.C.get()).cuda(),
                    th.tensor(self.A).cuda())))
        self.Psi_shifted = th.fft.fftshift(Psi)
        self.psi = th.fft.fftshift(th.fft.ifft2(Psi))
Пример #13
0
def calculate_phase(field, deg=False):
    """ 
    Definition to calculate phase of a single or multiple given electric field(s).

    Parameters
    ----------
    field        : torch.cfloat
                   Electric fields or an electric field.
    deg          : bool
                   If set True, the angles will be returned in degrees.

    Returns
    ----------
    phase        : torch.float
                   Phase or phases of electric field(s) in radians.
    """
    phase = torch.angle(field)
    if deg == True:
        phase *= 180. / np.pi
    return phase
Пример #14
0
def phase_harmonics(z, k):
    """
    Compute the phase harmonics of the input tensor.

    Parameters
    ----------
    z : tensor
        Input.
    k : tensor
        Exponents.

    Returns
    -------
    result : tensor
        Output.

    """
    indices_k_0 = torch.where(k == 0)[0]
    indices_other_k = torch.where(k >= 2)[0]

    result = z.clone()
    del z

    # k == 0
    result[..., indices_k_0, :, :] = torch.abs(
        torch.index_select(result, -3, indices_k_0)).to(result.dtype)

    # k == 1 is left unchanged

    # k >= 2
    other_k = k[indices_other_k].unsqueeze(-1).unsqueeze(-1)
    z_other_k = torch.index_select(result, -3, indices_other_k)
    r = torch.abs(z_other_k)
    theta = torch.angle(z_other_k)
    result[..., indices_other_k, :, :] = r * (torch.cos(other_k * theta) +
                                              1j * torch.sin(other_k * theta))

    return result
Пример #15
0
def spectralResidueSaliency(image):
    """
    this function is used to calculate the visual saliency map for the given
    image using the spectral residue method proposed by Xiaodi Hou and Liqing
    Zhang. For more details about this method, you can refer to the paper:
    Saliency detection: a spectral residual approach.
    there are some parameters needed to be adjusted
    """
    scale = 0.25  # constant
    aveKernelSize = 3  # constant
    gauSigma = 3.8  # constant
    gauSize = 9  # constant

    # correction of built-in round function which
    # "for values exactly halfway between rounded decimal values, rounds to the nearest even value"
    # as opposite to matlab which always round it up
    def _round(a):
        return int(torch.rint(torch.nextafter(a, a + 1)))

    inImg = cv2.resize(
        image,
        (_round(scale * image.shape[1]), _round(scale * image.shape[0])),
        interpolation=cv2.INTER_CUBIC)

    myFFT = fft2(inImg)
    myLogAmplitude = torch.log(torch.abs(myFFT))
    myPhase = torch.angle(myFFT)
    mySpectralResidual = myLogAmplitude - cv2.boxFilter(
        myLogAmplitude, -1,
        (aveKernelSize, aveKernelSize), cv2.BORDER_REPLICATE)
    saliencyMap = torch.abs(ifft2(torch.exp(mySpectralResidual +
                                            1j * myPhase)))**2

    blurred = cv2.GaussianBlur(saliencyMap, (gauSize, gauSize), gauSigma,
                               gauSigma)
    saliencyMap = torch.nn.functional.normalize(blurred)
    return cv2.resize(saliencyMap, (image.shape[1], image.shape[0]))
Пример #16
0
def cartesian_to_polar(input):
    return torch.stack((torch.abs(input), torch.angle(input)), dim=3)
Пример #17
0
def GW_loss_prep(temp_index,
                 data,
                 y_pred,
                 temp_mean,
                 temp_sd,
                 gw_mean,
                 gw_std,
                 num_task,
                 type='fft'):
    # assumes that axis 0 of data and y_pred are the reaches and axis 1 are daily values
    # assumes the first two columns of data are the observed flow and temperature, and the remaining
    # ones (extracted here) are the data for gw analysis

    assert type == 'fft', "the groundwater loss calculation method must be fft"

    y_true = data[:, :, num_task:]
    y_true_temp = data[:, :, int(temp_index):(int(temp_index) + 1)]

    y_pred_temp = y_pred[:, :, int(temp_index):(
        int(temp_index) + 1)]  # extract just the predicted temperature
    # unscale the predicted temps prior to calculating the amplitude and phase
    y_pred_temp = y_pred_temp * temp_sd + temp_mean
    y_true_temp = y_true_temp * temp_sd + temp_mean

    #set temps < 1 to 1
    y_pred_temp[y_pred_temp < 1] = 1
    y_true_temp[y_true_temp < 1] = 1

    Ar_obs = y_true[:, 0, 0]
    delPhi_obs = y_true[:, 0, 1]
    Tmean_obs = y_true[:, 0, 2]
    if type == 'fft':
        y_pred_temp = torch.squeeze(y_pred_temp)
        y_pred_mean = torch.mean(y_pred_temp, 1, keepdims=True)
        temp_demean = y_pred_temp - y_pred_mean
        fft_torch = torch.fft.rfft(temp_demean)
        Phiw = torch.angle(fft_torch)
        phiIndex = torch.argmax(torch.abs(fft_torch), 1)
        Phiw_out = Phiw[:, 1]

        Aw = torch.max(torch.abs(fft_torch), 1).values / fft_torch.shape[
            1]  # tf.shape(fft_tf, out_type=tf.dtypes.float32)[1]

        #get the air signal properties
        y_true_air = y_true[:, :, -1]
        y_true_air_mean = torch.mean(y_true_air, 1, keepdims=True)
        air_demean = y_true_air - y_true_air_mean
        fft_torch_air = torch.fft.rfft(air_demean)
        Phia = torch.angle(fft_torch_air)

        phiIndex_air = torch.argmax(torch.abs(fft_torch_air), 1)
        Phia_out = Phia[:, 1]

        Aa = torch.max(torch.abs(fft_torch_air), 1).values / fft_torch.shape[
            1]  # tf.shape(fft_tf_air, out_type=tf.dtypes.float32)[1]

        # calculate and scale predicted values
        # delPhi_pred = the difference in phase between the water temp and air temp sinusoids, in days
        delPhi_pred = (Phia_out - Phiw_out)
        delPhi_pred = (delPhi_pred * 365 / (2 * m.pi) - gw_mean[1]) / gw_std[1]

        # Ar_pred = the ratio of the water temp and air temp amplitudes
        Ar_pred = (Aw / Aa - gw_mean[0]) / gw_std[0]

    elif type == "linalg":
        x_lm = y_true[:, :, -3:-1]  #extract the sin(wt) and cos(wt)

        #a tensor of the sin(wt) and cos(wt) for each reach x day, the 1's are for the intercept of the linear regression
        # T(t) = T_mean + a*sin(wt) + b*cos(wt)
        # Johnson, Z.C., Johnson, B.G., Briggs, M.A., Snyder, C.D., Hitt, N.P., and Devine, W.D., 2021, Heed the data gap: Guidelines for
        #using incomplete datasets in annual stream temperature analyses: Ecological Indicators, v. 122, p. 107229,
        #http://www.sciencedirect.com/science/article/pii/S1470160X20311687.

        X_mat = torch.stack((torch.ones(
            y_pred_temp.shape[0:2]).to(device), x_lm[:, :, 0], x_lm[:, :, 1]),
                            axis=1)
        #getting the coefficients using a 3-d version of the normal equation:
        #https://cmdlinetips.com/2020/03/linear-regression-using-matrix-multiplication-in-python-using-numpy/
        #http://mlwiki.org/index.php/Normal_Equation
        X_mat_T = torch.permute(X_mat, dims=(0, 2, 1))
        X_mat_T_dot = torch.einsum(
            'bij,bjk->bik', X_mat_T, X_mat
        )  #eigensums are used instead of dot products because we want the dot products of axis 1 and 2, not 0
        X_mat_inv = torch.linalg.pinv(X_mat_T_dot)
        X_mat_inv_dot = torch.einsum(
            'bij,bjk->bik', X_mat_inv, X_mat_T
        )  #eigensums are used instead of dot products because we want the dot products of axis 1 and 2, not 0
        a_b = torch.einsum(
            'bij,bik->bjk', X_mat_inv_dot.float(), y_pred_temp.float()
        )  #eigensums are used instead of dot products because we want the dot products of axis 1 and 2, not 0
        #the tensor a_b has the coefficients from the regression (reach x [[intercept],[a],[b]])
        #Aw = amplitude of the water temp sinusoid (deg C)
        #A = sqrt (a^2 + b^2)
        Aw = torch.sqrt(a_b[:, 1, 0]**2 + a_b[:, 2, 0]**2)
        #Phiw = phase of the water temp sinusoid (radians)
        #Phi = atan (b/a) - in radians
        Phiw = torch.atan(a_b[:, 2, 0] / a_b[:, 1, 0])

        #calculate the air properties
        y_true_air = y_true[:, :, -1:]
        a_b_air = torch.einsum('bij,bik->bjk', X_mat_inv_dot, y_true_air)
        A_air = torch.sqrt(a_b_air[:, 1, 0]**2 + a_b_air[:, 2, 0]**2)
        Phi_air = torch.atan(a_b_air[:, 2, 0] / a_b_air[:, 1, 0])

        #calculate and scale predicted values
        #delPhi_pred = the difference in phase between the water temp and air temp sinusoids, in days
        delPhi_pred = Phi_air - Phiw
        delPhi_pred = (delPhi_pred * 365 / (2 * m.pi) - gw_mean[1]) / gw_std[1]

        #Ar_pred = the ratio of the water temp and air temp amplitudes
        Ar_pred = (Aw / A_air - gw_mean[0]) / gw_std[0]
        y_pred_temp = torch.squeeze(y_pred_temp)
        y_pred_mean = torch.mean(y_pred_temp, 1, keepdims=True)

    #scale the predicted mean temp
    Tmean_pred = torch.squeeze((y_pred_mean - gw_mean[2]) / gw_std[2])

    return Ar_obs, Ar_pred, delPhi_obs, delPhi_pred, Tmean_obs, Tmean_pred
Пример #18
0
    shape = 'same'
    ftshift = False
    x_np = np.array([1, 2, 3, 4, 5])
    h_np = np.array([1 + 2j, 2, 3, 4, 5, 6, 7])

    x_th = th.tensor(x_np)
    h_th = th.tensor(h_np)
    x_th = th.stack([x_th, th.zeros(x_th.size())], dim=-1)
    h_th = th.stack([h_th.real, h_th.imag], dim=-1)

    y1 = ts.fftconv1(x_th,
                     h_th,
                     axis=0,
                     nfft=None,
                     shape=shape,
                     ftshift=ftshift)

    fftconv1layer = FFTConv1(h_th.size(0), h=h_th, nfft=None, shape=shape)

    for p in fftconv1layer.parameters():
        print(p)

    y2 = fftconv1layer.forward(x_th)
    # y2 = th.view_as_complex(y2)
    y2 = y2.cpu().detach()

    # print(y1)
    # print(y2)
    print(th.sum(th.abs(y1 - y2)), th.sum(th.angle(y1) - th.angle(y2)))
Пример #19
0
def meta_angle_out(self, out):
    torch._resize_output_(out, self.size(), self.device)
    return out.copy_(torch.angle(self))
Пример #20
0
def vaerecon(ksp,
             coilmaps,
             mode,
             vae_model,
             gt,
             logdir,
             device,
             writer=False,
             norm=1,
             nsampl=100,
             boot_samples=500,
             k=1,
             patchsize=28,
             parfact=25,
             num_iter=200,
             stepsize=5e-4,
             lmb=0.01,
             num_priors=1,
             use_momentum=True):
    # Init data
    imcoils, imsizer, imsizec = ksp.shape
    ksp = ksp.to(device)
    coilmaps = coilmaps.to(device)
    vae_model = vae_model.to(device)
    uspat = (torch.abs(ksp[0]) > 0).type(torch.uint8).to(device)
    recs_gpu = tUFT_pytorch(ksp, uspat, coilmaps)
    rss = rss_pytorch(ksp)

    # Init coilmaps estimation with JSENSE
    if mode == 'JDDP':
        # Polynomial order
        max_basis_order = 6
        num_coeffs = (max_basis_order + 1)**2

        # Create the basis functions for the sense estimation estimation
        basis_funct = create_basis_functions(imsizer,
                                             imsizec,
                                             max_basis_order,
                                             show_plot=False)
        plot_basis = False
        if plot_basis:
            for i in range(num_coeffs):
                writer.log({
                    "Basis funcs": [
                        writer.Image(transforms.ToPILImage()(normalize_tensor(
                            torch.from_numpy(basis_funct[i, :, :]))),
                                     caption="")
                    ]
                })

        basis_funct = torch.from_numpy(
            np.tile(basis_funct[np.newaxis, :, :, :],
                    [coilmaps.shape[0], 1, 1, 1])).to(device)
        coeffs_array = sense_estimation_ls(ksp, recs_gpu, basis_funct, uspat)

        coilmaps = torch.sum(
            coeffs_array[:, :, np.newaxis, np.newaxis] * basis_funct,
            1).to(device)

        recs_gpu = tUFT_pytorch(ksp, uspat, coilmaps)

        if writer:
            for i in range(coilmaps.shape[0]):
                writer.log(
                    {
                        "abs Coilmaps": [
                            writer.Image(
                                transforms.ToPILImage()(normalize_tensor(
                                    torch.abs(coilmaps[i, :, :]))),
                                caption="")
                        ]
                    },
                    step=0)
                writer.log(
                    {
                        "phase Coilmaps": [
                            writer.Image(
                                transforms.ToPILImage()(normalize_tensor(
                                    torch.angle(coilmaps[i, :, :]))),
                                caption="")
                        ]
                    },
                    step=0)
        print("Coilmaps init done")

    # Log
    if writer:
        writer.log(
            {
                "Gt rss": [
                    writer.Image(transforms.ToPILImage()(normalize_tensor(gt)),
                                 caption="")
                ]
            },
            step=0)
        writer.log(
            {
                "Restored rss": [
                    writer.Image(transforms.ToPILImage()(
                        normalize_tensor(rss)),
                                 caption="")
                ]
            },
            step=0)
        writer.log(
            {
                "Restored abs": [
                    writer.Image(transforms.ToPILImage()(normalize_tensor(
                        torch.abs(recs_gpu))),
                                 caption="")
                ]
            },
            step=0)
        writer.log(
            {
                "Restored Phase": [
                    writer.Image(transforms.ToPILImage()(normalize_tensor(
                        torch.angle(recs_gpu))),
                                 caption="")
                ]
            },
            step=0)
        writer.log(
            {
                "diff rss": [
                    writer.Image(transforms.ToPILImage()(normalize_tensor(
                        (rss.detach().cpu() / norm - gt.detach().cpu()))),
                                 caption="")
                ]
            },
            step=0)
        ssim_v = ssim(rss[160:-160].detach().cpu().numpy() / norm,
                      gt[160:-160].detach().cpu().numpy())
        nmse_v = nmse(rss[160:-160].detach().cpu().numpy() / norm,
                      gt[160:-160].detach().cpu().numpy())
        psnr_v = psnr(rss[160:-160].detach().cpu().numpy() / norm,
                      gt[160:-160].detach().cpu().numpy())
        print('SSIM: ', ssim_v, ' NMSE: ', nmse_v, ' PSNR: ', psnr_v)
        writer.log({"SSIM": ssim_v, "NMSE": nmse_v, "PSNR": psnr_v}, step=0)

        lik, dc = prior_value(rss, ksp, uspat, coilmaps, patchsize, parfact,
                              nsampl, vae_model)
        writer.log({"ELBO": lik}, step=0)
        writer.log({"DC err": dc}, step=0)

    t = 1
    for it in range(0, num_iter, 2):
        print('Itr: ', it)

        # Magnitude prior projection step
        for _ in range(num_priors):
            # Gradient descent of Prior
            if mode == 'TV':
                tvnorm, abstvgrad = tv_norm(torch.abs(rss))
                priorgrad = abstvgrad * recs_gpu / (torch.abs(recs_gpu))
                recs_gpu = recs_gpu - stepsize * priorgrad

                if writer:  #and it%10 == 0:
                    writer.log(
                        {
                            "TVgrad": [
                                writer.Image(transforms.ToPILImage()(
                                    normalize_tensor(abstvgrad)),
                                             caption="")
                            ]
                        },
                        step=it + 1)
                    writer.log(
                        {
                            "TV": [
                                writer.Image(transforms.ToPILImage()(
                                    normalize_tensor(tvnorm)),
                                             caption="")
                            ]
                        },
                        step=it + 1)

            elif mode == 'DDP' or mode == 'JDDP':
                g_abs_lik, est_uncert, g_dc = prior_gradient(
                    rss, ksp, uspat, coilmaps, patchsize, parfact, nsampl,
                    vae_model, boot_samples, mode)
                priorgrad = g_abs_lik * recs_gpu / (torch.abs(recs_gpu))

                if it > -1:
                    recs_gpu = recs_gpu - stepsize * priorgrad

                if writer:  # Log
                    writer.log(
                        {
                            "VAEgrad abs": [
                                writer.Image(transforms.ToPILImage()(
                                    normalize_tensor(torch.abs(g_abs_lik))),
                                             caption="")
                            ]
                        },
                        step=it + 1)
                    writer.log({"STD": torch.mean(torch.abs(est_uncert))},
                               step=it + 1)

                    tmp1 = UFT_pytorch(recs_gpu, 1 - uspat, coilmaps)
                    tmp2 = ksp * uspat.unsqueeze(0)
                    tmp = tmp1 + tmp2
                    rss = rss_pytorch(tmp)
                    nmse_v = nmse(
                        (rss[160:-160].detach().cpu().numpy() / norm),
                        gt[160:-160].detach().cpu().numpy())
                    ssim_v = ssim(rss[160:-160].detach().cpu().numpy() / norm,
                                  gt[160:-160].detach().cpu().numpy())
                    psnr_v = psnr(rss[160:-160].detach().cpu().numpy() / norm,
                                  gt[160:-160].detach().cpu().numpy())
                    print('SSIM: ', ssim_v, ' NMSE: ', nmse_v, ' PSNR: ',
                          psnr_v)
                    writer.log({
                        "SSIM": ssim_v,
                        "NMSE": nmse_v,
                        "PSNR": psnr_v
                    },
                               step=it + 1)
            else:
                print("Error: Prior method does not exists.")
                exit()

        # Phase projection step
        if lmb > 0:
            tmpa = torch.abs(recs_gpu)
            tmpp = torch.angle(recs_gpu)

            # We apply phase regularization to prefer smooth phase images
            #tmpptv = reg2_proj(tmpp, imsizer, imsizec, alpha=lmb, niter=2)  # 0.1, 15
            tmpptv = tv_proj(tmpp, mu=0.125, lmb=lmb, IT=50)  # 0.1, 15
            # We combine back the phase and the magnitude
            recs_gpu = tmpa * torch.exp(1j * tmpptv)

        # Coilmaps estimation step (if JSENSE)
        if mode == 'JDDP':
            # computed on cpu since pytorch gpu can handle complex numbers...
            coeffs_array = sense_estimation_ls(ksp, recs_gpu, basis_funct,
                                               uspat)
            coilmaps = torch.sum(
                coeffs_array[:, :, np.newaxis, np.newaxis] * basis_funct,
                1).to(device)

            if writer:
                writer.log(
                    {
                        "abs Coilmaps": [
                            writer.Image(
                                transforms.ToPILImage()(normalize_tensor(
                                    torch.abs(coilmaps[0, :, :]))),
                                caption="")
                        ]
                    },
                    step=it + 1)
                writer.log(
                    {
                        "phase Coilmaps": [
                            writer.Image(
                                transforms.ToPILImage()(normalize_tensor(
                                    torch.angle(coilmaps[0, :, :]))),
                                caption="")
                        ]
                    },
                    step=it + 1)

        # Data consistency projection
        tmp1 = UFT_pytorch(recs_gpu, 1 - uspat, coilmaps)
        tmp2 = ksp * uspat.unsqueeze(0)
        tmp = tmp1 + tmp2
        recs_gpu = tFT_pytorch(tmp, coilmaps)
        # recs[it + 2] = recs_gpu.detach().cpu().numpy()
        rss = rss_pytorch(tmp)

        # Log
        nmse_v = nmse((rss[160:-160].detach().cpu().numpy() / norm),
                      gt[160:-160].detach().cpu().numpy())
        ssim_v = ssim(rss[160:-160].detach().cpu().numpy() / norm,
                      gt[160:-160].detach().cpu().numpy())
        psnr_v = psnr(rss[160:-160].detach().cpu().numpy() / norm,
                      gt[160:-160].detach().cpu().numpy())
        print('SSIM: ', ssim_v, ' NMSE: ', nmse_v, ' PSNR: ', psnr_v)

        if writer:
            writer.log({
                "SSIM": ssim_v,
                "NMSE": nmse_v,
                "PSNR": psnr_v
            },
                       step=it + 1)
            writer.log(
                {
                    "Restored rss": [
                        writer.Image(transforms.ToPILImage()(
                            normalize_tensor(rss)),
                                     caption="")
                    ]
                },
                step=it + 1)
            writer.log(
                {
                    "Restored Phase": [
                        writer.Image(transforms.ToPILImage()(normalize_tensor(
                            torch.angle(recs_gpu))),
                                     caption="")
                    ]
                },
                step=it + 1)
            writer.log(
                {
                    "diff rss": [
                        writer.Image(transforms.ToPILImage()(normalize_tensor(
                            (rss.detach().cpu() / norm - gt.detach().cpu()))),
                                     caption="")
                    ]
                },
                step=it + 1)
            writer.log(
                {
                    "Restored 1ch kspace": [
                        writer.Image(transforms.ToPILImage()(normalize_tensor(
                            torch.log(torch.abs(tmp[0])))),
                                     caption="")
                    ]
                },
                step=it + 1)
            lik, dc = prior_value(rss, ksp, uspat, coilmaps, patchsize,
                                  parfact, nsampl, vae_model)
            writer.log({"ELBO": lik}, step=it + 1)
            writer.log({"DC err": dc}, step=it + 1)

    return rss / norm
Пример #21
0
def L2_phase(x):
    return (torch.angle(torch.exp(1j * x[:,:-1] - 1j * x[:,1:]))**2).mean() + \
           (torch.angle(torch.exp(1j * x[:,:,:-1] - 1j * x[:,:,1:]))**2).mean()
Пример #22
0
    # ---Recieved signal
    Sr = chirp_recv(t, Tp, K, Fc, a=1., g=G, r=R)

    chirp = Chirp(Tp=Tp, K=K, Fc=Fc, a=1.)

    St = chirp.tran(t)
    Sr = chirp.recv(t, g=G, r=R)

    plt.figure()
    plt.subplot(221)
    plt.plot(t * 1e6, th.real(St))
    plt.plot(t * 1e6, th.imag(St))
    plt.xlabel('Time/us')
    plt.legend(['real', 'imag'])
    plt.subplot(222)
    plt.plot(t * 1e6, th.angle(St))
    plt.xlabel('Time/us')
    plt.subplot(223)
    plt.plot(t * 1e6, th.real(Sr))
    plt.plot(t * 1e6, th.imag(Sr))
    plt.xlabel('Time/us')
    plt.legend(['real', 'imag'])
    plt.subplot(224)
    plt.plot(t * 1e6, th.angle(Sr))
    plt.xlabel('Time/us')
    plt.show()

    # ---Frequency domain
    Yt = fftshift(fft(fftshift(St, dim=0), dim=0), dim=0)
    Yr = fftshift(fft(fftshift(Sr, dim=0), dim=0), dim=0)
Пример #23
0
        axs[i, j].set_title(idx_to_class[tgt], fontsize=40)
        fig.colorbar(imshow, ax=axs[i, j])
        ind += 1
# -
# ## Analyse de la phase

# +
# %matplotlib inline
fig, axs = plt.subplots(row_nb, col_nb, figsize=(row_nb * 10, col_nb * 10))

ind = 0
for i in range(row_nb):
    for j in range(col_nb):
        img, tgt = subset[ind]
        fft_img = torch.fft.fft2(img).squeeze()
        freq_img = torch.angle(fft_img)
        freq_x = torch.fft.fftshift(torch.fft.fftfreq(img.shape[2])).numpy()
        freq_y = torch.fft.fftshift(torch.fft.fftfreq(img.shape[1])).numpy()
        x_range = np.hstack((np.arange(0, img.shape[2],
                                       50), np.array([img.shape[2] - 1]),
                             np.array([img.shape[2] // 2])))
        y_range = np.hstack((np.arange(0, img.shape[1],
                                       50), np.array([img.shape[1] - 1]),
                             np.array([img.shape[2] // 2])))
        axs[i, j].set_xticks(x_range)
        axs[i, j].set_xticklabels(freq_x[x_range])
        axs[i, j].set_yticks(y_range)
        axs[i, j].set_yticklabels(freq_y[y_range])
        imshow = axs[i, j].imshow(freq_img)
        axs[i, j].set_title(idx_to_class[tgt], fontsize=40)
        fig.colorbar(imshow, ax=axs[i, j])
Пример #24
0
 def pointwise_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     r = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     s = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
     f = torch.zeros(3)
     g = torch.tensor([-1, 0, 1])
     w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
     return (
         torch.abs(torch.tensor([-1, -2, 3])),
         torch.absolute(torch.tensor([-1, -2, 3])),
         torch.acos(a),
         torch.arccos(a),
         torch.acosh(a.uniform_(1.0, 2.0)),
         torch.add(a, 20),
         torch.add(a, torch.randn(4, 1), alpha=10),
         torch.addcdiv(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.addcmul(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.angle(a),
         torch.asin(a),
         torch.arcsin(a),
         torch.asinh(a),
         torch.arcsinh(a),
         torch.atan(a),
         torch.arctan(a),
         torch.atanh(a.uniform_(-1.0, 1.0)),
         torch.arctanh(a.uniform_(-1.0, 1.0)),
         torch.atan2(a, a),
         torch.bitwise_not(t),
         torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.ceil(a),
         torch.clamp(a, min=-0.5, max=0.5),
         torch.clamp(a, min=0.5),
         torch.clamp(a, max=0.5),
         torch.clip(a, min=-0.5, max=0.5),
         torch.conj(a),
         torch.copysign(a, 1),
         torch.copysign(a, b),
         torch.cos(a),
         torch.cosh(a),
         torch.deg2rad(
             torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0,
                                                              -90.0]])),
         torch.div(a, b),
         torch.divide(a, b, rounding_mode="trunc"),
         torch.divide(a, b, rounding_mode="floor"),
         torch.digamma(torch.tensor([1.0, 0.5])),
         torch.erf(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfc(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfinv(torch.tensor([0.0, 0.5, -1.0])),
         torch.exp(torch.tensor([0.0, math.log(2.0)])),
         torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])),
         torch.expm1(torch.tensor([0.0, math.log(2.0)])),
         torch.fake_quantize_per_channel_affine(
             torch.randn(2, 2, 2),
             (torch.randn(2) + 1) * 0.05,
             torch.zeros(2),
             1,
             0,
             255,
         ),
         torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255),
         torch.float_power(torch.randint(10, (4, )), 2),
         torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4,
                                                             -5])),
         torch.floor(a),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4),
         torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2),
         torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.frac(torch.tensor([1.0, 2.5, -3.2])),
         torch.randn(4, dtype=torch.cfloat).imag,
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1])),
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])),
         torch.lerp(torch.arange(1.0, 5.0),
                    torch.empty(4).fill_(10), 0.5),
         torch.lerp(
             torch.arange(1.0, 5.0),
             torch.empty(4).fill_(10),
             torch.full_like(torch.arange(1.0, 5.0), 0.5),
         ),
         torch.lgamma(torch.arange(0.5, 2, 0.5)),
         torch.log(torch.arange(5) + 10),
         torch.log10(torch.rand(5)),
         torch.log1p(torch.randn(5)),
         torch.log2(torch.rand(5)),
         torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([-100.0, -200.0, -300.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([1.0, 2000.0, 30000.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-100.0, -200.0, -300.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([1.0, 2000.0, 30000.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logical_and(r, s),
         torch.logical_and(r.double(), s.double()),
         torch.logical_and(r.double(), s),
         torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)),
         torch.logical_not(
             torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)),
         torch.logical_not(
             torch.tensor([0.0, 1.0, -10.0], dtype=torch.double),
             out=torch.empty(3, dtype=torch.int16),
         ),
         torch.logical_or(r, s),
         torch.logical_or(r.double(), s.double()),
         torch.logical_or(r.double(), s),
         torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_xor(r, s),
         torch.logical_xor(r.double(), s.double()),
         torch.logical_xor(r.double(), s),
         torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logit(torch.rand(5), eps=1e-6),
         torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])),
         torch.i0(torch.arange(5, dtype=torch.float32)),
         torch.igamma(a, b),
         torch.igammac(a, b),
         torch.mul(torch.randn(3), 100),
         torch.multiply(torch.randn(4, 1), torch.randn(1, 4)),
         torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2),
         torch.tensor([float("nan"),
                       float("inf"), -float("inf"), 3.14]),
         torch.nan_to_num(w),
         torch.nan_to_num(w, nan=2.0),
         torch.nan_to_num(w, nan=2.0, posinf=1.0),
         torch.neg(torch.randn(5)),
         # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]),
         torch.polygamma(1, torch.tensor([1.0, 0.5])),
         torch.polygamma(2, torch.tensor([1.0, 0.5])),
         torch.polygamma(3, torch.tensor([1.0, 0.5])),
         torch.polygamma(4, torch.tensor([1.0, 0.5])),
         torch.pow(a, 2),
         torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)),
         torch.rad2deg(
             torch.tensor([[3.142, -3.142], [6.283, -6.283],
                           [1.570, -1.570]])),
         torch.randn(4, dtype=torch.cfloat).real,
         torch.reciprocal(a),
         torch.remainder(torch.tensor([-3.0, -2.0]), 2),
         torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.round(a),
         torch.rsqrt(a),
         torch.sigmoid(a),
         torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sgn(a),
         torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sin(a),
         torch.sinc(a),
         torch.sinh(a),
         torch.sqrt(a),
         torch.square(a),
         torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2),
         torch.tan(a),
         torch.tanh(a),
         torch.trunc(a),
         torch.xlogy(f, g),
         torch.xlogy(f, g),
         torch.xlogy(f, 4),
         torch.xlogy(2, g),
     )
Пример #25
0
def Prep(data):
    dic_data = isinstance(data, dict)
    if dic_data:
        angle = torch.tensor(data["angle"]).unsqueeze(dim=0)
        input = torch.from_numpy(data["mix"]).unsqueeze(dim=0).float()
        input = torch.transpose(input, 2, 1)
        data["mix"] = input.squeeze()
        R = data["R"]
    else:
        input = torch.from_numpy(data[0]).unsqueeze(dim=0).float()
        angle = torch.tensor(data[2]).unsqueeze(dim=0)
        R = data[5]
        return_list = []

    mic_array_layout = R - np.tile(R[:, 0].reshape((3, 1)), (1, n_mic))
    pairs = ((0, 3), (1, 4), (2, 5), (0, 1), (2, 3), (4, 5))
    ori_pairs = ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5))
    delay = np.zeros((n_mic, n_grid))
    for h, m in enumerate(ori_pairs):
        dx = mic_array_layout[0, m[1]] - mic_array_layout[0, m[0]]
        dy = mic_array_layout[1, m[1]] - mic_array_layout[1, m[0]]
        for i in range(n_grid):
            delay[
                h,
                i] = dx * np.cos(i * np.pi / 18) + dy * np.sin(i * np.pi / 18)
    delay = torch.from_numpy(delay).unsqueeze(dim=-1).expand(-1, -1, m_total)
    w = torch.exp(-2j * np.pi * m_data * delay) / V
    batch_size = input.size(0)
    mag, ph, real, image = stft.transform(input.reshape(-1, input.size()[-1]))

    pad = Variable(torch.zeros(mag.size()[0],
                               mag.size()[1], 1)).type(input.type())
    mag = torch.cat([mag, pad], -1)
    ph = torch.cat([ph, pad], -1)
    channel = mag.size()[-1]
    mag = mag.view(batch_size, n_mic, -1, channel)
    ph = ph.view(batch_size, n_mic, -1, channel)
    #LPS = 10 * torch.log10(mag ** 2 + 10e-20)
    complex = (mag * torch.exp(ph * 1j))
    IPD_list = []
    for m in pairs:
        com_u1 = complex[:, m[0]]
        com_u2 = complex[:, m[1]]
        IPD = torch.angle(com_u1) - torch.angle(com_u2)
        #IPD /= (frequency_vector + 1.0)[:, None]
        #IPD = IPD % (2 * np.pi)
        IPD = IPD.unsqueeze(dim=1)
        IPD_list.append(IPD)
    IPD = torch.cat(IPD_list, dim=1)
    complex = complex.unsqueeze(dim=2).expand(-1, -1, n_grid, -1, -1)
    for i in range(n_sp):
        ang = angle[:, i]
        steering_vector = __get_steering_vector(ang, pairs, mic_array_layout)
        steering_vector = steering_vector.unsqueeze(dim=-1)
        AF = steering_vector * torch.exp(1j * IPD)
        AF = AF / (torch.sqrt(AF.real**2 + AF.imag**2) + 10e-20)
        AF = AF.sum(dim=1)
        w_ = w.unsqueeze(dim=0).expand(AF.size()[0], -1, -1,
                                       -1).unsqueeze(-1).expand(
                                           -1, -1, -1, -1, channel)
        mod_w_com = (w_ * complex) * torch.conj(w_ * complex)
        dpr = mod_w_com.sum(dim=1) / (
            (mod_w_com).sum(dim=1).sum(dim=1, keepdims=True) + 10e-20)
        p = (ang / np.pi * 18).type(torch.long)
        dpr = dpr[range(batch_size), p]
        feature_IPD = IPD.reshape(batch_size,
                                  IPD.size()[1] * IPD.size()[2], IPD.size(-1))
        feature_list = [AF, torch.cos(feature_IPD), dpr]
        fusion = torch.cat(feature_list, dim=1).real.float()
        if dic_data:
            data[i] = fusion.squeeze()
        else:
            return_list.append(fusion)
    if not dic_data:
        return return_list
Пример #26
0
    plt.title('Convolution matched filter')
    plt.xlabel(r'Time/$\mu s$')
    plt.ylabel('Amplitude')
    plt.subplot(222)
    plt.plot(t * 1e6, th.imag(Sm1))
    plt.plot(t * 1e6, th.abs(Sm1))
    plt.grid()
    plt.legend(['Imaginary part', 'Amplitude'])
    plt.title('Convolution matched filter')
    plt.xlabel(r'Time/$\mu s$')
    plt.ylabel('Amplitude')
    plt.subplot(223)
    plt.plot(f, th.abs(Ym1))
    plt.grid()
    plt.subplot(224)
    plt.plot(f, th.angle(Ym1))
    plt.grid()

    plt.figure(2)
    plt.subplot(221)
    plt.plot(t * 1e6, th.real(Sm2))
    plt.plot(t * 1e6, th.abs(Sm2))
    plt.grid()
    plt.legend(['Real part', 'Amplitude'])
    plt.title('Correlation matched filter')
    plt.xlabel(r'Time/$\mu s$')
    plt.ylabel('Amplitude')
    plt.subplot(222)
    plt.plot(t * 1e6, th.imag(Sm2))
    plt.plot(t * 1e6, th.abs(Sm2))
    plt.grid()
Пример #27
0
class TorchBox(qml.math.TensorBox):
    """Implements the :class:`~.TensorBox` API for Torch tensors.

    For more details, please refer to the :class:`~.TensorBox` documentation.
    """

    abs = wrap_output(lambda self: torch.abs(self.data))
    angle = wrap_output(lambda self: torch.angle(self.data))
    arcsin = wrap_output(lambda self: torch.asin(self.data))
    expand_dims = wrap_output(
        lambda self, axis: torch.unsqueeze(self.data, dim=axis))
    ones_like = wrap_output(lambda self: torch.ones_like(self.data))
    sqrt = wrap_output(lambda self: torch.sqrt(
        self.data.to(torch.float64)
        if self.data.dtype in (torch.int64, torch.int32) else self.data))
    T = wrap_output(lambda self: self.data.T)

    @staticmethod
    def astensor(tensor):
        return torch.as_tensor(tensor)

    @wrap_output
    def cast(self, dtype):
        if isinstance(dtype, torch.dtype):
            return self.data.to(dtype)

        dtype_name = np.dtype(dtype).name
        torch_dtype = getattr(torch, dtype_name, None)

        if torch_dtype is None:
            raise ValueError(f"Unable to convert {dtype} to a Torch dtype")

        return self.data.to(torch_dtype)

    @staticmethod
    def _coerce_types(tensors):
        dtypes = {i.dtype for i in tensors}

        if len(dtypes) == 1:
            return tensors

        complex_priority = [torch.complex64, torch.complex128]
        float_priority = [torch.float16, torch.float32, torch.float64]
        int_priority = [torch.int8, torch.int16, torch.int32, torch.int64]

        complex_type = [i for i in complex_priority if i in dtypes]
        float_type = [i for i in float_priority if i in dtypes]
        int_type = [i for i in int_priority if i in dtypes]

        cast_type = complex_type or float_type or int_type
        cast_type = list(cast_type)[-1]

        return [t.to(cast_type) for t in tensors]

    @staticmethod
    @wrap_output
    def concatenate(values, axis=0):
        if axis is None:
            # flatten and then concatenate zero'th dimension
            # to reproduce numpy's behaviour
            tensors = [
                TorchBox.astensor(t).flatten()
                for t in TorchBox.unbox_list(values)
            ]
            return torch.cat(tensors, dim=0)

        tensors = [TorchBox.astensor(t) for t in TorchBox.unbox_list(values)]
        return torch.cat(tensors, dim=axis)

    @staticmethod
    @wrap_output
    def dot(x, y):
        x, y = [TorchBox.astensor(t) for t in TorchBox.unbox_list([x, y])]
        x, y = TorchBox._coerce_types([x, y])

        if x.ndim == 0 and y.ndim == 0:
            return x * y

        if x.ndim <= 2 and y.ndim <= 2:
            return x @ y

        return torch.tensordot(x, y, dims=[[-1], [-2]])

    @property
    def interface(self):
        return "torch"

    def numpy(self):
        return self.data.detach().cpu().numpy()

    @property
    def requires_grad(self):
        return self.data.requires_grad

    @property
    def shape(self):
        return tuple(self.data.shape)

    @staticmethod
    @wrap_output
    def stack(values, axis=0):
        tensors = [TorchBox.astensor(t) for t in TorchBox.unbox_list(values)]
        res = torch.stack(tensors, axis=axis)
        return res

    @wrap_output
    def sum(self, axis=None, keepdims=False):
        if axis is None:
            return torch.sum(self.data)

        return torch.sum(self.data, dim=axis, keepdim=keepdims)

    @wrap_output
    def take(self, indices, axis=None):
        if not isinstance(indices, torch.Tensor):
            indices = self.astensor(indices)

        if axis is None:
            return self.data.flatten()[indices]

        if indices.ndim == 1:
            if (indices < 0).any():
                # index_select doesn't allow negative indices
                dim_length = self.data.size(
                )[0] if axis is None else self.shape[axis]

                indices = qml.math.where(indices >= 0, indices,
                                         indices + dim_length)

            return torch.index_select(self.data, dim=axis, index=indices)

        fancy_indices = [slice(None)] * axis + [indices]
        return self.data[fancy_indices]

    @staticmethod
    @wrap_output
    def where(condition, x, y):
        return torch.where(TorchBox.astensor(condition),
                           *TorchBox.unbox_list([x, y]))
Пример #28
0
def mag_phase(complex_tensor):
    return torch.abs(complex_tensor), torch.angle(complex_tensor)
Пример #29
0
        loss = smooth_amplitude_loss(a_model, indices_target[take_ind],
                                     counts_target[take_ind])
        loss_sum = loss.mean()
        sum_loss += loss_sum.item()
        loss_sum.backward()

        # if i > probe_start:
        #     plotAbsAngle(psi_model.grad[0].cpu().detach().numpy(),'psi_model.grad')
        # plotAbsAngle(S_model[0].cpu().detach().numpy(), 'S_model')

        optimizer.step()
        optimizer.zero_grad()

    c = th.vdot(T[slic].ravel(), S_model[slic].ravel())
    T_hat = T * th.exp(-1j * th.angle(c))
    dist = th.norm(S_model[slic] - T_hat[slic])
    x_norm = th.norm(T)
    err = dist / x_norm

    errs.append(err)
    sum_loss /= n_batches
    losses.append(sum_loss)
    print(f'{i:3d}  loss: {sum_loss}    err: {err}')

# print(f'i {i} loss {sum_loss}, C_model = {C_model[0]} , C_target = {C_target[0]}')
# %%
d = margin + M[0] // 2
d = 1
plotAbsAngle(S_model[0, d:-d, d:-d].cpu().detach().numpy(),
             'Reconstruction',
Пример #30
0
 def forward(self, x):
     out1 = x.angle()
     out2 = torch.angle(x)
     return out1, out2