Ejemplo n.º 1
0
def image_to_ndarray(filename, convert_grey=True, cmap=None, debug=False):
    """
    Convert an image to a numpy array using pillow (matplotlib only supports the PNG format).

    :param filename: absolute path of the image to open
    :param convert_grey: if True and the number of layers is 3, it will be converted to a single layer of grey
    :param cmap: colormap for the plots
    :param debug: True to see plots
    :return:
    """
    from PIL import Image
    if cmap is None:
        cmap = gu.Colormap(bad_color='1.0').cmap

    im = Image.open(filename)

    array = np.asarray(im)
    if array.ndim == 3 and convert_grey:
        print('converting image to gray')
        array = rgb2gray(array)

    print(f'Image shape after conversion to ndarray: {array.shape}')
    if debug:
        gu.imshow_plot(array,
                       sum_axis=2,
                       plot_colorbar=True,
                       cmap=cmap,
                       reciprocal_space=False)
    return array
Ejemplo n.º 2
0
#######################
# Initialize detector #
#######################
detector = exp.Detector(name=detector, binning=binning, roi=roi_detector)

nbz, nby, nbx = int(np.floor((detector.roi[3] - detector.roi[2]) / detector.binning[2])), \
                   int(np.floor((detector.roi[1] - detector.roi[0]) / detector.binning[1])), \
                   int(np.floor((detector.roi[3] - detector.roi[2]) / detector.binning[2]))
# for P10 data the rotation is around y vertical, hence gridded data range & binning in z and x are identical

###################
# define colormap #
###################
bad_color = '1.0'  # white background
colormap = gu.Colormap(bad_color=bad_color)
my_cmap = colormap.cmap
plt.ion()

######################
# create the lattice #
######################
pivot, _, q_values, lattice, peaks = simu.lattice(
    energy=energy,
    sdd=sdd,
    direct_beam=direct_beam,
    detector=detector,
    unitcell=unitcell,
    unitcell_param=unitcell_param,
    euler_angles=angles,
    offset_indices=False)
Ejemplo n.º 3
0
def main(calc_self, user_comment):
    """
    Protection for multiprocessing.

    :param calc_self: if True, the cross-correlation will be calculated between same q-values
    :param user_comment: comment to include in the filename when saving results
    """
    ##########################
    # check input parameters #
    ##########################
    global corr_count, current_point
    assert len(
        origin_qspace
    ) == 3, "origin_qspace should be a tuple of 3 integer pixel values"
    assert type(calc_self) is bool, "unexpected type for calc_self"
    assert len(q_range) > 1, "at least 2 values are needed for q_range"

    print('the CCF map will be calculated for {:d} q values: '.format(
        len(q_range)))
    for idx in range(len(q_range)):
        if calc_self:
            print('q1 = {:.3f}  q2 = {:.3f}'.format(q_range[idx],
                                                    q_range[idx]))
        else:
            print('q1 = {:.3f}  q2 = {:.3f}'.format(q_range[0], q_range[idx]))
    warnings.filterwarnings("ignore")

    ###################
    # define colormap #
    ###################
    bad_color = '1.0'  # white background
    colormap = gu.Colormap(bad_color=bad_color)
    my_cmap = colormap.cmap
    plt.ion()

    ###################################
    # load experimental data and mask #
    ###################################
    plt.ion()
    root = tk.Tk()
    root.withdraw()
    file_path = filedialog.askopenfilename(
        initialdir=datadir,
        title="Select the 3D reciprocal space map",
        filetypes=[("NPZ", "*.npz")])
    data = np.load(file_path)['data']

    file_path = filedialog.askopenfilename(initialdir=datadir,
                                           title="Select the 3D mask",
                                           filetypes=[("NPZ", "*.npz")])
    mask = np.load(file_path)['mask']

    print((data > hotpix_threshold).sum(), ' hotpixels masked')
    mask[data > hotpix_threshold] = 1
    data[np.nonzero(mask)] = np.nan
    del mask
    gc.collect()

    file_path = filedialog.askopenfilename(initialdir=datadir,
                                           title="Select q values",
                                           filetypes=[("NPZ", "*.npz")])
    qvalues = np.load(file_path)
    qx = qvalues['qx']
    qz = qvalues['qz']
    qy = qvalues['qy']

    del qvalues
    gc.collect()

    ##############################################################
    # calculate the angular average using mean and median values #
    ##############################################################
    if plot_meandata:
        q_axis, y_mean_masked, y_median_masked = xcca.angular_avg(
            data=data,
            q_values=(qx, qz, qy),
            origin=origin_qspace,
            nb_bins=250,
            debugging=debug)
        fig, ax = plt.subplots(1, 1)
        ax.plot(q_axis, np.log10(y_mean_masked), 'r', label='mean')
        ax.plot(q_axis, np.log10(y_median_masked), 'b', label='median')
        ax.axvline(x=q_range[0],
                   ymin=0,
                   ymax=1,
                   color='g',
                   linestyle='--',
                   label='q_start')
        ax.axvline(x=q_range[-1],
                   ymin=0,
                   ymax=1,
                   color='r',
                   linestyle=':',
                   label='q_stop')
        ax.set_xlabel('q (1/nm)')
        ax.set_ylabel('Angular average (A.U.)')
        ax.legend()
        plt.pause(0.1)
        fig.savefig(savedir + '1D_average.png')

        del q_axis, y_median_masked, y_mean_masked

    ##############################################################
    # interpolate the data onto spheres at user-defined q values #
    ##############################################################
    # calculate the matrix of distances from the origin of reciprocal space
    distances = np.sqrt(
        (qx[:, np.newaxis, np.newaxis] - qx[origin_qspace[0]])**2 +
        (qz[np.newaxis, :, np.newaxis] - qz[origin_qspace[1]])**2 +
        (qy[np.newaxis, np.newaxis, :] - qy[origin_qspace[2]])**2)
    dq = min(qx[1] - qx[0], qz[1] - qz[0], qy[1] - qy[0])

    q_int = dict()  # create dictionnary
    dict_fields = []
    [dict_fields.append('q' + str(idx + 1))
     for idx in range(len(q_range))]  # ['q1', 'q2', 'q3', ...]
    nb_points = []

    for counter, q_value in enumerate(q_range):
        indices = np.nonzero((np.logical_and((distances < q_value + dq),
                                             (distances > q_value - dq))))
        nb_voxels = indices[0].shape
        print(
            '\nNumber of voxels for the sphere of radius q ={:.3f} 1/nm:'.
            format(q_value), nb_voxels)

        qx_voxels = qx[indices[0]]  # qx downstream, axis 0
        qz_voxels = qz[indices[1]]  # qz vertical up, axis 1
        qy_voxels = qy[indices[2]]  # qy outboard, axis 2
        int_voxels = data[indices]

        if debug:
            # calculate the stereographic projection
            stereo_proj, uv_labels = fu.calc_stereoproj_facet(
                projection_axis=1,
                radius_mean=q_value,
                stereo_center=0,
                vectors=np.concatenate(
                    (qx_voxels[:, np.newaxis], qz_voxels[:, np.newaxis],
                     qy_voxels[:, np.newaxis]),
                    axis=1))
            # plot the projection from the South pole
            fig, _ = gu.scatter_stereographic(
                euclidian_u=stereo_proj[:, 0],
                euclidian_v=stereo_proj[:, 1],
                color=int_voxels,
                title='Projection from the South pole'
                ' at q={:.3f} (1/nm)'.format(q_value),
                uv_labels=uv_labels,
                cmap=my_cmap)
            fig.savefig(savedir + 'South pole_q={:.3f}.png'.format(q_value))
            plt.close(fig)

            # plot the projection from the North pole
            fig, _ = gu.scatter_stereographic(
                euclidian_u=stereo_proj[:, 2],
                euclidian_v=stereo_proj[:, 3],
                color=int_voxels,
                title='Projection from the North pole'
                ' at q={:.3f} (1/nm)'.format(q_value),
                uv_labels=uv_labels,
                cmap=my_cmap)
            fig.savefig(savedir + 'North pole_q={:.3f}.png'.format(q_value))
            plt.close(fig)

        # look for nan values
        nan_indices = np.argwhere(np.isnan(int_voxels))

        #  remove nan values before calculating the cross-correlation function
        qx_voxels = np.delete(qx_voxels, nan_indices)
        qz_voxels = np.delete(qz_voxels, nan_indices)
        qy_voxels = np.delete(qy_voxels, nan_indices)
        int_voxels = np.delete(int_voxels, nan_indices)

        # normalize the intensity by the median value (remove the influence of the form factor)
        print('q={:.3f}:'.format(q_value), ' normalizing by the median value',
              np.median(int_voxels))
        int_voxels = int_voxels / np.median(int_voxels)

        q_int[dict_fields[counter]] = np.concatenate(
            (qx_voxels[:, np.newaxis], qz_voxels[:, np.newaxis],
             qy_voxels[:, np.newaxis], int_voxels[:, np.newaxis]),
            axis=1)
        # update the number of points without nan
        nb_points.append(len(qx_voxels))
        print('q={:.3f}:'.format(q_value), ' removing', nan_indices.size,
              'nan values,', nb_points[counter], 'remain')

        del qx_voxels, qz_voxels, qy_voxels, int_voxels, indices, nan_indices
        gc.collect()
    del qx, qy, qz, distances, data
    gc.collect()

    ############################################
    # calculate the cross-correlation function #
    ############################################
    cross_corr = np.empty((len(q_range), int(180 / angular_resolution), 2))
    angular_bins = np.linspace(start=0,
                               stop=np.pi,
                               num=corr_count.shape[0],
                               endpoint=False)

    start = time.time()
    print("\nNumber of processors: ", mp.cpu_count())
    mp.freeze_support()

    for ind_q in range(len(q_range)):
        pool = mp.Pool(mp.cpu_count())  # use this number of processes
        if calc_self:
            key_q1 = 'q' + str(ind_q + 1)
            key_q2 = key_q1
            print('\n' + key_q2 +
                  ': the CCF will be calculated over {:d} * {:d}'
                  ' points and {:d} angular bins'.format(
                      nb_points[ind_q], nb_points[ind_q], corr_count.shape[0]))
            for ind_point in range(nb_points[ind_q]):
                pool.apply_async(xcca.calc_ccf_rect,
                                 args=(ind_point, key_q1, key_q2, angular_bins,
                                       q_int),
                                 callback=collect_result,
                                 error_callback=util.catch_error)
        else:
            key_q1 = 'q1'
            key_q2 = 'q' + str(ind_q + 1)
            print('\n' + key_q2 +
                  ': the CCF will be calculated over {:d} * {:d}'
                  ' points and {:d} angular bins'.format(
                      nb_points[0], nb_points[ind_q], corr_count.shape[0]))
            for ind_point in range(nb_points[0]):
                pool.apply_async(xcca.calc_ccf_rect,
                                 args=(ind_point, key_q1, key_q2, angular_bins,
                                       q_int),
                                 callback=collect_result,
                                 error_callback=util.catch_error)

        # close the pool and let all the processes complete
        pool.close()
        pool.join(
        )  # postpones the execution of next line of code until all processes in the queue are done.

        # normalize the cross-correlation by the counter
        indices = np.nonzero(corr_count[:, 1])
        corr_count[indices,
                   0] = corr_count[indices, 0] / corr_count[indices, 1]
        cross_corr[ind_q, :, :] = corr_count

        # initialize the globals for the next q value
        corr_count = np.zeros(
            (int(180 / angular_resolution),
             2))  # corr_count is declared as a global, this should work
        current_point = 0

    end = time.time()
    print('\nTime ellapsed for the calculation of the CCF map:',
          str(datetime.timedelta(seconds=int(end - start))))

    #######################################
    # save the cross-correlation function #
    #######################################
    if calc_self:
        user_comment = user_comment + '_self'
    else:
        user_comment = user_comment + '_cross'
    filename = 'CCFmap_qstart={:.3f}_qstop={:.3f}'.format(q_range[0], q_range[-1]) +\
               '_res{:.3f}'.format(angular_resolution) + user_comment
    np.savez_compressed(savedir + filename + '.npz',
                        angles=180 * angular_bins / np.pi,
                        q_range=q_range,
                        ccf=cross_corr[:, :, 0],
                        points=cross_corr[:, :, 1])

    #######################################
    # plot the cross-correlation function #
    #######################################
    # find the y limit excluding the peaks at 0 and 180 degrees
    indices = np.argwhere(
        np.logical_and((angular_bins >= 20 * np.pi / 180),
                       (angular_bins <= 160 * np.pi / 180)))
    vmax = 1.2 * cross_corr[:, indices, 0].max()
    print('Discarding CCF values with a zero counter:',
          (cross_corr[:, :, 1] == 0).sum(), 'points masked')
    cross_corr[(cross_corr[:, :, 1] == 0),
               0] = np.nan  # discard these values of the CCF

    dq = q_range[1] - q_range[0]
    fig, ax = plt.subplots()
    plt0 = ax.imshow(
        cross_corr[:, :, 0],
        cmap=my_cmap,
        vmin=0,
        vmax=vmax,
        extent=[0, 180, q_range[-1] + dq / 2,
                q_range[0] - dq / 2])  # extent (left, right, bottom, top)
    ax.set_xlabel('Angle (deg)')
    ax.set_ylabel('q (nm$^{-1}$)')
    ax.set_xticks(np.arange(0, 181, 30))
    ax.set_yticks(q_range)
    ax.set_aspect('auto')
    if calc_self:
        ax.set_title('self CCF from q={:.3f} 1/nm  to q={:.3f} 1/nm'.format(
            q_range[0], q_range[-1]))
    else:
        ax.set_title('cross CCF from q={:.3f} 1/nm  to q={:.3f} 1/nm'.format(
            q_range[0], q_range[-1]))
    gu.colorbar(plt0, scale='linear', numticks=5)
    fig.savefig(savedir + filename + '.png')

    plt.ioff()
    plt.show()
Ejemplo n.º 4
0
]  # plot vertical dashed lines at these q values, leave [] otherwise
# position in pixels of the origin of the angular average in the array.
# if a nan value is used, the origin will be set at the middle of the array in the corresponding dimension.
threshold = 0  # data < threshold will be set to 0
debug = False  # True to show more plots
xlim = None  # limits used for the horizontal axis of the angular plot, leave None otherwise
ylim = None  # limits used for the vertical axis of++ plots, leave None otherwise
save_txt = True  # True to save q values and the average in .txt format
##########################
# end of user parameters #
##########################

###################
# define colormap #
###################
colormap = gu.Colormap()
my_cmap = colormap.cmap

##############################
# load reciprocal space data #
##############################
plt.ion()
root = tk.Tk()
root.withdraw()
file_path = filedialog.askopenfilename(initialdir=root_folder,
                                       title="Select the diffraction pattern",
                                       filetypes=[("NPZ", "*.npz")])
npzfile = np.load(file_path)
diff_pattern = pu.bin_data(npzfile[list(npzfile.files)[0]].astype(int),
                           (bin_factor, bin_factor, bin_factor),
                           debugging=False)
Ejemplo n.º 5
0
def main(user_comment):
    """
    Protection for multiprocessing.

    :param user_comment: comment to include in the filename when saving results
    """
    ##########################
    # check input parameters #
    ##########################
    global corr_count

    if len(q_xcca) != 2:
        raise ValueError("Two q values should be provided (it can be the same value)")
    if len(origin_qspace) != 3:
        raise ValueError("origin_qspace should be a tuple of 3 integer pixel values")
    q_xcca.sort()
    same_q = q_xcca[0] == q_xcca[1]
    warnings.filterwarnings("ignore")

    ###################
    # define colormap #
    ###################
    bad_color = "1.0"  # white background
    colormap = gu.Colormap(bad_color=bad_color)
    my_cmap = colormap.cmap
    plt.ion()

    ###################################
    # load experimental data and mask #
    ###################################
    plt.ion()
    root = tk.Tk()
    root.withdraw()
    file_path = filedialog.askopenfilename(
        initialdir=datadir,
        title="Select the 3D reciprocal space map",
        filetypes=[("NPZ", "*.npz")],
    )
    data = np.load(file_path)["data"]

    file_path = filedialog.askopenfilename(
        initialdir=datadir, title="Select the 3D mask", filetypes=[("NPZ", "*.npz")]
    )
    mask = np.load(file_path)["mask"]

    print((data > hotpix_threshold).sum(), " hotpixels masked")
    mask[data > hotpix_threshold] = 1
    data[np.nonzero(mask)] = np.nan
    del mask
    gc.collect()

    file_path = filedialog.askopenfilename(
        initialdir=datadir, title="Select q values", filetypes=[("NPZ", "*.npz")]
    )
    qvalues = np.load(file_path)
    qx = qvalues["qx"]
    qz = qvalues["qz"]
    qy = qvalues["qy"]

    del qvalues
    gc.collect()

    ##############################################################
    # calculate the angular average using mean and median values #
    ##############################################################
    if plot_meandata:
        q_axis, y_mean_masked, y_median_masked = xcca.angular_avg(
            data=data,
            q_values=(qx, qz, qy),
            origin=origin_qspace,
            nb_bins=250,
            debugging=debug,
        )
        fig, ax = plt.subplots(1, 1)
        ax.plot(q_axis, np.log10(y_mean_masked), "r", label="mean")
        ax.plot(q_axis, np.log10(y_median_masked), "b", label="median")
        ax.axvline(x=q_xcca[0], ymin=0, ymax=1, color="g", linestyle="--", label="q1")
        ax.axvline(x=q_xcca[1], ymin=0, ymax=1, color="r", linestyle=":", label="q2")
        ax.set_xlabel("q (1/nm)")
        ax.set_ylabel("Angular average (A.U.)")
        ax.legend()
        plt.pause(0.1)
        fig.savefig(savedir + "1D_average.png")

        del q_axis, y_median_masked, y_mean_masked

    ##############################################################
    # interpolate the data onto spheres at user-defined q values #
    ##############################################################
    # calculate the matrix of distances from the origin of reciprocal space
    distances = np.sqrt(
        (qx[:, np.newaxis, np.newaxis] - qx[origin_qspace[0]]) ** 2
        + (qz[np.newaxis, :, np.newaxis] - qz[origin_qspace[1]]) ** 2
        + (qy[np.newaxis, np.newaxis, :] - qy[origin_qspace[2]]) ** 2
    )
    dq = min(qx[1] - qx[0], qz[1] - qz[0], qy[1] - qy[0])

    q_int = {}  # create dictionnary
    dict_fields = ["q1", "q2"]
    nb_points = []

    for counter, q_value in enumerate(q_xcca):
        if (counter == 0) or ((counter == 1) and not same_q):
            indices = np.nonzero(
                (np.logical_and((distances < q_value + dq), (distances > q_value - dq)))
            )
            nb_voxels = indices[0].shape
            print(
                "\nNumber of voxels for the sphere of radius q ={:.3f} 1/nm:".format(
                    q_value
                ),
                nb_voxels,
            )

            qx_voxels = qx[indices[0]]  # qx downstream, axis 0
            qz_voxels = qz[indices[1]]  # qz vertical up, axis 1
            qy_voxels = qy[indices[2]]  # qy outboard, axis 2
            int_voxels = data[indices]

            if debug:
                # calculate the stereographic projection
                stereo_proj, uv_labels = fu.calc_stereoproj_facet(
                    projection_axis=1,
                    radius_mean=q_value,
                    stereo_center=0,
                    vectors=np.concatenate(
                        (
                            qx_voxels[:, np.newaxis],
                            qz_voxels[:, np.newaxis],
                            qy_voxels[:, np.newaxis],
                        ),
                        axis=1,
                    ),
                )
                # plot the projection from the South pole
                fig, _ = gu.scatter_stereographic(
                    euclidian_u=stereo_proj[:, 0],
                    euclidian_v=stereo_proj[:, 1],
                    color=int_voxels,
                    title="Projection from the South pole"
                    " at q={:.3f} (1/nm)".format(q_value),
                    uv_labels=uv_labels,
                    cmap=my_cmap,
                )
                fig.savefig(savedir + "South pole_q={:.3f}.png".format(q_value))
                plt.close(fig)

                # plot the projection from the North pole
                fig, _ = gu.scatter_stereographic(
                    euclidian_u=stereo_proj[:, 2],
                    euclidian_v=stereo_proj[:, 3],
                    color=int_voxels,
                    title="Projection from the North pole"
                    " at q={:.3f} (1/nm)".format(q_value),
                    uv_labels=uv_labels,
                    cmap=my_cmap,
                )
                fig.savefig(savedir + "North pole_q={:.3f}.png".format(q_value))
                plt.close(fig)

            # look for nan values
            nan_indices = np.argwhere(np.isnan(int_voxels))

            #  remove nan values before calculating the cross-correlation function
            qx_voxels = np.delete(qx_voxels, nan_indices)
            qz_voxels = np.delete(qz_voxels, nan_indices)
            qy_voxels = np.delete(qy_voxels, nan_indices)
            int_voxels = np.delete(int_voxels, nan_indices)

            # normalize the intensity by the median value (remove the influence of
            # the form factor)
            print(
                "q={:.3f}:".format(q_value),
                " normalizing by the median value",
                np.median(int_voxels),
            )
            int_voxels = int_voxels / np.median(int_voxels)

            q_int[dict_fields[counter]] = np.concatenate(
                (
                    qx_voxels[:, np.newaxis],
                    qz_voxels[:, np.newaxis],
                    qy_voxels[:, np.newaxis],
                    int_voxels[:, np.newaxis],
                ),
                axis=1,
            )
            # update the number of points without nan
            nb_points.append(len(qx_voxels))
            print(
                "q={:.3f}:".format(q_value),
                " removing",
                nan_indices.size,
                "nan values,",
                nb_points[counter],
                "remain",
            )

            del qx_voxels, qz_voxels, qy_voxels, int_voxels, indices, nan_indices
            gc.collect()
    del qx, qy, qz, distances, data
    gc.collect()

    ############################################
    # calculate the cross-correlation function #
    ############################################
    if same_q:
        key_q2 = "q1"
        print(
            "\nThe CCF will be calculated over {:d} * {:d}"
            " points and {:d} angular bins".format(
                nb_points[0], nb_points[0], corr_count.shape[0]
            )
        )
    else:
        key_q2 = "q2"
        print(
            "\nThe CCF will be calculated over {:d} * {:d}"
            " points and {:d} angular bins".format(
                nb_points[0], nb_points[1], corr_count.shape[0]
            )
        )

    angular_bins = np.linspace(
        start=0, stop=np.pi, num=corr_count.shape[0], endpoint=False
    )

    start = time.time()
    if single_proc:
        for idx in range(nb_points[0]):
            ccf_uniq_val, counter_val, counter_indices = xcca.calc_ccf_rect(
                point=idx,
                q1_name="q1",
                q2_name=key_q2,
                bin_values=angular_bins,
                q_int=q_int,
            )
            collect_result_debug(ccf_uniq_val, counter_val, counter_indices)
    else:
        print("\nNumber of processors: ", mp.cpu_count())
        mp.freeze_support()
        pool = mp.Pool(mp.cpu_count())  # use this number of processes
        for idx in range(nb_points[0]):
            pool.apply_async(
                xcca.calc_ccf_rect,
                args=(idx, "q1", key_q2, angular_bins, q_int),
                callback=collect_result,
                error_callback=util.catch_error,
            )
        # close the pool and let all the processes complete
        pool.close()
        pool.join()  # postpones the execution of next line of code until all
        # processes in the queue are done.
    end = time.time()
    print(
        "\nTime ellapsed for the calculation of the CCF:",
        str(datetime.timedelta(seconds=int(end - start))),
    )

    # normalize the cross-correlation by the counter
    indices = np.nonzero(corr_count[:, 1])
    corr_count[indices, 0] = corr_count[indices, 0] / corr_count[indices, 1]

    #######################################
    # save the cross-correlation function #
    #######################################
    filename = (
        "CCF_q1={:.3f}_q2={:.3f}".format(q_xcca[0], q_xcca[1])
        + "_points{:d}_res{:.3f}".format(nb_points[0], angular_resolution)
        + user_comment
    )
    np.savez_compressed(
        savedir + filename + ".npz",
        angles=180 * angular_bins / np.pi,
        ccf=corr_count[:, 0],
        points=corr_count[:, 1],
    )

    #######################################
    # plot the cross-correlation function #
    #######################################
    # find the y limit excluding the peaks at 0 and 180 degrees
    indices = np.argwhere(
        np.logical_and(
            (angular_bins >= 5 * np.pi / 180), (angular_bins <= 175 * np.pi / 180)
        )
    )
    ymax = 1.2 * corr_count[indices, 0].max()
    print(
        "Discarding CCF values with a zero counter:",
        (corr_count[:, 1] == 0).sum(),
        "points masked",
    )
    corr_count[(corr_count[:, 1] == 0), 0] = np.nan  # discard these values of the CCF

    fig, ax = plt.subplots()
    ax.plot(
        180 * angular_bins / np.pi,
        corr_count[:, 0],
        color="red",
        linestyle="-",
        markerfacecolor="blue",
        marker=".",
    )
    ax.set_xlim(0, 180)
    ax.set_ylim(0, ymax)
    ax.set_xlabel("Angle (deg)")
    ax.set_ylabel("Cross-correlation")
    ax.set_xticks(np.arange(0, 181, 30))
    ax.set_title(
        "CCF at q1={:.3f} 1/nm  and q2={:.3f} 1/nm".format(q_xcca[0], q_xcca[1])
    )
    fig.savefig(savedir + filename + ".png")

    _, ax = plt.subplots()
    ax.plot(
        180 * angular_bins / np.pi,
        corr_count[:, 1],
        linestyle="None",
        markerfacecolor="blue",
        marker=".",
    )
    ax.set_xlim(0, 180)
    ax.set_xlabel("Angle (deg)")
    ax.set_ylabel("Number of points")
    ax.set_xticks(np.arange(0, 181, 30))
    ax.set_title("Points per angular bin")
    plt.ioff()
    plt.show()
Ejemplo n.º 6
0
def run(prm):
    """
    Run the postprocessing.

    :param prm: the parsed parameters
    """
    pretty = pprint.PrettyPrinter(indent=4)

    ################################
    # assign often used parameters #
    ################################
    bragg_peak = prm.get("bragg_peak")
    debug = prm.get("debug", False)
    comment = prm.get("comment", "")
    centering_method = prm.get("centering_method", "max_com")
    original_size = prm.get("original_size")
    phasing_binning = prm.get("phasing_binning", [1, 1, 1])
    preprocessing_binning = prm.get("preprocessing_binning", [1, 1, 1])
    ref_axis_q = prm.get("ref_axis_q", "y")
    fix_voxel = prm.get("fix_voxel")
    save = prm.get("save", True)
    tick_spacing = prm.get("tick_spacing", 50)
    tick_direction = prm.get("tick_direction", "inout")
    tick_length = prm.get("tick_length", 10)
    tick_width = prm.get("tick_width", 2)
    invert_phase = prm.get("invert_phase", True)
    correct_refraction = prm.get("correct_refraction", False)
    threshold_unwrap_refraction = prm.get("threshold_unwrap_refraction", 0.05)
    threshold_gradient = prm.get("threshold_gradient", 1.0)
    offset_method = prm.get("offset_method", "mean")
    phase_offset = prm.get("phase_offset", 0)
    offset_origin = prm.get("phase_offset_origin")
    sort_method = prm.get("sort_method", "variance/mean")
    correlation_threshold = prm.get("correlation_threshold", 0.90)
    roi_detector = create_roi(dic=prm)

    # parameters below must be provided
    try:
        detector_name = prm["detector"]
        beamline_name = prm["beamline"]
        rocking_angle = prm["rocking_angle"]
        isosurface_strain = prm["isosurface_strain"]
        output_size = prm["output_size"]
        save_frame = prm["save_frame"]
        data_frame = prm["data_frame"]
        scan = prm["scan"]
        sample_name = prm["sample_name"]
        root_folder = prm["root_folder"]
    except KeyError as ex:
        print("Required parameter not defined")
        raise ex

    prm["sample"] = (f"{sample_name}+{scan}",)
    #########################
    # Check some parameters #
    #########################
    if not prm.get("backend"):
        prm["backend"] = "Qt5Agg"
    matplotlib.use(prm["backend"])
    if prm["simulation"]:
        invert_phase = False
        correct_refraction = 0
    if invert_phase:
        phase_fieldname = "disp"
    else:
        phase_fieldname = "phase"

    if data_frame == "detector":
        is_orthogonal = False
    else:
        is_orthogonal = True

    if data_frame == "crystal" and save_frame != "crystal":
        print(
            "data already in the crystal frame before phase retrieval,"
            " it is impossible to come back to the laboratory "
            "frame, parameter 'save_frame' defaulted to 'crystal'"
        )
        save_frame = "crystal"

    axis_to_array_xyz = {
        "x": np.array([1, 0, 0]),
        "y": np.array([0, 1, 0]),
        "z": np.array([0, 0, 1]),
    }  # in xyz order

    ###############
    # Set backend #
    ###############
    if prm.get("backend") is not None:
        try:
            plt.switch_backend(prm["backend"])
        except ModuleNotFoundError:
            print(f"{prm['backend']} backend is not supported.")

    ###################
    # define colormap #
    ###################
    if prm.get("grey_background"):
        bad_color = "0.7"
    else:
        bad_color = "1.0"  # white background
    colormap = gu.Colormap(bad_color=bad_color)
    my_cmap = colormap.cmap

    #######################
    # Initialize detector #
    #######################
    detector = create_detector(
        name=detector_name,
        template_imagefile=prm.get("template_imagefile"),
        roi=roi_detector,
        binning=phasing_binning,
        preprocessing_binning=preprocessing_binning,
        pixel_size=prm.get("pixel_size"),
    )

    ####################################
    # define the experimental geometry #
    ####################################
    setup = Setup(
        beamline=beamline_name,
        detector=detector,
        energy=prm.get("energy"),
        outofplane_angle=prm.get("outofplane_angle"),
        inplane_angle=prm.get("inplane_angle"),
        tilt_angle=prm.get("tilt_angle"),
        rocking_angle=rocking_angle,
        distance=prm.get("sdd"),
        sample_offsets=prm.get("sample_offsets"),
        actuators=prm.get("actuators"),
        custom_scan=prm.get("custom_scan", False),
        custom_motors=prm.get("custom_motors"),
        dirbeam_detector_angles=prm.get("dirbeam_detector_angles"),
        direct_beam=prm.get("direct_beam"),
        is_series=prm.get("is_series", False),
    )

    ########################################
    # Initialize the paths and the logfile #
    ########################################
    setup.init_paths(
        sample_name=sample_name,
        scan_number=scan,
        root_folder=root_folder,
        data_dir=prm.get("data_dir"),
        save_dir=prm.get("save_dir"),
        specfile_name=prm.get("specfile_name"),
        template_imagefile=prm.get("template_imagefile"),
    )

    setup.create_logfile(
        scan_number=scan, root_folder=root_folder, filename=detector.specfile
    )

    # load the goniometer positions needed in the calculation
    # of the transformation matrix
    setup.read_logfile(scan_number=scan)

    ###################
    # print instances #
    ###################
    print(f'{"#"*(5+len(str(scan)))}\nScan {scan}\n{"#"*(5+len(str(scan)))}')
    print("\n##############\nSetup instance\n##############")
    pretty.pprint(setup.params)
    print("\n#################\nDetector instance\n#################")
    pretty.pprint(detector.params)

    ################
    # preload data #
    ################
    if prm.get("reconstruction_file") is not None:
        file_path = (prm["reconstruction_file"],)
    else:
        root = tk.Tk()
        root.withdraw()
        file_path = filedialog.askopenfilenames(
            initialdir=detector.scandir
            if prm.get("data_dir") is None
            else detector.datadir,
            filetypes=[
                ("NPZ", "*.npz"),
                ("NPY", "*.npy"),
                ("CXI", "*.cxi"),
                ("HDF5", "*.h5"),
            ],
        )

    nbfiles = len(file_path)
    plt.ion()

    obj, extension = util.load_file(file_path[0])
    if extension == ".h5":
        comment = comment + "_mode"

    print("\n###############\nProcessing data\n###############")
    nz, ny, nx = obj.shape
    print("Initial data size: (", nz, ",", ny, ",", nx, ")")
    if not original_size:
        original_size = obj.shape
    print("FFT size before accounting for phasing_binning", original_size)
    original_size = tuple(
        [
            original_size[index] // phasing_binning[index]
            for index in range(len(phasing_binning))
        ]
    )
    print("Binning used during phasing:", detector.binning)
    print("Padding back to original FFT size", original_size)
    obj = util.crop_pad(array=obj, output_shape=original_size)

    ###########################################################################
    # define range for orthogonalization and plotting - speed up calculations #
    ###########################################################################
    zrange, yrange, xrange = pu.find_datarange(
        array=obj, amplitude_threshold=0.05, keep_size=prm.get("keep_size", False)
    )

    numz = zrange * 2
    numy = yrange * 2
    numx = xrange * 2
    print(
        f"Data shape used for orthogonalization and plotting: ({numz}, {numy}, {numx})"
    )

    ####################################################################################
    # find the best reconstruction from the list, based on mean amplitude and variance #
    ####################################################################################
    if nbfiles > 1:
        print("\nTrying to find the best reconstruction\nSorting by ", sort_method)
        sorted_obj = pu.sort_reconstruction(
            file_path=file_path,
            amplitude_threshold=isosurface_strain,
            data_range=(zrange, yrange, xrange),
            sort_method=sort_method,
        )
    else:
        sorted_obj = [0]

    #######################################
    # load reconstructions and average it #
    #######################################
    avg_obj = np.zeros((numz, numy, numx))
    ref_obj = np.zeros((numz, numy, numx))
    avg_counter = 1
    print("\nAveraging using", nbfiles, "candidate reconstructions")
    for counter, value in enumerate(sorted_obj):
        obj, extension = util.load_file(file_path[value])
        print("\nOpening ", file_path[value])
        prm[f"from_file_{counter}"] = file_path[value]

        if prm.get("flip_reconstruction", False):
            obj = pu.flip_reconstruction(obj, debugging=True)

        if extension == ".h5":
            centering_method = "do_nothing"  # do not center, data is already cropped
            # just on support for mode decomposition
            # correct a roll after the decomposition into modes in PyNX
            obj = np.roll(obj, prm.get("roll_modes", [0, 0, 0]), axis=(0, 1, 2))
            fig, _, _ = gu.multislices_plot(
                abs(obj),
                sum_frames=True,
                plot_colorbar=True,
                title="1st mode after centering",
            )

        # use the range of interest defined above
        obj = util.crop_pad(obj, [2 * zrange, 2 * yrange, 2 * xrange], debugging=False)

        # align with average reconstruction
        if counter == 0:  # the fist array loaded will serve as reference object
            print("This reconstruction will be used as reference.")
            ref_obj = obj

        avg_obj, flag_avg = reg.average_arrays(
            avg_obj=avg_obj,
            ref_obj=ref_obj,
            obj=obj,
            support_threshold=0.25,
            correlation_threshold=correlation_threshold,
            aligning_option="dft",
            space=prm.get("averaging_space", "reciprocal_space"),
            reciprocal_space=False,
            is_orthogonal=is_orthogonal,
            debugging=debug,
        )
        avg_counter = avg_counter + flag_avg

    avg_obj = avg_obj / avg_counter
    if avg_counter > 1:
        print("\nAverage performed over ", avg_counter, "reconstructions\n")
    del obj, ref_obj
    gc.collect()

    ################
    # unwrap phase #
    ################
    phase, extent_phase = pu.unwrap(
        avg_obj,
        support_threshold=threshold_unwrap_refraction,
        debugging=debug,
        reciprocal_space=False,
        is_orthogonal=is_orthogonal,
    )

    print(
        "Extent of the phase over an extended support (ceil(phase range)) ~ ",
        int(extent_phase),
        "(rad)",
    )
    phase = util.wrap(phase, start_angle=-extent_phase / 2, range_angle=extent_phase)
    if debug:
        gu.multislices_plot(
            phase,
            width_z=2 * zrange,
            width_y=2 * yrange,
            width_x=2 * xrange,
            plot_colorbar=True,
            title="Phase after unwrap + wrap",
            reciprocal_space=False,
            is_orthogonal=is_orthogonal,
        )

    #############################################
    # phase ramp removal before phase filtering #
    #############################################
    amp, phase, rampz, rampy, rampx = pu.remove_ramp(
        amp=abs(avg_obj),
        phase=phase,
        initial_shape=original_size,
        method="gradient",
        amplitude_threshold=isosurface_strain,
        threshold_gradient=threshold_gradient,
    )
    del avg_obj
    gc.collect()

    if debug:
        gu.multislices_plot(
            phase,
            width_z=2 * zrange,
            width_y=2 * yrange,
            width_x=2 * xrange,
            plot_colorbar=True,
            title="Phase after ramp removal",
            reciprocal_space=False,
            is_orthogonal=is_orthogonal,
        )

    ########################
    # phase offset removal #
    ########################
    support = np.zeros(amp.shape)
    support[amp > isosurface_strain * amp.max()] = 1
    phase = pu.remove_offset(
        array=phase,
        support=support,
        offset_method=offset_method,
        phase_offset=phase_offset,
        offset_origin=offset_origin,
        title="Phase",
        debugging=debug,
    )
    del support
    gc.collect()

    phase = util.wrap(
        obj=phase, start_angle=-extent_phase / 2, range_angle=extent_phase
    )

    ##############################################################################
    # average the phase over a window or apodize to reduce noise in strain plots #
    ##############################################################################
    half_width_avg_phase = prm.get("half_width_avg_phase", 0)
    if half_width_avg_phase != 0:
        bulk = pu.find_bulk(
            amp=amp, support_threshold=isosurface_strain, method="threshold"
        )
        # the phase should be averaged only in the support defined by the isosurface
        phase = pu.mean_filter(
            array=phase, support=bulk, half_width=half_width_avg_phase
        )
        del bulk
        gc.collect()

    if half_width_avg_phase != 0:
        comment = comment + "_avg" + str(2 * half_width_avg_phase + 1)

    gridz, gridy, gridx = np.meshgrid(
        np.arange(0, numz, 1),
        np.arange(0, numy, 1),
        np.arange(0, numx, 1),
        indexing="ij",
    )

    phase = (
        phase + gridz * rampz + gridy * rampy + gridx * rampx
    )  # put back the phase ramp otherwise the diffraction
    # pattern will be shifted and the prtf messed up

    if prm.get("apodize", False):
        amp, phase = pu.apodize(
            amp=amp,
            phase=phase,
            initial_shape=original_size,
            window_type=prm.get("apodization_window", "blackman"),
            sigma=prm.get("apodization_sigma", [0.30, 0.30, 0.30]),
            mu=prm.get("apodization_mu", [0.0, 0.0, 0.0]),
            alpha=prm.get("apodization_alpha", [1.0, 1.0, 1.0]),
            is_orthogonal=is_orthogonal,
            debugging=True,
        )
        comment = comment + "_apodize_" + prm.get("apodization_window", "blackman")

    ################################################################
    # save the phase with the ramp for PRTF calculations,          #
    # otherwise the object will be misaligned with the measurement #
    ################################################################
    np.savez_compressed(
        detector.savedir + "S" + str(scan) + "_avg_obj_prtf" + comment,
        obj=amp * np.exp(1j * phase),
    )

    ####################################################
    # remove again phase ramp before orthogonalization #
    ####################################################
    phase = phase - gridz * rampz - gridy * rampy - gridx * rampx

    avg_obj = amp * np.exp(1j * phase)  # here the phase is again wrapped in [-pi pi[

    del amp, phase, gridz, gridy, gridx, rampz, rampy, rampx
    gc.collect()

    ######################
    # centering of array #
    ######################
    if centering_method == "max":
        avg_obj = pu.center_max(avg_obj)
        # shift based on max value,
        # required if it spans across the edge of the array before COM
    elif centering_method == "com":
        avg_obj = pu.center_com(avg_obj)
    elif centering_method == "max_com":
        avg_obj = pu.center_max(avg_obj)
        avg_obj = pu.center_com(avg_obj)

    #######################
    #  save support & vti #
    #######################
    if prm.get("save_support", False):
        # to be used as starting support in phasing, hence still in the detector frame
        support = np.zeros((numz, numy, numx))
        support[abs(avg_obj) / abs(avg_obj).max() > 0.01] = 1
        # low threshold because support will be cropped by shrinkwrap during phasing
        np.savez_compressed(
            detector.savedir + "S" + str(scan) + "_support" + comment, obj=support
        )
        del support
        gc.collect()

    if prm.get("save_rawdata", False):
        np.savez_compressed(
            detector.savedir + "S" + str(scan) + "_raw_amp-phase" + comment,
            amp=abs(avg_obj),
            phase=np.angle(avg_obj),
        )

        # voxel sizes in the detector frame
        voxel_z, voxel_y, voxel_x = setup.voxel_sizes_detector(
            array_shape=original_size,
            tilt_angle=(
                prm.get("tilt_angle")
                * detector.preprocessing_binning[0]
                * detector.binning[0]
            ),
            pixel_x=detector.pixelsize_x,
            pixel_y=detector.pixelsize_y,
            verbose=True,
        )
        # save raw amp & phase to VTK
        # in VTK, x is downstream, y vertical, z inboard,
        # thus need to flip the last axis
        gu.save_to_vti(
            filename=os.path.join(
                detector.savedir, "S" + str(scan) + "_raw_amp-phase" + comment + ".vti"
            ),
            voxel_size=(voxel_z, voxel_y, voxel_x),
            tuple_array=(abs(avg_obj), np.angle(avg_obj)),
            tuple_fieldnames=("amp", "phase"),
            amplitude_threshold=0.01,
        )

    #########################################################
    # calculate q of the Bragg peak in the laboratory frame #
    #########################################################
    q_lab = (
        setup.q_laboratory
    )  # (1/A), in the laboratory frame z downstream, y vertical, x outboard
    qnorm = np.linalg.norm(q_lab)
    q_lab = q_lab / qnorm

    angle = simu.angle_vectors(
        ref_vector=[q_lab[2], q_lab[1], q_lab[0]],
        test_vector=axis_to_array_xyz[ref_axis_q],
    )
    print(
        f"\nNormalized diffusion vector in the laboratory frame (z*, y*, x*): "
        f"({q_lab[0]:.4f} 1/A, {q_lab[1]:.4f} 1/A, {q_lab[2]:.4f} 1/A)"
    )

    planar_dist = 2 * np.pi / qnorm  # qnorm should be in angstroms
    print(f"Wavevector transfer: {qnorm:.4f} 1/A")
    print(f"Atomic planar distance: {planar_dist:.4f} A")
    print(f"\nAngle between q_lab and {ref_axis_q} = {angle:.2f} deg")
    if debug:
        print(
            "Angle with y in zy plane = "
            f"{np.arctan(q_lab[0]/q_lab[1])*180/np.pi:.2f} deg"
        )
        print(
            "Angle with y in xy plane = "
            f"{np.arctan(-q_lab[2]/q_lab[1])*180/np.pi:.2f} deg"
        )
        print(
            "Angle with z in xz plane = "
            f"{180+np.arctan(q_lab[2]/q_lab[0])*180/np.pi:.2f} deg\n"
        )

    planar_dist = planar_dist / 10  # switch to nm

    #######################
    #  orthogonalize data #
    #######################
    print("\nShape before orthogonalization", avg_obj.shape, "\n")
    if data_frame == "detector":
        if debug:
            phase, _ = pu.unwrap(
                avg_obj,
                support_threshold=threshold_unwrap_refraction,
                debugging=True,
                reciprocal_space=False,
                is_orthogonal=False,
            )
            gu.multislices_plot(
                phase,
                width_z=2 * zrange,
                width_y=2 * yrange,
                width_x=2 * xrange,
                sum_frames=False,
                plot_colorbar=True,
                reciprocal_space=False,
                is_orthogonal=False,
                title="unwrapped phase before orthogonalization",
            )
            del phase
            gc.collect()

        if not prm.get("outofplane_angle") and not prm.get("inplane_angle"):
            print("Trying to correct detector angles using the direct beam")
            # corrected detector angles not provided
            if bragg_peak is None and detector.template_imagefile is not None:
                # Bragg peak position not provided, find it from the data
                data, _, _, _ = setup.diffractometer.load_check_dataset(
                    scan_number=scan,
                    detector=detector,
                    setup=setup,
                    frames_pattern=prm.get("frames_pattern"),
                    bin_during_loading=False,
                    flatfield=prm.get("flatfield"),
                    hotpixels=prm.get("hotpix_array"),
                    background=prm.get("background"),
                    normalize=prm.get("normalize_flux", "skip"),
                )
                bragg_peak = bu.find_bragg(
                    data=data,
                    peak_method="maxcom",
                    roi=detector.roi,
                    binning=None,
                )
                roi_center = (
                    bragg_peak[0],
                    bragg_peak[1] - detector.roi[0],  # no binning as in bu.find_bragg
                    bragg_peak[2] - detector.roi[2],  # no binning as in bu.find_bragg
                )
                bu.show_rocking_curve(
                    data,
                    roi_center=roi_center,
                    tilt_values=setup.incident_angles,
                    savedir=detector.savedir,
                )
            setup.correct_detector_angles(bragg_peak_position=bragg_peak)
            prm["outofplane_angle"] = setup.outofplane_angle
            prm["inplane_angle"] = setup.inplane_angle

        obj_ortho, voxel_size, transfer_matrix = setup.ortho_directspace(
            arrays=avg_obj,
            q_com=np.array([q_lab[2], q_lab[1], q_lab[0]]),
            initial_shape=original_size,
            voxel_size=fix_voxel,
            reference_axis=axis_to_array_xyz[ref_axis_q],
            fill_value=0,
            debugging=True,
            title="amplitude",
        )
        prm["transformation_matrix"] = transfer_matrix
    else:  # data already orthogonalized using xrayutilities
        # or the linearized transformation matrix
        obj_ortho = avg_obj
        try:
            print("Select the file containing QxQzQy")
            file_path = filedialog.askopenfilename(
                title="Select the file containing QxQzQy",
                initialdir=detector.savedir,
                filetypes=[("NPZ", "*.npz")],
            )
            npzfile = np.load(file_path)
            qx = npzfile["qx"]
            qy = npzfile["qy"]
            qz = npzfile["qz"]
        except FileNotFoundError:
            raise FileNotFoundError(
                "q values not provided, the voxel size cannot be calculated"
            )
        dy_real = (
            2 * np.pi / abs(qz.max() - qz.min()) / 10
        )  # in nm qz=y in nexus convention
        dx_real = (
            2 * np.pi / abs(qy.max() - qy.min()) / 10
        )  # in nm qy=x in nexus convention
        dz_real = (
            2 * np.pi / abs(qx.max() - qx.min()) / 10
        )  # in nm qx=z in nexus convention
        print(
            f"direct space voxel size from q values: ({dz_real:.2f} nm,"
            f" {dy_real:.2f} nm, {dx_real:.2f} nm)"
        )
        if fix_voxel:
            voxel_size = fix_voxel
            print(f"Direct space pixel size for the interpolation: {voxel_size} (nm)")
            print("Interpolating...\n")
            obj_ortho = pu.regrid(
                array=obj_ortho,
                old_voxelsize=(dz_real, dy_real, dx_real),
                new_voxelsize=voxel_size,
            )
        else:
            # no need to interpolate
            voxel_size = dz_real, dy_real, dx_real  # in nm

        if (
            data_frame == "laboratory"
        ):  # the object must be rotated into the crystal frame
            # before the strain calculation
            print("Rotating the object in the crystal frame for the strain calculation")

            amp, phase = util.rotate_crystal(
                arrays=(abs(obj_ortho), np.angle(obj_ortho)),
                is_orthogonal=True,
                reciprocal_space=False,
                voxel_size=voxel_size,
                debugging=(True, False),
                axis_to_align=q_lab[::-1],
                reference_axis=axis_to_array_xyz[ref_axis_q],
                title=("amp", "phase"),
            )

            obj_ortho = amp * np.exp(
                1j * phase
            )  # here the phase is again wrapped in [-pi pi[
            del amp, phase

    del avg_obj
    gc.collect()

    ######################################################
    # center the object (centering based on the modulus) #
    ######################################################
    print("\nCentering the crystal")
    obj_ortho = pu.center_com(obj_ortho)

    ####################
    # Phase unwrapping #
    ####################
    print("\nPhase unwrapping")
    phase, extent_phase = pu.unwrap(
        obj_ortho,
        support_threshold=threshold_unwrap_refraction,
        debugging=True,
        reciprocal_space=False,
        is_orthogonal=True,
    )
    amp = abs(obj_ortho)
    del obj_ortho
    gc.collect()

    #############################################
    # invert phase: -1*phase = displacement * q #
    #############################################
    if invert_phase:
        phase = -1 * phase

    ########################################
    # refraction and absorption correction #
    ########################################
    if correct_refraction:  # or correct_absorption:
        bulk = pu.find_bulk(
            amp=amp,
            support_threshold=threshold_unwrap_refraction,
            method=prm.get("optical_path_method", "threshold"),
            debugging=debug,
        )

        kin = setup.incident_wavevector
        kout = setup.exit_wavevector
        # kin and kout were calculated in the laboratory frame,
        # but after the geometric transformation of the crystal, this
        # latter is always in the crystal frame (for simpler strain calculation).
        # We need to transform kin and kout back
        # into the crystal frame (also, xrayutilities output is in crystal frame)
        kin = util.rotate_vector(
            vectors=[kin[2], kin[1], kin[0]],
            axis_to_align=axis_to_array_xyz[ref_axis_q],
            reference_axis=[q_lab[2], q_lab[1], q_lab[0]],
        )
        kout = util.rotate_vector(
            vectors=[kout[2], kout[1], kout[0]],
            axis_to_align=axis_to_array_xyz[ref_axis_q],
            reference_axis=[q_lab[2], q_lab[1], q_lab[0]],
        )

        # calculate the optical path of the incoming wavevector
        path_in = pu.get_opticalpath(
            support=bulk, direction="in", k=kin, debugging=debug
        )  # path_in already in nm

        # calculate the optical path of the outgoing wavevector
        path_out = pu.get_opticalpath(
            support=bulk, direction="out", k=kout, debugging=debug
        )  # path_our already in nm

        optical_path = path_in + path_out
        del path_in, path_out
        gc.collect()

        if correct_refraction:
            phase_correction = (
                2 * np.pi / (1e9 * setup.wavelength) * prm["dispersion"] * optical_path
            )
            phase = phase + phase_correction

            gu.multislices_plot(
                np.multiply(phase_correction, bulk),
                width_z=2 * zrange,
                width_y=2 * yrange,
                width_x=2 * xrange,
                sum_frames=False,
                plot_colorbar=True,
                vmin=0,
                vmax=np.nan,
                title="Refraction correction on the support",
                is_orthogonal=True,
                reciprocal_space=False,
            )
        correct_absorption = False
        if correct_absorption:
            amp_correction = np.exp(
                2 * np.pi / (1e9 * setup.wavelength) * prm["absorption"] * optical_path
            )
            amp = amp * amp_correction

            gu.multislices_plot(
                np.multiply(amp_correction, bulk),
                width_z=2 * zrange,
                width_y=2 * yrange,
                width_x=2 * xrange,
                sum_frames=False,
                plot_colorbar=True,
                vmin=1,
                vmax=1.1,
                title="Absorption correction on the support",
                is_orthogonal=True,
                reciprocal_space=False,
            )

        del bulk, optical_path
        gc.collect()

    ##############################################
    # phase ramp and offset removal (mean value) #
    ##############################################
    print("\nPhase ramp removal")
    amp, phase, _, _, _ = pu.remove_ramp(
        amp=amp,
        phase=phase,
        initial_shape=original_size,
        method=prm.get("phase_ramp_removal", "gradient"),
        amplitude_threshold=isosurface_strain,
        threshold_gradient=threshold_gradient,
        debugging=debug,
    )

    ########################
    # phase offset removal #
    ########################
    print("\nPhase offset removal")
    support = np.zeros(amp.shape)
    support[amp > isosurface_strain * amp.max()] = 1
    phase = pu.remove_offset(
        array=phase,
        support=support,
        offset_method=offset_method,
        phase_offset=phase_offset,
        offset_origin=offset_origin,
        title="Orthogonal phase",
        debugging=debug,
        reciprocal_space=False,
        is_orthogonal=True,
    )
    del support
    gc.collect()
    # Wrap the phase around 0 (no more offset)
    phase = util.wrap(
        obj=phase, start_angle=-extent_phase / 2, range_angle=extent_phase
    )

    ################################################################
    # calculate the strain depending on which axis q is aligned on #
    ################################################################
    print(f"\nCalculation of the strain along {ref_axis_q}")
    strain = pu.get_strain(
        phase=phase,
        planar_distance=planar_dist,
        voxel_size=voxel_size,
        reference_axis=ref_axis_q,
        extent_phase=extent_phase,
        method=prm.get("strain_method", "default"),
        debugging=debug,
    )

    ################################################
    # optionally rotates back the crystal into the #
    # laboratory frame (for debugging purpose)     #
    ################################################
    q_final = None
    if save_frame in {"laboratory", "lab_flat_sample"}:
        comment = comment + "_labframe"
        print("\nRotating back the crystal in laboratory frame")
        amp, phase, strain = util.rotate_crystal(
            arrays=(amp, phase, strain),
            axis_to_align=axis_to_array_xyz[ref_axis_q],
            voxel_size=voxel_size,
            is_orthogonal=True,
            reciprocal_space=False,
            reference_axis=[q_lab[2], q_lab[1], q_lab[0]],
            debugging=(True, False, False),
            title=("amp", "phase", "strain"),
        )
        # q_lab is already in the laboratory frame
        q_final = q_lab

    if save_frame == "lab_flat_sample":
        comment = comment + "_flat"
        print("\nSending sample stage circles to 0")
        (amp, phase, strain), q_final = setup.diffractometer.flatten_sample(
            arrays=(amp, phase, strain),
            voxel_size=voxel_size,
            q_com=q_lab[::-1],  # q_com needs to be in xyz order
            is_orthogonal=True,
            reciprocal_space=False,
            rocking_angle=setup.rocking_angle,
            debugging=(True, False, False),
            title=("amp", "phase", "strain"),
        )
    if save_frame == "crystal":
        # rotate also q_lab to have it along ref_axis_q,
        # as a cross-checkm, vectors needs to be in xyz order
        comment = comment + "_crystalframe"
        q_final = util.rotate_vector(
            vectors=q_lab[::-1],
            axis_to_align=axis_to_array_xyz[ref_axis_q],
            reference_axis=q_lab[::-1],
        )

    ###############################################
    # rotates the crystal e.g. for easier slicing #
    # of the result along a particular direction  #
    ###############################################
    # typically this is an inplane rotation, q should stay aligned with the axis
    # along which the strain was calculated
    if prm.get("align_axis", False):
        print("\nRotating arrays for visualization")
        amp, phase, strain = util.rotate_crystal(
            arrays=(amp, phase, strain),
            reference_axis=axis_to_array_xyz[prm["ref_axis"]],
            axis_to_align=prm["axis_to_align"],
            voxel_size=voxel_size,
            debugging=(True, False, False),
            is_orthogonal=True,
            reciprocal_space=False,
            title=("amp", "phase", "strain"),
        )
        # rotate q accordingly, vectors needs to be in xyz order
        q_final = util.rotate_vector(
            vectors=q_final[::-1],
            axis_to_align=axis_to_array_xyz[prm["ref_axis"]],
            reference_axis=prm["axis_to_align"],
        )

    q_final = q_final * qnorm
    print(
        f"\nq_final = ({q_final[0]:.4f} 1/A,"
        f" {q_final[1]:.4f} 1/A, {q_final[2]:.4f} 1/A)"
    )

    ##############################################
    # pad array to fit the output_size parameter #
    ##############################################
    if output_size is not None:
        amp = util.crop_pad(array=amp, output_shape=output_size)
        phase = util.crop_pad(array=phase, output_shape=output_size)
        strain = util.crop_pad(array=strain, output_shape=output_size)
    print(f"\nFinal data shape: {amp.shape}")

    ######################
    # save result to vtk #
    ######################
    print(
        f"\nVoxel size: ({voxel_size[0]:.2f} nm, {voxel_size[1]:.2f} nm,"
        f" {voxel_size[2]:.2f} nm)"
    )
    bulk = pu.find_bulk(
        amp=amp, support_threshold=isosurface_strain, method="threshold"
    )
    if save:
        prm["comment"] = comment
        np.savez_compressed(
            f"{detector.savedir}S{scan}_amp{phase_fieldname}strain{comment}",
            amp=amp,
            phase=phase,
            bulk=bulk,
            strain=strain,
            q_com=q_final,
            voxel_sizes=voxel_size,
            detector=detector.params,
            setup=setup.params,
            params=prm,
        )

        # save results in hdf5 file
        with h5py.File(
            f"{detector.savedir}S{scan}_amp{phase_fieldname}strain{comment}.h5", "w"
        ) as hf:
            out = hf.create_group("output")
            par = hf.create_group("params")
            out.create_dataset("amp", data=amp)
            out.create_dataset("bulk", data=bulk)
            out.create_dataset("phase", data=phase)
            out.create_dataset("strain", data=strain)
            out.create_dataset("q_com", data=q_final)
            out.create_dataset("voxel_sizes", data=voxel_size)
            par.create_dataset("detector", data=str(detector.params))
            par.create_dataset("setup", data=str(setup.params))
            par.create_dataset("parameters", data=str(prm))

        # save amp & phase to VTK
        # in VTK, x is downstream, y vertical, z inboard,
        # thus need to flip the last axis
        gu.save_to_vti(
            filename=os.path.join(
                detector.savedir,
                "S"
                + str(scan)
                + "_amp-"
                + phase_fieldname
                + "-strain"
                + comment
                + ".vti",
            ),
            voxel_size=voxel_size,
            tuple_array=(amp, bulk, phase, strain),
            tuple_fieldnames=("amp", "bulk", phase_fieldname, "strain"),
            amplitude_threshold=0.01,
        )

    ######################################
    # estimate the volume of the crystal #
    ######################################
    amp = amp / amp.max()
    temp_amp = np.copy(amp)
    temp_amp[amp < isosurface_strain] = 0
    temp_amp[np.nonzero(temp_amp)] = 1
    volume = temp_amp.sum() * reduce(lambda x, y: x * y, voxel_size)  # in nm3
    del temp_amp
    gc.collect()

    ##############################
    # plot slices of the results #
    ##############################
    pixel_spacing = [tick_spacing / vox for vox in voxel_size]
    print(
        "\nPhase extent without / with thresholding the modulus "
        f"(threshold={isosurface_strain}): {phase.max()-phase.min():.2f} rad, "
        f"{phase[np.nonzero(bulk)].max()-phase[np.nonzero(bulk)].min():.2f} rad"
    )
    piz, piy, pix = np.unravel_index(phase.argmax(), phase.shape)
    print(
        f"phase.max() = {phase[np.nonzero(bulk)].max():.2f} "
        f"at voxel ({piz}, {piy}, {pix})"
    )
    strain[bulk == 0] = np.nan
    phase[bulk == 0] = np.nan

    # plot the slice at the maximum phase
    gu.combined_plots(
        (phase[piz, :, :], phase[:, piy, :], phase[:, :, pix]),
        tuple_sum_frames=False,
        tuple_sum_axis=0,
        tuple_width_v=None,
        tuple_width_h=None,
        tuple_colorbar=True,
        tuple_vmin=np.nan,
        tuple_vmax=np.nan,
        tuple_title=("phase at max in xy", "phase at max in xz", "phase at max in yz"),
        tuple_scale="linear",
        cmap=my_cmap,
        is_orthogonal=True,
        reciprocal_space=False,
    )

    # bulk support
    fig, _, _ = gu.multislices_plot(
        bulk,
        sum_frames=False,
        title="Orthogonal bulk",
        vmin=0,
        vmax=1,
        is_orthogonal=True,
        reciprocal_space=False,
    )
    fig.text(0.60, 0.45, "Scan " + str(scan), size=20)
    fig.text(
        0.60,
        0.40,
        "Bulk - isosurface=" + str("{:.2f}".format(isosurface_strain)),
        size=20,
    )
    plt.pause(0.1)
    if save:
        plt.savefig(detector.savedir + "S" + str(scan) + "_bulk" + comment + ".png")

    # amplitude
    fig, _, _ = gu.multislices_plot(
        amp,
        sum_frames=False,
        title="Normalized orthogonal amp",
        vmin=0,
        vmax=1,
        tick_direction=tick_direction,
        tick_width=tick_width,
        tick_length=tick_length,
        pixel_spacing=pixel_spacing,
        plot_colorbar=True,
        is_orthogonal=True,
        reciprocal_space=False,
    )
    fig.text(0.60, 0.45, f"Scan {scan}", size=20)
    fig.text(
        0.60,
        0.40,
        f"Voxel size=({voxel_size[0]:.1f}, {voxel_size[1]:.1f}, "
        f"{voxel_size[2]:.1f}) (nm)",
        size=20,
    )
    fig.text(0.60, 0.35, f"Ticks spacing={tick_spacing} nm", size=20)
    fig.text(0.60, 0.30, f"Volume={int(volume)} nm3", size=20)
    fig.text(0.60, 0.25, "Sorted by " + sort_method, size=20)
    fig.text(0.60, 0.20, f"correlation threshold={correlation_threshold}", size=20)
    fig.text(0.60, 0.15, f"average over {avg_counter} reconstruction(s)", size=20)
    fig.text(0.60, 0.10, f"Planar distance={planar_dist:.5f} nm", size=20)
    if prm.get("get_temperature", False):
        temperature = pu.bragg_temperature(
            spacing=planar_dist * 10,
            reflection=prm["reflection"],
            spacing_ref=prm.get("reference_spacing"),
            temperature_ref=prm.get("reference_temperature"),
            use_q=False,
            material="Pt",
        )
        fig.text(0.60, 0.05, f"Estimated T={temperature} C", size=20)
    if save:
        plt.savefig(detector.savedir + f"S{scan}_amp" + comment + ".png")

    # amplitude histogram
    fig, ax = plt.subplots(1, 1)
    ax.hist(amp[amp > 0.05 * amp.max()].flatten(), bins=250)
    ax.set_ylim(bottom=1)
    ax.tick_params(
        labelbottom=True,
        labelleft=True,
        direction="out",
        length=tick_length,
        width=tick_width,
    )
    ax.spines["right"].set_linewidth(1.5)
    ax.spines["left"].set_linewidth(1.5)
    ax.spines["top"].set_linewidth(1.5)
    ax.spines["bottom"].set_linewidth(1.5)
    fig.savefig(detector.savedir + f"S{scan}_histo_amp" + comment + ".png")

    # phase
    fig, _, _ = gu.multislices_plot(
        phase,
        sum_frames=False,
        title="Orthogonal displacement",
        vmin=-prm.get("phase_range", np.pi / 2),
        vmax=prm.get("phase_range", np.pi / 2),
        tick_direction=tick_direction,
        cmap=my_cmap,
        tick_width=tick_width,
        tick_length=tick_length,
        pixel_spacing=pixel_spacing,
        plot_colorbar=True,
        is_orthogonal=True,
        reciprocal_space=False,
    )
    fig.text(0.60, 0.30, f"Scan {scan}", size=20)
    fig.text(
        0.60,
        0.25,
        f"Voxel size=({voxel_size[0]:.1f}, {voxel_size[1]:.1f}, "
        f"{voxel_size[2]:.1f}) (nm)",
        size=20,
    )
    fig.text(0.60, 0.20, f"Ticks spacing={tick_spacing} nm", size=20)
    fig.text(0.60, 0.15, f"average over {avg_counter} reconstruction(s)", size=20)
    if half_width_avg_phase > 0:
        fig.text(
            0.60, 0.10, f"Averaging over {2*half_width_avg_phase+1} pixels", size=20
        )
    else:
        fig.text(0.60, 0.10, "No phase averaging", size=20)
    if save:
        plt.savefig(detector.savedir + f"S{scan}_displacement" + comment + ".png")

    # strain
    fig, _, _ = gu.multislices_plot(
        strain,
        sum_frames=False,
        title="Orthogonal strain",
        vmin=-prm.get("strain_range", 0.002),
        vmax=prm.get("strain_range", 0.002),
        tick_direction=tick_direction,
        tick_width=tick_width,
        tick_length=tick_length,
        plot_colorbar=True,
        cmap=my_cmap,
        pixel_spacing=pixel_spacing,
        is_orthogonal=True,
        reciprocal_space=False,
    )
    fig.text(0.60, 0.30, f"Scan {scan}", size=20)
    fig.text(
        0.60,
        0.25,
        f"Voxel size=({voxel_size[0]:.1f}, "
        f"{voxel_size[1]:.1f}, {voxel_size[2]:.1f}) (nm)",
        size=20,
    )
    fig.text(0.60, 0.20, f"Ticks spacing={tick_spacing} nm", size=20)
    fig.text(0.60, 0.15, f"average over {avg_counter} reconstruction(s)", size=20)
    if half_width_avg_phase > 0:
        fig.text(
            0.60, 0.10, f"Averaging over {2*half_width_avg_phase+1} pixels", size=20
        )
    else:
        fig.text(0.60, 0.10, "No phase averaging", size=20)
    if save:
        plt.savefig(detector.savedir + f"S{scan}_strain" + comment + ".png")
Ejemplo n.º 7
0
def main(parameters):
    """
    Protection for multiprocessing.

    :param parameters: dictionnary containing input parameters
    """

    def collect_result(result):
        """
        Callback processing the result after asynchronous multiprocessing. Update the global arrays.

        :param result: the output of load_p10_file, containing the 2d data, 2d mask, counter for each frame, and the
         file index
        """
        nonlocal sumdata, mask, counter, nb_files, current_point
        # result is a tuple: data, mask, counter, file_index
        current_point += 1
        sumdata = sumdata + result[0]
        mask[np.nonzero(result[1])] = 1
        counter.append(result[2])

        sys.stdout.write('\rFile {:d} / {:d}'.format(current_point, nb_files))
        sys.stdout.flush()

    ######################################
    # load the dictionnary of parameters #
    ######################################
    scan = parameters['scan']
    samplename = parameters['sample_name']
    rootdir = parameters['rootdir']
    image_nb = parameters['file_list']
    counterroi = parameters['counter_roi']
    savedir = parameters['savedir']
    load_scan = parameters['is_scan']
    compare_end = parameters['compare_ends']
    savemask = parameters['save_mask']
    multiproc = parameters['multiprocessing']
    threshold = parameters['threshold']
    cb_min = parameters['cb_min']
    cb_max = parameters['cb_max']
    grey_bckg = parameters['grey_bckg']
    ###################
    # define colormap #
    ###################
    if grey_bckg:
        bad_color = '0.7'
    else:
        bad_color = '1.0'  # white background
    colormap = gu.Colormap(bad_color=bad_color)
    my_cmap = colormap.cmap

    #######################
    # Initialize detector #
    #######################
    detector = exp.Detector(name=parameters['detector'])
    nb_pixel_y, nb_pixel_x = detector.nb_pixel_y, detector.nb_pixel_x
    sumdata = np.zeros((nb_pixel_y, nb_pixel_x))
    mask = np.zeros((nb_pixel_y, nb_pixel_x))
    counter = []

    ####################
    # Initialize paths #
    ####################
    if type(image_nb) == int:
        image_nb = [image_nb]
    if len(counterroi) == 0:
        counterroi = [0, nb_pixel_y, 0, nb_pixel_x]

    assert (counterroi[0] >= 0
            and counterroi[1] <= nb_pixel_y
            and counterroi[2] >= 0
            and counterroi[3] <= nb_pixel_x), 'counter_roi setting does not match the detector size'
    nb_files = len(image_nb)
    if nb_files == 1:
        multiproc = False

    if load_scan:  # scan or time series
        detector.datadir = rootdir + samplename + '_' + str('{:05d}'.format(scan)) + '/e4m/'
        template_file = detector.datadir + samplename + '_' + str('{:05d}'.format(scan)) + "_data_"
    else:  # single image
        detector.datadir = rootdir + samplename + '/e4m/'
        template_file = detector.datadir + samplename + '_take_' + str('{:05d}'.format(scan)) + "_data_"
        compare_end = False

    detector.savedir = savedir or os.path.abspath(os.path.join(detector.datadir, os.pardir)) + '/'
    print(f'datadir: {detector.datadir}')
    print(f'savedir: {detector.savedir}')

    #############
    # Load data #
    #############
    plt.ion()
    filenames = [template_file + '{:06d}.h5'.format(image_nb[idx]) for idx in range(nb_files)]
    roi_counter = None
    current_point = 0
    start = time.time()

    if multiproc:
        print("\nNumber of processors used: ", min(mp.cpu_count(), len(filenames)))
        mp.freeze_support()
        pool = mp.Pool(processes=min(mp.cpu_count(), len(filenames)))  # use this number of processes

        for file in range(nb_files):
            pool.apply_async(load_p10_file, args=(detector, filenames[file], file, counterroi, threshold),
                             callback=collect_result, error_callback=util.catch_error)

        pool.close()
        pool.join()  # postpones the execution of next line of code until all processes in the queue are done.

        # sort out counter values (we are using asynchronous multiprocessing, order is not preserved)
        roi_counter = sorted(counter, key=lambda x: x[1])

    else:
        for idx in range(nb_files):
            sys.stdout.write('\rLoading file {:d}'.format(idx + 1) + ' / {:d}'.format(nb_files))
            sys.stdout.flush()
            h5file = h5py.File(filenames[idx], 'r')
            data = h5file['entry']['data']['data'][:]
            data[data <= threshold] = 0
            nbz, nby, nbx = data.shape
            [counter.append(data[index, counterroi[0]:counterroi[1], counterroi[2]:counterroi[3]].sum())
                for index in range(nbz)]
            if compare_end and nb_files == 1:
                data_start, _ = detector.mask_detector(data=data[0, :, :], mask=mask)
                data_start = data_start.astype(float)
                data_stop, _ = detector.mask_detector(data=data[-1, :, :], mask=mask)
                data_stop = data_stop.astype(float)

                fig, _, _ = gu.imshow_plot(data_stop - data_start, plot_colorbar=True, scale='log',
                                           title='difference between the last frame and the first frame of the series')
            nb_frames = data.shape[0]  # collect the number of frames in the eventual series
            data, mask = detector.mask_detector(data=data.sum(axis=0), mask=mask, nb_img=nb_frames)
            sumdata = sumdata + data
            roi_counter = [[counter, idx]]

    end = time.time()
    print('\nTime ellapsed for loading data:', str(datetime.timedelta(seconds=int(end - start))))

    frame_per_series = int(len(counter) / nb_files)

    print('')
    if load_scan:
        if nb_files > 1:
            plot_title = 'masked data - sum of ' + str(nb_files)\
                         + ' points with {:d} frames each'.format(frame_per_series)
        else:
            plot_title = 'masked data - sum of ' + str(frame_per_series) + ' frames'
        filename = 'S' + str(scan) + '_scan.png'
    else:  # single image
        plot_title = 'masked data'
        filename = 'S' + str(scan) + '_image_' + str(image_nb[0]) + '.png'

    if savemask:
        fig, _, _ = gu.imshow_plot(mask, plot_colorbar=False, title='mask')
        np.savez_compressed(detector.savedir+'hotpixels.npz', mask=mask)
        fig.savefig(detector.savedir + 'mask.png')

    y0, x0 = np.unravel_index(abs(sumdata).argmax(), sumdata.shape)
    print("Max at (y, x): ", y0, x0, ' Max = ', int(sumdata[y0, x0]))

    np.savez_compressed(detector.savedir + f'{sample_name}_{scan_nb:05d}_sumdata.npz', data=sumdata)
    if save_to_mat:
        savemat(detector.savedir + f'{sample_name}_{scan_nb:05d}_sumdata.mat', {'data': sumdata})

    if len(roi_counter[0][0]) > 1:  # roi_counter[0][0] is the list of counter intensities in a series
        int_roi = []
        [int_roi.append(val[0][idx]) for val in roi_counter for idx in range(frame_per_series)]
        plt.figure()
        plt.plot(np.asarray(int_roi))
        plt.title('Integrated intensity in counter_roi')
        plt.pause(0.1)

    cb_min = cb_min or sumdata.min()
    cb_max = cb_max or sumdata.max()

    fig, _, _ = gu.imshow_plot(sumdata, plot_colorbar=True, title=plot_title, vmin=cb_min, vmax=cb_max, scale='log',
                               cmap=my_cmap)
    np.savez_compressed(detector.savedir + 'hotpixels.npz', mask=mask)
    fig.savefig(detector.savedir + filename)
    plt.show()
Ejemplo n.º 8
0
def main(user_comment):
    """
    Protection for multiprocessing.

    :param user_comment: comment to include in the filename when saving results
    """
    ##########################
    # check input parameters #
    ##########################
    global corr_count

    assert len(
        q_xcca
    ) == 2, "Two q values should be provided (it can be the same value)"
    assert len(
        origin_qspace
    ) == 3, "origin_qspace should be a tuple of 3 integer pixel values"
    q_xcca.sort()
    if q_xcca[0] == q_xcca[1]:
        same_q = True
    else:
        same_q = False
    warnings.filterwarnings("ignore")

    ###################
    # define colormap #
    ###################
    bad_color = '1.0'  # white background
    colormap = gu.Colormap(bad_color=bad_color)
    my_cmap = colormap.cmap
    plt.ion()

    ###################################
    # load experimental data and mask #
    ###################################
    plt.ion()
    root = tk.Tk()
    root.withdraw()
    file_path = filedialog.askopenfilename(
        initialdir=datadir,
        title="Select the 3D reciprocal space map",
        filetypes=[("NPZ", "*.npz")])
    data = np.load(file_path)['data']

    file_path = filedialog.askopenfilename(initialdir=datadir,
                                           title="Select the 3D mask",
                                           filetypes=[("NPZ", "*.npz")])
    mask = np.load(file_path)['mask']

    print((data > hotpix_threshold).sum(), ' hotpixels masked')
    mask[data > hotpix_threshold] = 1
    data[np.nonzero(mask)] = np.nan
    del mask
    gc.collect()

    file_path = filedialog.askopenfilename(initialdir=datadir,
                                           title="Select q values",
                                           filetypes=[("NPZ", "*.npz")])
    qvalues = np.load(file_path)
    qx = qvalues['qx']
    qz = qvalues['qz']
    qy = qvalues['qy']

    del qvalues
    gc.collect()

    ##############################################################
    # calculate the angular average using mean and median values #
    ##############################################################
    if plot_meandata:
        q_axis, y_mean_masked, y_median_masked = xcca.angular_avg(
            data=data,
            q_values=(qx, qz, qy),
            origin=origin_qspace,
            nb_bins=250,
            debugging=debug)
        fig, ax = plt.subplots(1, 1)
        ax.plot(q_axis, np.log10(y_mean_masked), 'r', label='mean')
        ax.plot(q_axis, np.log10(y_median_masked), 'b', label='median')
        ax.axvline(x=q_xcca[0],
                   ymin=0,
                   ymax=1,
                   color='g',
                   linestyle='--',
                   label='q1')
        ax.axvline(x=q_xcca[1],
                   ymin=0,
                   ymax=1,
                   color='r',
                   linestyle=':',
                   label='q2')
        ax.set_xlabel('q (1/nm)')
        ax.set_ylabel('Angular average (A.U.)')
        ax.legend()
        plt.pause(0.1)
        fig.savefig(savedir + '1D_average.png')

        del q_axis, y_median_masked, y_mean_masked

    ##############################################################
    # interpolate the data onto spheres at user-defined q values #
    ##############################################################
    # calculate the matrix of distances from the origin of reciprocal space
    distances = np.sqrt(
        (qx[:, np.newaxis, np.newaxis] - qx[origin_qspace[0]])**2 +
        (qz[np.newaxis, :, np.newaxis] - qz[origin_qspace[1]])**2 +
        (qy[np.newaxis, np.newaxis, :] - qy[origin_qspace[2]])**2)
    dq = min(qx[1] - qx[0], qz[1] - qz[0], qy[1] - qy[0])

    theta_phi_int = dict()  # create dictionnary
    dict_fields = ['q1', 'q2']
    nb_points = []

    for counter, q_value in enumerate(q_xcca):
        if (counter == 0) or ((counter == 1) and not same_q):
            nb_pixels = (np.logical_and((distances < q_value + dq),
                                        (distances > q_value - dq))).sum()

            print(
                '\nNumber of voxels for the sphere of radius q ={:.3f} 1/nm:'.
                format(q_value), nb_pixels)

            nb_pixels = int(nb_pixels / interp_factor)
            print(
                'Dividing the number of voxels by interp_factor: {:d} voxels remaining'
                .format(nb_pixels))

            indices = np.arange(0, nb_pixels, dtype=float) + 0.5

            # angles for interpolation are chosen using the 'golden spiral method', so that the corresponding points
            # are evenly distributed on the sphere
            theta = np.arccos(
                1 - 2 * indices / nb_pixels
            )  # theta is the polar angle of the spherical coordinates
            phi = np.pi * (
                1 + np.sqrt(5)
            ) * indices  # phi is the azimuthal angle of the spherical coordinates

            qx_sphere = q_value * np.cos(phi) * np.sin(theta)
            qz_sphere = q_value * np.cos(theta)
            qy_sphere = q_value * np.sin(phi) * np.sin(theta)

            # interpolate the data onto the new points
            rgi = RegularGridInterpolator((qx, qz, qy),
                                          data,
                                          method='linear',
                                          bounds_error=False,
                                          fill_value=np.nan)
            sphere_int = rgi(
                np.concatenate((qx_sphere.reshape(
                    (1, nb_pixels)), qz_sphere.reshape(
                        (1, nb_pixels)), qy_sphere.reshape(
                            (1, nb_pixels)))).transpose())

            # look for nan values
            nan_indices = np.argwhere(np.isnan(sphere_int))
            if debug:
                sphere_debug = np.copy(
                    sphere_int
                )  # create a copy to see also nans in the debugging plot

            #  remove nan values before calculating the cross-correlation function
            theta = np.delete(theta, nan_indices)
            phi = np.delete(phi, nan_indices)
            sphere_int = np.delete(sphere_int, nan_indices)

            # normalize the intensity by the median value (remove the influence of the form factor)
            print('q={:.3f}:'.format(q_value),
                  ' normalizing by the median value', np.median(sphere_int))
            sphere_int = sphere_int / np.median(sphere_int)

            theta_phi_int[dict_fields[counter]] = np.concatenate(
                (theta[:, np.newaxis], phi[:, np.newaxis],
                 sphere_int[:, np.newaxis]),
                axis=1)
            # update the number of points without nan
            nb_points.append(len(theta))
            print('q={:.3f}:'.format(q_value), ' removing', nan_indices.size,
                  'nan values,', nb_points[counter], 'remain')

            if debug:
                # calculate the stereographic projection
                stereo_proj, uv_labels = fu.calc_stereoproj_facet(
                    projection_axis=1,
                    radius_mean=q_value,
                    stereo_center=0,
                    vectors=np.concatenate(
                        (qx_sphere[:, np.newaxis], qz_sphere[:, np.newaxis],
                         qy_sphere[:, np.newaxis]),
                        axis=1))
                # plot the projection from the South pole
                fig, _ = gu.scatter_stereographic(
                    euclidian_u=stereo_proj[:, 0],
                    euclidian_v=stereo_proj[:, 1],
                    color=sphere_debug,
                    title='Projection from the South pole'
                    ' at q={:.3f} (1/nm)'.format(q_value),
                    uv_labels=uv_labels,
                    cmap=my_cmap)
                fig.savefig(savedir +
                            'South pole_q={:.3f}.png'.format(q_value))
                plt.close(fig)

                # plot the projection from the North pole
                fig, _ = gu.scatter_stereographic(
                    euclidian_u=stereo_proj[:, 2],
                    euclidian_v=stereo_proj[:, 3],
                    color=sphere_debug,
                    title='Projection from the North pole'
                    ' at q={:.3f} (1/nm)'.format(q_value),
                    uv_labels=uv_labels,
                    cmap=my_cmap)
                fig.savefig(savedir +
                            'North pole_q={:.3f}.png'.format(q_value))
                plt.close(fig)
                del sphere_debug

            del qx_sphere, qz_sphere, qy_sphere, theta, phi, sphere_int, indices, nan_indices
            gc.collect()
    del qx, qy, qz, distances, data
    gc.collect()

    ############################################
    # calculate the cross-correlation function #
    ############################################
    if same_q:
        key_q2 = 'q1'
        print('\nThe CCF will be calculated over {:d} * {:d}'
              ' points and {:d} angular bins'.format(nb_points[0],
                                                     nb_points[0],
                                                     corr_count.shape[0]))
    else:
        key_q2 = 'q2'
        print('\nThe CCF will be calculated over {:d} * {:d}'
              ' points and {:d} angular bins'.format(nb_points[0],
                                                     nb_points[1],
                                                     corr_count.shape[0]))

    angular_bins = np.linspace(start=0,
                               stop=np.pi,
                               num=corr_count.shape[0],
                               endpoint=False)

    start = time.time()
    if single_proc:
        for idx in range(nb_points[0]):
            ccf_uniq_val, counter_val, counter_indices = \
                 xcca.calc_ccf_polar(point=idx, q1_name='q1', q2_name=key_q2, bin_values=angular_bins,
                                     polar_azi_int=theta_phi_int)
            collect_result_debug(ccf_uniq_val, counter_val, counter_indices)
    else:
        print("\nNumber of processors: ", mp.cpu_count())
        mp.freeze_support()
        pool = mp.Pool(mp.cpu_count())  # use this number of processes
        for idx in range(nb_points[0]):
            pool.apply_async(xcca.calc_ccf_polar,
                             args=(idx, 'q1', key_q2, angular_bins,
                                   theta_phi_int),
                             callback=collect_result,
                             error_callback=util.catch_error)
        # close the pool and let all the processes complete
        pool.close()
        pool.join(
        )  # postpones the execution of next line of code until all processes in the queue are done.
    end = time.time()
    print('\nTime ellapsed for the calculation of the CCF:',
          str(datetime.timedelta(seconds=int(end - start))))

    # normalize the cross-correlation by the counter
    indices = np.nonzero(corr_count[:, 1])
    corr_count[indices, 0] = corr_count[indices, 0] / corr_count[indices, 1]

    #######################################
    # save the cross-correlation function #
    #######################################
    filename = 'CCF_q1={:.3f}_q2={:.3f}'.format(q_xcca[0], q_xcca[1]) +\
               '_points{:d}_interp{:d}_res{:.3f}'.format(nb_points[0], interp_factor, angular_resolution) + user_comment
    np.savez_compressed(savedir + filename + '.npz',
                        angles=180 * angular_bins / np.pi,
                        ccf=corr_count[:, 0],
                        points=corr_count[:, 1])

    #######################################
    # plot the cross-correlation function #
    #######################################
    # find the y limit excluding the peaks at 0 and 180 degrees
    indices = np.argwhere(
        np.logical_and((angular_bins >= 5 * np.pi / 180),
                       (angular_bins <= 175 * np.pi / 180)))
    ymax = 1.2 * corr_count[indices, 0].max()
    print('Discarding CCF values with a zero counter:',
          (corr_count[:, 1] == 0).sum(), 'points masked')
    corr_count[(corr_count[:, 1] == 0),
               0] = np.nan  # discard these values of the CCF

    fig, ax = plt.subplots()
    ax.plot(180 * angular_bins / np.pi,
            corr_count[:, 0],
            color='red',
            linestyle="-",
            markerfacecolor='blue',
            marker='.')
    ax.set_xlim(0, 180)
    ax.set_ylim(0, ymax)
    ax.set_xlabel('Angle (deg)')
    ax.set_ylabel('Cross-correlation')
    ax.set_xticks(np.arange(0, 181, 30))
    ax.set_title('CCF at q1={:.3f} 1/nm  and q2={:.3f} 1/nm'.format(
        q_xcca[0], q_xcca[1]))
    fig.savefig(savedir + filename + '.png')

    _, ax = plt.subplots()
    ax.plot(180 * angular_bins / np.pi,
            corr_count[:, 1],
            linestyle="None",
            markerfacecolor='blue',
            marker='.')
    ax.set_xlim(0, 180)
    ax.set_xlabel('Angle (deg)')
    ax.set_ylabel('Number of points')
    ax.set_xticks(np.arange(0, 181, 30))
    ax.set_title('Points per angular bin')
    plt.ioff()
    plt.show()
Ejemplo n.º 9
0
def main(parameters):
    """
    Protection for multiprocessing.

    :param parameters: dictionnary containing input parameters
    """
    def collect_result(result):
        """
        Callback processing the result after asynchronous multiprocessing.

        Update the global arrays.

        :param result: the output of load_p10_file, containing the 2d data, 2d mask,
         counter for each frame, and the file index
        """
        nonlocal sumdata, mask, counter, nb_files, current_point
        # result is a tuple: data, mask, counter, file_index
        current_point += 1
        sumdata = sumdata + result[0]
        mask[np.nonzero(result[1])] = 1
        counter.append(result[2])

        sys.stdout.write("\rFile {:d} / {:d}".format(current_point, nb_files))
        sys.stdout.flush()

    ######################################
    # load the dictionnary of parameters #
    ######################################
    scan = parameters["scan"]
    samplename = parameters["sample_name"]
    rootdir = parameters["rootdir"]
    image_nb = parameters["file_list"]
    counterroi = parameters["counter_roi"]
    savedir = parameters["savedir"]
    load_scan = parameters["is_scan"]
    compare_end = parameters["compare_ends"]
    savemask = parameters["save_mask"]
    multiproc = parameters["multiprocessing"]
    threshold = parameters["threshold"]
    cb_min = parameters["cb_min"]
    cb_max = parameters["cb_max"]
    grey_bckg = parameters["grey_bckg"]
    ###################
    # define colormap #
    ###################
    if grey_bckg:
        bad_color = "0.7"
    else:
        bad_color = "1.0"  # white background
    colormap = gu.Colormap(bad_color=bad_color)
    my_cmap = colormap.cmap

    #######################
    # Initialize detector #
    #######################
    detector = create_detector(name=parameters["detector"])
    nb_pixel_y, nb_pixel_x = detector.nb_pixel_y, detector.nb_pixel_x
    sumdata = np.zeros((nb_pixel_y, nb_pixel_x))
    mask = np.zeros((nb_pixel_y, nb_pixel_x))
    counter = []

    ####################
    # Initialize paths #
    ####################
    if isinstance(image_nb, int):
        image_nb = [image_nb]
    if len(counterroi) == 0:
        counterroi = [0, nb_pixel_y, 0, nb_pixel_x]

    if not (counterroi[0] >= 0 and counterroi[1] <= nb_pixel_y
            and counterroi[2] >= 0 and counterroi[3] <= nb_pixel_x):
        raise ValueError(
            "counter_roi setting does not match the detector size")

    nb_files = len(image_nb)
    if nb_files == 1:
        multiproc = False

    if load_scan:  # scan or time series
        detector.datadir = (rootdir + samplename + "_" +
                            str("{:05d}".format(scan)) + "/e4m/")
        template_file = (detector.datadir + samplename + "_" +
                         str("{:05d}".format(scan)) + "_data_")
    else:  # single image
        detector.datadir = rootdir + samplename + "/e4m/"
        template_file = (detector.datadir + samplename + "_take_" +
                         str("{:05d}".format(scan)) + "_data_")
        compare_end = False

    detector.savedir = (
        savedir
        or os.path.abspath(os.path.join(detector.datadir, os.pardir)) + "/")
    print(f"datadir: {detector.datadir}")
    print(f"savedir: {detector.savedir}")

    #############
    # Load data #
    #############
    plt.ion()
    filenames = [
        template_file + "{:06d}.h5".format(image_nb[idx])
        for idx in range(nb_files)
    ]
    roi_counter = None
    current_point = 0
    start = time.time()

    if multiproc:
        print("\nNumber of processors used: ",
              min(mp.cpu_count(), len(filenames)))
        mp.freeze_support()
        pool = mp.Pool(processes=min(
            mp.cpu_count(), len(filenames)))  # use this number of processes

        for file in range(nb_files):
            pool.apply_async(
                load_p10_file,
                args=(detector, filenames[file], file, counterroi, threshold),
                callback=collect_result,
                error_callback=util.catch_error,
            )

        pool.close()
        pool.join()  # postpones the execution of next line of code
        # until all processes in the queue are done.

        # sort out counter values
        # (we are using asynchronous multiprocessing, order is not preserved)
        roi_counter = sorted(counter, key=lambda x: x[1])

    else:
        for idx in range(nb_files):
            sys.stdout.write("\rLoading file {:d}".format(idx + 1) +
                             " / {:d}".format(nb_files))
            sys.stdout.flush()
            h5file = h5py.File(filenames[idx], "r")
            data = h5file["entry"]["data"]["data"][:]
            data[data <= threshold] = 0
            nbz, nby, nbx = data.shape
            for index in range(nbz):
                counter.append(data[index, counterroi[0]:counterroi[1],
                                    counterroi[2]:counterroi[3], ].sum())

            if compare_end and nb_files == 1:
                data_start, _ = detector.mask_detector(data=data[0, :, :],
                                                       mask=mask)
                data_start = data_start.astype(float)
                data_stop, _ = detector.mask_detector(data=data[-1, :, :],
                                                      mask=mask)
                data_stop = data_stop.astype(float)

                fig, _, _ = gu.imshow_plot(
                    data_stop - data_start,
                    plot_colorbar=True,
                    scale="log",
                    title="""difference between the last frame and
                    the first frame of the series""",
                )
            nb_frames = data.shape[
                0]  # collect the number of frames in the eventual series
            data, mask = detector.mask_detector(data=data.sum(axis=0),
                                                mask=mask,
                                                nb_frames=nb_frames)
            sumdata = sumdata + data
            roi_counter = [[counter, idx]]

    end = time.time()
    print(
        "\nTime ellapsed for loading data:",
        str(datetime.timedelta(seconds=int(end - start))),
    )

    frame_per_series = int(len(counter) / nb_files)

    print("")
    if load_scan:
        if nb_files > 1:
            plot_title = (
                "masked data - sum of " + str(nb_files) +
                " points with {:d} frames each".format(frame_per_series))
        else:
            plot_title = "masked data - sum of " + str(
                frame_per_series) + " frames"
        filename = "S" + str(scan) + "_scan.png"
    else:  # single image
        plot_title = "masked data"
        filename = "S" + str(scan) + "_image_" + str(image_nb[0]) + ".png"

    if savemask:
        fig, _, _ = gu.imshow_plot(mask, plot_colorbar=False, title="mask")
        np.savez_compressed(detector.savedir + "hotpixels.npz", mask=mask)
        fig.savefig(detector.savedir + "mask.png")

    y0, x0 = np.unravel_index(abs(sumdata).argmax(), sumdata.shape)
    print("Max at (y, x): ", y0, x0, " Max = ", int(sumdata[y0, x0]))

    np.savez_compressed(detector.savedir +
                        f"{sample_name}_{scan_nb:05d}_sumdata.npz",
                        data=sumdata)
    if save_to_mat:
        savemat(
            detector.savedir + f"{sample_name}_{scan_nb:05d}_sumdata.mat",
            {"data": sumdata},
        )

    if len(roi_counter[0][0]) > 1:
        # roi_counter[0][0] is the list of counter intensities in a series
        int_roi = [
            val[0][idx] for val in roi_counter
            for idx in range(frame_per_series)
        ]
        plt.figure()
        plt.plot(np.asarray(int_roi))
        plt.title("Integrated intensity in counter_roi")
        plt.pause(0.1)

    cb_min = cb_min or sumdata.min()
    cb_max = cb_max or sumdata.max()

    fig, _, _ = gu.imshow_plot(
        sumdata,
        plot_colorbar=True,
        title=plot_title,
        vmin=cb_min,
        vmax=cb_max,
        scale="log",
        cmap=my_cmap,
    )
    np.savez_compressed(detector.savedir + "hotpixels.npz", mask=mask)
    fig.savefig(detector.savedir + filename)
    plt.show()