def A(self, hf: TensorList): # Classify sh = complex.mtimes(self.training_samples, hf.permute(2,3,1,0,4)) # (h, w, num_samp, num_filt, 2) sh = complex.mult(self.sample_weights.view(1,1,-1,1), sh) # Multiply with transpose hf_out = complex.mtimes(sh.permute(0,1,3,2,4), self.training_samples, conj_b=True).permute(2,3,0,1,4) # Add regularization for hfe, hfe_out, reg_filter in zip(hf, hf_out, self.reg_filter): reg_pad1 = min(reg_filter.shape[-2] - 1, hfe.shape[-3] - 1) reg_pad2 = min(reg_filter.shape[-1] - 1, 2*hfe.shape[-2]- 2) # Add part needed for convolution if reg_pad2 > 0: hfe_conv = torch.cat([complex.conj(hfe[...,1:reg_pad2+1,:].flip((2,3))), hfe], -2) else: hfe_conv = hfe.clone() # Shift data to batch dimension hfe_conv = hfe_conv.permute(0,1,4,2,3).reshape(-1, 1, hfe_conv.shape[-3], hfe_conv.shape[-2]) # Do first convolution hfe_conv = F.conv2d(hfe_conv, reg_filter, padding=(reg_pad1, reg_pad2)) # Do second convolution remove_size = min(reg_pad2, hfe.shape[-2]-1) hfe_conv = F.conv2d(hfe_conv[...,remove_size:], reg_filter) # Reshape back and add hfe_out += hfe_conv.reshape(hfe.shape[0], hfe.shape[1], 2, hfe.shape[2], hfe.shape[3]).permute(0,1,3,4,2) return hf_out
def __call__(self, x: TensorList): """ Compute residuals :param x: [filters, projection_matrices] :return: [data_terms, filter_regularizations, proj_mat_regularizations] """ hf = x[:len(x) // 2] P = x[len(x) // 2:] compressed_samples = complex.mtimes(self.training_samples, P) residuals = complex.mtimes(compressed_samples, hf.permute( 2, 3, 1, 0, 4)) # (h, w, num_samp, num_filt, 2) residuals = residuals - self.yf if self.sample_weights_sqrt is not None: residuals = complex.mult( self.sample_weights_sqrt.view(1, 1, -1, 1), residuals) # Add spatial regularization for hfe, reg_filter in zip(hf, self.reg_filter): reg_pad1 = min(reg_filter.shape[-2] - 1, hfe.shape[-3] - 1) reg_pad2 = min(reg_filter.shape[-1] - 1, hfe.shape[-2] - 1) # Add part needed for convolution if reg_pad2 > 0: hfe_left_padd = complex.conj( hfe[..., 1:reg_pad2 + 1, :].clone().detach().flip((2, 3))) hfe_conv = torch.cat([hfe_left_padd, hfe], -2) else: hfe_conv = hfe.clone() # Shift data to batch dimension hfe_conv = hfe_conv.permute(0, 1, 4, 2, 3).reshape(-1, 1, hfe_conv.shape[-3], hfe_conv.shape[-2]) # Do first convolution hfe_conv = F.conv2d(hfe_conv, reg_filter, padding=(reg_pad1, reg_pad2)) residuals.append(hfe_conv) # Add regularization for projection matrix residuals.extend(math.sqrt(self.params.projection_reg) * P) return residuals
def symmetrize_filter(self): for hf in self.filter: hf[:, :, :, 0, :] /= 2 hf[:, :, :, 0, :] += complex.conj(hf[:, :, :, 0, :].flip((2, )))