Exemple #1
0
 def update(self, dict_in):
     if self.data == []:
         if self.transform != None:
             #this transform is useful if you want to compute the isnr
             #on, say, a masked version of the image, where the transform
             #is a mask operator
             self.x = (self.transform * dict_in['x']).flatten()
         else:
             if dict_in[self.y_key].shape != dict_in['x'].shape:
                 pdb.set_trace()
                 self.x = crop_center(dict_in['x'],
                                      dict_in['y'].shape).flatten()
             else:
                 self.x = dict_in['x'].flatten()
         self.y = dict_in[self.y_key].flatten()
     if self.transform != None:
         x_n = (self.transform * dict_in['x_n']).flatten()
     else:
         if dict_in[self.y_key].shape != dict_in['x_n'].shape:
             x_n = crop_center(dict_in['x_n'],
                               dict_in[self.y_key].shape).flatten()
         else:
             x_n = dict_in['x_n'].flatten()
     value = 10 * log10(
         (norm(self.y - self.x, 2)**2) / (norm(x_n - self.x, 2)**2))
     super(ISNR, self).update(value)
 def solve(self, dict_in):
     super(RichardsonLucy, self).solve()
     H = self.H
     #input data
     x_n = dict_in['x_0'].copy()
     b = dict_in['b']  #background
     sigma_sq = dict_in['noisevariance']
     #dummy multiply to intialize H
     H * x_n
     dict_in['x_n'] = su.crop_center(x_n, dict_in['y'].shape)
     gamma = (~H) * np.ones(dict_in['y'].shape)
     #begin iterations here
     self.results.update(dict_in)
     print 'Finished itn: n=' + str(0)
     if self.profile:
         dict_profile = {}
         dict_profile['twoft_time'] = []
         dict_profile['other_time'] = []
         dict_profile['ht_time'] = []
         dict_in['profiling'] = dict_profile
     for n in np.arange(self.int_iterations):
         #save current iterate
         twoft_0 = time.time()
         div = (H * x_n + b)
         twoft_1 = time.time()
         if self.profile:
             dict_profile['twoft_time'].append(twoft_1 - twoft_0)
         other_time_0 = time.time()
         div = dict_in['y'] / div
         div[div == np.nan] = 0.0
         other_time_1 = time.time()
         if self.profile:
             dict_profile['other_time'].append(other_time_1 - other_time_0)
         twoft_2 = time.time()
         x_n = ((~H) * div) * x_n / gamma
         twoft_3 = time.time()
         if self.profile:
             dict_profile['ht_time'].append(twoft_3 - twoft_2)
         x_n = su.crop_center(x_n, dict_in['y'].shape)
         dict_in['x_n'] = x_n
         x_n = su.pad_center(x_n, dict_in['x_0'].shape)
         #update results
         self.results.update(dict_in)
         print 'Finished itn: n=' + str(n + 1)
     return dict_in
Exemple #3
0
 def update(self, dict_in):
     """
     Expects a single value or array. If array, store the whole vector and stop.
     """
     if self.data == []:
         if dict_in['fb'].shape != dict_in['x'].shape:
             self.x = crop_center(dict_in['x'],
                                  dict_in['fb'].shape).flatten()
             self.fb = dict_in['fb'].flatten()
         else:
             self.x = dict_in['x'].flatten()
             self.fb = dict_in['fb'].flatten()
     if dict_in['fb'].shape != dict_in['x_n'].shape:
         x_n = crop_center(dict_in['x_n'], dict_in['fb'].shape).flatten()
     else:
         x_n = dict_in['x_n'].flatten()
     value = mean(((x_n - self.x)**2) / self.fb)
     self.data.append(value)
     super(NMISE, self).update()
Exemple #4
0
 def update(self, dict_in):
     if self.data == []:
         if dict_in['y'].shape != dict_in['x'].shape:
             self.x = crop_center(dict_in['x'], dict_in['y'].shape)
             self.y = dict_in['y']
         else:
             self.x = dict_in['x']
             self.y = dict_in['y']
         self.L = np.max(self.x) - np.min(self.x)
     if dict_in['y'].shape != dict_in['x_n'].shape:
         x_n = crop_center(dict_in['x_n'], dict_in['y'].shape)
     else:
         x_n = dict_in['x_n']
     if dict_in['x_n'].ndim == 2:
         value = self.compute_ssim(x_n, self.x)
     elif dict_in['x_n'].ndim == 3:
         value = self.compute_ssim_3d(x_n, self.x)
     else:
         raise Exception('unsupported number of dimensions in x_n')
     self.data.append(value)
     super(SSIM, self).update()
Exemple #5
0
    def update(self, dict_in):
        """
        Expects a single value or array. If array, store the whole vector and stop.
        """
        if self.data == []:
            self.xshape = dict_in['x'].shape
            self.x = dict_in['x'].flatten()
            if self.peak == 0:
                self.peak = nmax(self.x)
            if self.bordercrop != 0:
                # self.slices=tuple([slice(self.bordercrop,-self.bordercrop)
                #                    for i in xrange(len(self.xshape))])
                self.crop_center_size = tuple(
                    [el - 2 * self.bordercrop for el in self.xshape])
                self.x = crop_center(dict_in['x'], self.crop_center_size)
                self.peak = nmax(self.x)
            if self.bytecompare:
                self.x = np.asarray(self.x / self.peak * 255.0, dtype='uint8')
                self.peak = nmax(self.x)
            if self.get_val('peak', True) > 0:
                self.peak = self.get_val('peak', True)
            self.x = self.x.flatten()

        if dict_in['x_n'].shape != self.xshape:
            x_n = crop_center(dict_in['x_n'], self.xshape).flatten()
        else:
            x_n = dict_in['x_n']
            if self.bordercrop != 0:
                x_n = crop_center(x_n, self.crop_center_size)
            if self.bytecompare:
                x_n = np.asarray(x_n, dtype='uint8')
            x_n = x_n.flatten()
        mse = mean((x_n - self.x)**2)
        if mse == 0:
            snr_db = np.inf
        else:
            snr_db = 10 * log10((self.peak**2) / mse)
        value = snr_db
        self.data.append(value)
        super(PSNR, self).update()
Exemple #6
0
 def __mul__(self,ary_mcand):
     """Perform forward or adjoint sampled FFT. If no mask is present
     simply perform a a fully sampled forward or adjoint FFT.
     
     """
     if not self.lgc_adjoint:
         ary_mcand = 1/sqrt(ary_mcand.size)*fftshift(fftn(ifftshift(ary_mcand)))
         if self.mask is not None and (ary_mcand.shape != self.mask.shape):
             ary_mcand = crop_center(ary_mcand, self.mask.shape)
         if self.mask is not None:    
             ary_mcand *= self.mask
     else:    
         if self.mask is not None and (ary_mcand.shape != self.mask.shape):
             ary_mcand = pad_center(ary_mcand,self.mask.shape)
         ary_mcand = sqrt(ary_mcand.size)*ifftshift(ifftn(fftshift(ary_mcand)))
         
     return super(SampledFT,self).__mul__(ary_mcand)
Exemple #7
0
 def update(self, dict_in):
     """
     Expects a single value or array. If array, store the whole vector and stop.
     """
     if self.data == []:
         self.fmetrics.compute_support(dict_in)
         #must use fortran ordering, since this is what matlab uses
         #and we've computed the fourier shell indices assuming this.
         self.x_f = np.ravel(self.fmetrics.x_f, order='F')
         self.x_f_shape = self.fmetrics.x_f.shape
         self.fmetrics.compute_support(dict_in)
     x_n_f = dict_in['x_n']
     if x_n_f.shape != self.x_f.shape:
         x_n_f = crop_center(x_n_f, self.x_f_shape)
     x_n_f = np.ravel(fftn(x_n_f), order='F')
     value = tuple(np.real([np.vdot(np.take(x_n_f,self.fmetrics.s_indices[k]),
                     np.take(self.x_f,self.fmetrics.s_indices[k])) / \
              norm(np.take(self.x_f,self.fmetrics.s_indices[k]).flatten(),2) / \
              norm(np.take(x_n_f,self.fmetrics.s_indices[k]).flatten(),2) \
              for k in xrange(self.fmetrics.K)]))
     self.data.append(tuple(value))
     super(FourierCorrelation, self).update()
Exemple #8
0
 def update(self, dict_in):
     """
     Expects a single value or array. If array, store the whole vector and stop.
     """
     if self.data == []:
         self.fmetrics.compute_support(dict_in)
         self.x_f = np.ravel(self.fmetrics.x_f, order='F')
         self.x_f_shape = self.fmetrics.x_f.shape
         self.fmetrics.compute_support(dict_in)
     x_n_f = dict_in['x_n']
     if x_n_f.shape != self.x_f.shape:
         x_n_f = crop_center(x_n_f, self.x_f_shape)
     x_n_f = np.ravel(fftn(x_n_f), order='F')
     d_e_bar = self.x_f - x_n_f
     d_e_bar = conj(d_e_bar) * d_e_bar
     e_bar = conj(self.x_f) * self.x_f
     G = [
         nsum(np.take(e_bar, self.fmetrics.s_indices[k]))
         for k in xrange(self.fmetrics.K)
     ]
     value = tuple(np.real([(G[k] - nsum(np.take(d_e_bar,self.fmetrics.s_indices[k])))/G[k] \
                    for k in arange(self.fmetrics.K)]))
     self.data.append(value)
     super(RER, self).update()
Exemple #9
0
    def observe(self, dict_in):
        """
        Loads observation model parameters into a dictionary, 
        performs the forward model and provides an initial solution.

        Args:
        dict_in (dict): Dictionary which will be overwritten with 
        all of the observation model parameters, forward model 
        observation 'y', and initial estimate 'x_0'.
        """
        warnings.simplefilter("ignore", np.ComplexWarning)
        #########################################
        #fetch observation model parameters here#
        #########################################

        if (self.str_type[:11] == 'convolution'
                or self.str_type == 'compressed_sensing'):
            wrf = self.get_val('wienerfactor', True)
            str_domain = self.get_val('domain', False)
            noise_pars = defaultdict(int)  #build a dict to generate the noise
            noise_pars['seed'] = self.get_val('seed', True)
            noise_pars['variance'] = self.get_val('noisevariance', True)
            noise_pars['distribution'] = self.get_val('noisedistribution',
                                                      False)
            noise_pars['mean'] = self.get_val('noisemean', True)
            noise_pars['interval'] = self.get_val('noiseinterval',
                                                  True)  #uniform
            noise_pars['size'] = dict_in['x'].shape
            dict_in['noisevariance'] = noise_pars['variance']

            if self.str_type == 'compressed_sensing':
                noise_pars['complex_noise'] = 1
            if dict_in['noisevariance'] > 0:
                dict_in['n'] = noise_gen(noise_pars)
            else:
                dict_in['n'] = 0

        elif self.str_type == 'classification':
            #partition the classification dataset into an 'observed' training set
            #and an unobserved evaluation/test set, and generate features
            dict_in['x_train'] = {}
            dict_in['x_test'] = {}
            dict_in['y_label'] = {}
            dict_in['x_feature'] = {}
            dict_in['n_training_samples'] = 0
            dict_in['n_testing_samples'] = 0
            shuffle = self.get_val('shuffle', True)
            if shuffle:
                shuffleseed = self.get_val('shuffleseed', True)
            training_proportion = self.get_val('trainingproportion', True)
            classes = dict_in['x'].keys()
            #partition and generate numeric class labels
            for _class_index, _class in enumerate(classes):
                class_size = len(dict_in['x'][_class])
                training_size = int(training_proportion * class_size)
                dict_in['n_training_samples'] += training_size
                dict_in['n_testing_samples'] += class_size - training_size
                if shuffle:
                    np.random.seed(shuffleseed)
                    indices = np.random.permutation(class_size)
                else:
                    indices = np.array(range(class_size), dtype='uint16')
                dict_in['x_train'][_class] = indices[:training_size]
                dict_in['x_test'][_class] = indices[training_size:]
                dict_in['y_label'][_class] = _class_index
        else:
            raise ValueError('unsupported observation model')
        ################################################
        #compute the forward model and initial estimate#
        ################################################
        if self.str_type == 'convolution':
            H = self.Phi

            H.set_output_fourier(False)
            dict_in['Hx'] = H * dict_in['x']
            dict_in['y'] = dict_in['Hx'] + dict_in['n']
            #regularized Wiener filtering in Fourier domain
            H.set_output_fourier(True)
            dict_in['x_0'] = real(
                ifftn(~H * dict_in['y'] /
                      (H.get_spectrum_sq() + wrf * noise_pars['variance'])))
            # dict_in['x_0'] = real(ifftn(~H * dict_in['y'])) %testing only
            H.set_output_fourier(False)
            #compute bsnr
            self.compute_bsnr(dict_in, noise_pars)
        elif self.str_type == 'convolution_downsample':
            Phi = self.Phi
            #this order is important in the config file
            D = Phi.ls_ops[1]
            H = Phi.ls_ops[0]
            H.set_output_fourier(False)
            if self.get_val('spatialblur', True):
                dict_in['Phix'] = D * convolve(dict_in['x'], H.kernel, 'same')
                dict_in['Hxpn'] = convolve(dict_in['x'], H.kernel,
                                           'same') + dict_in['n']
            else:
                dict_in['Phix'] = Phi * dict_in['x']
                dict_in['Hxpn'] = H * dict_in['x'] + dict_in['n']
            dict_in['Hx'] = dict_in['Phix']
            #the version of y without downsampling
            dict_in['DHxpn'] = np.zeros((D * dict_in['Hxpn']).shape)
            if dict_in['n'].__class__.__name__ == 'ndarray':
                dict_in['n'] = D * dict_in['n']
            dict_in['y'] = dict_in['Hx'] + dict_in['n']
            DH = fftn(Phi * nd_impulse(dict_in['x'].shape))
            DHt = conj(DH)
            Hty = fftn(D * (~Phi * dict_in['y']))
            HtDtDH = np.real(DHt * DH)
            # dict_in['x_0'] = ~D*real(ifftn(Hty /
            #                                (HtDtDH +
            #                                 wrf * noise_pars['variance'])))
            dict_in['x_0'] = ~D * dict_in['y']
            #optional interpolation
            xdim = dict_in['x'].ndim
            xshp = dict_in['x'].shape
            if self.get_val('interpinitialsolution', True):
                if xdim == 2:
                    if self.get_val('useimresize', True):
                        interp_vals = imresize(
                            dict_in['y'],
                            tuple(D.ds_factor *
                                  np.asarray(dict_in['y'].shape)),
                            interp='bicubic')
                    else:
                        grids = np.mgrid[[
                            slice(0, xshp[j]) for j in xrange(xdim)
                        ]]
                        grids = tuple(
                            [grids[i] for i in xrange(grids.shape[0])])
                        sampled_coords = np.mgrid[[
                            slice(D.offset[j], xshp[j], D.ds_factor[j])
                            for j in xrange(xdim)
                        ]]
                        values = dict_in['x_0'][[
                            coord.flatten() for coord in sampled_coords
                        ]]
                        points = np.vstack([
                            sampled_coords[i, Ellipsis].flatten()
                            for i in xrange(sampled_coords.shape[0])
                        ]).transpose()  #pts to interp
                        interp_vals = griddata(points,
                                               values,
                                               grids,
                                               method='cubic',
                                               fill_value=0.0)
                else:
                    values = dict_in[
                        'y']  #we're not using blank values, different interpolation scheme..
                    dsfactors = np.asarray(
                        [int(D.ds_factor[j]) for j in xrange(values.ndim)])
                    valshpcorrect = (
                        np.asarray(values.shape) -
                        np.asarray(xshp, dtype='uint16') / dsfactors)
                    valshpcorrect = valshpcorrect / np.asarray(dsfactors,
                                                               dtype='float32')
                    interp_coords = iprod(*[
                        np.arange(0, values.shape[j] - valshpcorrect[j], 1.0 /
                                  D.ds_factor[j]) for j in xrange(values.ndim)
                    ])
                    interp_coords = np.array([el for el in interp_coords
                                              ]).transpose()
                    interp_vals = map_coordinates(values,
                                                  interp_coords,
                                                  order=3,
                                                  mode='nearest').reshape(xshp)
                    # interp_vals = map_coordinates(values,interp_coords,order=3,mode='nearest')
                    # cut off the edges
                    # if xdim == 2:
                    # interp_vals = interp_vals[0:xshp[0],0:xshp[1]]
                    # else:
                    interp_vals = interp_vals[0:xshp[0], 0:xshp[1], 0:xshp[2]]
                dict_in['x_0'] = interp_vals
            elif self.get_val('inputinitialsoln', False) != '':
                init_soln_inputsec = Input(
                    self.ps_parameters, self.get_val('inputinitialsoln',
                                                     False))
                dict_in['x_0'] = init_soln_inputsec.read({}, True)
            self.compute_bsnr(dict_in, noise_pars)

        elif self.str_type == 'convolution_poisson':
            dict_in['mp'] = self.get_val('maximumphotonspervoxel', True)
            dict_in['b'] = self.get_val('background', True)
            H = self.Phi
            if str_domain == 'fourier':
                H.set_output_fourier(False)  #return spatial domain object
                orig_shape = dict_in['x'].shape
                Hspec = np.zeros(orig_shape)
                dict_in['r'] = H * dict_in['x']
                k = dict_in['mp'] / nmax(dict_in['r'])
                dict_in['r'] = k * dict_in['r']
                #normalize the output image to have the same
                #maximum photon count as the ouput image
                dict_in['x'] = k * dict_in['x']
                dict_in['x'] = crop_center(
                    dict_in['x'], dict_in['r'].shape).astype('float32')
                #the spatial domain measurements, before photon counts
                dict_in['fb'] = dict_in['r'] + dict_in['b']
                #lambda of the poisson distn
                noise_pars['ary_mean'] = dict_in['fb']
                #specifying the poisson distn
                noise_distn2 = self.get_val('noisedistribution2', False)
                noise_pars['distribution'] = noise_distn2
                #generating quantized (uint16) poisson measurements
                # dict_in['y'] = (noise_gen(noise_pars)+dict_in['n']).astype('uint16').astype('int32')
                dict_in['y'] = noise_gen(noise_pars) + crop_center(
                    dict_in['n'], dict_in['fb'].shape)
                dict_in['y'][dict_in['y'] < 0] = 0
            elif str_domain == 'evaluation':  #are given the observation, which is stored in 'x'
                dict_in['y'] = dict_in.pop('x')
            else:
                raise Exception('domain not supported: ' + str_domain)
            dict_in['x_0'] = ((~H) * (dict_in['y'])).astype(dtype='float32')
            dict_in['y_padded'] = pad_center(dict_in['y'],
                                             dict_in['x_0'].shape)

        elif self.str_type == 'compressed_sensing':
            Fu = self.Phi
            dict_in['Hx'] = Fu * dict_in['x']
            dict_in['y'] = dict_in['Hx'] + dict_in['n']
            dict_in['x_0'] = (~Fu) * dict_in['y']
            dict_in['theta_0'] = angle(dict_in['x_0'])
            dict_in['theta_0'] = su.phase_unwrap(dict_in['theta_0'],
                                                 dict_in['dict_global_lims'],
                                                 dict_in['ls_local_lim_secs'])
            dict_in['magnitude_0'] = nabs(dict_in['x_0'])
            if self.get_val('maskinitialsoln', True):
                dict_in['theta_0'] *= dict_in['mask']
                dict_in['magnitude_0'] *= dict_in['mask']
            dict_in['x_0'] = dict_in['magnitude_0'] * exp(
                1j * dict_in['theta_0'])
            self.compute_bsnr(dict_in, noise_pars)
        #store the wavelet domain version of the ground truth
        if np.iscomplexobj(dict_in['x']):
            dict_in['w'] = [
                self.W * dict_in['x'].real, self.W * dict_in['x'].imag
            ]
        else:
            dict_in['w'] = [self.W * dict_in['x']]
Exemple #10
0
    def solve(self, dict_in):
        super(MSIST, self).solve()

        ##################################
        ### Transforms and Modalities ####
        ##################################
        H = self.H  #mapping from solution domain to observation domain
        dict_in['H'] = H
        W = self.W  #sparsifying transform
        dict_in['W'] = W
        # precision = 'float32'
        # if W.output_dtype!='':
        #     precision = W.output_dtype

        if self.alpha.__class__.__name__ != 'ndarray':
            self.alpha = su.spectral_radius(
                self.W, self.H, dict_in['x_0'].shape,
                self.get_val('alphamethod', False, 'spectrum'))
            # self.alpha = su.spectral_radius(self.W, self.H, (64,64,64),
            #                                 self.get_val('alphamethod', False, 'spectrum'))
        alpha = self.alpha  #Lambda_alpha main diagonal (B-sized vector of subband gains)
        dict_in['alpha'] = alpha
        ############
        #Input Data#
        ############

        if H.output_fourier:
            y_hat = dict_in['y']
        else:
            #do an extra FFT to do deconvolution in fourier domain
            y_hat = fftn(dict_in['y'])

        x_n = dict_in['x_0'].copy()  #seed current solution
        # The famous Joan Lasenby "residuals"
        dict_in['x_n'] = x_n
        x = dict_in['x'].copy()
        dict_in['resid_n'] = x - x_n
        x_max = np.max(x)
        x_min = np.min(x)
        # dict_in['resid_range'] = np.array([x_min - x_max, x_max + x_max])
        dict_in['resid_range'] = np.array([-255.0 / 2, 255.0 / 2])
        #######################
        #Common Initialization#
        #######################

        #determine whether/not we need double the wavelet transforms on
        #each iteration for a complex-valued input signal
        self.input_complex = np.iscomplexobj(x_n)

        #initialize current solution in sparse domain
        #g_i is the element group size (2 for CWT, 4 for CWT and input_complex)
        if self.input_complex:
            if self.input_phase_encoded:
                theta_n = su.phase_unwrap(angle(x_n),
                                          dict_in['dict_global_lims'],
                                          dict_in['ls_local_lim_secs'])
            else:
                theta_n = angle(x_n)
            dict_in['theta_n'] = theta_n
            dict_in['magnitude_n'] = nabs(x_n)

            w_n = [W * x_n.real, W * x_n.imag]
            g_i = 2 * (w_n[0].is_wavelet_complex() + 1)
        else:
            w_n = [W * x_n]
            g_i = (w_n[0].is_wavelet_complex() + 1)
        w_n_len = len(w_n)
        w_n_it = xrange(w_n_len)  #iterator for w_n
        dict_in['w_n'] = w_n

        #initialize the precision matrix with zeros
        S_n = w_n[0] * 0
        dict_in['S_n'] = S_n
        #initialize continuation parameters
        epsilon, nu = self.get_epsilon_nu()
        if self.ordepsilon:
            # self.ordepsilonpercstart = 8.0/9.0*(.55**2)
            epsilon = np.zeros(self.int_iterations + 1, )
            self.percentiles = np.arange(30, self.ordepsilonpercstop,
                                         -1.0 / self.int_iterations)
            epsilon[0] = self.get_ord_epsilon(w_n[0], np.inf,
                                              self.percentiles[0])
            dict_in['epsilon_sq'] = epsilon[0]**2
        else:
            dict_in['epsilon_sq'] = epsilon**2
        if self.convexnu:
            nu = np.zeros(self.int_iterations + 1, )
            nu[0] = self.get_convex_nu(w_n[0], epsilon[0]**2, np.min(alpha))
            dict_in['nu_sq'] = nu[0]**2
        else:
            dict_in['nu_sq'] = nu**2

        #wavelet domain variance used for poisson deblurring
        ary_p_var = 0

        ########################################
        #Sparse penalty-specific initialization#
        ########################################

        if self.str_sparse_pen == 'l0rl2_bivar':
            w_tilde = w_n[0] * 0
            sqrt3 = sqrt(3.0)
            sigsq_n = self.get_val('nustop', True)**2
            sig_n = sqrt(sigsq_n)

        if self.str_sparse_pen == 'l0rl2_group':
            tau = self.get_val('tau', True)
            tau_rate = self.get_val('taurate', True)
            tau_start = self.get_val('taustart', True)
            if np.all(tau_start != 0) and tau_rate != 0:
                tau_end = tau
                tau = tau_start
            A = sf.create_section(self.ps_parameters,
                                  self.get_val('clusteraverage',
                                               False))  #cluster
            G = sf.create_section(self.ps_parameters,
                                  self.get_val('groupaverage', False))  #group
            #initialize A and G with parameters of the master vector
            A.init_csr_avg(w_n[0])
            G.init_csr_avg(w_n[0])
            dup_it = xrange(
                A.duplicates)  # iterator for duplicate variable space

            #initialize non-overlapping space (list of ws objects ls_w_hat_n)
            ls_w_hat_n = [[w_n[ix_] * 1 for j in dup_it] for ix_ in w_n_it]

            #initialize non-overlapping space precision
            ls_S_hat_n = [
                ((sum([w_n[ix].energy()
                       for ix in w_n_it]) / g_i) + epsilon[0]**2).invert()
                for int_dup in dup_it
            ]
            w_bar_n = [w_n[ix_] * 1 for ix_ in w_n_it]

            #using the structure of A, initialize the support of Shat, what
            A_row_ix = np.nonzero(A.csr_avg)[0]
            A_col_ix = np.nonzero(A.csr_avg)[1]
            D = csr_matrix((np.ones(A_col_ix.size, ), (A_row_ix, A_col_ix)),
                           shape=A.csr_avg.shape)

            #compute the support of Shat
            ls_S_hat_sup = unflat_list(
                D.transpose() * ((w_n[0] * 0 + 1).flatten()), A.duplicates)

            #load this vector into each new wavelet subband object
            ls_S_hat_sup = [(w_n[0] * 0).unflatten(S_sup)
                            for S_sup in ls_S_hat_sup]
            ls_S_hat_sup = [
                S_hat_n_sup.nonzero() for S_hat_n_sup in ls_S_hat_sup
            ]
            del S_sup
            del S_hat_n_sup
            #precompute AtA (doesn't change from one iteration to the next)
            AtA = (A.csr_avg.transpose() * A.csr_avg).tocsr()

            #convert tau**2 to csr format to allow for subband-adaptive constraint
            if tau.__class__.__name__ != 'ndarray':
                tau_sq = np.ones(w_n[0].int_subbands) * tau**2
            else:
                tau_sq = tau**2
            tau_sq_dia = [((w_n[0] * 0 + 1).cast(A.dtype)) * tau_sq
                          for j in dup_it]
            # tau_sq_dia = [((w_n[0]*0+1))*tau_sq for j in dup_it]
            tau_sq_dia = su.flatten_list(tau_sq_dia)
            offsets = np.array([0])
            tau_sz = tau_sq_dia.size
            tau_sq_dia = dia_matrix((tau_sq_dia, offsets),
                                    shape=(tau_sz, tau_sz))

            #initialize S_hat_bar parameters for efficient matrix inverses
            Shatbar_p_filename = A.file_path.split('.pkl')[0] + 'Shatbar.pkl'
            if not os.path.isfile(Shatbar_p_filename):
                dict_in['col_offset'] = A.int_size
                S_hat_n_csr = su.flatten_list_to_csr(ls_S_hat_sup)
                su.inv_block_diag((tau_sq_dia) * AtA + S_hat_n_csr, dict_in)
                filehandler = open(Shatbar_p_filename, 'wb')
                cPickle.dump(dict_in['dict_bdiag'], filehandler, -1)
                del S_hat_n_csr
            else:
                filehandler = open(Shatbar_p_filename, 'rb')
                dict_in['dict_bdiag'] = cPickle.load(filehandler)
            filehandler.close()

            #store all of the l0rl2_group specific variables in the solver dict_in
            dict_in['ls_S_hat_n'] = ls_S_hat_n
            dict_in['ls_w_hat_n'] = ls_w_hat_n
            dict_in['w_bar_n'] = w_bar_n
            dict_in['G'] = G
            dict_in['A'] = A
            dict_in['W'] = W
            dict_in['AtA'] = AtA
            dict_in['ls_S_hat_sup'] = ls_S_hat_sup
            dict_in['w_n_it'] = w_n_it
            dict_in['dup_it'] = dup_it
            dict_in['ws_dummy'] = w_n[0] * 0
            dict_in['g_i'] = g_i

            # self.update_duplicates(dict_in,nu[0],epsilon[0],tau_sq, tau_sq_dia)

            w_bar_n = dict_in['w_bar_n']
            ls_w_hat_n = dict_in['ls_w_hat_n']
            ls_S_hat_n = dict_in['ls_S_hat_n']
            del D  #iterations need A and G only, not D

        if (self.str_sparse_pen == 'vbmm' or  #vbmm
                self.str_sparse_pen == 'vbmm_hmt'):
            p_a = self.get_val('p_a', True)
            p_b_0 = self.get_val('p_b_0', True)
            p_k = self.get_val('p_k', True)
            p_theta = self.get_val('p_theta', True)
            p_c = self.get_val('p_c', True)
            p_d = self.get_val('p_d', True)
            b_n = w_n[0] * 0
            sigma_n = 0
            if self.str_sparse_pen == 'vbmm_hmt':
                ary_a = self.get_gamma_shapes(W * dict_in['x_0'])
                b_n = w_n[0] * p_b_0

        #poisson + gaussiang noise,
        #using the scaling coefficients in the regularization (MSIST-P)
        if self.input_poisson_corrupted:
            #need a 0-padded y to get the right size for the scaling coefficients
            b = dict_in['b']
            if not H.output_fourier:
                y_hat = fftn(dict_in['y'] - b)
            else:
                y_hat = fftn(ifftn(dict_in['y']) - b)
            w_y = (W * dict_in['y_padded'])
            dict_in['x_n'] = su.crop_center(x_n, dict_in['y'].shape)
            w_y_scaling_coeffs = w_y.downsample_scaling()

        self.results.update(dict_in)
        print 'Finished itn: n=' + str(0)
        #begin iterations here for the MSIST(-X) algorithm, add some profiling info here
        if self.profile:
            dict_profile = {}
            dict_profile['twoft_time'] = []
            dict_profile['wht_time'] = []
            dict_profile['other_time'] = []
            dict_profile['reproj_time_inv'] = []
            dict_profile['reproj_time_for'] = []
            dict_in['profiling'] = dict_profile
            t0 = time.time()
        ####################
        ##Begin Iterations##
        ####################
        for n in np.arange(self.int_iterations):
            ####################
            ###Landweber Step###
            ####################
            twoft_0 = time.time()
            H.set_output_fourier(True)  #force Fourier output to reduce ffts
            if self.input_complex:
                f_resid = y_hat - H * x_n  #Landweber difference
            else:
                f_resid = ifftn(y_hat - H * x_n)
                H.set_output_fourier(False)
            twoft_1 = time.time()
            if self.input_complex:
                HtHf = (~H) * f_resid
                w_resid = [W * (HtHf).real, W * (HtHf).imag]
            else:
                w_resid = [W * ((~H) * f_resid)]
            wht = time.time()
            if self.profile:
                dict_profile['twoft_time'].append(twoft_1 - twoft_0)
                dict_profile['wht_time'].append(wht - twoft_1)
            ########################
            ######Convex Nu#########
            #####Ord/HMT Epsilon####
            ########################
            if self.ordepsilon:
                if n == 0:
                    prevepsilon = epsilon[0]
                else:
                    prevepsilon = epsilon[n - 1]
                epsilon[n] = self.get_ord_epsilon(w_n[0], prevepsilon,
                                                  self.percentiles[n])
                dict_in['epsilon_sq'] = epsilon[n]**2
            if self.convexnu:
                nu[n] = self.get_convex_nu(w_n[0], epsilon[n]**2,
                                           np.min(self.alpha))
                dict_in['nu_sq'] = nu[n]**2

            ###############################################
            ###Sparse Penalty-Specific Thresholding Step###
            ###############################################
            if self.str_sparse_pen == 'l0rl2_group':
                #S_hat_n, w_hat_n, and wb_bar (eqs 11, 19, and 13)
                self.update_duplicates(dict_in, nu[n], epsilon[n], tau_sq,
                                       tau_sq_dia)
                w_bar_n = dict_in['w_bar_n']
                ls_w_hat_n = dict_in['ls_w_hat_n']

            #####################################################
            #Subband-adaptive subband update of precision matrix#
            #####################################################
            if (self.str_sparse_pen[0:5] == 'l0rl2'
                    and self.str_sparse_pen[-5:] != 'bivar'):
                if self.str_sparse_pen == 'l0rl2_group':
                    S0_n = nsum(
                        [nabs(w_n[ix].ary_lowpass)**2 for ix in w_n_it],
                        axis=0) / g_i + epsilon[n]**2
                    S0_n = 1.0 / S0_n
                else:
                    if self.hmt:
                        S_n_prev = S_n * 1.0
                        S_n.set_subband(
                            0, (1.0 /
                                ((1.0 / g_i) * nabs(w_n[0].get_subband(0))**2 +
                                 (epsilon[n]**2))))

                        for s in xrange(w_n[0].int_subbands - 1, 0, -1):
                            sigma_sq_parent_us = nabs(
                                w_n[0].get_upsampled_parent(s))**2
                            s_parent_sq = 1.0 / (
                                (2.0**(-2.25)) *
                                (1.0 / g_i * sigma_sq_parent_us))
                            S_n.set_subband(s, s_parent_sq)

                    else:
                        S_n = (sum([w_n[ix_].energy()
                                    for ix_ in w_n_it]) / g_i +
                               epsilon[n]**2).invert()
            elif (self.str_sparse_pen[0:5] == 'vbmm'
                  and self.str_sparse_pen[-5:] != 'hmt'):
                cplx_norm = 1.0 + self.input_complex
                S_n = ((g_i + 2.0 * p_a) *
                       (sum([w_n[ix_].energy()
                             for ix_ in w_n_it]) / cplx_norm + sigma_n +
                        2.0 * b_n).invert())
                b_n = (p_k + p_a) * (S_n.get_subband(s) + p_theta).invert()
                sigma_n = (1.0 / nu[n]**2 * alpha[s] + S_n).invert()

            else:
                #iterating through subbands is necessary, coarse to fine
                for s in xrange(w_n[0].int_subbands - 1, -1, -1):
                    #Sendur Selesnick BSWLVE paper
                    if self.str_sparse_pen == 'l0rl2_bivar':
                        if s > 0:
                            s_parent_us = nabs(
                                w_n[0].get_upsampled_parent(s))**2
                            s_child = nabs(w_n[0].get_subband(s))**2
                            yi, yi_mask = su.get_neighborhoods(s_child,
                                                               1)  #eq 8
                            s_child_norm = sqrt(s_parent_us + s_child)
                            sigsq_y = np.sum(yi, axis=yi.ndim - 1) / np.sum(
                                yi_mask, axis=yi.ndim - 1)  #still eq 8...
                            sig = sqrt(np.maximum(sigsq_y - sigsq_n, 0))
                            w_tilde.set_subband(s, sqrt3 * sigsq_n /
                                                sig)  #the thresholding fn
                            thresh = np.maximum(
                                s_child_norm - w_tilde.get_subband(s),
                                0) / s_child_norm  #eq 5
                            if np.mod(
                                    n, 2
                            ) == 0:  #update with the bivariate thresholded coefficients on every other iteration
                                S_n.set_subband(
                                    s,
                                    (1.0 /
                                     ((1.0 / g_i) *
                                      nabs(thresh * w_n[0].get_subband(s))**2 +
                                      (epsilon[n]**2))))
                            else:
                                S_n.set_subband(
                                    s, (1.0 / ((1.0 / g_i) *
                                               nabs(w_n[0].get_subband(s))**2 +
                                               (epsilon[n]**2))))
                        else:
                            S_n.set_subband(s, (1.0 / (
                                (1.0 / g_i) * nabs(w_n[0].get_subband(s))**2 +
                                epsilon[n]**2)))

                    elif self.str_sparse_pen == 'vbmm_hmt':  #vbmm
                        if n == 0:
                            sigma_n = 0
                        else:
                            sigma_n = (1.0 / nu[n]**2 * alpha[s] +
                                       S_n.get_subband(s))**(-1)
                        if s > 0:
                            w_parent_us = w_n[0].get_upsampled_parent(s)
                            alpha_dec = 2.25
                            if s > S_n.int_orientations:
                                s_child = S_n.subband_group_sum(
                                    s - S_n.int_orientations, 'children')
                                b_child = b_n.subband_group_sum(
                                    s - S_n.int_orientations, 'children')
                            else:
                                s_child = 0
                                b_child = 0
                            if s < S_n.int_subbands - S_n.int_orientations:
                                ap = ary_a[s + S_n.int_orientations]
                            else:
                                ap = .5
                            w_en_avg = w_n[0].subband_group_sum(
                                s, 'parent_children')
                            S_n.set_subband(
                                s, (g_i + 2.0 * ary_a[s]) /
                                (nabs(w_n[0].get_subband(s))**2 + sigma_n +
                                 2.0 * b_n.get_subband(s)))
                            b_n.set_subband(s, ary_a[s] * w_en_avg)
                        else:  #no parents, so generate fixed-param gammas
                            S_n.set_subband(
                                s, (g_i + 2.0 * ary_a[s]) /
                                (nabs(w_n[0].get_subband(s))**2 + sigma_n +
                                 2.0 * b_n.get_subband(s)))
                            b_n.set_subband(s, (p_k + ary_a[s]) /
                                            (S_n.get_subband(s) + p_theta))
                    else:
                        raise ValueError('no such solver variant')
            #########################
            #Update current solution#
            #########################
            for s in xrange(w_n[0].int_subbands - 1, -1, -1):
                if self.input_poisson_corrupted:
                    if s == 0:
                        ary_p_var = w_y.ary_lowpass
                    else:
                        int_lev, int_ori = w_n[0].lev_ori_from_subband(s)
                        ary_p_var = w_y_scaling_coeffs[int_lev]
                        ary_p_var[ary_p_var <= 0] = 0
                if (self.str_sparse_pen == 'l0rl2_group'):
                    if s > 0:
                        for ix_ in w_n_it:
                            w_n[ix_].set_subband(
                                s,
                                (alpha[s] * w_n[ix_].get_subband(s) +
                                 w_resid[ix_].get_subband(s) +
                                 (tau_sq[s]) * w_bar_n[ix_].get_subband(s)) /
                                (alpha[s] + tau_sq[s]))
                    else:  #a standard msist update for the lowpass coeffs
                        for ix_ in w_n_it:
                            w_n[ix_].set_subband(s, \
                                                 (alpha[s] * w_n[ix_].get_subband(s) + w_resid[ix_].get_subband(s)) /
                                                 (alpha[s] + (nu[n]**2) * S0_n))
                else:
                    for ix_ in w_n_it:
                        w_n[ix_].set_subband(
                            s, (alpha[s] * w_n[ix_].get_subband(s) +
                                w_resid[ix_].get_subband(s)) /
                            (alpha[s] +
                             (nu[n]**2 + self.sc_factor * ary_p_var) *
                             S_n.get_subband(s)))
                #end updating subbands

            #############################################
            ##Solution Domain Projection and Operations##
            #############################################
            tother = time.time()
            if self.input_complex:
                x_n = np.asfarray(~W * w_n[0], 'complex128')
                x_n += 1j * np.asfarray(~W * w_n[1], 'complex128')
                m_n = nabs(x_n)
                theta_n = angle(x_n)
                if self.input_phase_encoded:  #need to apply boundary conditions for phase encoded velocity
                    #the following isn't part of the documented algorithm
                    #it only needs to be executed at the end to fix
                    #phase wrapping in very high dynamic-phase regions
                    theta_n = su.phase_unwrap(angle(x_n),
                                              dict_in['dict_global_lims'],
                                              dict_in['ls_local_lim_secs'])
                    if self.get_val(
                            'iterationmask', True
                    ):  #apply boundary conditions for phase encoded velocity
                        theta_n *= dict_in['mask']
                        if self.get_val('magnitudemask', True, 1):
                            m_n *= dict_in[
                                'mask']  #uncomment this for 'total' masking
                    x_n = m_n * exp(1j * theta_n)
                dict_in['theta_n'] = theta_n
                dict_in['magnitude_n'] = m_n
            else:
                x_n = ~W * w_n[0]
            tinvdwt = time.time()
            #implicit convolution operator is used, so crop and repad
            if H.str_object_name == 'Blur' and H.lgc_even_fft:
                x_n = su.crop_center(x_n, dict_in['y'].shape)
            if self.input_poisson_corrupted and self.spatial_threshold:
                x_n[x_n < self.spatial_threshold_val] = 0.0

            #finished spatial domain operations on this iteration, store
            dict_in['x_n'] = x_n
            # store "residuals"
            dict_in['resid_n'] = x - x_n
            # dict_in['resid_range'] = np.array([np.min(dict_in['resid_n']), np.max(dict_in['resid_n'])])
            print "resid min " + str(np.round(np.min(dict_in['resid_n']), 2))
            print "resid max " + str(np.round(np.max(dict_in['resid_n']), 2))

            if H.str_object_name == 'Blur' and H.lgc_even_fft:
                x_n = su.pad_center(x_n, dict_in['x_0'].shape)

            #############################
            #Wavelet Domain Reprojection#
            #############################
            if self.profile:
                dict_profile['other_time'].append(tother - wht)
            if self.input_complex:
                w_n = [W * x_n.real, W * x_n.imag]
            else:
                w_n = [W * x_n]
            tforwardwt = time.time()
            if self.profile:
                dict_profile['reproj_time_inv'].append(tinvdwt - tother)
                dict_profile['reproj_time_for'].append(tforwardwt - tinvdwt)
            if self.str_sparse_pen[:11] == 'l0rl2_group':
                ls_w_hat_n = [[
                    ls_w_hat_n[ix_][j] * ls_S_hat_sup[j] + w_bar_n[ix_] *
                    ((ls_S_hat_sup[j] + (-1)) * (-1)) for j in dup_it
                ] for ix_ in w_n_it]  #fill in the gaps with w_bar_n
                w_bar_n = [W * ((~W) * w_bar_n[ix_]) for ix_ in w_n_it]
                ls_w_hat_n = [[
                    W * ((~W) * w_hat_n) for w_hat_n in ls_w_hat_n[ix_]
                ] for ix_ in w_n_it]
                dict_in['w_bar_n'] = w_bar_n
                dict_in['ls_w_hat_n'] = ls_w_hat_n
                if tau_rate != 0 and not np.any(tau > tau_end):
                    tau_sq_dia = tau_rate * tau_sq_dia
                    tau = np.sqrt(tau_rate) * tau
            dict_in['w_n'] = w_n
            dict_in['S_n'] = S_n
            ################
            #Update Results#
            ################
            self.results.update(dict_in)
            print 'Finished itn: n=' + str(n + 1)
            # if self.str_sparse_pen[:11] == 'l0rl2_group' and n==150: #an interesting experiment for cs..
            #     self.str_sparse_pen = 'l0rl2'

        return dict_in
Exemple #11
0
    def preprocess(self, dict_in):
        """Loads observation model parameters into a dictionary, 
        performs the forward model and provides an initial solution.

        Args:
        dict_in (dict): Dictionary which must include the following members:
            'x' (ndarray): The 'ground truth' input signal to be modified.
        """
        #build the preprocessing parameters
        if (self.str_type == 'brainwebmri'):
            #need to pad/crop the input data for wavelet processing
            swap_axes = self.get_val('swapaxes', True)
            if swap_axes.__class__.__name__ == 'ndarray':
                dict_in['x'] = dict_in['x'].swapaxes(swap_axes[0],
                                                     swap_axes[1])
            input_shape = dict_in['x'].shape

            #cropping
            new_shape = self.get_val('newshape', True)
            if new_shape.__class__.__name__ == 'ndarray':
                new_shape = tuple(new_shape)
            #figure out what to crop, if anything
            if np.any(new_shape < input_shape):
                crop_shape = np.min(np.vstack((new_shape, input_shape)),
                                    axis=0)
                dict_in['x'] = crop_center(dict_in['x'], crop_shape)
            else:
                crop_shape = input_shape
            #padding
            if np.any(new_shape > crop_shape):
                pad_shape = np.max(np.vstack((new_shape, crop_shape)), axis=0)
                dict_in['x'] = pad_center(dict_in['x'], pad_shape)
        # elif (self.str_type == 'superresolution'):
        #     #need to crop edge of image to make results compatible with the literature

        elif (self.str_type == 'phasevelocity'):
            mask_sec_in = self.get_val('masksectioninput', False)
            bmask_sec_in = self.get_val('boundarymasksectioninput', False)
            ls_local_lim_sec_in = self.get_val('vcorrects', False)
            if ls_local_lim_sec_in.__class__.__name__ == 'str' and ls_local_lim_sec_in:
                ls_local_lim_sec_in = [ls_local_lim_sec_in]
            ls_local_lim_secs = []
            if ls_local_lim_sec_in:
                ls_local_lim_secs = [
                    sf.create_section(self.get_params(), local_lim_sec_in)
                    for local_lim_sec_in in ls_local_lim_sec_in
                ]
                ls_local_lim_secs = [{
                    'phaselowerlimit':
                    local_lim.get_val('phaselowerlimit', True),
                    'phaseupperlimit':
                    local_lim.get_val('phaseupperlimit', True),
                    'regionupperleft':
                    local_lim.get_val('regionupperleft', True),
                    'regionlowerright':
                    local_lim.get_val('regionlowerright', True)
                } for local_lim in ls_local_lim_secs]
            #load the mask
            if mask_sec_in != '':
                sec_mask_in = sf.create_section(self.get_params(), mask_sec_in)
                dict_in['mask'] = np.asarray(sec_mask_in.read(dict_in, True),
                                             dtype='bool')
            else:
                dict_in['mask'] = True

            if bmask_sec_in != '':
                sec_bmask_in = sf.create_section(self.get_params(),
                                                 bmask_sec_in)
                dict_in['boundarymask'] = np.asarray(sec_bmask_in.read(
                    dict_in, True),
                                                     dtype='bool')
            else:
                dict_in['boundarymask'] = np.asarray(np.zeros(
                    dict_in['x'][:, :, 0].shape),
                                                     dtype='bool')

            if self.get_val('nmracquisition',
                            True):  #compute phase from lab measurement
                #The frame ordering determines in which direction to compute the
                #phase differences to obtain positive velocities

                frame_order = [0, 1]
                if self.get_val('reverseframeorder'):
                    frame_order = [1, 0]
                #Fully sampled fourier transform in order to extract phase data
                for frame in xrange(2):
                    dict_in['x'][:, :, frame] = fftn(
                        fftshift(dict_in['x'][:, :, frame]))
                if self.get_val('extrafftshift', True):
                    for frame in xrange(2):
                        dict_in['x'][:, :,
                                     frame] = fftshift(dict_in['x'][:, :,
                                                                    frame])

                #Compute phase differences between the two frames
                diff_method = self.get_val('phasedifferencemethod')
                if diff_method == 'conjugateproduct':
                    new_x = (dict_in['x'][:, :, frame_order[1]] *
                             conj(dict_in['x'][:, :, frame_order[0]]))
                    theta = angle(new_x)
                    theta += np.max(np.abs(theta))
                    # theta /= np.max(np.abs(theta))
                    # theata *= np.pi*2
                    magnitude = sqrt(abs(new_x))

                elif diff_method == 'subtraction':
                    theta = (angle(dict_in['x'][:, :, frame_order[1]]) -
                             angle(dict_in['x'][:, :, frame_order[0]]))
                    magnitude = 0.5 * (
                        np.abs(dict_in['x'][:, :, frame_order[0]]) +
                        np.abs(dict_in['x'][:, :, frame_order[1]]))
                # if self.get_val('reverseframeorder'):
                #     theta = -theta
                #     theta+=np.abs(np.min(theta))
                new_x = magnitude * exp(1j * theta)

            else:  #synthetic data
                theta = angle(dict_in['x'])
                magnitude = nabs(dict_in['x'])

            #Do phase unwrapping. This works almost everywhere, except
            #in certain areas where the range of phases exceeds 2*pi.
            #These areas must also be unwrapped with special limits
            #which are determined from the data.
            dict_global_lims = {}
            dict_global_lims['lowerlimit'] = self.get_val(
                'phaselowerlimit', True)
            dict_global_lims['upperlimit'] = self.get_val(
                'phaseupperlimit', True)
            dict_global_lims['boundary_mask'] = dict_in['boundarymask']
            dict_global_lims['boundary_upperlimit'] = self.get_val(
                'boundaryphaseupperlimit', True)
            dict_global_lims['boundaryoverlapvcorrects'] = self.get_val(
                'boundaryoverlapvcorrects', True)

            theta = phase_unwrap(theta, dict_global_lims, ls_local_lim_secs)
            magnitude /= np.max(nabs(magnitude))
            dict_in['x'] = magnitude * exp(1j * theta)
            dict_in['theta'] = dict_in['mask'] * theta
            dict_in['magnitude'] = magnitude
            dict_in['dict_global_lims'] = dict_global_lims
            dict_in['ls_local_lim_secs'] = ls_local_lim_secs