cyldir_2))) * 180 / np.pi

            # sig_fasc1 = util.rotate_atom(dic_sing_fasc[:, ID_1],
            #                              sch_mat_b0, refdir, cyldir_1,
            #                              WM_DIFF, S0_fasc[:, ID_1])
            sig_fasc1 = dic_sing_fasc[:, ID_1]
            sig_fasc2 = util.rotate_atom(dic_sing_fasc[:, ID_2], sch_mat_b0,
                                         refdir, cyldir_2, WM_DIFF,
                                         S0_fasc[:, ID_2])

            DW_image = nu1 * sig_fasc1 + nu2 * sig_fasc2

            # Simulate noise and MRI scanner scaling
            DW_image_store[:, i] = DW_image

            DW_image_noisy = util.gen_SoS_MRI(DW_image, sigma_g, num_coils)
            DW_image_noisy = M0 * DW_image_noisy

            DW_noisy_store[:, i] = DW_image_noisy

            # Estimate peak directions from noisy signal
            peaks = get_csd_peaks(DW_image_noisy, sch_mat_b0, num_fasc)

            # Analyze result of CSD (just for displaying progress). You only need to
            # store the groundtruth and estimated directions to compute all these
            # metrics afterwards.
            num_pk_detected = np.sum(np.sum(np.abs(peaks), axis=1) > 0)

            if num_pk_detected < num_fasc:
                # There should always be at least 1 detected peak because the ODF
                # always has a max.
def gen_test_data(num_samples, use_noise=False, SNR_min=80, SNR_max=100):

    #SNR_min = 80
    #SNR_max = 100
    SNR_dist = 'uniform'  # 'uniform' or 'triangular'
    starttime = time.time()

    # Prepare memory
    IDs = np.zeros((num_samples, num_fasc), dtype=np.int32)
    nus = np.zeros((num_samples, num_fasc))
    SNRs = np.zeros(num_samples)

    DW_image_store = np.zeros((552, num_samples))
    DW_noisy_store = np.zeros((552, num_samples))

    orientations = np.zeros((num_samples, num_fasc, 3))

    dic_compile = np.zeros((num_samples, num_mris, num_fasc * num_atoms),
                           dtype=np.float64)

    dictionary = np.zeros((num_mris, num_fasc * num_atoms), dtype=np.float64)

    dictionary[:, :num_atoms] = dic_sing_fasc  #first direction fixed

    for i in range(num_samples):
        if i % 100 == 0:
            print(i)

        nu1 = nu_min + (nu_max - nu_min) * np.random.rand()
        nu2 = 1 - nu1
        ID_1 = np.random.randint(0, num_atoms)
        ID_2 = np.random.randint(0, num_atoms)
        if SNR_dist == 'triangular':
            SNR = np.random.triangular(SNR_min, SNR_min, SNR_max, 1)
        elif SNR_dist == 'uniform':
            SNR = np.random.uniform(SNR_min, SNR_max, 1)
        else:
            raise ValueError("Unknown SNR distribution %s" % SNR_dist)

        sigma_g = S0_max / SNR

        # First fascicle direction fixed, second fascicle rotated on the fly
        cyldir_1 = refdir
        cyldir_2 = refdir.copy()
        while np.dot(refdir, cyldir_2) > np.cos(crossangle_min):
            cyldir_2 = np.random.randn(3)
            norm_2 = np.sqrt(np.sum(cyldir_2**2))
            if norm_2 < 1e-11:
                cyldir_2 = refdir
            else:
                cyldir_2 = cyldir_2 / norm_2

        dic_sing_fasc_2 = util.rotate_atom(dic_sing_fasc, sch_mat_b0, refdir,
                                           cyldir_2, WM_DIFF, S0_fasc)
        dictionary[:, num_atoms:] = dic_sing_fasc_2

        dic_compile[i, :, :] = dictionary

        # Assemble synthetic DWI
        DW_image = (nu1 * dic_sing_fasc[:, ID_1] +
                    nu2 * dic_sing_fasc_2[:, ID_2])

        # Simulate noise and MRI scanner scaling
        DW_image_store[:, i] = DW_image

        DW_image_noisy = util.gen_SoS_MRI(DW_image, sigma_g, num_coils)
        #DW_image_noisy = M0 * DW_image_noisy

        DW_noisy_store[:, i] = DW_image_noisy

        # Store
        IDs[i, :] = np.array([ID_1, ID_2])
        nus[i, :] = np.array([nu1, nu2])

    time_elapsed = time.time() - starttime

    return DW_image_store, DW_noisy_store, dic_compile, time_elapsed, sch_mat_b0, IDs, nus