Пример #1
0
    def test_update_A(self):
        warnings.filterwarnings("ignore")

        # Check that the A update step works as intended
        n_wavs = 1000
        n_freqs = 8
        n_sources = 5

        # Start by generating a our A and S matrix
        S = np.ones((n_sources, n_wavs))
        A = np.ones((n_freqs, n_sources)) / np.sqrt(n_freqs)

        # Create the remainder matrix
        R_i = np.ones((n_freqs, n_wavs))

        # Create an A_p to use for testing
        A_p = np.random.randn(n_freqs * n_sources).reshape(
            (n_freqs, n_sources))
        helpers.A_norm(A_p)
        enforce_nn_A = False

        # Test results for various values of lam_p
        lam_p_tests = [0.0, 0.1, 1, 2, 100, 200, 1000, 1e8]
        for lam_p_val in lam_p_tests:
            lam_p = [lam_p_val] * n_sources
            for i in range(n_sources):
                gmca_core.update_A(S, A, R_i, lam_p, A_p, enforce_nn_A, i)

                # Make sure that all the columns have the correct value.
                check_A = np.ones(n_freqs) * n_wavs
                check_A += lam_p_val * A_p[:, i]
                check_A /= np.linalg.norm(check_A)
                self.assertAlmostEqual(np.max(np.abs(A[:, i] - check_A)), 0)

        # Test that everything still holds when nonegativity is enforced
        enforce_nn_A = True
        lam_p_tests = [0.0, 0.1, 1, 2, 100, 200, 1000, 1e8]
        for lam_p_val in lam_p_tests:
            lam_p = [lam_p_val] * n_sources
            for i in range(n_sources):
                gmca_core.update_A(S, A, R_i, lam_p, A_p, enforce_nn_A, i)

                # Make sure that all the columns have the correct value.
                check_A = np.ones(n_freqs) * n_wavs
                check_A += lam_p_val * A_p[:, i]
                check_A[check_A < 0] = 0
                if np.sum(check_A) > 0:
                    check_A /= np.linalg.norm(check_A)
                self.assertAlmostEqual(np.max(np.abs(A[:, i] - check_A)), 0)

        # Check that 0s in the remainder leave the mixing matrix prior
        # dominated
        R_i[0, :] = 0
        A_p = np.random.rand(n_freqs * n_sources).reshape((n_freqs, n_sources))
        helpers.A_norm(A_p)
        lam_p = [1] * n_sources
        for i in range(n_sources):
            gmca_core.update_A(S, A, R_i, lam_p, A_p, enforce_nn_A, i)
        self.assertAlmostEqual(np.std(A[0] / A_p[0]), 0)
Пример #2
0
    def test_wrapper(self):
        # Test that the wrapper returns the same results as the numba
        # implementation
        freq_dim = 10
        pix_dim = 100
        n_iterations = 10
        n_sources = 5
        lam_p = [0.0] * 5

        X = np.random.normal(loc=1000, size=(freq_dim, pix_dim))

        A_numba = np.random.normal(size=(freq_dim, n_sources))
        helpers.A_norm(A_numba)
        S_numba = np.ones((n_sources, pix_dim))

        A_p = np.random.normal(size=(freq_dim, n_sources))
        helpers.A_norm(A_p)

        lam_s_vals = [0, 10]
        lam_p_vals = [0, 1000]
        min_rmse_rates = [0, 2]
        ret_min_rmse_vals = [True, False]
        enforce_nn_A = True

        for lam_s in lam_s_vals:
            for lam_p_val in lam_p_vals:
                lam_p = [lam_p_val] * n_sources
                for min_rmse_rate in min_rmse_rates:
                    for ret_min_rmse in ret_min_rmse_vals:
                        A_init = np.copy(A_numba)
                        S_init = np.copy(S_numba)
                        A, S = gmca_core.gmca(X,
                                              n_sources,
                                              n_iterations,
                                              A_init,
                                              S_init,
                                              A_p,
                                              lam_p,
                                              enforce_nn_A,
                                              lam_s,
                                              ret_min_rmse,
                                              min_rmse_rate,
                                              seed=2)
                        gmca_core.gmca_numba(X,
                                             n_sources,
                                             n_iterations,
                                             A_numba,
                                             S_numba,
                                             A_p,
                                             lam_p,
                                             enforce_nn_A,
                                             lam_s,
                                             ret_min_rmse,
                                             min_rmse_rate,
                                             seed=2)
                        self.assertAlmostEqual(np.max(np.abs(A_numba - A)), 0)
                        self.assertAlmostEqual(np.max(np.abs(S_numba - S)), 0)
Пример #3
0
    def test_min_rmse_rate(self):
        warnings.filterwarnings("ignore")

        # Check that the minimum RMSE solution is returned

        n_freqs = 10
        n_wavs = 100
        n_iterations = 50
        n_sources = 5
        lam_p = [0.0] * 5

        # Generate ground truth A and S
        A_org = np.random.normal(size=(n_freqs, n_sources))
        helpers.A_norm(A_org)
        S_org = np.random.normal(size=(n_sources, n_wavs))
        X = np.dot(A_org, S_org)

        # Initialize A and S for GMCA
        A_p = np.ones(A_org.shape)
        helpers.A_norm(A_p)
        A = np.ones(A_org.shape)
        helpers.A_norm(A)
        S = np.ones(S_org.shape)

        # Run GMCA
        gmca_core.gmca_numba(X,
                             n_sources,
                             n_iterations,
                             A,
                             S,
                             A_p,
                             lam_p,
                             ret_min_rmse=False,
                             min_rmse_rate=n_iterations)

        # Check that GMCA returns the minimum RMSE solution
        np.testing.assert_almost_equal(S, np.dot(np.linalg.pinv(A), X))

        # Reset A and S
        A = np.ones(A_org.shape)
        helpers.A_norm(A)
        S = np.ones(S_org.shape)

        # Re-run GMCA without ret_min_rmse
        gmca_core.gmca_numba(X,
                             n_sources,
                             n_iterations,
                             A,
                             S,
                             A_p,
                             lam_p,
                             ret_min_rmse=False,
                             min_rmse_rate=n_iterations - 1)

        # Check that GMCA does not return the min_rmse solution
        self.assertGreater(np.mean(np.abs(S - np.dot(np.linalg.pinv(A), X))),
                           1e-4)
Пример #4
0
	def test_A_norm(self):
		# Check that for multiple random values A_norm behaves as desired
		n_A_test = 10
		n_freqs = 8
		n_sources = 5

		for _ in range(n_A_test):
			A = np.random.randn(n_freqs*n_sources).reshape((n_freqs,n_sources))
			helpers.A_norm(A)
			for i in range(n_sources):
				self.assertAlmostEqual(np.sum(np.square(A[:,i])),1)
Пример #5
0
def gmca(X,
         n_sources,
         n_iterations,
         A_init=None,
         S_init=None,
         A_p=None,
         lam_p=None,
         enforce_nn_A=True,
         lam_s=1,
         ret_min_rmse=True,
         min_rmse_rate=0,
         seed=0):
    """Runs the base gmca algorithm on X.

	Parameters:
		X (np.array): A numpy array with dimensions number_of_maps (or
			frequencies) x number of data points (wavelet coefficients) per
			map.
		n_sources (int): the number of sources to attempt to extract from the
			data.
		n_iterations (int): the number of iterations of coordinate descent to
			conduct.
		A_init (np.array): An initial value for the mixing matrix with
			dimensions (X.shape[0],n_sources). If set to None
		S_init (np.array): An initial value for the source matrix with
			dimensions (n_sources,X.shape[1]).
		A_p (np.array): A matrix prior for the gmca calculation.
		lam_p ([float,...]): A n_sources long array of prior for each of
			the columns of A_p. This allows for a different lam_p to be
			applied to different columns of A_p.
		enforce_nn_A (bool): a boolean that determines if the mixing matrix
			will be forced to only have non-negative values.
		lam_s (float): The lambda parameter for the sparsity l1 norm.
		ret_min_rmse (bool): A boolean parameter that decides if the minimum
			rmse error solution for S will be returned. This will give best
			CMB reconstruction but will not return the minimum of the loss
			function.
		min_rmse_rate (int): How often the source matrix will be set to the
			minimum rmse solution. 0 will never return min_rmse within the
			gmca optimization.
		seed (int): An integer to seed the random number generator.

	Returns:
		(np.array,np.array): Returns the mixing matrix A and the
		source matrix S.

	Notes:
		A and S must be passed in as contiguous arrays. This can be done
		with np.ascontiguousarray.
	"""
    # Deal with the case where no lam_p or A_p is passed in.
    if A_p is None or lam_p is None:
        A_p = np.zeros((len(X), n_sources))
        lam_p = np.zeros(n_sources)

    # Set up A and S.
    if A_init is not None:
        A = np.copy(A_init)
    else:
        # Initialize A to the prior matrix A_p.
        A = np.copy(A_p)

        # Find which sources have no prior.
        non_priors = np.where(np.array(lam_p) == 0)[0]

        # For sources with no prior, initialize with PCA
        X_sq = np.matmul(X, X.T)
        _, eig_v = np.linalg.eig(X_sq)
        A[:, non_priors] = eig_v[:, 0:len(non_priors)]

        # Ensure that A is stored contiguously in memory.
        A = np.ascontiguousarray(np.real(A))

        # Deal with the potential edge case of having an entirely negative
        # column left over from PCA.
        if enforce_nn_A:
            for i in range(len(A[0])):
                if max(A[:, i]) < 0:
                    A[:, i] = -A[:, i]
        helpers.A_norm(A)

    # We can now initialize S using our initial value for A unless an S value
    # is provided.
    if S_init is None:
        S = np.matmul(np.linalg.pinv(A), X)
    else:
        S = S_init

    # Call gmca_numba
    gmca_numba(X,
               n_sources,
               n_iterations,
               A,
               S,
               A_p,
               lam_p,
               enforce_nn_A=enforce_nn_A,
               lam_s=lam_s,
               ret_min_rmse=ret_min_rmse,
               min_rmse_rate=min_rmse_rate,
               seed=seed)

    # Return the mixing matrix and the source.
    return A, S
Пример #6
0
    def test_gmca_end_to_end(self):
        warnings.filterwarnings("ignore")

        # Test that gmca works end to end, returning reasonable results.
        rseed = 5
        n_freqs = 10
        n_wavs = int(1e3)
        n_iterations = 50
        n_sources = 5
        lam_s = 1e-5
        lam_p = [0.0] * n_sources
        s_mag = 1e3

        # Generate ground truth A and S
        A_org = np.random.rand(n_freqs * n_sources).reshape(
            (n_freqs, n_sources))
        helpers.A_norm(A_org)
        S_org = np.random.rand(n_sources * n_wavs).reshape((n_sources, n_wavs))
        S_org *= s_mag
        X = np.dot(A_org, S_org)

        # Initialize A and S for GMCA
        A_p = np.ones(A_org.shape)
        helpers.A_norm(A_p)
        A = np.ones(A_org.shape)
        helpers.A_norm(A)
        S = np.ones(S_org.shape)

        # Run GMCA
        gmca_core.gmca_numba(X,
                             n_sources,
                             n_iterations,
                             A,
                             S,
                             A_p,
                             lam_p,
                             lam_s=lam_s,
                             ret_min_rmse=False,
                             min_rmse_rate=2 * n_iterations,
                             enforce_nn_A=True,
                             seed=rseed)

        # Save sparsity of S for later test
        sparsity_1 = np.sum(np.abs(S))
        err1 = np.sum(np.abs(np.dot(A, S) - X))

        # Continue GMCA
        gmca_core.gmca_numba(X,
                             n_sources,
                             n_iterations,
                             A,
                             S,
                             A_p,
                             lam_p,
                             lam_s=lam_s,
                             ret_min_rmse=False,
                             min_rmse_rate=2 * n_iterations,
                             enforce_nn_A=True,
                             seed=rseed)
        err2 = np.sum(np.abs(np.dot(A, S) - X))

        self.assertGreater(err1, err2)

        gmca_core.gmca_numba(X,
                             n_sources,
                             200,
                             A,
                             S,
                             A_p,
                             lam_p,
                             lam_s=lam_s,
                             ret_min_rmse=False,
                             min_rmse_rate=2 * n_iterations,
                             enforce_nn_A=True,
                             seed=rseed)

        np.testing.assert_almost_equal(np.dot(A, S), X, decimal=4)

        # Test that lam_s enforces sparsity end_to_end
        lam_s = 10

        A = np.ones(A_org.shape)
        helpers.A_norm(A)
        S = np.ones(S_org.shape)

        gmca_core.gmca_numba(X,
                             n_sources,
                             n_iterations,
                             A,
                             S,
                             A_p,
                             lam_p,
                             lam_s=lam_s,
                             ret_min_rmse=False,
                             min_rmse_rate=2 * n_iterations,
                             enforce_nn_A=True,
                             seed=rseed)

        # Save closeness to prior for later test
        A_p_val = np.sum(np.abs(A - A_p))

        self.assertLess(np.sum(np.abs(S)), sparsity_1)
        # Test that lam_p enforcces prior end_to_end

        A = np.ones(A_org.shape)
        helpers.A_norm(A)
        S = np.ones(S_org.shape)
        lam_p = [1e8] * n_sources

        gmca_core.gmca_numba(X,
                             n_sources,
                             n_iterations,
                             A,
                             S,
                             A_p,
                             lam_p,
                             lam_s=lam_s,
                             ret_min_rmse=False,
                             min_rmse_rate=2 * n_iterations,
                             enforce_nn_A=True,
                             seed=rseed)

        self.assertLess(np.sum(np.abs(A - A_p)), A_p_val)

        # Finally, test data with nans does not impose constraining power
        # where there are nans.
        X_copy = np.copy(X)
        for i in range(n_freqs):
            X_copy[i, np.random.randint(0, X.shape[1], 10)] = np.nan
        A = np.ones(A_org.shape)
        helpers.A_norm(A)
        S = np.ones(S_org.shape)
        lam_p = [0.0] * n_sources
        lam_s = 1e-5
        gmca_core.gmca_numba(X_copy,
                             n_sources,
                             200,
                             A,
                             S,
                             A_p,
                             lam_p,
                             lam_s=lam_s,
                             ret_min_rmse=False,
                             min_rmse_rate=2 * n_iterations,
                             enforce_nn_A=True,
                             seed=rseed)
        self.assertEqual(np.sum(np.isnan(np.dot(A, S))), 0)
        self.assertLess(np.mean(np.abs(X - np.dot(A, S))) / s_mag, 0.1)
Пример #7
0
def hgmca_opt(wav_analysis_maps,
              n_sources,
              n_epochs,
              lam_hier,
              lam_s,
              n_iterations,
              A_init=None,
              A_global=None,
              lam_global=None,
              seed=0,
              enforce_nn_A=True,
              min_rmse_rate=0,
              save_dict=None,
              verbose=False):
    """Runs the Hierachical GMCA algorithm on a dictionary of input maps.

	Paramters:
		wav_analysis_maps (dict): A dictionary containing the wavelet maps that
			we will run HGMCA on.
		n_sources (int): The number of sources.
		n_epochs (int): The number of times the algorithm should pass
			through all the levels.
		lam_hier (np.array): A n_sources long array of the prior for each of
			the columns of the mixing matrices in A_hier. This allows for a
			source dependent prior.
		lam_s (float): The lambda parameter for the sparsity l1 norm.
		n_iterations (int): The number of iterations of coordinate descent
			to conduct per epoch.
		A_init (np.array): A value at which to initialize all of the
			matrices in the hierarchy. If None then matrices will be
			initialized to 1/np.sqrt(n_freqs)
		A_global (np.array): A global mixing matrix prior that will be
			applied to all matrices in the hierarchy. If None no global
			prior will be enforced.
		lam_global (np.array): A n_sources long array of the prior for each
			of the columns of the global prior. Must be set if A_global is
			set.
		seed (int): An integer to seed the random number generator.
		enforce_nn_A (bool): A boolean that determines if the mixing matrix
			will be forced to only have non-negative values.
		min_rmse_rate (int): How often the source matrix will be set to the
			minimum rmse solution. 0 will never return min_rmse within the
			gmca optimization.
		save_dict (dict): A dictionary containing two entries, save_rate
			which is how often (per epoch) the results will be saved,
			and save_path, a folder the results will be saved to. If
			save_dict is provided the algorithm will try to initialize from
			the last save state.
		verbose (bool): If set to true, some timing statistics will be
			outputted for the epochs.

	Returns:
		(dict): Returns a dict with the mixing matrix and the source matrix at
		each level of analysis.
	"""
    if wav_analysis_maps['analysis_type'] != 'hgmca':
        raise ValueError('These wavelet functions were not generated using ' +
                         'the hgmca analysis type')

    if (A_global is None) != (lam_global is None):
        raise ValueError('Either both A_global and lam_global should be ' +
                         'passed in or neither should be passed in.')

    # Copy over the information we need from the wav_analysis_maps dict.
    hgmca_analysis_maps = {
        'input_maps_dict': wav_analysis_maps['input_maps_dict'],
        'analysis_type': 'hgmca',
        'scale_int': wav_analysis_maps['scale_int'],
        'j_min': wav_analysis_maps['j_min'],
        'j_max': wav_analysis_maps['j_max'],
        'band_lim': wav_analysis_maps['band_lim'],
        'target_fwhm': wav_analysis_maps['target_fwhm'],
        'output_nside': wav_analysis_maps['output_nside'],
        'm_level': wav_analysis_maps['m_level']
    }

    # Get all the pieces we need to pass into the numba code
    m_level = wav_analysis_maps['m_level']
    X_level = convert_wav_to_X_level(wav_analysis_maps)
    n_freqs = wav_analysis_maps['n_freqs']
    A_shape = (n_freqs, n_sources)

    # If a save dict is provided that already has values in it, load them.
    # Otherwise initialize our A_hier_list and S_level
    if save_dict is None or not os.path.isdir(
            os.path.join(save_dict['save_path'], 'hgmca_save')):
        A_hier_list = allocate_A_hier(m_level, A_shape, A_init=A_init)
        S_level = allocate_S_level(m_level, X_level, n_sources)
        # Initialize S to the minimum rmse solution
        init_min_rmse(X_level, A_hier_list, S_level)
    else:
        if verbose:
            print('Loading previous values from %s' % (save_dict['save_path']))
        A_hier_list, S_level = load_numba_hier_list(save_dict['save_path'],
                                                    m_level)

    # The numba code is not built to accept None arguments, so modify things
    # accordingly.
    if A_global is None:
        # Setting lam_global to 0 ensures no effect from A_global
        lam_global = np.zeros(n_sources)
        A_global = np.ones((n_freqs, n_sources))
        helpers.A_norm(A_global)

    if verbose:
        print('Running HGMCA with the following parameters:')
        print('	maximum level: %d' % (m_level))
        print('	number of epochs: %d' % (n_epochs))
        print('	iterations per epoch: %d' % (n_iterations))
        print('	enforce non-negativity of mixing matrix: %s' % (enforce_nn_A))
        print('	minimum rmse rate: %d' % (min_rmse_rate))
        print('Running main HGMCA loop')

    if save_dict:
        save_rate = save_dict['save_rate']
        if verbose:
            print('Saving results to %s every %d epochs' %
                  (save_dict['save_path'], save_rate))
        for si in tqdm(range(n_epochs // save_rate),
                       desc='hgmca epochs',
                       unit='%d epochs' % (save_dict['save_rate']),
                       unit_scale=save_rate):
            hgmca_epoch_numba(X_level,
                              A_hier_list,
                              lam_hier,
                              A_global,
                              lam_global,
                              S_level,
                              n_epochs,
                              m_level,
                              n_iterations,
                              lam_s,
                              seed,
                              enforce_nn_A,
                              min_rmse_rate,
                              epoch_start=si * save_rate)
            save_numba_hier_lists(A_hier_list, S_level, save_dict['save_path'])
            # We want reproducible behavior, but we don't want the same seed
            # for each set of epochs. This is the best quick fix.
            if seed > 0:
                seed += 1
    else:
        if verbose:
            start = time.time()
        hgmca_epoch_numba(X_level, A_hier_list, lam_hier, A_global, lam_global,
                          S_level, n_epochs, m_level, n_iterations, lam_s,
                          seed, enforce_nn_A, min_rmse_rate)
        if verbose:
            print('HGMCA loop took %f seconds' % (time.time() - start))

    # Package the output into the hgmca_analysis_maps
    for level in range(m_level + 1):
        hgmca_analysis_maps[str(level)] = {
            'A': A_hier_list[level],
            'S': S_level[level]
        }

    return hgmca_analysis_maps
Пример #8
0
def hgmca_epoch_numba(X_level,
                      A_hier_list,
                      lam_hier,
                      A_global,
                      lam_global,
                      S_level,
                      n_epochs,
                      m_level,
                      n_iterations,
                      lam_s,
                      seed,
                      enforce_nn_A,
                      min_rmse_rate,
                      epoch_start=0):
    """ Runs the epoch loop for a given level of hgmca using numba optimization.

	For an in depth description of the algorithm see arxiv 1910.08077

	Parameters:
		X_level (np.array): A numba.typed.List of numpy arrays corresponding to
		the data for each level of analysis.
		A_hier_list ([np.array]): A numba.typed.List of numpy arrays containing
			the mixing matrix hierarchy for each level of analysis.
		lam_hier (np.array): A n_sources long array of the prior for each of
			the columns of the mixing matrices in A_hier. This allows for a
			source dependent prior.
		A_global (np.array): A global mixing matrix prior that will be
			applied to all matrices in the hierarchy
		lam_global (np.array): A n_sources long array of the prior for each
			of the columns of the global prior.
		S_level ([np.array,...]):A numba.typed.List of numpy arrays containing
			the source matrices for each level of analysis.
		n_epochs (int): The number of times the algorithm should pass
			through all the levels.
		m_level (int): The maximum level of analysis.
		n_iterations (int): The number of iterations of coordinate descent
			to conduct per epoch.
		lam_s (float): The lambda parameter for the sparsity l1 norm.
		seed (int): An integer to seed the random number generator.
		enforce_nn_A (bool): A boolean that determines if the mixing matrix
			will be forced to only have non-negative values.
		min_rmse_rate (int): How often the source matrix will be set to the
			minimum rmse solution. 0 will never return min_rmse within the
			gmca optimization.
		epoch_start (int): What epoch the code is starting at. Important
			for min_rmse_rate.

	Notes:
		A_hier and S_level will be updated in place.
	"""
    # We allocate the arrays that gmca will use here to avoid them being
    # reallocated
    R_i_level = numba.typed.List()
    AS_level = numba.typed.List()
    A_R_level = numba.typed.List()
    A_i_level = numba.typed.List()
    for level in range(m_level + 1):
        R_i_level.append(np.zeros(X_level[level][0].shape))
        AS_level.append(np.zeros(X_level[level][0].shape))
        A_R_level.append(np.zeros((1, X_level[level].shape[2])))
        A_i_level.append(np.zeros((X_level[level].shape[1], 1)))
    # Set the random seed.
    np.random.seed(seed)
    # Now we iterate through our graphical model using our approximate closed
    # form solution for the desired number of epochs.
    for epoch in range(epoch_start, epoch_start + n_epochs):
        # We want to iterate through the levels in random order. This should
        # theoretically speed up convergence.
        level_perm = np.random.permutation(m_level + 1)
        for level in level_perm:
            npatches = wavelets_hgmca.level_to_npatches(level)
            for patch in range(npatches):
                # Calculate the mixing matrix prior and the number of matrices
                # used to construct that prior.
                A_prior = get_A_prior(A_hier_list, level, patch, lam_hier)
                A_prior += lam_global * A_global
                # First we deal with the relatively simple case where there are
                # no sources at this level
                if X_level[level].size == 0:
                    A_hier_list[level][patch] = A_prior
                    helpers.A_norm(A_hier_list[level][patch])
                # If there are sources at this level we must run GMCA
                else:
                    # Extract the data for the patch.
                    X_p = X_level[level][patch]
                    S_p = S_level[level][patch]
                    # For HGMCA we want to store the source signal that
                    # includes the lasso shooting subtraction of lam_s for
                    # stability of the loss function. Only in the last step do
                    # we want gmca to return the min_rmse solution.
                    if min_rmse_rate == 0:
                        ret_min_rmse = False
                    else:
                        ret_min_rmse = (((epoch + 1) % min_rmse_rate) == 0)

                    # Call gmca for the patch. Note the lam_p has already been
                    # accounted for.
                    n_sources = len(S_p)
                    gmca_core.gmca_numba(X_p,
                                         n_sources,
                                         n_iterations,
                                         A_hier_list[level][patch],
                                         S_p,
                                         A_p=A_prior,
                                         lam_p=np.ones(n_sources),
                                         enforce_nn_A=enforce_nn_A,
                                         lam_s=lam_s,
                                         ret_min_rmse=ret_min_rmse,
                                         R_i_init=R_i_level[level],
                                         AS_init=AS_level[level],
                                         A_R_init=A_R_level[level],
                                         A_i_init=A_i_level[level])
Пример #9
0
	def test_hgmca_epoch_numba(self):
		# Test that the hgmca optimization returns roughly what we would
		# expect for a small lam_s
		n_freqs = 8
		n_sources = 5
		n_wavs = 256
		m_level = 2
		lam_s = 1e-6
		lam_hier = np.zeros(n_sources)
		lam_global = np.zeros(n_sources)

		# The true values
		s_mag = 100
		A_org = np.random.rand(n_freqs*n_sources).reshape((n_freqs,n_sources))
		helpers.A_norm(A_org)
		S_org = np.random.rand(n_sources*n_wavs).reshape((n_sources,n_wavs))
		S_org *= s_mag

		# Allocate what we need
		X_level = numba.typed.List()
		X_level.append(np.empty((0,0,0)))

		# Level 1
		npatches = wavelets_hgmca.level_to_npatches(1)
		X_level.append(np.zeros((npatches,n_freqs,n_wavs)))
		X_level[1][:] += np.dot(A_org,S_org)

		# Level 2
		npatches = wavelets_hgmca.level_to_npatches(2)
		X_level.append(np.zeros((npatches,n_freqs,n_wavs)))
		X_level[2][:] += np.dot(A_org,S_org)

		# The rest
		A_hier_list = hgmca_core.allocate_A_hier(m_level,(n_freqs,n_sources))
		S_level = hgmca_core.allocate_S_level(m_level,X_level,n_sources)
		A_global = np.random.rand(n_freqs*n_sources).reshape(
			n_freqs,n_sources)
		helpers.A_norm(A_global)

		# Run hgmca
		min_rmse_rate = 5
		enforce_nn_A = True
		seed = 5
		n_epochs = 5
		n_iterations = 30
		hgmca_core.hgmca_epoch_numba(X_level,A_hier_list,lam_hier,A_global,
			lam_global,S_level,n_epochs,m_level,n_iterations,lam_s,seed,
			enforce_nn_A,min_rmse_rate)

		for level in range(1,3):
			for patch in range(wavelets_hgmca.level_to_npatches(level)):
				np.testing.assert_almost_equal(X_level[level][patch],
					np.dot(A_hier_list[level][patch],S_level[level][patch]),
					decimal=1)

		# Repeat the same but with strong priors. Start with the global prior
		lam_global = np.ones(n_sources)*1e12
		n_epochs = 2
		n_iterations = 5
		A_hier_list = hgmca_core.allocate_A_hier(m_level,(n_freqs,n_sources))
		S_level = hgmca_core.allocate_S_level(m_level,X_level,n_sources)
		hgmca_core.hgmca_epoch_numba(X_level,A_hier_list,lam_hier,A_global,
			lam_global,S_level,n_epochs,m_level,n_iterations,lam_s,seed,
			enforce_nn_A,min_rmse_rate)
		for level in range(3):
			for patch in range(wavelets_hgmca.level_to_npatches(level)):
				np.testing.assert_almost_equal(A_hier_list[level][patch],
					A_global,decimal=4)

		# The same test, but now for the hierarchy
		lam_hier = np.ones(n_sources)*1e12
		lam_global = np.zeros(n_sources)
		n_epochs = 2
		n_iterations = 5
		A_init = np.random.rand(n_freqs*n_sources).reshape((n_freqs,n_sources))
		A_hier_list = hgmca_core.allocate_A_hier(m_level,(n_freqs,n_sources),
			A_init=A_init)
		S_level = hgmca_core.allocate_S_level(m_level,X_level,n_sources)
		for level in range(3):
			for patch in range(wavelets_hgmca.level_to_npatches(level)):
				np.testing.assert_almost_equal(A_hier_list[level][patch],
					A_init,decimal=4)