Beispiel #1
0
def galvo_images(data_path, target_path=None):
    """
    save load data from galvo scans and save images to target directory
    Args:
        data_path: path to image data
        target_path: target path to save images

    Returns:

    """
    if target_path == None:
        target_path = DATA_PATH

    data = Script.load_data(DATA_PATH)
    number_of_images = len(
        [k for k in list(data.keys()) if len(k.split('image')) > 1])

    for c in range(number_of_images):
        k = 'image_{:d}'.format(c)
        fig = plt.figure()
        ax = plt.subplot(111)
        plot_fluorescence(data[k], extent=[0.02, 0.17, 0.05, -0.10], axes=ax)
        fig.savefig('{:s}/{:s}.png'.format(TARGET_PATH, k))
        fig.close()
def visualize_magnetic_fields(src_folder, target_folder, manual=True):
    filepath = os.path.join(target_folder, os.path.dirname(src_folder).split('./')[1])

    # load the fit_data
    if manual:
        #         df = pd.read_csv(os.path.join(target_folder, '{:s}-manual.csv'.format(folder.split('./')[1].replace('/','-'))), index_col = 'id', skipinitialspace=True)
        filename = '{:s}\\data-manual.csv'.format(os.path.basename(src_folder))
    else:
        filename = '{:s}.csv'.format(os.path.basename(src_folder))
    # df = pd.read_csv(os.path.join(target_folder, '{:s}.csv'.format(folder.split('./')[1].replace('/','-'))), index_col = 'id', skipinitialspace=True)
    df = pd.read_csv(os.path.join(filepath, filename), index_col='id', skipinitialspace=True)
    df = df.drop('Unnamed: 0', 1)

    # include manual corrections
    if manual:
        for id, nv_type in enumerate(df['manual_nv_type']):
            if not pd.isnull(nv_type):
                df.set_value(id, 'NV type', nv_type)
                if nv_type == 'split':
                    df.set_value(id, 'B-field (gauss)', df.get_value(id, 'manual_B_field'))

    # load the image data
    select_points_data = Script.load_data(get_select_points(src_folder))
    image_data = select_points_data['image_data']
    #     points_data = select_points_data['nv_locations']
    extent = select_points_data['extent']

    # prepare figure
    f = plt.figure(figsize=(15, 8))
    gs = gridspec.GridSpec(1, 2, height_ratios=[1, 1])
    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1])
    ax = [ax0, ax1]

    # plot the map
    ax0.imshow(image_data, extent=extent, interpolation='nearest', cmap='pink')

    for nv_type, color in zip(['split', 'bad', 'no_peak', 'single'], ['g', 'k', 'b', 'r']):
        subset = df[df['NV type'] == nv_type]
        ax0.scatter(subset['xo'], subset['yo'], c=color, label=nv_type)

    ax0.legend(bbox_to_anchor=(2.3, 1))

    subset = df[df['NV type'] == 'split']
    for i, j, n in zip(subset['xo'], subset['yo'], subset.index):
        corr = +0.005  # adds a little correction to put annotation in marker's centrum
        ax0.annotate(str(n), xy=(i + corr, j + corr), fontsize=8, color='w')
    ax0.set_xlabel('x (V)')
    ax0.set_ylabel('y (V)')

    # plot the fields on a 1D plot
    subset = df[df['NV type'] == 'split']
    ax1.plot(subset['xo'], subset['B-field (gauss)'] / freq_to_mag * 1e-6, 'o')
    ax1.set_title('ESR Splittings')
    ax1.set_xlabel('x coordinate (V)')
    ax1.set_ylabel('Splitting (MHz)')
    ax1.set_xlim([extent[0], extent[1]])

    ax2 = ax1.twinx()
    mn, mx = ax1.get_ylim()
    ax2.set_ylim(mn * freq_to_mag * 1e6, mx * freq_to_mag * 1e6)
    ax2.set_ylabel('Magnetic Field Projection (Gauss)')

    for i, j, n in zip(subset['xo'], subset['B-field (gauss)'] / freq_to_mag * 1e-6, subset.index):
        corr = 0.0005  # adds a little correction to put annotation in marker's centrum
        ax1.annotate(str(n), xy=(i + corr, j + corr))

    # f.set_tight_layout(True)

    #     f.savefig(os.path.join(target_folder, '{:s}.jpg'.format(os.path.basename(folder))),
    #                bbox_inches='tight',
    # #                transparent=True,
    #                pad_inches=0)

    f.savefig(os.path.join(filepath, '{:s}.jpg'.format(os.path.basename(src_folder))),
              bbox_inches='tight',
              #                transparent=True,
              pad_inches=0)
def manual_correction(folder, target_folder, fit_data_set, nv_type_manual, b_field_manual, queue, current_id_queue, lower_peak_widget, upper_peak_widget, lower_fit_widget, upper_fit_widget):
    """
    Backend code to display and fit ESRs, then once input has been received from front-end, incorporate the data into the
    current data set

    Args:
        folders: folder containing the data to analyze
        target_folder: target for processed data
        fit_data_set: starting point data set containing automatically fitted esrs
        nv_type_manual: pointer to empty array to populate with nv type manual corrections
        b_field_manual: pointer to empty array to populate with b field manual corrections
        queue: queue for communication with front-end in separate thread
        lower_peak_widget: widget containing frequency value of lower peak
        upper_peak_widget: widget containing frequency value of upper peak

    Poststate: populates fit_data_set with manual corrections
    """

    lower_peak_manual = [np.nan] * len(fit_data_set)
    upper_peak_manual = [np.nan] * len(fit_data_set)

    filepath = os.path.join(target_folder, folder[2:])
    data_filepath = os.path.join(filepath, 'data-manual.csv')
    if os.path.exists(data_filepath):
        prev_data = pd.read_csv(data_filepath)
        if 'manual_peak_1' in list(prev_data.keys()):
            for i in range(0, len(prev_data['manual_B_field'])):
                b_field_manual[i] = prev_data['manual_B_field'][i]
                nv_type_manual[i] = prev_data['manual_nv_type'][i]
            lower_peak_manual = prev_data['manual_peak_1']
            upper_peak_manual = prev_data['manual_peak_2']


    #TODO: Add saving as you go, add ability to start at arbitrary NV, add ability to specify a next NV number, eliminate peak/correct -> only have 'accept fit'

    try:


        print('STARTING')

        fit_data_set_array = fit_data_set.as_matrix()

        w = widgets.HTML("Event information appears here when you click on the figure")
        display(w)

        # loop over all the folders in the data_subscripts subfolder and retrieve fitparameters and position of NV
        esr_folders = glob.glob(os.path.join(folder, '.\\data_subscripts\\*esr*'))

        # create folder to save images to
        # filepath_image = os.path.join(target_folder, os.path.dirname(folder).split('./')[1])
        # image_folder = os.path.join(filepath_image, '{:s}\\images'.format(os.path.basename(folder)))
        image_folder = os.path.join(target_folder, '{:s}\\images'.format(folder[2:]))
        # image_folder = os.path.normpath(
        #     os.path.abspath(os.path.join(os.path.join(target_folder, 'images'), os.path.basename(folders[0]))))
        if not os.path.exists(image_folder):
            os.makedirs(image_folder)
        if not os.path.exists(os.path.join(image_folder, 'bad_data')):
            os.makedirs(os.path.join(image_folder, 'bad_data'))

        f = plt.figure(figsize=(12, 6))

        def onclick(event):
            if event.button == 1:
                if event.key == 'control':
                    lower_fit_widget.value = event.xdata
                else:
                    lower_peak_widget.value = event.xdata
            elif event.button == 3:
                if event.key == 'control':
                    upper_fit_widget.value = event.xdata
                else:
                    upper_peak_widget.value = event.xdata

        cid = f.canvas.mpl_connect('button_press_event', onclick)

        data_array = []
        data_pos_array = []
        for esr_folder in esr_folders:
            print(esr_folder)
            sys.stdout.flush()
            data = Script.load_data(esr_folder)
            data_array.append(data)
            print('looping')
            sys.stdout.flush()


        nv_folders = glob.glob(folder + '\\data_subscripts\\*find_nv*pt_*')
        for nv_folder in nv_folders:
            data_pos_array.append(Script.load_data(nv_folder))

        while True:


        # for i, esr_folder in enumerate(esr_folders):

            i = current_id_queue.queue[0]
            if i >= len(data_array):
                break

            lower_fit_widget.value = 0
            upper_fit_widget.value = 10e9

            lower_peak_widget.value = 2.87e9
            upper_peak_widget.value = 0

            def display_data(pt_id, lower_peak_widget = None, upper_peak_widget = None, display_fit = True):
                # find the NV index
                # pt_id = int(os.path.basename(esr_folder).split('pt_')[-1])

                # findnv_folder = glob.glob(folder + '\\data_subscripts\\*find_nv*pt_*{:d}'.format(pt_id))[0]

                f.clf()
                gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])
                ax0 = plt.subplot(gs[0])
                ax1 = plt.subplot(gs[1])
                ax = [ax0, ax1]
                plt.suptitle('NV #{:d}'.format(pt_id), fontsize=16)

                # load data
                data = data_array[i]
                if lower_fit_widget.value == 0 and upper_fit_widget.value == 10e9:
                    freq = data['frequency']
                    ampl = data['data']
                else:
                    freq = data['frequency'][np.logical_and(data['frequency'] > lower_fit_widget.value, data['frequency'] < upper_fit_widget.value)]
                    ampl = data['data'][np.logical_and(data['frequency'] > lower_fit_widget.value, data['frequency'] < upper_fit_widget.value)]
                if lower_peak_widget is None:
                    fit_params = fit_data_set_array[pt_id, 2:8]
                else:
                    lower_peak = lower_peak_widget.value
                    upper_peak = upper_peak_widget.value
                    if upper_peak == 0:
                        start_vals = get_lorentzian_fit_starting_values(freq, ampl)
                        start_vals[2] = lower_peak
                        start_vals[1] = ampl[np.argmin(np.abs(freq - lower_peak))] - start_vals[0]
                        try:
                            fit_params = fit_lorentzian(freq, ampl, starting_params=start_vals,
                                                 bounds=[(0, -np.inf, 0, 0), (np.inf, 0, np.inf, np.inf)])
                        except:
                            # ESR fit failed!
                            fit_params = [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan]

                    else:
                        center_freq = np.mean(freq)
                        start_vals = []
                        start_vals.append(
                            get_lorentzian_fit_starting_values(freq[freq < center_freq], ampl[freq < center_freq]))
                        start_vals.append(
                            get_lorentzian_fit_starting_values(freq[freq > center_freq], ampl[freq > center_freq]))
                        start_vals = [
                            np.mean([start_vals[0][0], start_vals[1][0]]),  # offset
                            np.sum([start_vals[0][3], start_vals[1][3]]),  # FWHM
                            ampl[np.argmin(np.abs(freq-lower_peak))] - start_vals[0][0], ampl[np.argmin(np.abs(freq-upper_peak))]- start_vals[1][0],  # amplitudes
                            lower_peak, upper_peak  # centers
                        ]
                        try:
                            fit_params = fit_double_lorentzian(freq, ampl, starting_params=start_vals, bounds=
                            [(0, 0, -np.inf, -np.inf, min(freq), min(freq)), (np.inf, np.inf, 0, 0, max(freq), max(freq))])
                        except:
                            fit_params = [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan]

                if len(fit_params) == 4 or np.isnan(fit_params[4]):
                    fit_params = fit_params[0:4]

                # get nv positions
                #             data_pos = {'initial_point': [fit_data_set['xo'].values[pt_id]]}
                data_pos = data_pos_array[i]
                #             pos = data_pos['maximum_point']
                #             pos_init = data_pos['initial_point']

                # plot NV image
                FindNV.plot_data([ax[1]], data_pos)

                # plot data and fits
                # print("fit_params: ", fit_params)

                sys.stdout.flush()

                if display_fit:
                    plot_esr(ax[0], data['frequency'], data['data'], fit_params=fit_params)
                else:
                    plot_esr(ax[0], data['frequency'], data['data'], fit_params=[np.nan, np.nan, np.nan, np.nan, np.nan, np.nan])

                plt.tight_layout()
                plt.subplots_adjust(top=0.85) # Makes room at top of plot for figure suptitle

                plt.draw()
                plt.show()

                return fit_params, pt_id

            fit_params, pt_id = display_data(i)
            if len(fit_params) == 6:
                lower_peak_widget.value = fit_params[4]
                upper_peak_widget.value = fit_params[5]
            elif len(fit_params) == 4:
                lower_peak_widget.value = fit_params[2]
                upper_peak_widget.value = 0

            while True:
                if queue.empty():
                    time.sleep(.5)
                else:
                    value = queue.get()
                    if value == -1:
                        fit_params, point_id = display_data(i, lower_peak_widget=lower_peak_widget, upper_peak_widget=upper_peak_widget)
                        if len(fit_params) == 6:
                            lower_peak_widget.value = fit_params[4]
                            upper_peak_widget.value = fit_params[5]
                        elif len(fit_params) == 4:
                            lower_peak_widget.value = fit_params[2]
                            upper_peak_widget.value = 0
                        continue
                    elif value == -2:
                        display_data(display_fit = False)
                        fit_params = [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan]
                        lower_fit_widget.value = 0
                        upper_fit_widget.value = 10e9
                    else:
                        break

            if nv_type_manual[i] == 'split':
                if np.isnan(fit_params[0]):
                    lower_peak_manual[i] = lower_peak_widget.value
                    upper_peak_manual[i] = upper_peak_widget.value
                    b_field_manual[i] = ((upper_peak_widget.value - lower_peak_widget.value) / 5.6e6)
                elif len(fit_params) == 4:
                    lower_peak_manual[i] = fit_params[2]
                    b_field_manual[i] = (np.abs(2.87e9-fit_params[2]) / 2.8e6)
                else:
                    lower_peak_manual[i] = fit_params[4]
                    upper_peak_manual[i] = fit_params[5]
                    b_field_manual[i] = ((fit_params[5] - fit_params[4]) / 5.6e6)
            elif nv_type_manual[i] == 'single':
                if np.isnan(fit_params[0]):
                    lower_peak_manual[i] = lower_peak_widget.value
                    b_field_manual[i] = 0
                else:
                    lower_peak_manual[i] = fit_params[2]
                    b_field_manual[i] = 0

            if nv_type_manual[i] == '':
                if fit_params is None:
                    f.savefig(os.path.join(os.path.join(image_folder, 'bad_data'), 'esr_pt_{:02d}.jpg'.format(pt_id)))
                else:
                    f.savefig(os.path.join(image_folder, 'esr_pt_{:02d}.jpg'.format(pt_id)))
            else:
                if nv_type_manual[i] in ['bad', 'no_split']:
                    f.savefig(os.path.join(os.path.join(image_folder, 'bad_data'), 'esr_pt_{:02d}.jpg'.format(pt_id)))
                else:
                    f.savefig(os.path.join(image_folder, 'esr_pt_{:02d}.jpg'.format(pt_id)))

            if not os.path.exists(filepath):
                os.makedirs(filepath)
            fit_data_set['manual_B_field'] = b_field_manual
            fit_data_set['manual_nv_type'] = nv_type_manual
            fit_data_set['manual_peak_1'] = lower_peak_manual
            fit_data_set['manual_peak_2'] = upper_peak_manual
            fit_data_set.to_csv(data_filepath)

        f.canvas.mpl_disconnect(cid)
        fit_data_set['manual_B_field'] = b_field_manual
        fit_data_set['manual_nv_type'] = nv_type_manual
        fit_data_set['manual_peak_1'] = lower_peak_manual
        fit_data_set['manual_peak_2'] = upper_peak_manual

        # filepath = os.path.join(target_folder, folder[2:])
        # data_filepath = os.path.join(filepath, 'data-manual.csv')
        # filename = '{:s}\\data-manual.csv'.format(os.path.basename(folder))
        # filepath = os.path.join(target_folder, os.path.dirname(folder).split('./')[1])
        # data_filepath = os.path.join(filepath, filename)

        if not os.path.exists(filepath):
            os.makedirs(filepath)

        fit_data_set.to_csv(data_filepath)

        create_shortcut(os.path.abspath(os.path.join(filepath, 'to_data.lnk')), os.path.abspath(folder))
        create_shortcut(os.path.join(os.path.abspath(folder), 'to_processed.lnk'), os.path.abspath(filepath))

        print('DONE!')

    except Exception as e:
        print(e)
        raise
def autofit_esrs(folder):
    """

    fits the esr data, plots them and asks the user for confirmation, the fit data is saved to the folder target_folder with the same structure as folders

    Args:
        folders: source folder with esr data, this folder shoudl contain a subfolder data_subscripts which contains subfolders *esr* with the esr data
        target_folder: target folder where the output data is saved in form of a .csv file

    Returns: fitdataset as a pandas array

    """
    # loop over all the folders in the data_subscripts subfolder and retrieve fitparameters and position of NV
    esr_folders = glob.glob(os.path.join(folder, './data_subscripts/*esr*'))

    fit_data_set = None

    # classify the nvs according to the following categories
    # by default we set this to na (not available) and try to figure it out based on the data and fitquality
    # nv_type = 'na' # split / single / no_peak / no_nv / na

    for i, esr_folder in enumerate(esr_folders):

        # find the NV index
        pt_id = int(os.path.basename(esr_folder).split('pt_')[-1])

        findnv_folder =  sorted(glob.glob(folder + '/data_subscripts/*find_nv*pt_*{:d}'.format(pt_id)))[0]

        # load data
        data = Script.load_data(esr_folder)
        fit_params = fit_esr(data['frequency'], data['data'])
        nv_type = get_nv_type(fit_params)

        # get initial guess for peaks
        freq_peaks, ampl_peaks = find_nv_peaks(data['frequency'], data['data'])

        # get nv positions
        data_pos = Script.load_data(findnv_folder)
        pos = data_pos['maximum_point']
        pos_init = data_pos['initial_point']

        if fit_params is None:
            fit_data_set_single = {}
        else:
            fit_data_set_single = {'fit_{:d}'.format(i): f for i, f in enumerate(fit_params)}

        fit_data_set_single.update({'id': pt_id, 'NV type': nv_type})

        fit_data_set_single.update({'x': pos['x'][0], 'y': pos['y'][0]})
        fit_data_set_single.update({'xo': pos_init['x'][0], 'yo': pos_init['y'][0]})

        fit_data_set_single.update({'B-field (gauss)': get_B_field(nv_type, fit_params)})

        # convert to dataframe
        fit_data_set_single = pd.DataFrame.from_dict({k: [v] for k, v in fit_data_set_single.items()})

        if fit_data_set is None:
            fit_data_set = pd.DataFrame(fit_data_set_single)
        else:
            fit_data_set = fit_data_set.append(fit_data_set_single, ignore_index=True)

    return fit_data_set
Beispiel #5
0
        def run(self):
            """
            Code to run fitting routine. Should be run in a separate thread from the gui.
            """
            esr_folders = glob.glob(os.path.join(self.filepath, './data_subscripts/*esr*'))

            data_array = []
            self.status.emit('loading data')
            for esr_folder in esr_folders[:-1]:
                data = Script.load_data(esr_folder)
                data_array.append(data)

            self.fits = self.load_fitdata()

            self.status.emit('executing manual fitting')
            index = 0
            self.last_good = []
            self.initial_fit = False
            # for data in data_array:
            while index < len(data_array):
                data = data_array[index]
                #this must be after the draw command, otherwise plot doesn't display for some reason
                self.status.emit('executing manual fitting NV #' + str(index))
                self.plotwidget.axes.clear()
                self.plotwidget.axes.plot(data['frequency'], data['data'])
                if index in self.fits['nv_id'].values:
                    fitdf = self.fits.loc[(self.fits['nv_id'] == index)]
                    offset = fitdf['offset'].as_matrix()[0]
                    centers = fitdf['fit_center'].as_matrix()
                    amplitudes = fitdf['fit_amplitude'].as_matrix()
                    widths = fitdf['fit_width'].as_matrix()
                    fit_params = np.concatenate((np.concatenate((widths, amplitudes)), centers))
                    self.plotwidget.axes.plot(data['frequency'], self.n_lorentzian(data['frequency'], *np.concatenate(([offset], fit_params))))
                self.plotwidget.draw()

                if not self.last_good == []:
                    self.initial_fit = True
                    self.queue.put('fit')

                while(True):
                    if self.queue.empty():
                        time.sleep(.5)
                    else:
                        value = self.queue.get()
                        if value == 'next':
                            while not self.peak_vals == []:
                                self.last_good.append(self.peak_vals.pop(-1))
                            if self.single_fit:
                                to_delete = np.where(self.fits['nv_id'].values == index)
                                # print(self.fits[to_delete])
                                self.fits = self.fits.drop(self.fits.index[to_delete])
                                # for val in to_delete[0][::-1]:
                                #     for key in self.fits.keys():
                                #         del self.fits[key][val]
                                for peak in self.single_fit:
                                    self.fits = self.fits.append(pd.DataFrame(peak))

                            index += 1
                            self.status.emit('saving')
                            self.save()
                            break
                        elif value == 'clear':
                            self.last_good = []
                            self.plotwidget.axes.clear()
                            self.plotwidget.axes.plot(data['frequency'], data['data'])
                            self.plotwidget.draw()
                        elif value == 'fit':
                            if self.initial_fit:
                                input = self.last_good
                            else:
                                input = self.peak_vals
                            if len(input) > 1:
                                centers, heights = list(zip(*input))
                                widths = 1e7 * np.ones(len(heights))
                            elif len(input) == 1:
                                centers, heights = input[0]
                                widths = 1e7
                            elif len(input) == 0:
                                self.single_fit = None
                                self.peak_locs.setText('No Peak')
                                self.plotwidget.axes.plot(data['frequency'],np.repeat(np.mean(data['data']), len(data['frequency'])))
                                self.plotwidget.draw()
                                continue
                            offset = np.mean(data['data'])
                            amplitudes = offset-np.array(heights)
                            if len(input) > 1:
                                fit_start_params = [[offset], np.concatenate((widths, amplitudes, centers))]
                                fit_start_params = [y for x in fit_start_params for y in x]
                            elif len(input) == 1:
                                fit_start_params = [offset, widths, amplitudes, centers]
                            try:
                                popt = self.fit_n_lorentzian(data['frequency'], data['data'], fit_start_params = fit_start_params)
                            except RuntimeError:
                                print('fit failed, optimal parameters not found')
                                break
                            self.plotwidget.axes.clear()
                            self.plotwidget.axes.plot(data['frequency'], data['data'])
                            self.plotwidget.axes.plot(data['frequency'], self.n_lorentzian(data['frequency'], *popt))
                            self.plotwidget.draw()
                            params = popt[1:]
                            widths_array = params[:len(params)/3]
                            amplitude_array = params[len(params)/3: 2 * len(params) / 3]
                            center_array = params[2 * len(params) / 3:]
                            positions = list(zip(center_array, amplitude_array, widths_array))
                            self.single_fit = []
                            peak_index = 0
                            for position in positions:
                                self.single_fit.append({'nv_id': [index], 'peak_id': [peak_index], 'offset': [popt[0]], 'fit_center': [position[0]], 'fit_amplitude': [position[1]], 'fit_width': [position[2]], 'manual_center': [input[peak_index][0]], 'manual_height': [input[peak_index][1]]})
                                peak_index += 1
                            self.peak_locs.setText('Peak Positions: ' + str(center_array))
                            self.initial_fit = False
                        elif value == 'prev':
                            index -= 1
                            break
                        elif value == 'skip':
                            index += 1
                            break
                        elif type(value) is int:
                            index = int(value)
                            break

            self.finished.emit()
            self.status.emit('dataset finished')
Beispiel #6
0
        def run(self):
            """
            Code to run fitting routine. Should be run in a separate thread from the gui.
            """
            esr_folders = glob.glob(
                os.path.join(self.filepath, './data_subscripts/*esr*'))

            data_array = []
            self.status.emit('loading data')
            for esr_folder in esr_folders[:-1]:
                data = Script.load_data(esr_folder)
                data_array.append(data)

            self.fits = self.load_fitdata()

            self.status.emit('executing manual fitting')
            index = 0
            self.last_good = []
            self.initial_fit = False
            # for data in data_array:
            while index < len(data_array):
                data = data_array[index]
                #this must be after the draw command, otherwise plot doesn't display for some reason
                self.status.emit('executing manual fitting NV #' + str(index))
                self.plotwidget.axes.clear()
                self.plotwidget.axes.plot(data['frequency'], data['data'])
                if index in self.fits['nv_id'].values:
                    fitdf = self.fits.loc[(self.fits['nv_id'] == index)]
                    offset = fitdf['offset'].as_matrix()[0]
                    centers = fitdf['fit_center'].as_matrix()
                    amplitudes = fitdf['fit_amplitude'].as_matrix()
                    widths = fitdf['fit_width'].as_matrix()
                    fit_params = np.concatenate((np.concatenate(
                        (widths, amplitudes)), centers))
                    self.plotwidget.axes.plot(
                        data['frequency'],
                        self.n_lorentzian(
                            data['frequency'],
                            *np.concatenate(([offset], fit_params))))
                self.plotwidget.draw()

                if not self.last_good == []:
                    self.initial_fit = True
                    self.queue.put('fit')

                while (True):
                    if self.queue.empty():
                        time.sleep(.5)
                    else:
                        value = self.queue.get()
                        if value == 'next':
                            while not self.peak_vals == []:
                                self.last_good.append(self.peak_vals.pop(-1))
                            if self.single_fit:
                                to_delete = np.where(
                                    self.fits['nv_id'].values == index)
                                # print(self.fits[to_delete])
                                self.fits = self.fits.drop(
                                    self.fits.index[to_delete])
                                # for val in to_delete[0][::-1]:
                                #     for key in self.fits.keys():
                                #         del self.fits[key][val]
                                for peak in self.single_fit:
                                    self.fits = self.fits.append(
                                        pd.DataFrame(peak))

                            index += 1
                            self.status.emit('saving')
                            self.save()
                            break
                        elif value == 'clear':
                            self.last_good = []
                            self.plotwidget.axes.clear()
                            self.plotwidget.axes.plot(data['frequency'],
                                                      data['data'])
                            self.plotwidget.draw()
                        elif value == 'fit':
                            if self.initial_fit:
                                input = self.last_good
                            else:
                                input = self.peak_vals
                            if len(input) > 1:
                                centers, heights = list(zip(*input))
                                widths = 1e7 * np.ones(len(heights))
                            elif len(input) == 1:
                                centers, heights = input[0]
                                widths = 1e7
                            elif len(input) == 0:
                                self.single_fit = None
                                self.peak_locs.setText('No Peak')
                                self.plotwidget.axes.plot(
                                    data['frequency'],
                                    np.repeat(np.mean(data['data']),
                                              len(data['frequency'])))
                                self.plotwidget.draw()
                                continue
                            offset = np.mean(data['data'])
                            amplitudes = offset - np.array(heights)
                            if len(input) > 1:
                                fit_start_params = [[offset],
                                                    np.concatenate(
                                                        (widths, amplitudes,
                                                         centers))]
                                fit_start_params = [
                                    y for x in fit_start_params for y in x
                                ]
                            elif len(input) == 1:
                                fit_start_params = [
                                    offset, widths, amplitudes, centers
                                ]
                            try:
                                popt = self.fit_n_lorentzian(
                                    data['frequency'],
                                    data['data'],
                                    fit_start_params=fit_start_params)
                            except RuntimeError:
                                print(
                                    'fit failed, optimal parameters not found')
                                break
                            self.plotwidget.axes.clear()
                            self.plotwidget.axes.plot(data['frequency'],
                                                      data['data'])
                            self.plotwidget.axes.plot(
                                data['frequency'],
                                self.n_lorentzian(data['frequency'], *popt))
                            self.plotwidget.draw()
                            params = popt[1:]
                            widths_array = params[:len(params) / 3]
                            amplitude_array = params[len(params) / 3:2 *
                                                     len(params) / 3]
                            center_array = params[2 * len(params) / 3:]
                            positions = list(
                                zip(center_array, amplitude_array,
                                    widths_array))
                            self.single_fit = []
                            peak_index = 0
                            for position in positions:
                                self.single_fit.append({
                                    'nv_id': [index],
                                    'peak_id': [peak_index],
                                    'offset': [popt[0]],
                                    'fit_center': [position[0]],
                                    'fit_amplitude': [position[1]],
                                    'fit_width': [position[2]],
                                    'manual_center': [input[peak_index][0]],
                                    'manual_height': [input[peak_index][1]]
                                })
                                peak_index += 1
                            self.peak_locs.setText('Peak Positions: ' +
                                                   str(center_array))
                            self.initial_fit = False
                        elif value == 'prev':
                            index -= 1
                            break
                        elif value == 'skip':
                            index += 1
                            break
                        elif type(value) is int:
                            index = int(value)
                            break

            self.finished.emit()
            self.status.emit('dataset finished')
        def run(self):

            print('_____>>>', self.filepath)
            if str(self.filepath).endswith('.h5'):
                print('loading from .h5')
                file = h5py.File(self.filepath, 'r')
                print('UUUU', file.keys())
                data_esr_norm = file['esr_map']
                self.frequencies = file['frequency']
                # print('loading freq from data_subscripts')

                #
                # sub_fs = glob.glob(os.path.join(os.path.dirname(self.filepath), 'data_subscripts/*'))
                # print('sssss', sub_fs)
                #
                # print('ASAAAA', sub_fs[0])
                #
                #
                # f = glob.glob(os.path.join(os.path.dirname(self.filepath), 'data_subscripts/*'))[0]
                # data = Script.load_data(f)
                # self.frequencies = data['frequency']

            else:
                print('loading from data_subscripts')
                data_esr = []
                for f in sorted(
                        glob.glob(
                            os.path.join(self.filepath,
                                         './data_subscripts/*'))):
                    data = Script.load_data(f)
                    data_esr.append(data['data'])
                self.frequencies = data['frequency']

                # normalize
                norm = 'quantile'
                norm_parameter = 0.75
                if norm == 'mean':
                    norm_value = [np.mean(d) for d in data_esr]
                elif norm == 'border':
                    if norm_parameter > 0:
                        norm_value = [
                            np.mean(d[0:norm_parameter]) for d in data_esr
                        ]
                    elif norm_parameter < 0:
                        norm_value = [
                            np.mean(d[norm_parameter:]) for d in data_esr
                        ]
                elif norm == 'quantile':
                    norm_value = [
                        np.quantile(d, norm_parameter) for d in data_esr
                    ]

                data_esr_norm = np.array([
                    d / n for d, n in zip(data_esr, norm_value)
                ])  # normalize and convert to numpy array

            # data_esr_norm = []
            # for d in data_esr:
            #     data_esr_norm.append(d / np.mean(d))

            angle = np.arange(len(data_esr_norm))
            print('<<<<<<<', self.frequencies.shape, angle.shape,
                  data_esr_norm.shape)

            self.x_range = list(range(0, len(data_esr_norm)))

            self.status.emit('executing manual fitting')
            index = 0
            # for data in data_array:
            while index < self.NUM_ESR_LINES:
                #this must be after the draw command, otherwise plot doesn't display for some reason
                self.status.emit('executing manual fitting NV #' + str(index))
                self.plotwidget.axes.clear()
                self.plotwidget.axes.pcolor(self.frequencies, angle,
                                            data_esr_norm)
                # self.plotwidget.axes.imshow(data_esr_norm, aspect = 'auto', origin = 'lower')
                if self.interps:
                    for f in self.interps:
                        self.plotwidget.axes.plot(f(self.x_range),
                                                  self.x_range)

                self.plotwidget.draw()

                while (True):
                    if self.queue.empty():
                        time.sleep(.5)
                    else:
                        value = self.queue.get()
                        if value == 'next':
                            while not self.peak_vals == []:
                                self.peak_vals.pop(-1)
                            # if len(self.single_fit) == 1:
                            #     self.fits[index] = self.single_fit
                            # else:
                            #     self.fits[index] = [y for x in self.single_fit for y in x]
                            index += 1
                            self.interps.append(f)
                            break
                        elif value == 'clear':
                            self.plotwidget.axes.clear()
                            self.plotwidget.axes.imshow(data_esr_norm,
                                                        aspect='auto',
                                                        origin='lower')
                            if self.interps:
                                for f in self.interps:
                                    self.plotwidget.axes.plot(
                                        f(self.x_range), self.x_range)
                            self.plotwidget.draw()
                        elif value == 'fit':
                            # peak_vals = sorted(self.peak_vals, key=lambda tup: tup[1])
                            peak_vals = np.array(self.peak_vals)
                            print('ggggg', peak_vals.shape)
                            y, x = peak_vals[:, 0], peak_vals[:, 1]

                            # y,x = list(zip(*peak_vals))
                            #
                            # print('sdasda', x)
                            #
                            # # sort the list such that points are in creasing (in case we accidently clicked below a point)
                            # y = [elem for _, elem in sorted(zip(x, y))]

                            y = y[x.argsort()]
                            x = sorted(x)

                            f = UnivariateSpline(x, y)
                            x_range = list(range(0, len(data_esr_norm)))
                            self.plotwidget.axes.plot(f(x_range), x_range)
                            self.plotwidget.draw()
                        elif value == 'prev':
                            index -= 1
                            break
                        elif value == 'skip':
                            index += 1
                            break
                        elif type(value) is int:
                            index = int(value)
                            break

            self.finished.emit()
            self.status.emit('saving')
            self.plotwidget.axes.clear()

            angle = np.arange(len(data_esr_norm))
            # print('asdadf', self.frequencies)
            self.plotwidget.axes.pcolor(self.frequencies, angle, data_esr_norm)

            # self.plotwidget.axes.imshow(data_esr_norm, aspect='auto', origin = 'lower')
            if self.interps:
                for f in self.interps:
                    self.plotwidget.axes.plot(f(self.x_range), self.x_range)
            self.save()
            self.status.emit('saving finished')
        def run(self):
            data_esr = []
            for f in sorted(
                    glob.glob(
                        os.path.join(self.filepath, './data_subscripts/*'))):
                data = Script.load_data(f)
                data_esr.append(data['data'])
            self.frequencies = data['frequency']

            data_esr_norm = []
            for d in data_esr:
                data_esr_norm.append(d / np.mean(d))

            self.x_range = list(range(0, len(data_esr_norm)))

            self.status.emit('executing manual fitting')
            index = 0
            # for data in data_array:
            while index < self.NUM_ESR_LINES:
                #this must be after the draw command, otherwise plot doesn't display for some reason
                self.status.emit('executing manual fitting NV #' + str(index))
                self.plotwidget.axes.clear()
                self.plotwidget.axes.imshow(data_esr_norm,
                                            aspect='auto',
                                            origin='lower')
                if self.interps:
                    for f in self.interps:
                        self.plotwidget.axes.plot(f(self.x_range),
                                                  self.x_range)

                self.plotwidget.draw()

                while (True):
                    if self.queue.empty():
                        time.sleep(.5)
                    else:
                        value = self.queue.get()
                        if value == 'next':
                            while not self.peak_vals == []:
                                self.peak_vals.pop(-1)
                            # if len(self.single_fit) == 1:
                            #     self.fits[index] = self.single_fit
                            # else:
                            #     self.fits[index] = [y for x in self.single_fit for y in x]
                            index += 1
                            self.interps.append(f)
                            break
                        elif value == 'clear':
                            self.plotwidget.axes.clear()
                            self.plotwidget.axes.imshow(data_esr_norm,
                                                        aspect='auto',
                                                        origin='lower')
                            if self.interps:
                                for f in self.interps:
                                    self.plotwidget.axes.plot(
                                        f(self.x_range), self.x_range)
                            self.plotwidget.draw()
                        elif value == 'fit':
                            peak_vals = sorted(self.peak_vals,
                                               key=lambda tup: tup[1])
                            y, x = list(zip(*peak_vals))
                            f = UnivariateSpline(np.array(x), np.array(y))
                            x_range = list(range(0, len(data_esr_norm)))
                            self.plotwidget.axes.plot(f(x_range), x_range)
                            self.plotwidget.draw()
                        elif value == 'prev':
                            index -= 1
                            break
                        elif value == 'skip':
                            index += 1
                            break
                        elif type(value) is int:
                            index = int(value)
                            break

            self.finished.emit()
            self.status.emit('saving')
            self.plotwidget.axes.clear()
            self.plotwidget.axes.imshow(data_esr_norm,
                                        aspect='auto',
                                        origin='lower')
            if self.interps:
                for f in self.interps:
                    self.plotwidget.axes.plot(f(self.x_range), self.x_range)
            self.save()
            self.status.emit('saving finished')
        def run(self):

            print('_____>>>', self.filepath)
            if str(self.filepath).endswith('.h5'):
                print('loading from .h5')
                file = h5py.File(self.filepath, 'r')
                print('UUUU', file.keys())
                data_esr_norm = file['esr_map']
                self.frequencies = file['frequency']
                # print('loading freq from data_subscripts')

                #
                # sub_fs = glob.glob(os.path.join(os.path.dirname(self.filepath), 'data_subscripts/*'))
                # print('sssss', sub_fs)
                #
                # print('ASAAAA', sub_fs[0])
                #
                #
                # f = glob.glob(os.path.join(os.path.dirname(self.filepath), 'data_subscripts/*'))[0]
                # data = Script.load_data(f)
                # self.frequencies = data['frequency']

            else:
                print('loading from data_subscripts')
                data_esr = []
                for f in sorted(glob.glob(os.path.join(self.filepath, './data_subscripts/*'))):
                    data = Script.load_data(f)
                    data_esr.append(data['data'])
                self.frequencies = data['frequency']

                # normalize
                norm = 'quantile'
                norm_parameter = 0.75
                if norm == 'mean':
                    norm_value = [np.mean(d) for d in data_esr]
                elif norm == 'border':
                    if norm_parameter > 0:
                        norm_value = [np.mean(d[0:norm_parameter]) for d in data_esr]
                    elif norm_parameter < 0:
                        norm_value = [np.mean(d[norm_parameter:]) for d in data_esr]
                elif norm == 'quantile':
                    norm_value = [np.quantile(d, norm_parameter) for d in data_esr]

                data_esr_norm = np.array([d / n for d, n in zip(data_esr, norm_value)])  # normalize and convert to numpy array


            # data_esr_norm = []
            # for d in data_esr:
            #     data_esr_norm.append(d / np.mean(d))

            angle = np.arange(len(data_esr_norm))
            print('<<<<<<<', self.frequencies.shape, angle.shape, data_esr_norm.shape)


            self.x_range = list(range(0, len(data_esr_norm)))

            self.status.emit('executing manual fitting')
            index = 0
            # for data in data_array:
            while index < self.NUM_ESR_LINES:
                #this must be after the draw command, otherwise plot doesn't display for some reason
                self.status.emit('executing manual fitting NV #' + str(index))
                self.plotwidget.axes.clear()
                self.plotwidget.axes.pcolor(self.frequencies, angle, data_esr_norm)
                # self.plotwidget.axes.imshow(data_esr_norm, aspect = 'auto', origin = 'lower')
                if self.interps:
                    for f in self.interps:
                        self.plotwidget.axes.plot(f(self.x_range), self.x_range)

                self.plotwidget.draw()

                while(True):
                    if self.queue.empty():
                        time.sleep(.5)
                    else:
                        value = self.queue.get()
                        if value == 'next':
                            while not self.peak_vals == []:
                                self.peak_vals.pop(-1)
                            # if len(self.single_fit) == 1:
                            #     self.fits[index] = self.single_fit
                            # else:
                            #     self.fits[index] = [y for x in self.single_fit for y in x]
                            index += 1
                            self.interps.append(f)
                            break
                        elif value == 'clear':
                            self.plotwidget.axes.clear()
                            self.plotwidget.axes.imshow(data_esr_norm, aspect='auto', origin = 'lower')
                            if self.interps:
                                for f in self.interps:
                                    self.plotwidget.axes.plot(f(self.x_range), self.x_range)
                            self.plotwidget.draw()
                        elif value == 'fit':
                            # peak_vals = sorted(self.peak_vals, key=lambda tup: tup[1])
                            peak_vals = np.array(self.peak_vals)
                            print('ggggg', peak_vals.shape)
                            y, x = peak_vals[:,0], peak_vals[:, 1]


                            # y,x = list(zip(*peak_vals))
                            #
                            # print('sdasda', x)
                            #
                            # # sort the list such that points are in creasing (in case we accidently clicked below a point)
                            # y = [elem for _, elem in sorted(zip(x, y))]

                            y = y[x.argsort()]
                            x = sorted(x)

                            f = UnivariateSpline(x, y)
                            x_range = list(range(0,len(data_esr_norm)))
                            self.plotwidget.axes.plot(f(x_range), x_range)
                            self.plotwidget.draw()
                        elif value == 'prev':
                            index -= 1
                            break
                        elif value == 'skip':
                            index += 1
                            break
                        elif type(value) is int:
                            index = int(value)
                            break

            self.finished.emit()
            self.status.emit('saving')
            self.plotwidget.axes.clear()

            angle = np.arange(len(data_esr_norm))
            # print('asdadf', self.frequencies)
            self.plotwidget.axes.pcolor(self.frequencies, angle, data_esr_norm)

            # self.plotwidget.axes.imshow(data_esr_norm, aspect='auto', origin = 'lower')
            if self.interps:
                for f in self.interps:
                    self.plotwidget.axes.plot(f(self.x_range), self.x_range)
            self.save()
            self.status.emit('saving finished')
def visualize_magnetic_fields(src_folder, target_folder, manual=True):
    filepath = os.path.join(target_folder,
                            os.path.dirname(src_folder).split('./')[1])

    # load the fit_data
    if manual:
        #         df = pd.read_csv(os.path.join(target_folder, '{:s}-manual.csv'.format(folder.split('./')[1].replace('/','-'))), index_col = 'id', skipinitialspace=True)
        filename = '{:s}\\data-manual.csv'.format(os.path.basename(src_folder))
    else:
        filename = '{:s}.csv'.format(os.path.basename(src_folder))
    # df = pd.read_csv(os.path.join(target_folder, '{:s}.csv'.format(folder.split('./')[1].replace('/','-'))), index_col = 'id', skipinitialspace=True)
    df = pd.read_csv(os.path.join(filepath, filename),
                     index_col='id',
                     skipinitialspace=True)
    df = df.drop('Unnamed: 0', 1)

    # include manual corrections
    if manual:
        for id, nv_type in enumerate(df['manual_nv_type']):
            if not pd.isnull(nv_type):
                df.set_value(id, 'NV type', nv_type)
                if nv_type == 'split':
                    df.set_value(id, 'B-field (gauss)',
                                 df.get_value(id, 'manual_B_field'))

    # load the image data
    select_points_data = Script.load_data(get_select_points(src_folder))
    image_data = select_points_data['image_data']
    #     points_data = select_points_data['nv_locations']
    extent = select_points_data['extent']

    # prepare figure
    f = plt.figure(figsize=(15, 8))
    gs = gridspec.GridSpec(1, 2, height_ratios=[1, 1])
    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1])
    ax = [ax0, ax1]

    # plot the map
    ax0.imshow(image_data, extent=extent, interpolation='nearest', cmap='pink')

    for nv_type, color in zip(['split', 'bad', 'no_peak', 'single'],
                              ['g', 'k', 'b', 'r']):
        subset = df[df['NV type'] == nv_type]
        ax0.scatter(subset['xo'], subset['yo'], c=color, label=nv_type)

    ax0.legend(bbox_to_anchor=(2.3, 1))

    subset = df[df['NV type'] == 'split']
    for i, j, n in zip(subset['xo'], subset['yo'], subset.index):
        corr = +0.005  # adds a little correction to put annotation in marker's centrum
        ax0.annotate(str(n), xy=(i + corr, j + corr), fontsize=8, color='w')
    ax0.set_xlabel('x (V)')
    ax0.set_ylabel('y (V)')

    # plot the fields on a 1D plot
    subset = df[df['NV type'] == 'split']
    ax1.plot(subset['xo'], subset['B-field (gauss)'] / freq_to_mag * 1e-6, 'o')
    ax1.set_title('ESR Splittings')
    ax1.set_xlabel('x coordinate (V)')
    ax1.set_ylabel('Splitting (MHz)')
    ax1.set_xlim([extent[0], extent[1]])

    ax2 = ax1.twinx()
    mn, mx = ax1.get_ylim()
    ax2.set_ylim(mn * freq_to_mag * 1e6, mx * freq_to_mag * 1e6)
    ax2.set_ylabel('Magnetic Field Projection (Gauss)')

    for i, j, n in zip(subset['xo'],
                       subset['B-field (gauss)'] / freq_to_mag * 1e-6,
                       subset.index):
        corr = 0.0005  # adds a little correction to put annotation in marker's centrum
        ax1.annotate(str(n), xy=(i + corr, j + corr))

    # f.set_tight_layout(True)

    #     f.savefig(os.path.join(target_folder, '{:s}.jpg'.format(os.path.basename(folder))),
    #                bbox_inches='tight',
    # #                transparent=True,
    #                pad_inches=0)

    f.savefig(
        os.path.join(filepath,
                     '{:s}.jpg'.format(os.path.basename(src_folder))),
        bbox_inches='tight',
        #                transparent=True,
        pad_inches=0)
def manual_correction(folder, target_folder, fit_data_set, nv_type_manual,
                      b_field_manual, queue, current_id_queue,
                      lower_peak_widget, upper_peak_widget, lower_fit_widget,
                      upper_fit_widget):
    """
    Backend code to display and fit ESRs, then once input has been received from front-end, incorporate the data into the
    current data set

    Args:
        folders: folder containing the data to analyze
        target_folder: target for processed data
        fit_data_set: starting point data set containing automatically fitted esrs
        nv_type_manual: pointer to empty array to populate with nv type manual corrections
        b_field_manual: pointer to empty array to populate with b field manual corrections
        queue: queue for communication with front-end in separate thread
        lower_peak_widget: widget containing frequency value of lower peak
        upper_peak_widget: widget containing frequency value of upper peak

    Poststate: populates fit_data_set with manual corrections
    """

    lower_peak_manual = [np.nan] * len(fit_data_set)
    upper_peak_manual = [np.nan] * len(fit_data_set)

    filepath = os.path.join(target_folder, folder[2:])
    data_filepath = os.path.join(filepath, 'data-manual.csv')
    if os.path.exists(data_filepath):
        prev_data = pd.read_csv(data_filepath)
        if 'manual_peak_1' in list(prev_data.keys()):
            for i in range(0, len(prev_data['manual_B_field'])):
                b_field_manual[i] = prev_data['manual_B_field'][i]
                nv_type_manual[i] = prev_data['manual_nv_type'][i]
            lower_peak_manual = prev_data['manual_peak_1']
            upper_peak_manual = prev_data['manual_peak_2']

    #TODO: Add saving as you go, add ability to start at arbitrary NV, add ability to specify a next NV number, eliminate peak/correct -> only have 'accept fit'

    try:

        print('STARTING')

        fit_data_set_array = fit_data_set.as_matrix()

        w = widgets.HTML(
            "Event information appears here when you click on the figure")
        display(w)

        # loop over all the folders in the data_subscripts subfolder and retrieve fitparameters and position of NV
        esr_folders = glob.glob(
            os.path.join(folder, '.\\data_subscripts\\*esr*'))

        # create folder to save images to
        # filepath_image = os.path.join(target_folder, os.path.dirname(folder).split('./')[1])
        # image_folder = os.path.join(filepath_image, '{:s}\\images'.format(os.path.basename(folder)))
        image_folder = os.path.join(target_folder,
                                    '{:s}\\images'.format(folder[2:]))
        # image_folder = os.path.normpath(
        #     os.path.abspath(os.path.join(os.path.join(target_folder, 'images'), os.path.basename(folders[0]))))
        if not os.path.exists(image_folder):
            os.makedirs(image_folder)
        if not os.path.exists(os.path.join(image_folder, 'bad_data')):
            os.makedirs(os.path.join(image_folder, 'bad_data'))

        f = plt.figure(figsize=(12, 6))

        def onclick(event):
            if event.button == 1:
                if event.key == 'control':
                    lower_fit_widget.value = event.xdata
                else:
                    lower_peak_widget.value = event.xdata
            elif event.button == 3:
                if event.key == 'control':
                    upper_fit_widget.value = event.xdata
                else:
                    upper_peak_widget.value = event.xdata

        cid = f.canvas.mpl_connect('button_press_event', onclick)

        data_array = []
        data_pos_array = []
        for esr_folder in esr_folders:
            print(esr_folder)
            sys.stdout.flush()
            data = Script.load_data(esr_folder)
            data_array.append(data)
            print('looping')
            sys.stdout.flush()

        nv_folders = glob.glob(folder + '\\data_subscripts\\*find_nv*pt_*')
        for nv_folder in nv_folders:
            data_pos_array.append(Script.load_data(nv_folder))

        while True:

            # for i, esr_folder in enumerate(esr_folders):

            i = current_id_queue.queue[0]
            if i >= len(data_array):
                break

            lower_fit_widget.value = 0
            upper_fit_widget.value = 10e9

            lower_peak_widget.value = 2.87e9
            upper_peak_widget.value = 0

            def display_data(pt_id,
                             lower_peak_widget=None,
                             upper_peak_widget=None,
                             display_fit=True):
                # find the NV index
                # pt_id = int(os.path.basename(esr_folder).split('pt_')[-1])

                # findnv_folder = glob.glob(folder + '\\data_subscripts\\*find_nv*pt_*{:d}'.format(pt_id))[0]

                f.clf()
                gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])
                ax0 = plt.subplot(gs[0])
                ax1 = plt.subplot(gs[1])
                ax = [ax0, ax1]
                plt.suptitle('NV #{:d}'.format(pt_id), fontsize=16)

                # load data
                data = data_array[i]
                if lower_fit_widget.value == 0 and upper_fit_widget.value == 10e9:
                    freq = data['frequency']
                    ampl = data['data']
                else:
                    freq = data['frequency'][np.logical_and(
                        data['frequency'] > lower_fit_widget.value,
                        data['frequency'] < upper_fit_widget.value)]
                    ampl = data['data'][np.logical_and(
                        data['frequency'] > lower_fit_widget.value,
                        data['frequency'] < upper_fit_widget.value)]
                if lower_peak_widget is None:
                    fit_params = fit_data_set_array[pt_id, 2:8]
                else:
                    lower_peak = lower_peak_widget.value
                    upper_peak = upper_peak_widget.value
                    if upper_peak == 0:
                        start_vals = get_lorentzian_fit_starting_values(
                            freq, ampl)
                        start_vals[2] = lower_peak
                        start_vals[1] = ampl[np.argmin(
                            np.abs(freq - lower_peak))] - start_vals[0]
                        try:
                            fit_params = fit_lorentzian(
                                freq,
                                ampl,
                                starting_params=start_vals,
                                bounds=[(0, -np.inf, 0, 0),
                                        (np.inf, 0, np.inf, np.inf)])
                        except:
                            # ESR fit failed!
                            fit_params = [
                                np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
                            ]

                    else:
                        center_freq = np.mean(freq)
                        start_vals = []
                        start_vals.append(
                            get_lorentzian_fit_starting_values(
                                freq[freq < center_freq],
                                ampl[freq < center_freq]))
                        start_vals.append(
                            get_lorentzian_fit_starting_values(
                                freq[freq > center_freq],
                                ampl[freq > center_freq]))
                        start_vals = [
                            np.mean([start_vals[0][0],
                                     start_vals[1][0]]),  # offset
                            np.sum([start_vals[0][3],
                                    start_vals[1][3]]),  # FWHM
                            ampl[np.argmin(np.abs(freq - lower_peak))] -
                            start_vals[0][0],
                            ampl[np.argmin(np.abs(freq - upper_peak))] -
                            start_vals[1][0],  # amplitudes
                            lower_peak,
                            upper_peak  # centers
                        ]
                        try:
                            fit_params = fit_double_lorentzian(
                                freq,
                                ampl,
                                starting_params=start_vals,
                                bounds=[(0, 0, -np.inf, -np.inf, min(freq),
                                         min(freq)),
                                        (np.inf, np.inf, 0, 0, max(freq),
                                         max(freq))])
                        except:
                            fit_params = [
                                np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
                            ]

                if len(fit_params) == 4 or np.isnan(fit_params[4]):
                    fit_params = fit_params[0:4]

                # get nv positions
                #             data_pos = {'initial_point': [fit_data_set['xo'].values[pt_id]]}
                data_pos = data_pos_array[i]
                #             pos = data_pos['maximum_point']
                #             pos_init = data_pos['initial_point']

                # plot NV image
                FindNV.plot_data([ax[1]], data_pos)

                # plot data and fits
                # print("fit_params: ", fit_params)

                sys.stdout.flush()

                if display_fit:
                    plot_esr(ax[0],
                             data['frequency'],
                             data['data'],
                             fit_params=fit_params)
                else:
                    plot_esr(ax[0],
                             data['frequency'],
                             data['data'],
                             fit_params=[
                                 np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
                             ])

                plt.tight_layout()
                plt.subplots_adjust(
                    top=0.85)  # Makes room at top of plot for figure suptitle

                plt.draw()
                plt.show()

                return fit_params, pt_id

            fit_params, pt_id = display_data(i)
            if len(fit_params) == 6:
                lower_peak_widget.value = fit_params[4]
                upper_peak_widget.value = fit_params[5]
            elif len(fit_params) == 4:
                lower_peak_widget.value = fit_params[2]
                upper_peak_widget.value = 0

            while True:
                if queue.empty():
                    time.sleep(.5)
                else:
                    value = queue.get()
                    if value == -1:
                        fit_params, point_id = display_data(
                            i,
                            lower_peak_widget=lower_peak_widget,
                            upper_peak_widget=upper_peak_widget)
                        if len(fit_params) == 6:
                            lower_peak_widget.value = fit_params[4]
                            upper_peak_widget.value = fit_params[5]
                        elif len(fit_params) == 4:
                            lower_peak_widget.value = fit_params[2]
                            upper_peak_widget.value = 0
                        continue
                    elif value == -2:
                        display_data(display_fit=False)
                        fit_params = [
                            np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
                        ]
                        lower_fit_widget.value = 0
                        upper_fit_widget.value = 10e9
                    else:
                        break

            if nv_type_manual[i] == 'split':
                if np.isnan(fit_params[0]):
                    lower_peak_manual[i] = lower_peak_widget.value
                    upper_peak_manual[i] = upper_peak_widget.value
                    b_field_manual[i] = (
                        (upper_peak_widget.value - lower_peak_widget.value) /
                        5.6e6)
                elif len(fit_params) == 4:
                    lower_peak_manual[i] = fit_params[2]
                    b_field_manual[i] = (np.abs(2.87e9 - fit_params[2]) /
                                         2.8e6)
                else:
                    lower_peak_manual[i] = fit_params[4]
                    upper_peak_manual[i] = fit_params[5]
                    b_field_manual[i] = ((fit_params[5] - fit_params[4]) /
                                         5.6e6)
            elif nv_type_manual[i] == 'single':
                if np.isnan(fit_params[0]):
                    lower_peak_manual[i] = lower_peak_widget.value
                    b_field_manual[i] = 0
                else:
                    lower_peak_manual[i] = fit_params[2]
                    b_field_manual[i] = 0

            if nv_type_manual[i] == '':
                if fit_params is None:
                    f.savefig(
                        os.path.join(os.path.join(image_folder, 'bad_data'),
                                     'esr_pt_{:02d}.jpg'.format(pt_id)))
                else:
                    f.savefig(
                        os.path.join(image_folder,
                                     'esr_pt_{:02d}.jpg'.format(pt_id)))
            else:
                if nv_type_manual[i] in ['bad', 'no_split']:
                    f.savefig(
                        os.path.join(os.path.join(image_folder, 'bad_data'),
                                     'esr_pt_{:02d}.jpg'.format(pt_id)))
                else:
                    f.savefig(
                        os.path.join(image_folder,
                                     'esr_pt_{:02d}.jpg'.format(pt_id)))

            if not os.path.exists(filepath):
                os.makedirs(filepath)
            fit_data_set['manual_B_field'] = b_field_manual
            fit_data_set['manual_nv_type'] = nv_type_manual
            fit_data_set['manual_peak_1'] = lower_peak_manual
            fit_data_set['manual_peak_2'] = upper_peak_manual
            fit_data_set.to_csv(data_filepath)

        f.canvas.mpl_disconnect(cid)
        fit_data_set['manual_B_field'] = b_field_manual
        fit_data_set['manual_nv_type'] = nv_type_manual
        fit_data_set['manual_peak_1'] = lower_peak_manual
        fit_data_set['manual_peak_2'] = upper_peak_manual

        # filepath = os.path.join(target_folder, folder[2:])
        # data_filepath = os.path.join(filepath, 'data-manual.csv')
        # filename = '{:s}\\data-manual.csv'.format(os.path.basename(folder))
        # filepath = os.path.join(target_folder, os.path.dirname(folder).split('./')[1])
        # data_filepath = os.path.join(filepath, filename)

        if not os.path.exists(filepath):
            os.makedirs(filepath)

        fit_data_set.to_csv(data_filepath)

        create_shortcut(os.path.abspath(os.path.join(filepath, 'to_data.lnk')),
                        os.path.abspath(folder))
        create_shortcut(
            os.path.join(os.path.abspath(folder), 'to_processed.lnk'),
            os.path.abspath(filepath))

        print('DONE!')

    except Exception as e:
        print(e)
        raise
def autofit_esrs(folder):
    """

    fits the esr data, plots them and asks the user for confirmation, the fit data is saved to the folder target_folder with the same structure as folders

    Args:
        folders: source folder with esr data, this folder shoudl contain a subfolder data_subscripts which contains subfolders *esr* with the esr data
        target_folder: target folder where the output data is saved in form of a .csv file

    Returns: fitdataset as a pandas array

    """
    # loop over all the folders in the data_subscripts subfolder and retrieve fitparameters and position of NV
    esr_folders = glob.glob(os.path.join(folder, './data_subscripts/*esr*'))

    fit_data_set = None

    # classify the nvs according to the following categories
    # by default we set this to na (not available) and try to figure it out based on the data and fitquality
    # nv_type = 'na' # split / single / no_peak / no_nv / na

    for i, esr_folder in enumerate(esr_folders):

        # find the NV index
        pt_id = int(os.path.basename(esr_folder).split('pt_')[-1])

        findnv_folder = sorted(
            glob.glob(folder +
                      '/data_subscripts/*find_nv*pt_*{:d}'.format(pt_id)))[0]

        # load data
        data = Script.load_data(esr_folder)
        fit_params = fit_esr(data['frequency'], data['data'])
        nv_type = get_nv_type(fit_params)

        # get initial guess for peaks
        freq_peaks, ampl_peaks = find_nv_peaks(data['frequency'], data['data'])

        # get nv positions
        data_pos = Script.load_data(findnv_folder)
        pos = data_pos['maximum_point']
        pos_init = data_pos['initial_point']

        if fit_params is None:
            fit_data_set_single = {}
        else:
            fit_data_set_single = {
                'fit_{:d}'.format(i): f
                for i, f in enumerate(fit_params)
            }

        fit_data_set_single.update({'id': pt_id, 'NV type': nv_type})

        fit_data_set_single.update({'x': pos['x'][0], 'y': pos['y'][0]})
        fit_data_set_single.update({
            'xo': pos_init['x'][0],
            'yo': pos_init['y'][0]
        })

        fit_data_set_single.update(
            {'B-field (gauss)': get_B_field(nv_type, fit_params)})

        # convert to dataframe
        fit_data_set_single = pd.DataFrame.from_dict(
            {k: [v]
             for k, v in fit_data_set_single.items()})

        if fit_data_set is None:
            fit_data_set = pd.DataFrame(fit_data_set_single)
        else:
            fit_data_set = fit_data_set.append(fit_data_set_single,
                                               ignore_index=True)

    return fit_data_set