def angle(input_, deg=False): """Wrapper of `torch.angle`. Parameters ---------- input_ : DTensor Input dense tensor. deg : bool, optional If true, result is in degree format. Otherwise, return in radians. By default False """ if deg: ret = torch.angle(input_) * 180 / math.pi else: ret = torch.angle(input_) return ret
def phase_comp(psi_comp, uwrap=False, dens=None): """Compute the phase (angle) of a single complex wavefunction component. Parameters ---------- psi_comp : NumPy :obj:`array` or PyTorch :obj:`Tensor` A single wavefunction component. Returns ------- angle : NumPy :obj:`array` or PyTorch :obj:`Tensor` The phase (angle) of the component's wavefunction. """ if isinstance(psi_comp, np.ndarray): ang = np.angle(psi_comp) if uwrap: ang = rest.unwrap_phase(ang) elif isinstance(psi_comp, torch.Tensor): ang = torch.angle(psi_comp) if uwrap: raise NotImplementedError("Unwrapping the complex phase is not " "implemented for PyTorch tensors.") if dens is not None: ang[dens < (dens.max() * 1e-6)] = 0 return ang
def stft_to_phase_magn( complex_values: th.Tensor, nb_vec: int = constant.N_VEC) -> Tuple[th.Tensor, th.Tensor]: magn = th.abs(complex_values) phase = th.angle(complex_values) magn = bark_magn_scale(magn, unscale=False) phase = unwrap(phase) phase = phase[:, 1:] - phase[:, :-1] magn = magn[:, 1:] max_magn = magn.max() min_magn = magn.min() max_phase = phase.max() min_phase = phase.min() magn = (magn - min_magn) / (max_magn - min_magn) phase = (phase - min_phase) / (max_phase - min_phase) magn, phase = magn * 2. - 1., phase * 2. - 1. magn = magn[:, magn.size()[1] % nb_vec:] phase = phase[:, phase.size()[1] % nb_vec:] magn = th.stack(magn.split(nb_vec, dim=1), dim=0) phase = th.stack(phase.split(nb_vec, dim=1), dim=0) return magn, phase
def forward(self, input, angle): # padding mag, ph, real, image= self.stft.transform(input.reshape(-1, input.size()[-1])) pad = Variable(torch.zeros(mag.size()[0],mag.size()[1], 1)).type(input.type()) mag = torch.cat([mag, pad], -1) ph = torch.cat([ph, pad], -1) output, rest = self.pad_signal(input) enc_output = self.encoder(output[:, :1]) # B, N, L mag = mag.view(enc_output.size(0), self.n_mic, -1, enc_output.size(-1)) ph = ph.view(enc_output.size(0), self.n_mic, -1, enc_output.size(-1)) LPS = 10 * torch.log10(mag ** 2 + 10e-20) complex = (mag * torch.exp(ph * 1j)) IPD_list = [] for m in self.pairs: com_u1 = complex[:, m[0]] com_u2 = complex[:, m[1]] IPD = torch.angle(com_u1 * torch.conj(com_u2)) IPD /= (self.frequency_vector + 1.0)[:, None] IPD = IPD % torch.pi IPD = IPD.unsqueeze(dim=1) IPD_list.append(IPD) IPD = torch.cat(IPD_list, dim=1) steering_vector = self.__get_steering_vector(angle, self.pairs) steering_vector = steering_vector.unsqueeze(dim=-1) AF = steering_vector * IPD AF = AF/AF.sum(dim=1, keepdims=True).real w = self.w.unsqueeze(dim=0).expand(AF.size()[0], -1, -1, -1) dpr = torch.zeros((AF.size(0), self.n_grid, AF.size(-2), AF.size(-1)), dtype=torch.complex128) print(w.size()) print(complex.size()) exit() for i in range(36): for j in range(602): for h in range(97): dpr[:, i, h, j] = (w[:, :, i, h] * complex[:, :, h, j]).sum(dim=1) dpr = (dpr * torch.conj(dpr))/ torch.sum(dpr * torch.conj(dpr), dim=1, keepdim=True) print(dpr.size()) print(AF.size()) feature_list = [enc_output.unsqueeze(dim=1), AF, dpr, torch.cos(IPD)] fusion = torch.cat(feature_list, dim=1).float() batch_size = output.size(0) fusion = fusion.view(batch_size, -1, fusion.size()[-1]) # waveform encoder masks = torch.sigmoid(self.TCN(fusion)).view(batch_size, self.num_spk, self.enc_dim, -1) # B, C, N, L masked_output = enc_output.unsqueeze(1) * masks # B, C, N, L # waveform decoder output = self.decoder(masked_output.view(batch_size*self.num_spk, self.enc_dim, -1)) # B*C, 1, L output = output[:,:,self.stride:-(rest+self.stride)].contiguous() # B*C, 1, L output = output.view(batch_size, self.num_spk, -1) # B, C, T return output
def entropy_loss(ent_out, ent_gt): if ent_out.is_complex: loss = torch.square(ent_gt - ent_out) return torch.abs(loss) + torch.angle(loss) # mag = torch.square(torch.abs(ent_gt) - torch.abs(ent_out)) # ph = torch.square(torch.angle(ent_gt) - torch.angle(ent_out)) # return mag+ph else: return torch.square(ent_gt - ent_out)
def forward(self, x): # encoding path temp = torch.stft(x, n_fft=128, return_complex=True) x_abs = torch.abs(temp) x_ang = torch.angle(temp) x_abs = torch.unsqueeze(x_abs, dim=1) # pdb.set_trace() #Get the right size! # x_abs = x_abs.reshape( x_abs.shape[0],1, 65, 126 ) x1 = self.Conv1(x_abs) x2 = self.Maxpool(x1) x2 = self.Conv2(x2) x3 = self.Maxpool(x2) x3 = self.Conv3(x3) x4 = self.Maxpool(x3) x4 = self.Conv4(x4) x5 = self.Maxpool(x4) x5 = self.Conv5(x5) # decoding + concat path d5 = self.Up5(x5) # pdb.set_trace() d5 = torch.cat((x4, d5), dim=1) d5 = self.Up_conv5(d5) d4 = self.Up4(d5) d4 = torch.cat((x3, d4), dim=1) d4 = self.Up_conv4(d4) d3 = self.Up3(d4) d3 = torch.cat((x2, d3), dim=1) d3 = self.Up_conv3(d3) d2 = self.Up2(d3) d2 = torch.cat((x1, d2), dim=1) d2 = self.Up_conv2(d2) d1 = self.Conv_1x1(d2) d1 = torch.squeeze(d1) S_complex = d1 * torch.exp(1j * x_ang) synthesis = torch.istft( S_complex, n_fft=128) # The istft is not the same as that of librosa! # out = self.Fine_tune( synthesis ) out = synthesis return d1, out
def test_angle(self): a = ht.array([1.0, 1.0j, 1 + 1j, -2 + 2j, 3 - 3j]) angle = ht.angle(a) res = torch.angle(a.larray) self.assertIs(angle.device, self.device) self.assertIs(angle.dtype, ht.float) self.assertEqual(angle.shape, (5, )) self.assertTrue(torch.equal(angle.larray, res)) a = ht.array([1.0, 1.0j, 1 + 1j, -2 + 2j, 3 - 3j], split=0) angle = ht.angle(a) res = torch.angle(a.larray) self.assertIs(angle.device, self.device) self.assertIs(angle.dtype, ht.float) self.assertEqual(angle.shape, (5, )) self.assertTrue(torch.equal(angle.larray, res)) a = ht.array([[1.0, 1.0j], [1 + 1j, -2 + 2j], [3 - 3j, -4 - 4j]], split=1) angle = ht.angle(a, deg=True) res = ht.array( [[0.0, 90.0], [45.0, 135.0], [-45.0, -135.0]], dtype=ht.float32, device=self.device, split=1, ) self.assertIs(angle.device, self.device) self.assertIs(angle.dtype, ht.float32) self.assertEqual(angle.shape, (3, 2)) self.assertTrue(ht.equal(angle, res)) # Not complex a = ht.ones((4, 4), split=1) angle = ht.angle(a) res = ht.zeros((4, 4), split=1) self.assertIs(angle.device, self.device) self.assertIs(angle.dtype, ht.float32) self.assertEqual(angle.shape, (4, 4)) self.assertTrue(ht.equal(angle, res))
def sumofsq(image_in, keep_dims=False, axis=-1, name="sumofsq", type="mag"): """Compute square root of sum of squares.""" if type == "mag": image_out = torch.square(torch.abs(image_in)) else: image_out = torch.square(torch.angle(image_in)) image_out = torch.sum(image_out, keep_dims=keep_dims, axis=axis) image_out = torch.sqrt(image_out) return image_out
def gcc_features(complex_specs: torch.Tensor, n_mels: int) -> torch.Tensor: if not torch.is_complex(complex_specs): complex_specs = torch.view_as_complex(complex_specs) # based on the codes from DCASE2020 SELDnet cls_feature_class.py # complex_specs: [chan, freq, time] n_chan = complex_specs.size(0) gcc_chan = n_chan * (n_chan - 1) // 2 gcc_feat = [] for m in range(n_chan): for n in range(m + 1, n_chan): R = torch.conj(complex_specs[m]) * complex_specs[n] cc = torch.fft.irfft(torch.exp(1.j * torch.angle(R)), dim=0) cc = torch.cat([cc[-n_mels // 2:], cc[:(n_mels + 1) // 2]], axis=0) gcc_feat.append(cc) return torch.stack(gcc_feat, axis=0)
def get_phase_stft_magnitude( raw_data: torch.Tensor, sampling_rate_in_hz: int, window_length_in_s: float, window_shift_in_s: float, num_fft_points: int, window_type: str, ) -> torch.Tensor: stft = _get_stft(raw_data, sampling_rate_in_hz, window_length_in_s, window_shift_in_s, num_fft_points, window_type=window_type) abs_stft = torch.abs(stft) phase = torch.angle(stft) stft_phase = torch.cat([phase, abs_stft], dim=1) return torch.transpose(stft_phase, 0, 1)
def stft(self, x): """ Perform STFT. Args: x (Tensor): Input signal tensor (B, T). Returns: Tensor: x_mag, x_phs Magnitude and phase spectra (B, fft_size // 2 + 1, frames). """ x_stft = torch.stft(x, self.fft_size, self.hop_size, self.win_length, self.window, return_complex=True) x_mag = torch.abs(x_stft) x_phs = torch.angle(x_stft) return x_mag, x_phs
def update_variables(self): self.Psi_Qp[:] = 0 self.Psi_Qp_left_sb[:] = 0 self.Psi_Qp_right_sb[:] = 0 eps = 1e-3 single_sideband_reconstruction( self.G, self.Qx1d, self.Qy1d, self.Kx, self.Ky, self.C, np.deg2rad(self.rotation_deg), self.meta.alpha_rad, self.Psi_Qp, self.Psi_Qp_left_sb, self.Psi_Qp_right_sb, eps, self.meta.wavelength, ) self.Psi_Rp[:] = fft.ifft2(self.Psi_Qp, norm="ortho") self.Psi_Rp_left_sb[:] = fft.ifft2(self.Psi_Qp_left_sb, norm="ortho") self.Psi_Rp_right_sb[:] = fft.ifft2(self.Psi_Qp_right_sb, norm="ortho") self.Gamma = disk_overlap_function(self.Qx_max1d, self.Qy_max1d, self.Kx, self.Ky, self.C, np.deg2rad(self.rotation_deg), self.meta.alpha_rad, self.meta.wavelength) Psi = self.probe_gen( th.tensor(self.C.get()).cuda(), th.tensor(self.A).cuda()) self.phases = th.angle( th.fft.fftshift( self.probe_gen( th.tensor(self.C.get()).cuda(), th.tensor(self.A).cuda()))) self.Psi_shifted = th.fft.fftshift(Psi) self.psi = th.fft.fftshift(th.fft.ifft2(Psi))
def calculate_phase(field, deg=False): """ Definition to calculate phase of a single or multiple given electric field(s). Parameters ---------- field : torch.cfloat Electric fields or an electric field. deg : bool If set True, the angles will be returned in degrees. Returns ---------- phase : torch.float Phase or phases of electric field(s) in radians. """ phase = torch.angle(field) if deg == True: phase *= 180. / np.pi return phase
def phase_harmonics(z, k): """ Compute the phase harmonics of the input tensor. Parameters ---------- z : tensor Input. k : tensor Exponents. Returns ------- result : tensor Output. """ indices_k_0 = torch.where(k == 0)[0] indices_other_k = torch.where(k >= 2)[0] result = z.clone() del z # k == 0 result[..., indices_k_0, :, :] = torch.abs( torch.index_select(result, -3, indices_k_0)).to(result.dtype) # k == 1 is left unchanged # k >= 2 other_k = k[indices_other_k].unsqueeze(-1).unsqueeze(-1) z_other_k = torch.index_select(result, -3, indices_other_k) r = torch.abs(z_other_k) theta = torch.angle(z_other_k) result[..., indices_other_k, :, :] = r * (torch.cos(other_k * theta) + 1j * torch.sin(other_k * theta)) return result
def spectralResidueSaliency(image): """ this function is used to calculate the visual saliency map for the given image using the spectral residue method proposed by Xiaodi Hou and Liqing Zhang. For more details about this method, you can refer to the paper: Saliency detection: a spectral residual approach. there are some parameters needed to be adjusted """ scale = 0.25 # constant aveKernelSize = 3 # constant gauSigma = 3.8 # constant gauSize = 9 # constant # correction of built-in round function which # "for values exactly halfway between rounded decimal values, rounds to the nearest even value" # as opposite to matlab which always round it up def _round(a): return int(torch.rint(torch.nextafter(a, a + 1))) inImg = cv2.resize( image, (_round(scale * image.shape[1]), _round(scale * image.shape[0])), interpolation=cv2.INTER_CUBIC) myFFT = fft2(inImg) myLogAmplitude = torch.log(torch.abs(myFFT)) myPhase = torch.angle(myFFT) mySpectralResidual = myLogAmplitude - cv2.boxFilter( myLogAmplitude, -1, (aveKernelSize, aveKernelSize), cv2.BORDER_REPLICATE) saliencyMap = torch.abs(ifft2(torch.exp(mySpectralResidual + 1j * myPhase)))**2 blurred = cv2.GaussianBlur(saliencyMap, (gauSize, gauSize), gauSigma, gauSigma) saliencyMap = torch.nn.functional.normalize(blurred) return cv2.resize(saliencyMap, (image.shape[1], image.shape[0]))
def cartesian_to_polar(input): return torch.stack((torch.abs(input), torch.angle(input)), dim=3)
def GW_loss_prep(temp_index, data, y_pred, temp_mean, temp_sd, gw_mean, gw_std, num_task, type='fft'): # assumes that axis 0 of data and y_pred are the reaches and axis 1 are daily values # assumes the first two columns of data are the observed flow and temperature, and the remaining # ones (extracted here) are the data for gw analysis assert type == 'fft', "the groundwater loss calculation method must be fft" y_true = data[:, :, num_task:] y_true_temp = data[:, :, int(temp_index):(int(temp_index) + 1)] y_pred_temp = y_pred[:, :, int(temp_index):( int(temp_index) + 1)] # extract just the predicted temperature # unscale the predicted temps prior to calculating the amplitude and phase y_pred_temp = y_pred_temp * temp_sd + temp_mean y_true_temp = y_true_temp * temp_sd + temp_mean #set temps < 1 to 1 y_pred_temp[y_pred_temp < 1] = 1 y_true_temp[y_true_temp < 1] = 1 Ar_obs = y_true[:, 0, 0] delPhi_obs = y_true[:, 0, 1] Tmean_obs = y_true[:, 0, 2] if type == 'fft': y_pred_temp = torch.squeeze(y_pred_temp) y_pred_mean = torch.mean(y_pred_temp, 1, keepdims=True) temp_demean = y_pred_temp - y_pred_mean fft_torch = torch.fft.rfft(temp_demean) Phiw = torch.angle(fft_torch) phiIndex = torch.argmax(torch.abs(fft_torch), 1) Phiw_out = Phiw[:, 1] Aw = torch.max(torch.abs(fft_torch), 1).values / fft_torch.shape[ 1] # tf.shape(fft_tf, out_type=tf.dtypes.float32)[1] #get the air signal properties y_true_air = y_true[:, :, -1] y_true_air_mean = torch.mean(y_true_air, 1, keepdims=True) air_demean = y_true_air - y_true_air_mean fft_torch_air = torch.fft.rfft(air_demean) Phia = torch.angle(fft_torch_air) phiIndex_air = torch.argmax(torch.abs(fft_torch_air), 1) Phia_out = Phia[:, 1] Aa = torch.max(torch.abs(fft_torch_air), 1).values / fft_torch.shape[ 1] # tf.shape(fft_tf_air, out_type=tf.dtypes.float32)[1] # calculate and scale predicted values # delPhi_pred = the difference in phase between the water temp and air temp sinusoids, in days delPhi_pred = (Phia_out - Phiw_out) delPhi_pred = (delPhi_pred * 365 / (2 * m.pi) - gw_mean[1]) / gw_std[1] # Ar_pred = the ratio of the water temp and air temp amplitudes Ar_pred = (Aw / Aa - gw_mean[0]) / gw_std[0] elif type == "linalg": x_lm = y_true[:, :, -3:-1] #extract the sin(wt) and cos(wt) #a tensor of the sin(wt) and cos(wt) for each reach x day, the 1's are for the intercept of the linear regression # T(t) = T_mean + a*sin(wt) + b*cos(wt) # Johnson, Z.C., Johnson, B.G., Briggs, M.A., Snyder, C.D., Hitt, N.P., and Devine, W.D., 2021, Heed the data gap: Guidelines for #using incomplete datasets in annual stream temperature analyses: Ecological Indicators, v. 122, p. 107229, #http://www.sciencedirect.com/science/article/pii/S1470160X20311687. X_mat = torch.stack((torch.ones( y_pred_temp.shape[0:2]).to(device), x_lm[:, :, 0], x_lm[:, :, 1]), axis=1) #getting the coefficients using a 3-d version of the normal equation: #https://cmdlinetips.com/2020/03/linear-regression-using-matrix-multiplication-in-python-using-numpy/ #http://mlwiki.org/index.php/Normal_Equation X_mat_T = torch.permute(X_mat, dims=(0, 2, 1)) X_mat_T_dot = torch.einsum( 'bij,bjk->bik', X_mat_T, X_mat ) #eigensums are used instead of dot products because we want the dot products of axis 1 and 2, not 0 X_mat_inv = torch.linalg.pinv(X_mat_T_dot) X_mat_inv_dot = torch.einsum( 'bij,bjk->bik', X_mat_inv, X_mat_T ) #eigensums are used instead of dot products because we want the dot products of axis 1 and 2, not 0 a_b = torch.einsum( 'bij,bik->bjk', X_mat_inv_dot.float(), y_pred_temp.float() ) #eigensums are used instead of dot products because we want the dot products of axis 1 and 2, not 0 #the tensor a_b has the coefficients from the regression (reach x [[intercept],[a],[b]]) #Aw = amplitude of the water temp sinusoid (deg C) #A = sqrt (a^2 + b^2) Aw = torch.sqrt(a_b[:, 1, 0]**2 + a_b[:, 2, 0]**2) #Phiw = phase of the water temp sinusoid (radians) #Phi = atan (b/a) - in radians Phiw = torch.atan(a_b[:, 2, 0] / a_b[:, 1, 0]) #calculate the air properties y_true_air = y_true[:, :, -1:] a_b_air = torch.einsum('bij,bik->bjk', X_mat_inv_dot, y_true_air) A_air = torch.sqrt(a_b_air[:, 1, 0]**2 + a_b_air[:, 2, 0]**2) Phi_air = torch.atan(a_b_air[:, 2, 0] / a_b_air[:, 1, 0]) #calculate and scale predicted values #delPhi_pred = the difference in phase between the water temp and air temp sinusoids, in days delPhi_pred = Phi_air - Phiw delPhi_pred = (delPhi_pred * 365 / (2 * m.pi) - gw_mean[1]) / gw_std[1] #Ar_pred = the ratio of the water temp and air temp amplitudes Ar_pred = (Aw / A_air - gw_mean[0]) / gw_std[0] y_pred_temp = torch.squeeze(y_pred_temp) y_pred_mean = torch.mean(y_pred_temp, 1, keepdims=True) #scale the predicted mean temp Tmean_pred = torch.squeeze((y_pred_mean - gw_mean[2]) / gw_std[2]) return Ar_obs, Ar_pred, delPhi_obs, delPhi_pred, Tmean_obs, Tmean_pred
shape = 'same' ftshift = False x_np = np.array([1, 2, 3, 4, 5]) h_np = np.array([1 + 2j, 2, 3, 4, 5, 6, 7]) x_th = th.tensor(x_np) h_th = th.tensor(h_np) x_th = th.stack([x_th, th.zeros(x_th.size())], dim=-1) h_th = th.stack([h_th.real, h_th.imag], dim=-1) y1 = ts.fftconv1(x_th, h_th, axis=0, nfft=None, shape=shape, ftshift=ftshift) fftconv1layer = FFTConv1(h_th.size(0), h=h_th, nfft=None, shape=shape) for p in fftconv1layer.parameters(): print(p) y2 = fftconv1layer.forward(x_th) # y2 = th.view_as_complex(y2) y2 = y2.cpu().detach() # print(y1) # print(y2) print(th.sum(th.abs(y1 - y2)), th.sum(th.angle(y1) - th.angle(y2)))
def meta_angle_out(self, out): torch._resize_output_(out, self.size(), self.device) return out.copy_(torch.angle(self))
def vaerecon(ksp, coilmaps, mode, vae_model, gt, logdir, device, writer=False, norm=1, nsampl=100, boot_samples=500, k=1, patchsize=28, parfact=25, num_iter=200, stepsize=5e-4, lmb=0.01, num_priors=1, use_momentum=True): # Init data imcoils, imsizer, imsizec = ksp.shape ksp = ksp.to(device) coilmaps = coilmaps.to(device) vae_model = vae_model.to(device) uspat = (torch.abs(ksp[0]) > 0).type(torch.uint8).to(device) recs_gpu = tUFT_pytorch(ksp, uspat, coilmaps) rss = rss_pytorch(ksp) # Init coilmaps estimation with JSENSE if mode == 'JDDP': # Polynomial order max_basis_order = 6 num_coeffs = (max_basis_order + 1)**2 # Create the basis functions for the sense estimation estimation basis_funct = create_basis_functions(imsizer, imsizec, max_basis_order, show_plot=False) plot_basis = False if plot_basis: for i in range(num_coeffs): writer.log({ "Basis funcs": [ writer.Image(transforms.ToPILImage()(normalize_tensor( torch.from_numpy(basis_funct[i, :, :]))), caption="") ] }) basis_funct = torch.from_numpy( np.tile(basis_funct[np.newaxis, :, :, :], [coilmaps.shape[0], 1, 1, 1])).to(device) coeffs_array = sense_estimation_ls(ksp, recs_gpu, basis_funct, uspat) coilmaps = torch.sum( coeffs_array[:, :, np.newaxis, np.newaxis] * basis_funct, 1).to(device) recs_gpu = tUFT_pytorch(ksp, uspat, coilmaps) if writer: for i in range(coilmaps.shape[0]): writer.log( { "abs Coilmaps": [ writer.Image( transforms.ToPILImage()(normalize_tensor( torch.abs(coilmaps[i, :, :]))), caption="") ] }, step=0) writer.log( { "phase Coilmaps": [ writer.Image( transforms.ToPILImage()(normalize_tensor( torch.angle(coilmaps[i, :, :]))), caption="") ] }, step=0) print("Coilmaps init done") # Log if writer: writer.log( { "Gt rss": [ writer.Image(transforms.ToPILImage()(normalize_tensor(gt)), caption="") ] }, step=0) writer.log( { "Restored rss": [ writer.Image(transforms.ToPILImage()( normalize_tensor(rss)), caption="") ] }, step=0) writer.log( { "Restored abs": [ writer.Image(transforms.ToPILImage()(normalize_tensor( torch.abs(recs_gpu))), caption="") ] }, step=0) writer.log( { "Restored Phase": [ writer.Image(transforms.ToPILImage()(normalize_tensor( torch.angle(recs_gpu))), caption="") ] }, step=0) writer.log( { "diff rss": [ writer.Image(transforms.ToPILImage()(normalize_tensor( (rss.detach().cpu() / norm - gt.detach().cpu()))), caption="") ] }, step=0) ssim_v = ssim(rss[160:-160].detach().cpu().numpy() / norm, gt[160:-160].detach().cpu().numpy()) nmse_v = nmse(rss[160:-160].detach().cpu().numpy() / norm, gt[160:-160].detach().cpu().numpy()) psnr_v = psnr(rss[160:-160].detach().cpu().numpy() / norm, gt[160:-160].detach().cpu().numpy()) print('SSIM: ', ssim_v, ' NMSE: ', nmse_v, ' PSNR: ', psnr_v) writer.log({"SSIM": ssim_v, "NMSE": nmse_v, "PSNR": psnr_v}, step=0) lik, dc = prior_value(rss, ksp, uspat, coilmaps, patchsize, parfact, nsampl, vae_model) writer.log({"ELBO": lik}, step=0) writer.log({"DC err": dc}, step=0) t = 1 for it in range(0, num_iter, 2): print('Itr: ', it) # Magnitude prior projection step for _ in range(num_priors): # Gradient descent of Prior if mode == 'TV': tvnorm, abstvgrad = tv_norm(torch.abs(rss)) priorgrad = abstvgrad * recs_gpu / (torch.abs(recs_gpu)) recs_gpu = recs_gpu - stepsize * priorgrad if writer: #and it%10 == 0: writer.log( { "TVgrad": [ writer.Image(transforms.ToPILImage()( normalize_tensor(abstvgrad)), caption="") ] }, step=it + 1) writer.log( { "TV": [ writer.Image(transforms.ToPILImage()( normalize_tensor(tvnorm)), caption="") ] }, step=it + 1) elif mode == 'DDP' or mode == 'JDDP': g_abs_lik, est_uncert, g_dc = prior_gradient( rss, ksp, uspat, coilmaps, patchsize, parfact, nsampl, vae_model, boot_samples, mode) priorgrad = g_abs_lik * recs_gpu / (torch.abs(recs_gpu)) if it > -1: recs_gpu = recs_gpu - stepsize * priorgrad if writer: # Log writer.log( { "VAEgrad abs": [ writer.Image(transforms.ToPILImage()( normalize_tensor(torch.abs(g_abs_lik))), caption="") ] }, step=it + 1) writer.log({"STD": torch.mean(torch.abs(est_uncert))}, step=it + 1) tmp1 = UFT_pytorch(recs_gpu, 1 - uspat, coilmaps) tmp2 = ksp * uspat.unsqueeze(0) tmp = tmp1 + tmp2 rss = rss_pytorch(tmp) nmse_v = nmse( (rss[160:-160].detach().cpu().numpy() / norm), gt[160:-160].detach().cpu().numpy()) ssim_v = ssim(rss[160:-160].detach().cpu().numpy() / norm, gt[160:-160].detach().cpu().numpy()) psnr_v = psnr(rss[160:-160].detach().cpu().numpy() / norm, gt[160:-160].detach().cpu().numpy()) print('SSIM: ', ssim_v, ' NMSE: ', nmse_v, ' PSNR: ', psnr_v) writer.log({ "SSIM": ssim_v, "NMSE": nmse_v, "PSNR": psnr_v }, step=it + 1) else: print("Error: Prior method does not exists.") exit() # Phase projection step if lmb > 0: tmpa = torch.abs(recs_gpu) tmpp = torch.angle(recs_gpu) # We apply phase regularization to prefer smooth phase images #tmpptv = reg2_proj(tmpp, imsizer, imsizec, alpha=lmb, niter=2) # 0.1, 15 tmpptv = tv_proj(tmpp, mu=0.125, lmb=lmb, IT=50) # 0.1, 15 # We combine back the phase and the magnitude recs_gpu = tmpa * torch.exp(1j * tmpptv) # Coilmaps estimation step (if JSENSE) if mode == 'JDDP': # computed on cpu since pytorch gpu can handle complex numbers... coeffs_array = sense_estimation_ls(ksp, recs_gpu, basis_funct, uspat) coilmaps = torch.sum( coeffs_array[:, :, np.newaxis, np.newaxis] * basis_funct, 1).to(device) if writer: writer.log( { "abs Coilmaps": [ writer.Image( transforms.ToPILImage()(normalize_tensor( torch.abs(coilmaps[0, :, :]))), caption="") ] }, step=it + 1) writer.log( { "phase Coilmaps": [ writer.Image( transforms.ToPILImage()(normalize_tensor( torch.angle(coilmaps[0, :, :]))), caption="") ] }, step=it + 1) # Data consistency projection tmp1 = UFT_pytorch(recs_gpu, 1 - uspat, coilmaps) tmp2 = ksp * uspat.unsqueeze(0) tmp = tmp1 + tmp2 recs_gpu = tFT_pytorch(tmp, coilmaps) # recs[it + 2] = recs_gpu.detach().cpu().numpy() rss = rss_pytorch(tmp) # Log nmse_v = nmse((rss[160:-160].detach().cpu().numpy() / norm), gt[160:-160].detach().cpu().numpy()) ssim_v = ssim(rss[160:-160].detach().cpu().numpy() / norm, gt[160:-160].detach().cpu().numpy()) psnr_v = psnr(rss[160:-160].detach().cpu().numpy() / norm, gt[160:-160].detach().cpu().numpy()) print('SSIM: ', ssim_v, ' NMSE: ', nmse_v, ' PSNR: ', psnr_v) if writer: writer.log({ "SSIM": ssim_v, "NMSE": nmse_v, "PSNR": psnr_v }, step=it + 1) writer.log( { "Restored rss": [ writer.Image(transforms.ToPILImage()( normalize_tensor(rss)), caption="") ] }, step=it + 1) writer.log( { "Restored Phase": [ writer.Image(transforms.ToPILImage()(normalize_tensor( torch.angle(recs_gpu))), caption="") ] }, step=it + 1) writer.log( { "diff rss": [ writer.Image(transforms.ToPILImage()(normalize_tensor( (rss.detach().cpu() / norm - gt.detach().cpu()))), caption="") ] }, step=it + 1) writer.log( { "Restored 1ch kspace": [ writer.Image(transforms.ToPILImage()(normalize_tensor( torch.log(torch.abs(tmp[0])))), caption="") ] }, step=it + 1) lik, dc = prior_value(rss, ksp, uspat, coilmaps, patchsize, parfact, nsampl, vae_model) writer.log({"ELBO": lik}, step=it + 1) writer.log({"DC err": dc}, step=it + 1) return rss / norm
def L2_phase(x): return (torch.angle(torch.exp(1j * x[:,:-1] - 1j * x[:,1:]))**2).mean() + \ (torch.angle(torch.exp(1j * x[:,:,:-1] - 1j * x[:,:,1:]))**2).mean()
# ---Recieved signal Sr = chirp_recv(t, Tp, K, Fc, a=1., g=G, r=R) chirp = Chirp(Tp=Tp, K=K, Fc=Fc, a=1.) St = chirp.tran(t) Sr = chirp.recv(t, g=G, r=R) plt.figure() plt.subplot(221) plt.plot(t * 1e6, th.real(St)) plt.plot(t * 1e6, th.imag(St)) plt.xlabel('Time/us') plt.legend(['real', 'imag']) plt.subplot(222) plt.plot(t * 1e6, th.angle(St)) plt.xlabel('Time/us') plt.subplot(223) plt.plot(t * 1e6, th.real(Sr)) plt.plot(t * 1e6, th.imag(Sr)) plt.xlabel('Time/us') plt.legend(['real', 'imag']) plt.subplot(224) plt.plot(t * 1e6, th.angle(Sr)) plt.xlabel('Time/us') plt.show() # ---Frequency domain Yt = fftshift(fft(fftshift(St, dim=0), dim=0), dim=0) Yr = fftshift(fft(fftshift(Sr, dim=0), dim=0), dim=0)
axs[i, j].set_title(idx_to_class[tgt], fontsize=40) fig.colorbar(imshow, ax=axs[i, j]) ind += 1 # - # ## Analyse de la phase # + # %matplotlib inline fig, axs = plt.subplots(row_nb, col_nb, figsize=(row_nb * 10, col_nb * 10)) ind = 0 for i in range(row_nb): for j in range(col_nb): img, tgt = subset[ind] fft_img = torch.fft.fft2(img).squeeze() freq_img = torch.angle(fft_img) freq_x = torch.fft.fftshift(torch.fft.fftfreq(img.shape[2])).numpy() freq_y = torch.fft.fftshift(torch.fft.fftfreq(img.shape[1])).numpy() x_range = np.hstack((np.arange(0, img.shape[2], 50), np.array([img.shape[2] - 1]), np.array([img.shape[2] // 2]))) y_range = np.hstack((np.arange(0, img.shape[1], 50), np.array([img.shape[1] - 1]), np.array([img.shape[2] // 2]))) axs[i, j].set_xticks(x_range) axs[i, j].set_xticklabels(freq_x[x_range]) axs[i, j].set_yticks(y_range) axs[i, j].set_yticklabels(freq_y[y_range]) imshow = axs[i, j].imshow(freq_img) axs[i, j].set_title(idx_to_class[tgt], fontsize=40) fig.colorbar(imshow, ax=axs[i, j])
def pointwise_ops(self): a = torch.randn(4) b = torch.randn(4) t = torch.tensor([-1, -2, 3], dtype=torch.int8) r = torch.tensor([0, 1, 10, 0], dtype=torch.int8) t = torch.tensor([-1, -2, 3], dtype=torch.int8) s = torch.tensor([4, 0, 1, 0], dtype=torch.int8) f = torch.zeros(3) g = torch.tensor([-1, 0, 1]) w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637]) return ( torch.abs(torch.tensor([-1, -2, 3])), torch.absolute(torch.tensor([-1, -2, 3])), torch.acos(a), torch.arccos(a), torch.acosh(a.uniform_(1.0, 2.0)), torch.add(a, 20), torch.add(a, torch.randn(4, 1), alpha=10), torch.addcdiv(torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1), torch.addcmul(torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1), torch.angle(a), torch.asin(a), torch.arcsin(a), torch.asinh(a), torch.arcsinh(a), torch.atan(a), torch.arctan(a), torch.atanh(a.uniform_(-1.0, 1.0)), torch.arctanh(a.uniform_(-1.0, 1.0)), torch.atan2(a, a), torch.bitwise_not(t), torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)), torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)), torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)), torch.ceil(a), torch.clamp(a, min=-0.5, max=0.5), torch.clamp(a, min=0.5), torch.clamp(a, max=0.5), torch.clip(a, min=-0.5, max=0.5), torch.conj(a), torch.copysign(a, 1), torch.copysign(a, b), torch.cos(a), torch.cosh(a), torch.deg2rad( torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0, -90.0]])), torch.div(a, b), torch.divide(a, b, rounding_mode="trunc"), torch.divide(a, b, rounding_mode="floor"), torch.digamma(torch.tensor([1.0, 0.5])), torch.erf(torch.tensor([0.0, -1.0, 10.0])), torch.erfc(torch.tensor([0.0, -1.0, 10.0])), torch.erfinv(torch.tensor([0.0, 0.5, -1.0])), torch.exp(torch.tensor([0.0, math.log(2.0)])), torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])), torch.expm1(torch.tensor([0.0, math.log(2.0)])), torch.fake_quantize_per_channel_affine( torch.randn(2, 2, 2), (torch.randn(2) + 1) * 0.05, torch.zeros(2), 1, 0, 255, ), torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255), torch.float_power(torch.randint(10, (4, )), 2), torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4, -5])), torch.floor(a), # torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])), # torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4), torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2), torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5), torch.frac(torch.tensor([1.0, 2.5, -3.2])), torch.randn(4, dtype=torch.cfloat).imag, torch.ldexp(torch.tensor([1.0]), torch.tensor([1])), torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])), torch.lerp(torch.arange(1.0, 5.0), torch.empty(4).fill_(10), 0.5), torch.lerp( torch.arange(1.0, 5.0), torch.empty(4).fill_(10), torch.full_like(torch.arange(1.0, 5.0), 0.5), ), torch.lgamma(torch.arange(0.5, 2, 0.5)), torch.log(torch.arange(5) + 10), torch.log10(torch.rand(5)), torch.log1p(torch.randn(5)), torch.log2(torch.rand(5)), torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])), torch.logaddexp(torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3])), torch.logaddexp(torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3])), torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])), torch.logaddexp2(torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3])), torch.logaddexp2(torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3])), torch.logical_and(r, s), torch.logical_and(r.double(), s.double()), torch.logical_and(r.double(), s), torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)), torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)), torch.logical_not( torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)), torch.logical_not( torch.tensor([0.0, 1.0, -10.0], dtype=torch.double), out=torch.empty(3, dtype=torch.int16), ), torch.logical_or(r, s), torch.logical_or(r.double(), s.double()), torch.logical_or(r.double(), s), torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)), torch.logical_xor(r, s), torch.logical_xor(r.double(), s.double()), torch.logical_xor(r.double(), s), torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)), torch.logit(torch.rand(5), eps=1e-6), torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])), torch.i0(torch.arange(5, dtype=torch.float32)), torch.igamma(a, b), torch.igammac(a, b), torch.mul(torch.randn(3), 100), torch.multiply(torch.randn(4, 1), torch.randn(1, 4)), torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2), torch.tensor([float("nan"), float("inf"), -float("inf"), 3.14]), torch.nan_to_num(w), torch.nan_to_num(w, nan=2.0), torch.nan_to_num(w, nan=2.0, posinf=1.0), torch.neg(torch.randn(5)), # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]), torch.polygamma(1, torch.tensor([1.0, 0.5])), torch.polygamma(2, torch.tensor([1.0, 0.5])), torch.polygamma(3, torch.tensor([1.0, 0.5])), torch.polygamma(4, torch.tensor([1.0, 0.5])), torch.pow(a, 2), torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)), torch.rad2deg( torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]])), torch.randn(4, dtype=torch.cfloat).real, torch.reciprocal(a), torch.remainder(torch.tensor([-3.0, -2.0]), 2), torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5), torch.round(a), torch.rsqrt(a), torch.sigmoid(a), torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])), torch.sgn(a), torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])), torch.sin(a), torch.sinc(a), torch.sinh(a), torch.sqrt(a), torch.square(a), torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2), torch.tan(a), torch.tanh(a), torch.trunc(a), torch.xlogy(f, g), torch.xlogy(f, g), torch.xlogy(f, 4), torch.xlogy(2, g), )
def Prep(data): dic_data = isinstance(data, dict) if dic_data: angle = torch.tensor(data["angle"]).unsqueeze(dim=0) input = torch.from_numpy(data["mix"]).unsqueeze(dim=0).float() input = torch.transpose(input, 2, 1) data["mix"] = input.squeeze() R = data["R"] else: input = torch.from_numpy(data[0]).unsqueeze(dim=0).float() angle = torch.tensor(data[2]).unsqueeze(dim=0) R = data[5] return_list = [] mic_array_layout = R - np.tile(R[:, 0].reshape((3, 1)), (1, n_mic)) pairs = ((0, 3), (1, 4), (2, 5), (0, 1), (2, 3), (4, 5)) ori_pairs = ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5)) delay = np.zeros((n_mic, n_grid)) for h, m in enumerate(ori_pairs): dx = mic_array_layout[0, m[1]] - mic_array_layout[0, m[0]] dy = mic_array_layout[1, m[1]] - mic_array_layout[1, m[0]] for i in range(n_grid): delay[ h, i] = dx * np.cos(i * np.pi / 18) + dy * np.sin(i * np.pi / 18) delay = torch.from_numpy(delay).unsqueeze(dim=-1).expand(-1, -1, m_total) w = torch.exp(-2j * np.pi * m_data * delay) / V batch_size = input.size(0) mag, ph, real, image = stft.transform(input.reshape(-1, input.size()[-1])) pad = Variable(torch.zeros(mag.size()[0], mag.size()[1], 1)).type(input.type()) mag = torch.cat([mag, pad], -1) ph = torch.cat([ph, pad], -1) channel = mag.size()[-1] mag = mag.view(batch_size, n_mic, -1, channel) ph = ph.view(batch_size, n_mic, -1, channel) #LPS = 10 * torch.log10(mag ** 2 + 10e-20) complex = (mag * torch.exp(ph * 1j)) IPD_list = [] for m in pairs: com_u1 = complex[:, m[0]] com_u2 = complex[:, m[1]] IPD = torch.angle(com_u1) - torch.angle(com_u2) #IPD /= (frequency_vector + 1.0)[:, None] #IPD = IPD % (2 * np.pi) IPD = IPD.unsqueeze(dim=1) IPD_list.append(IPD) IPD = torch.cat(IPD_list, dim=1) complex = complex.unsqueeze(dim=2).expand(-1, -1, n_grid, -1, -1) for i in range(n_sp): ang = angle[:, i] steering_vector = __get_steering_vector(ang, pairs, mic_array_layout) steering_vector = steering_vector.unsqueeze(dim=-1) AF = steering_vector * torch.exp(1j * IPD) AF = AF / (torch.sqrt(AF.real**2 + AF.imag**2) + 10e-20) AF = AF.sum(dim=1) w_ = w.unsqueeze(dim=0).expand(AF.size()[0], -1, -1, -1).unsqueeze(-1).expand( -1, -1, -1, -1, channel) mod_w_com = (w_ * complex) * torch.conj(w_ * complex) dpr = mod_w_com.sum(dim=1) / ( (mod_w_com).sum(dim=1).sum(dim=1, keepdims=True) + 10e-20) p = (ang / np.pi * 18).type(torch.long) dpr = dpr[range(batch_size), p] feature_IPD = IPD.reshape(batch_size, IPD.size()[1] * IPD.size()[2], IPD.size(-1)) feature_list = [AF, torch.cos(feature_IPD), dpr] fusion = torch.cat(feature_list, dim=1).real.float() if dic_data: data[i] = fusion.squeeze() else: return_list.append(fusion) if not dic_data: return return_list
plt.title('Convolution matched filter') plt.xlabel(r'Time/$\mu s$') plt.ylabel('Amplitude') plt.subplot(222) plt.plot(t * 1e6, th.imag(Sm1)) plt.plot(t * 1e6, th.abs(Sm1)) plt.grid() plt.legend(['Imaginary part', 'Amplitude']) plt.title('Convolution matched filter') plt.xlabel(r'Time/$\mu s$') plt.ylabel('Amplitude') plt.subplot(223) plt.plot(f, th.abs(Ym1)) plt.grid() plt.subplot(224) plt.plot(f, th.angle(Ym1)) plt.grid() plt.figure(2) plt.subplot(221) plt.plot(t * 1e6, th.real(Sm2)) plt.plot(t * 1e6, th.abs(Sm2)) plt.grid() plt.legend(['Real part', 'Amplitude']) plt.title('Correlation matched filter') plt.xlabel(r'Time/$\mu s$') plt.ylabel('Amplitude') plt.subplot(222) plt.plot(t * 1e6, th.imag(Sm2)) plt.plot(t * 1e6, th.abs(Sm2)) plt.grid()
class TorchBox(qml.math.TensorBox): """Implements the :class:`~.TensorBox` API for Torch tensors. For more details, please refer to the :class:`~.TensorBox` documentation. """ abs = wrap_output(lambda self: torch.abs(self.data)) angle = wrap_output(lambda self: torch.angle(self.data)) arcsin = wrap_output(lambda self: torch.asin(self.data)) expand_dims = wrap_output( lambda self, axis: torch.unsqueeze(self.data, dim=axis)) ones_like = wrap_output(lambda self: torch.ones_like(self.data)) sqrt = wrap_output(lambda self: torch.sqrt( self.data.to(torch.float64) if self.data.dtype in (torch.int64, torch.int32) else self.data)) T = wrap_output(lambda self: self.data.T) @staticmethod def astensor(tensor): return torch.as_tensor(tensor) @wrap_output def cast(self, dtype): if isinstance(dtype, torch.dtype): return self.data.to(dtype) dtype_name = np.dtype(dtype).name torch_dtype = getattr(torch, dtype_name, None) if torch_dtype is None: raise ValueError(f"Unable to convert {dtype} to a Torch dtype") return self.data.to(torch_dtype) @staticmethod def _coerce_types(tensors): dtypes = {i.dtype for i in tensors} if len(dtypes) == 1: return tensors complex_priority = [torch.complex64, torch.complex128] float_priority = [torch.float16, torch.float32, torch.float64] int_priority = [torch.int8, torch.int16, torch.int32, torch.int64] complex_type = [i for i in complex_priority if i in dtypes] float_type = [i for i in float_priority if i in dtypes] int_type = [i for i in int_priority if i in dtypes] cast_type = complex_type or float_type or int_type cast_type = list(cast_type)[-1] return [t.to(cast_type) for t in tensors] @staticmethod @wrap_output def concatenate(values, axis=0): if axis is None: # flatten and then concatenate zero'th dimension # to reproduce numpy's behaviour tensors = [ TorchBox.astensor(t).flatten() for t in TorchBox.unbox_list(values) ] return torch.cat(tensors, dim=0) tensors = [TorchBox.astensor(t) for t in TorchBox.unbox_list(values)] return torch.cat(tensors, dim=axis) @staticmethod @wrap_output def dot(x, y): x, y = [TorchBox.astensor(t) for t in TorchBox.unbox_list([x, y])] x, y = TorchBox._coerce_types([x, y]) if x.ndim == 0 and y.ndim == 0: return x * y if x.ndim <= 2 and y.ndim <= 2: return x @ y return torch.tensordot(x, y, dims=[[-1], [-2]]) @property def interface(self): return "torch" def numpy(self): return self.data.detach().cpu().numpy() @property def requires_grad(self): return self.data.requires_grad @property def shape(self): return tuple(self.data.shape) @staticmethod @wrap_output def stack(values, axis=0): tensors = [TorchBox.astensor(t) for t in TorchBox.unbox_list(values)] res = torch.stack(tensors, axis=axis) return res @wrap_output def sum(self, axis=None, keepdims=False): if axis is None: return torch.sum(self.data) return torch.sum(self.data, dim=axis, keepdim=keepdims) @wrap_output def take(self, indices, axis=None): if not isinstance(indices, torch.Tensor): indices = self.astensor(indices) if axis is None: return self.data.flatten()[indices] if indices.ndim == 1: if (indices < 0).any(): # index_select doesn't allow negative indices dim_length = self.data.size( )[0] if axis is None else self.shape[axis] indices = qml.math.where(indices >= 0, indices, indices + dim_length) return torch.index_select(self.data, dim=axis, index=indices) fancy_indices = [slice(None)] * axis + [indices] return self.data[fancy_indices] @staticmethod @wrap_output def where(condition, x, y): return torch.where(TorchBox.astensor(condition), *TorchBox.unbox_list([x, y]))
def mag_phase(complex_tensor): return torch.abs(complex_tensor), torch.angle(complex_tensor)
loss = smooth_amplitude_loss(a_model, indices_target[take_ind], counts_target[take_ind]) loss_sum = loss.mean() sum_loss += loss_sum.item() loss_sum.backward() # if i > probe_start: # plotAbsAngle(psi_model.grad[0].cpu().detach().numpy(),'psi_model.grad') # plotAbsAngle(S_model[0].cpu().detach().numpy(), 'S_model') optimizer.step() optimizer.zero_grad() c = th.vdot(T[slic].ravel(), S_model[slic].ravel()) T_hat = T * th.exp(-1j * th.angle(c)) dist = th.norm(S_model[slic] - T_hat[slic]) x_norm = th.norm(T) err = dist / x_norm errs.append(err) sum_loss /= n_batches losses.append(sum_loss) print(f'{i:3d} loss: {sum_loss} err: {err}') # print(f'i {i} loss {sum_loss}, C_model = {C_model[0]} , C_target = {C_target[0]}') # %% d = margin + M[0] // 2 d = 1 plotAbsAngle(S_model[0, d:-d, d:-d].cpu().detach().numpy(), 'Reconstruction',
def forward(self, x): out1 = x.angle() out2 = torch.angle(x) return out1, out2