def batch_fft(data, normalize=False): """ Compute fourier transform of batch. Args: data: input tensor, (NxHxW) Returns: Batch fourier transform of input data. """ dim = data.ndim - 1 # subtract one for batch dimension if dim != 2: raise AttributeError(f'Data must be 2d but it is {dim}d.') dims = tuple(range(1, dim + 1)) # add one for batch dimension if normalize: norm = 'ortho' else: norm = 'backward' if not torch.is_complex(data): data = torch.complex(data, torch.zeros_like(data)) freq = fftn(data, dim=dims, norm=norm) return freq
def kspace_downsample_patch(hr_patch: torch.Tensor, center_crop: int = 25, end_crop: int = 6) -> torch.Tensor: """ Down-sample high resolution patch by k-space truncation. Note: end_crop worsens the picture quality much more than center_crop. Args: hr_patch: original high resolution patch center_crop: square dimension of center to be removed end_crop: final rows/cols to be removed Returns: lr_patch in shape (channel, patch_dim, patch_dim) """ lr_patch = fftn(hr_patch) # remove last n cols and rows if end_crop > 0: lr_patch[:, -end_crop:, :] = 0 lr_patch[:, :, -end_crop:] = 0 # remove square center if center_crop > 0: i = max(center_crop // 2, 1) x_center = lr_patch.shape[1] // 2 y_center = lr_patch.shape[2] // 2 lr_patch[:, x_center - i:x_center + i, y_center - i:y_center + i] = 0 return torch.abs(ifftn(lr_patch))
def apply(self, x: Tensor) -> Tensor: x_ = x.unflatten(dim=-1, sizes=self.size) y_ = fftn(x_, dim=(-2, -1), norm='ortho') y_ = y_.flatten(start_dim=-2) out_size = y_.size()[:-1] + (self.K,) y = torch.gather(y_, dim=-1, index=self.index.expand(out_size)) return y
def fft(input, inverse=False): """ Interface with torch FFT routines for 2D signals. Example ------- x = torch.randn(128, 32, 32, 2) x_fft = fft(x, inverse=True) Parameters ---------- input : tensor complex input for the FFT inverse : bool True for computing the inverse FFT. NB : if direction is equal to 'C2R', then the transform is automatically inverse. """ if not iscomplex(input): raise(TypeError('The input should be complex (e.g. last dimension is 2)')) if (not input.is_contiguous()): raise (RuntimeError('Tensors must be contiguous!')) if inverse: output = ifftn(input[..., 0] + 1j*input[..., 1], s=(-1, -1)) #output = torch.ifft(input, 2, normalized=False) #output = torch.fft.ifft(input, 2, norm= "forward") else: output = fftn(input[..., 0] + 1j*input[..., 1], s=(-1, -1)) #output = torch.fft(input, 2, normalized=False) #output = torch.fft.fft(input, 2, norm= "forward") output = torch.stack((output.real, output.imag), dim=-1) return output
def generate_single_frame(self, output_num, input_image=None): if not os.path.exists(self.output_folder): os.makedirs(self.output_folder) padding_size = 20 _, _, fx, fy = self.GridGenerate(self.image_size + 2 * padding_size, grid_mode='real') f_grid = pow((fx**2 + fy**2), 1 / 2) # The spatial freqneucy fr=sqrt( fx^2 + fy^2 ) OTF_padding = self.OTF_form(f_grid) random_distribution = torch.rand([self.image_size, self.image_size]) if input_image == None: input_image = torch.ones_like(random_distribution) random_distribution = random_distribution * input_image threhold = self.fluorophore_density fluorophore_num = round(self.fluorophore_density * self.image_size**2) fluorophore_loc = random_distribution < threhold while fluorophore_loc.sum() < fluorophore_num: threhold *= 2 fluorophore_loc = random_distribution < threhold fluorophore_GT = torch.zeros_like(random_distribution) fluorophore_GT[fluorophore_loc] = 1 ZeroPad_operation = nn.ZeroPad2d(20) fluorophore_GT_padding = ZeroPad_operation( fluorophore_GT) # 直接进行频域OTF滤波,会有边缘信息的串扰,用zero_padding方法去除 fluorophore_padding_diffractive_spectrum = torch_2d_fftshift( fft.fftn(fluorophore_GT_padding, dim=[0, 1])) * OTF_padding fluorophore_padding_diffractive_image = abs( fft.ifftn( torch_2d_ifftshift(fluorophore_padding_diffractive_spectrum), dim=[1, 0])) fluorophore_diffractive_image = fluorophore_padding_diffractive_image[ padding_size:-padding_size, padding_size:-padding_size] # print(fluorophore_loc.sum()) # common_utils.plot_single_tensor_image(fluorophore_diffractive_image) # image_size_real = self.image_size / self.downsample_rate AvgPool_operation = nn.AvgPool2d(kernel_size=self.downsample_rate, stride=self.downsample_rate) fluorophore_image_in_camera = AvgPool_operation( fluorophore_diffractive_image.unsqueeze(0).unsqueeze(0)).squeeze() # np.save(os.path.join(self.output_folder, output_num + '_label'), fluorophore_GT.numpy()) fluorophore_GT_loc = (fluorophore_GT.numpy() == 1) fluorophore_GT_loc_xy = np.where(fluorophore_GT_loc) x = fluorophore_GT_loc_xy[0].astype(np.int32) y = fluorophore_GT_loc_xy[1].astype(np.int32) label_file_dir = os.path.join(self.output_folder, output_num + '_label.txt') label_file = open(label_file_dir, 'w') for i in range(len(x)): label_file.write('{} {}\n'.format( x[i], y[i])) # todo there must be vector way label_file.close() return fluorophore_image_in_camera
def forward(self, add, model, data): self.checkDomainRange(model, data) if not add: data.zero() data[:] += fft.fftn(model.getNdArray(), s=self.nfft, dim=self.axes, norm='ortho') return
def forward(self, inputs): suscp = inputs[0] kernel = inputs[1] ks = fft.fftn(suscp, dim=[-3, -2, -1]) ks = ks * kernel fm = torch.real(fft.ifftn(ks, dim=[-3, -2, -1])) return fm
def __init__(self, D, m, b, wG, device, lambda_TV, P=1, alpha=0.5, rho=10): self.D = D self.m = m self.b = b self.wG = wG self.device = device self.lambda_TV = lambda_TV self.P = P self.alpha = alpha self.rho = rho self.Dconv = lambda x: torch.real(fft.ifftn(self.D * fft.fftn(x, dim=[0, 1, 2])))
def __call__(self, vis, u, v): input_grid = self.grid_2d(vis, u, v) input_grid = fftshift(input_grid, axes=None) out = fftn(input_grid) out = fftshift(out) alpha = self.config['alpha'] xl = int(0.5 * self.nx * (alpha - 1)) yl = int(0.5 * self.nx * (alpha - 1)) out = out[xl:xl + self.nx, yl:yl + self.ny] return out / self.gc
def closure(): optimizer.zero_grad() outputs = resnet(inputs_cat) outputs_cplx = outputs.type(torch.complex64) # loss RDFs_outputs = torch.real( fft.ifftn((fft.fftn(outputs_cplx, dim=[2, 3, 4]) * D), dim=[2, 3, 4])) diff = torch.abs(rdfs - RDFs_outputs) loss_fidelity = (1 - alpha) * 0.5 * torch.sum( (weights * diff)**2) loss_l2 = rho * 0.5 * torch.sum( (x - outputs[0, 0, ...] + mu)**2) loss = loss_fidelity + loss_l2 # loss = loss_fidelity loss.backward() return loss
def forward(self, x, up_feat_in): # separate feature for two frequency freq_x = fft.fftn(x, dim=(-2, -1)) freq_shift = fft.fftshift(freq_x, dim=(-2, -1)) # low_freq_shift = self.easy_low_pass_filter(freq_x) # high_freq_shift = self.easy_high_pass_filter(freq_x) low_freq_shift, high_freq_shift = self.guassian_low_high_pass_filter( freq_shift) low_freq_ishift = fft.ifftshift(low_freq_shift, dim=(-2, -1)) high_freq_ishift = fft.ifftshift(high_freq_shift, dim=(-2, -1)) _low_freq_x = torch.abs(fft.ifftn(low_freq_ishift, dim=(-2, -1))) _high_freq_x = torch.abs(fft.ifftn(high_freq_ishift, dim=(-2, -1))) low_freq_x = self.low_project(_low_freq_x) high_freq_x = self.high_project(_high_freq_x) feat = torch.cat([x, low_freq_x, high_freq_x], dim=1) context = self.out_project(feat) fuse_feature = context + x # Whether use skip connection or not if self.up_flag and self.smf_flag: if up_feat_in is not None: fuse_feature = self.upsample_add(up_feat_in, fuse_feature) up_feature = self.up(fuse_feature) smooth_feature = self.smooth(fuse_feature) return up_feature, smooth_feature if self.up_flag and not self.smf_flag: if up_feat_in is not None: fuse_feature = self.upsample_add(up_feat_in, fuse_feature) up_feature = self.up(fuse_feature) return up_feature if not self.up_flag and self.smf_flag: if up_feat_in is not None: fuse_feature = self.upsample_add(up_feat_in, fuse_feature) smooth_feature = self.smooth(fuse_feature) return smooth_feature
def forward(self, x): """Performs a forward pass over the data. Args: x (torch.Tensor): An input tensor for computing the forward pass. Returns: A tensor containing the DBN's outputs. """ #self.p = 0 frames = x.size(1) #frames dy, dx = x.size(2), x.size(3) ds = torch.zeros((x.size(0), frames, self.n_hidden)) # Checking whether GPU is avaliable and if it should be used if self.device == 'cuda': # Applies the GPU usage to the data x = x.cuda() ds = ds.cuda() for fr in range(frames): sps = x[:, fr, :, :].squeeze() # Creating the Fourier Spectrum spec_data = fftshift(fftn(sps))[:,:,:,0] spec_data = torch.abs(spec_data.squeeze()) # Flattening the samples' batch spec_data = spec_data.reshape(spec_data.size(0), self.n_visible) # Normalizing the samples' batch spec_data = ((spec_data - torch.mean(spec_data, 0, True)) / (torch.std(spec_data, 0, True) + c.EPSILON)).detach() spec_data, _ = self.hidden_sampling(spec_data) ds[:, fr, :] = spec_data.reshape((spec_data.size(0), self.n_hidden)) x.detach() sps.detach() return ds.detach()
def forward(self, x): # self.writer = writer freq_x = fft.fftn(x) freq_shift = fft.fftshift(freq_x) # low_freq_shift = self.easy_low_pass_filter(freq_x) # high_freq_shift = self.easy_high_pass_filter(freq_x) low_freq_shift, high_freq_shift = self.guassian_low_high_pass_filter(freq_shift) # low_freq_ishift = fft.ifftshift(low_freq_shift) high_freq_ishift = fft.ifftshift(high_freq_shift) # _low_freq_x = torch.abs(fft.ifftn(low_freq_ishift)) _high_freq_x = torch.abs(fft.ifftn(high_freq_ishift)) feat_rgb = self.sp(_high_freq_x) feat_dct = self.cp(x) feat_fuse = torch.cat((feat_rgb, feat_dct), dim=1) logits = self.head(feat_fuse) out = F.interpolate(logits, scale_factor=self.block_size, mode='bilinear', \ align_corners=True) return out
def batch_image_OTF_filter(self, batch_image): batch_image = batch_image.squeeze() ZeroPad_operation = nn.ZeroPad2d(20) padding_size = 20 _, _, fx, fy = self.GridGenerate(batch_image.shape[-1] + 2 * padding_size, grid_mode='real') OTF_padding = self.OTF_padding.to(batch_image.device) batch_image_padding = ZeroPad_operation( batch_image) # 直接进行频域OTF滤波,会有边缘信息的串扰,用zero_padding方法去除 batch_image_padding_diffractive_spectrum = torch_2d_fftshift( fft.fftn(batch_image_padding, dim=[1, 2 ])) * OTF_padding.unsqueeze(0) batch_image_padding_diffractive = abs( fft.ifftn( torch_2d_ifftshift(batch_image_padding_diffractive_spectrum), dim=[2, 1])) batch_image_diffractive = batch_image_padding_diffractive[:, padding_size: -padding_size, padding_size: -padding_size] return batch_image_diffractive
def generate_frame_batch(self): xx, yy, _, _ = self.GridGenerate(grid_mode='real') OTF = self.OTF random_distribution = torch.rand( [self.parallel_frames, self.image_size, self.image_size]) fluorophore_loc = random_distribution < self.fluorophore_density fluorophore_GT = torch.zeros_like(random_distribution) fluorophore_GT[fluorophore_loc] = 1 fluorophore_diffractive_spectrum = torch_2d_fftshift( fft.fftn(fluorophore_GT, dim=[1, 2])) * OTF.unsqueeze(0) fluorophore_diffractive_image = abs( fft.ifftn(torch_2d_ifftshift(fluorophore_diffractive_spectrum), dim=[2, 1])) common_utils.plot_single_tensor_image( fluorophore_diffractive_image[0, :, :]) # image_size_real = self.image_size / self.downsample_rate AvgPool_operation = nn.AvgPool2d(kernel_size=self.downsample_rate, stride=self.downsample_rate) fluorophore_image_in_camera = AvgPool_operation( fluorophore_diffractive_image.unsqueeze(0)).squeeze() return fluorophore_GT, fluorophore_image_in_camera
def _fft_convnd(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Tuple[int], padding: Tuple[int], dilation: Tuple[int], groups: int) -> Tensor: output_size = _conv_shape(input.shape[2:], weight.shape[2:], stride, padding, dilation) reversed_padding_repeated_twice = _reverse_repeat_tuple(padding, 2) padded_input = F.pad(input, reversed_padding_repeated_twice) s: List[int] = [] weight_s: List[int] = [] for i, (x_size, w_size, d, st) in enumerate( zip(padded_input.shape[2:], weight.shape[2:], dilation, stride)): s_size = max(x_size, w_size * d) # find s size that can be divided by stride and dilation rfft_even = 2 if i == len(stride) - 1 else 1 factor = _lcm(st * rfft_even, d * rfft_even) offset = s_size % factor if offset: s_size += factor - offset s.append(s_size) weight_s.append(s_size // d) X = rfftn(padded_input, s=s) W = rfft(weight, n=weight_s[-1]) # handle dilation # handle dilation for last dim if dilation[-1] > 1: W_neg_freq = W.flip(-1)[..., 1:] W_neg_freq.imag.mul_(-1) tmp = [W] for i in range(1, dilation[-1]): if i % 2: tmp.append(W_neg_freq) else: tmp.append(W[..., 1:]) W = torch.cat(tmp, -1) if len(weight_s) > 1: W = fftn(W, s=weight_s[:-1], dim=tuple(range(2, W.ndim - 1))) repeats = (1, 1) + dilation[:-1] + (1, ) W.imag.mul_(-1) if sum(repeats) > W.ndim: W = W.repeat(*repeats) else: W.imag.mul_(-1) Y = _complex_matmul(X, W, groups) # handle stride if len(stride) > 1: for i, st in enumerate(stride[:-1]): if st > 1: Y = Y.reshape(*Y.shape[:i + 2], st, -1, *Y.shape[i + 3:]).mean(i + 2) Y = ifft(Y, dim=i + 2) Y = Y.as_strided( Y.shape[:i + 2] + output_size[i:i + 1] + Y.shape[i + 3:], Y.stride()) if stride[-1] > 1: n_fft = Y.size(-1) * 2 - 2 new_n_fft = n_fft // stride[-1] step_size = new_n_fft // 2 strided_Y_size = step_size + 1 unfolded_Y_real = Y.real.unfold(-1, strided_Y_size, step_size) unfolded_Y_imag = Y.imag[..., 1:].unfold(-1, strided_Y_size - 2, step_size) Y_pos_real, Y_pos_imag = unfolded_Y_real[..., ::2, :].sum( -2), unfolded_Y_imag[..., ::2, :].sum(-2) Y_neg_real, Y_neg_imag = unfolded_Y_real[..., 1::2, :].sum(-2).flip( -1), unfolded_Y_imag[..., 1::2, :].sum(-2).flip(-1) Y_real = Y_pos_real.add_(Y_neg_real) Y_imag = Y_pos_imag.add_(Y_neg_imag, alpha=-1) Y_imag = F.pad(Y_imag, [1, 1]) Y = torch.view_as_complex(torch.stack((Y_real, Y_imag), -1)).div_(stride[-1]) output = irfft(Y) # Remove extra padded values output = output[..., :output_size[-1]].contiguous() # Optionally, add a bias term before returning. if bias is not None: output += bias[(slice(None), ) + (None, ) * (output.ndim - 2)] return output
def _fft_conv_transposend( input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Tuple[int], padding: Tuple[int], output_padding: Tuple[int], groups: int, dilation: Tuple[int], ) -> Tensor: output_size = _conv_transpose_shape(input.shape[2:], weight.shape[2:], stride, padding, output_padding, dilation) padded_output_size = tuple(o + 2 * p for o, p in zip(output_size, padding)) s: List[int] = [] weight_s: List[int] = [] for i, (x_size, w_size, d, st) in enumerate( zip(padded_output_size, weight.shape[2:], dilation, stride)): s_size = max(x_size, w_size * d) # find s size that can be divided by stride and dilation rfft_even = 2 if i == len(stride) - 1 else 1 factor = _lcm(st * rfft_even, d * rfft_even) offset = s_size % factor if offset: s_size += factor - offset s.append(s_size // st) weight_s.append(s_size // d) X = rfft(input, n=s[-1]) W = rfft(weight, n=weight_s[-1]) if stride[-1] > 1: X_neg_freq = X.flip(-1)[..., 1:] X_neg_freq.imag.mul_(-1) tmp = [X] for i in range(1, stride[-1]): if i % 2: tmp.append(X_neg_freq) else: tmp.append(X[..., 1:]) X = torch.cat(tmp, -1) if dilation[-1] > 1: W_neg_freq = W.flip(-1)[..., 1:] W_neg_freq.imag.mul_(-1) tmp = [W] for i in range(1, dilation[-1]): if i % 2: tmp.append(W_neg_freq) else: tmp.append(W[..., 1:]) W = torch.cat(tmp, -1) if len(s) > 1: X = fftn(X, s=s[:-1], dim=tuple(range(2, X.ndim - 1))) W = fftn(W, s=weight_s[:-1], dim=tuple(range(2, W.ndim - 1))) repeats = (1, 1) + stride[:-1] + (1, ) if sum(repeats) > X.ndim: X = X.repeat(*repeats) repeats = (1, 1) + dilation[:-1] + (1, ) if sum(repeats) > W.ndim: W = W.repeat(*repeats) Y = _complex_matmul(X, W, groups, True) output = irfftn(Y, dim=tuple(range(2, Y.ndim))) # Remove extra padded values index = (slice(None), ) * 2 + tuple( slice(p, o + p) for p, o in zip(padding, output_size)) output = output[index].contiguous() # Optionally, add a bias term before returning. if bias is not None: output += bias[(slice(None), ) + (None, ) * (output.ndim - 2)] return output
def fit(self, dataset, batch_size=128, epochs=10, frames=6): """Fits a new MultFRRBM model. Args: dataset (torch.utils.data.Dataset | Dataset): A Dataset object containing the training data. batch_size (int): Amount of samples per batch. epochs (list): Number of training epochs per layer. Returns: MSE (mean squared error) and log pseudo-likelihood from the training step. """ batches = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn) for ep in range(epochs): logger.info(f'Epoch {ep+1}/{epochs}') # Resetting epoch's MSE and pseudo-likelihood to zero mse, pl, cst = 0, 0, 0 inner_trans = tqdm.tqdm(total=len(batches), desc='Batch', position=1) start = time.time() for ii, batch in enumerate(batches): x, y = batch # Checking whether GPU is avaliable and if it should be used if self.device == 'cuda': x = x.cuda() mse2, pl2, cst2 = 0, 0, 0 cost, cost2 = 0, 0 # Initializing the gradient #self.models[1].optimizer.zero_grad() for fr in range(frames): x_ = x[:, fr, :, :].squeeze() spec_data = fftshift(fftn(x_))[:, :, :, 0] spec_data = torch.abs(spec_data.squeeze()) spec_data = spec_data.reshape(spec_data.size(0), self.n_visible) spec_data = ( (spec_data - torch.mean(spec_data, 0, True)) / (torch.std(spec_data, 0, True) + c.EPSILON)).detach() x_ = x_.reshape(x.size(0), self.n_visible) x_ = ((x_ - torch.mean(x_, 0, True)) / (torch.std(x_, 0, True) + c.EPSILON)).detach() # Performs the Gibbs sampling procedure _, _, _, _, visible_states = self.models[0].gibbs_sampling( spec_data) _, _, _, _, visible_states2 = self.models[ 1].gibbs_sampling(x_) # Calculates the loss for further gradients' computation cost = torch.mean( self.models[0].energy(spec_data)) - torch.mean( self.models[0].energy(visible_states)) cost2 = torch.mean(self.models[1].energy(x_)) - torch.mean( self.models[1].energy(visible_states2)) # Initializing the gradient self.models[0].optimizer.zero_grad() self.models[1].optimizer.zero_grad() # Computing the gradients #cost /= frames cost.backward() #cost2 /= frames cost2.backward() # Updating the parameters self.models[0].optimizer.step() self.models[1].optimizer.step() # Detaching the visible states from GPU for further computation visible_states = visible_states.detach() visible_states2 = visible_states2.detach() # Calculating current's batch MSE batch_mse1 = torch.div( torch.sum(torch.pow(spec_data - visible_states, 2)), batch_size).detach() batch_mse2 = torch.div( torch.sum(torch.pow(x_ - visible_states2, 2)), batch_size).detach() # Calculating the current's batch logarithm pseudo-likelihood batch_pl1 = self.models[0].pseudo_likelihood( spec_data).detach() batch_pl2 = self.models[1].pseudo_likelihood(x_).detach() # Summing up to epochs' MSE and pseudo-likelihood mse2 += (batch_mse1 + batch_mse2) pl2 += (batch_pl1 + batch_pl2) cst2 += (cost.detach() + cost2.detach()) mse2 /= frames pl2 /= frames cst2 /= frames #cost2 /= frames #cost2.backward() #self.models[1].optimizer.step() mse += mse2 pl += pl2 cst += cst2 if ii % 100 == 99: print('MSE:', (mse / ii).item(), 'Cost:', (cst / ii).item()) w8 = self.models[0].W.cpu().detach().numpy() img = _rasterize(w8.T, img_shape=(72, 96), tile_shape=(30, 30), tile_spacing=(1, 1)) im = Image.fromarray(img) im.save('w8_spec.png') w8 = self.models[1].W.cpu().detach().numpy() img = _rasterize(w8.T, img_shape=(72, 96), tile_shape=(30, 30), tile_spacing=(1, 1)) im = Image.fromarray(img) im.save('w8_gauss.png') x = visible_states[:100].cpu().detach().reshape( (100, 6912)).numpy() x = _rasterize(x, img_shape=(72, 96), tile_shape=(10, 10), tile_spacing=(1, 1)) im = Image.fromarray(x) im = im.convert("LA") im.save('spectral.png') x = visible_states2[:100].cpu().detach().reshape( (100, 6912)).numpy() x = _rasterize(x, img_shape=(72, 96), tile_shape=(10, 10), tile_spacing=(1, 1)) im = Image.fromarray(x) im = im.convert("LA") im.save('sample.png') inner_trans.update(1) mse /= len(batches) pl /= len(batches) cst /= len(batches) logger.info( f'MSE: {mse.item()} | log-PL: {pl.item()} | Cost: {cst.item()}' ) end = time.time() self.dump(mse=mse.item(), pl=pl.item(), fe=cst.item(), time=end - start) return mse, pl, cst
def fit(self, dataset, batch_size=128, epochs=10, frames=6): """Fits a new RBM model. Args: dataset (torch.utils.data.Dataset): A Dataset object containing the training data. batch_size (int): Amount of samples per batch. epochs (int): Number of training epochs. Returns: MSE (mean squared error) and log pseudo-likelihood from the training step. """ # Transforming the dataset into training batches batches = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn) # For every epoch for e in range(epochs): logger.info(f'Epoch {e+1}/{epochs}') # Calculating the time of the epoch's starting start = time.time() # Resetting epoch's MSE and pseudo-likelihood to zero mse, pl, cst = 0, 0, 0 # For every batch inner = tqdm.tqdm(total=len(batches), desc='Batch', position=1) for ii, batch in enumerate(batches): samples, _ = batch # Checking whether GPU is avaliable and if it should be used if self.device == 'cuda': samples = samples.cuda() mse2, pl2, cst2 = 0, 0, 0 cost = 0 # Initializing the gradient self.optimizer.zero_grad() for fr in range(frames): #torch.autograd.set_detect_anomaly(True) sps = samples[:, fr, :, :].squeeze() # Creating the Fourier Spectrum spec_data = fftshift(fftn(sps))[:,:,:,0] spec_data = torch.abs(spec_data.squeeze()) # Flattening the samples' batch spec_data = spec_data.view(spec_data.size(0), self.n_visible) # Normalizing the samples' batch spec_data = ((spec_data - torch.mean(spec_data, 0, True)) / (torch.std(spec_data, 0, True) + c.EPSILON)).detach() # Performs the Gibbs sampling procedure _, _, _, _, visible_states = self.gibbs_sampling(spec_data) # Calculates the loss for further gradients' computation cost += torch.mean(self.energy(spec_data)) - \ torch.mean(self.energy(visible_states)) # Detaching the visible states from GPU for further computation visible_states = visible_states.detach() # Gathering the size of the batch batch_size2 = sps.size(0) # Calculating current's batch MSE batch_mse = torch.div( torch.sum(torch.pow(spec_data - visible_states, 2)), batch_size2).detach() # Calculating the current's batch logarithm pseudo-likelihood batch_pl = self.pseudo_likelihood(spec_data).detach() # Summing up to epochs' MSE and pseudo-likelihood mse2 += batch_mse pl2 += batch_pl cst2 += cost.detach() # Computing the gradients cost /= frames cost.backward() # Updating the parameters self.optimizer.step() mse2 /= frames pl2 /= frames cst2 /= frames mse += mse2 pl += pl2 cst += cst2 if ii % 100 == 99: print('MSE:', (mse/ii).item(), 'Cost:', (cst/ii).item()) w8 = self.W.cpu().detach().numpy() img = _rasterize(w8.T, img_shape=(72, 96), tile_shape=(30, 30), tile_spacing=(1, 1)) im = Image.fromarray(img) im.save('w8_spec.png') x = visible_states[:100].cpu().detach().reshape((100, 6912)).numpy() x = _rasterize(x, img_shape=(72, 96), tile_shape=(10, 10), tile_spacing=(1, 1)) im = Image.fromarray(x) im = im.convert("LA") im.save('spectral.png') inner.update(1) # Normalizing the MSE and pseudo-likelihood with the number of batches mse /= len(batches) pl /= len(batches) cst /= len(batches) # Calculating the time of the epoch's ending end = time.time() # Dumps the desired variables to the model's history self.dump(mse=mse.item(), pl=pl.item(), fe=cst.item(), time=end-start) logger.info(f'MSE: {mse} | log-PL: {pl} | Cost: {cst}') self.p = 0 return mse, pl, cst
def reconstruct(self, dataset, bs=2**7): """Reconstructs batches of new samples. Args: dataset (torch.utils.data.Dataset): A Dataset object containing the testing data. Returns: Reconstruction error and visible probabilities, i.e., P(v|h). """ logger.info(f'Reconstructing new samples ...') # Resetting MSE to zero mse = 0 # Defining the batch size as the amount of samples in the dataset batch_size = bs # Transforming the dataset into training batches batches = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn) # For every batch inner = tqdm.tqdm(total=len(batches), desc='Batch', position=1) for _, batch in enumerate(batches): x, _ = batch frames = x.size(1) #frames dy, dx = x.size(2), x.size(3) reconstructed = torch.zeros((bs, frames, self.n_visible)) # Checking whether GPU is avaliable and if it should be used if self.device == 'cuda': # Applies the GPU usage to the data x = x.cuda() reconstructed = reconstructed.cuda() for fr in range(frames): sps = x[:, fr, :, :].squeeze() # Creating the Fourier Spectrum spec_data = fftshift(fftn(sps))[:,:,:,0] spec_data = torch.abs(spec_data.squeeze()) spec_data.detach() # Flattening the samples' batch spec_data = spec_data.view(spec_data.size(0),self.n_visible) # Normalizing the samples' batch spec_data = ((spec_data - torch.mean(spec_data, 0, True)) / (torch.std(spec_data, 0, True) + c.EPSILON)).detach() # Performs the Gibbs sampling procedure _, _, _, _, visible_states = self.gibbs_sampling(spec_data) visible_states = visible_states.detach() # Passing reconstructed data to a tensor reconstructed[:, fr, :] = visible_states # Calculating current's batch reconstruction MSE batch_mse = torch.div( torch.sum(torch.pow(x.reshape((len(x), frames, dy*dx)) - visible_states, 2)), bs).detach() # Summing up the reconstruction's MSE mse += batch_mse inner.update(1) break # Normalizing the MSE with the number of batches mse /= len(batches) logger.info(f'MSE: {mse}') return mse, x, reconstructed
def ufft(data: Tensor, mask: Tensor, signal_ndim: int, normalized: bool = False) -> Tensor: """Undersampled fast Fourier transform.""" return mask * fftn(data, dim=2, norm='ortho')
(weights * diff)**2) loss_l2 = rho * 0.5 * torch.sum( (x - outputs[0, 0, ...] + mu)**2) loss = loss_fidelity + loss_l2 # loss = loss_fidelity loss.backward() return loss optimizer.step(closure) # forward again to compute fidelity loss outputs = resnet(inputs_cat) outputs_cplx = outputs.type(torch.complex64) # loss RDFs_outputs = torch.real( fft.ifftn((fft.fftn(outputs_cplx, dim=[2, 3, 4]) * D), dim=[2, 3, 4])) diff = torch.abs(rdfs - RDFs_outputs) loss_fidelity = torch.sum((weights * diff)**2) fidelity_fine = 'epochs: [%d/%d], Ks: [%d/%d], time: %ds, Fidelity loss: %f' % ( epoch, niter, k + 1, K, time.time() - t0, loss_fidelity.item()) print(fidelity_fine) if k == K - 1: file.write(fidelity_fine) file.write('\n') # dual update with torch.no_grad(): mu = mu + x - outputs[0, 0, ...] # # metrics
def _compute_variances(self, Phi: Projection, alpha: Tensor, kmasks: Tensor, num_probes: int, num_cg_iters: int = 32, cg_tol: float = 1e-10) -> Tensor: num_contrasts, size_x, size_y = kmasks.size() masks = ~kmasks.unsqueeze(dim=0) z = self._samp_probes((num_probes, num_contrasts, size_x, size_y)) b = fftn(z, dim=(-2, -1), norm='ortho') b = masks * b b_lst = [] norm = torch.tensor(0., device=self.device) if 'x' in self.grad_dim: kx = torch.arange(size_x, device=self.device).view(1, 1, -1, 1) kfactor_x = (1 - torch.exp(-2 * np.pi * 1j * kx / size_x)) b_x = kfactor_x * b b_lst.append(b_x) norm = norm + torch.abs(kfactor_x)**2 if 'y' in self.grad_dim: ky = torch.arange(size_y, device=self.device).view(1, 1, 1, -1) kfactor_y = (1 - torch.exp(-2 * np.pi * 1j * ky / size_y)) b_y = kfactor_y * b b_lst.append(b_y) norm = norm + torch.abs(kfactor_y)**2 corr = torch.zeros((1, 1, size_x, size_y), device=self.device) if self.grad_dim == 'x': corr[0, 0, 0, :] = 1 elif self.grad_dim == 'y': corr[0, 0, :, 0] = 1 elif self.grad_dim == 'xy': corr[0, 0, 0, 0] = 1 norm = norm + corr b = torch.stack(b_lst, dim=1) / norm.unsqueeze(dim=1) b = ifftn(b, dim=(-2, -1), norm='ortho').real b = b.flatten(start_dim=-2) alpha = alpha.unsqueeze(dim=1).unsqueeze(dim=0) A = lambda x: self.alpha0 * (Phi.T(Phi(x))) + alpha * x out, _ = conjugate_gradient(A, b, -1, num_cg_iters, cg_tol) out = out.unflatten(dim=-1, sizes=(size_x, size_y)) out = fftn(out, dim=(-2, -1), norm='ortho') if 'x' in self.grad_dim: out[:, 0] = torch.conj(kfactor_x) * out[:, 0] / norm if 'y' in self.grad_dim: out[:, -1] = torch.conj(kfactor_y) * out[:, -1] / norm if self.grad_dim == 'xy': out = out[:, 0] + out[:, -1] else: out = out.squeeze(dim=1) out = masks * out out = ifftn(out, dim=(-2, -1), norm='ortho').real var = (z * out).mean(dim=0).clamp(min=0) return var
def apply(self, x: Tensor) -> Tensor: y = fftn(x, dim=(-2, -1), norm='ortho') y[..., ~self.mask] = 0 return y
def convolve_fft(array, kernel, axes=None): arrayfft = fftn(array, dim=axes) kernelfft = fftn(ifftshift(kernel, axes=axes), dim=axes) fftmult = kernelfft * arrayfft return torch.real(ifftn(fftmult, dim=axes))
match_dict['feature.2.weight'] = 'conv2_w' match_dict['feature.2.bias'] = 'conv2_b' for var_name in net.state_dict().keys(): print(var_name) key_in_model = match_dict[var_name] param_in_model = var_name.rsplit('.', 1)[1] if 'weight' in var_name: pth_state_dict[var_name] = torch.Tensor( np.transpose(p[key_in_model], (3, 2, 0, 1))) elif 'bias' in var_name: pth_state_dict[var_name] = torch.Tensor(np.squeeze( p[key_in_model])) if var_name == 'feature.0.weight': weight = pth_state_dict[var_name].data.numpy() weight = weight[:, ::-1, :, :].copy() # cv2 bgr input pth_state_dict[var_name] = torch.Tensor(weight) torch.save(pth_state_dict, 'param.pth') net.load_state_dict(torch.load('param.pth')) x_t = torch.Tensor(np.expand_dims(np.transpose(x, (2, 0, 1)), axis=0)) x_pred = net(x_t).data.numpy() pred_error = np.sum( np.abs( np.transpose(x_pred, (0, 2, 3, 1)).reshape(-1) - x_out.reshape(-1))) x_fft = fft.fftn(x_t, dim=[-2, -1]) print('model_transfer_error:{:.5f}'.format(pred_error))
t = cn.mulconj(xf, zf) kxzf = torch.sum(t, dim=1, keepdim=True) # [batch, 1, 121, 61, 2] alphaf = label.to(device=z.device) / (kzzf + lambda0) # [batch, 1, 121, 121] return torch.irfft(cn.mul(kxzf, alphaf), signal_ndim=2) ############################################## x = torch.rand((42, 32, 121, 121)) a = torch.rfft(x, signal_ndim=2, onesided=False) b = fft.fftn(x, dim=[-2, -1]) ca = torch.view_as_complex(a) print(a.shape) print(b.shape) print(torch.allclose(ca, b)) u = ca - b v = u.abs() h = torch.histc(v) import matplotlib.pyplot as plt plt.hist(v.flatten().numpy(), bins=500, log=True) plt.show() exit()
output = _fft(input, 1, normalized=(norm == 'ortho')) if norm == 'forward': output /= float(n) # Make complex and move back dimension to its original position if _torch_has_complex: output = torch.view_as_complex(output) output = utils.movedim(output, -1, dim) else: output = utils.movedim(output, -2, dim if dim >= 0 else dim - 1) return output if _torch_has_fft_module: fftn = lambda *a, real=None, **k: fft_mod.fftn(*a, **k) else: def fftn(input, s=None, dim=None, norm='backward', real=None): """N-dimensional discrete Fourier transform. Parameters ---------- input : tensor Input signal. If torch <= 1.5, the last dimension must be of length 2 and contain the real and imaginary parts of the signal, unless `real is True`. s : sequence[int], optional Signal size in the transformed dimensions. If given, each dimension dim[i] will either be zero-padded or
def _compute_imgs(self, grad: Tensor, kspaces: Tensor) -> Tensor: _, size_x, size_y = self.kspaces.size() num_grads, num_contrasts, _ = grad.size() grad = grad.view((self.num_grads, num_contrasts, size_x, size_y)) grad_old = grad if self.complex_imgs and self.tie_real_imag: num_contrasts //= 2 grad = grad[:, :num_contrasts] + 1j * grad[:, num_contrasts:] kspaces = kspaces[:num_contrasts] elif self.complex_imgs and not self.tie_real_imag: num_grads = self.num_grads // 2 grad = grad[:num_grads] + 1j * grad[num_grads:] img_fft = torch.tensor(0., device=self.device) norm = torch.tensor(0., device=self.device) if 'x' in self.grad_dim: kx = torch.arange(size_x, device=self.device).view(1, -1, 1) kfactor_x = (1 - torch.exp(-2 * np.pi * 1j * kx / size_x)) grad_x = grad[0] grad_x_fft = fftn(grad_x, dim=(-2, -1), norm='ortho') img_fft = img_fft + torch.conj(kfactor_x) * grad_x_fft norm = norm + torch.abs(kfactor_x)**2 if 'y' in self.grad_dim: ky = torch.arange(size_y, device=self.device).view(1, 1, -1) kfactor_y = (1 - torch.exp(-2 * np.pi * 1j * ky / size_y)) grad_y = grad[-1] grad_y_fft = fftn(grad_y, dim=(-2, -1), norm='ortho') img_fft = img_fft + torch.conj(kfactor_y) * grad_y_fft norm = norm + torch.abs(kfactor_y)**2 corr = torch.zeros((1, size_x, size_y), device=self.device) if self.grad_dim == 'x': corr[0, 0, :] = 1 elif self.grad_dim == 'y': corr[0, :, 0] = 1 elif self.grad_dim == 'xy': corr[0, 0, 0] = 1 norm = norm + corr img_fft = img_fft / norm * (self.kspaces == 0) + self.kspaces img = ifftn(img_fft, dim=(-2, -1), norm='ortho') # img.real = img.real.clamp(min=0, max=1) # img.imag = img.imag.clamp(min=0, max=1) if self.normalize: img = img * self.scale + self.bias if not self.complex_imgs: img = img.real # if self.complex_imgs: # num_contrasts //= 2 # img_real = img[:num_contrasts] # img_imag = img[num_contrasts:] # img = img_real + 1j * img_imag return img