def __init__(self, datadir=None, window_title='Data browser'):
        # set graphical user interface
        instance_ready = True
        self.app = QtCore.QCoreApplication.instance()
        if self.app is None:
            instance_ready = False
            self.app = QtWidgets.QApplication([])

        super(QtWidgets.QMainWindow, self).__init__()
        self.app.setStyleSheet(qdarkstyle.load_stylesheet_pyqt5())
        self.setupUi(self)

        if datadir is None:
            datadir = DataSet.default_io.base_location

        # set-up tree view for data
        self._treemodel = QtGui.QStandardItemModel()
        self.data_view.setModel(self._treemodel)
        self.tabWidget.addTab(QtWidgets.QWidget(), 'metadata')

        self.qplot = QtPlot(remote=False)
        self.plotwindow = self.qplot.win
        self.qplot.max_len = 10

        self.horizontalLayout_4.addWidget(self.plotwindow)

        # Fix some initializations in the window
        self.splitter_2.setSizes([int(self.height()/2)]*2)
        self.splitter.setSizes([int(self.width() / 3), int(2 * self.width() / 3)])

        # connect callbacks
        self.data_view.doubleClicked.connect(self.logCallback)
        self.actionReload_data.triggered.connect(self.updateLogs)
        self.actionPreload_all_info.triggered.connect(self.loadInfo)
        self.actionAdd_directory.triggered.connect(self.selectDirectory)
        self.filter_button.clicked.connect(lambda: self.updateLogs(filter_str=self.filter_input.text()))
        self.send_ppt.clicked.connect(self.pptCallback)
        self.copy_im.clicked.connect(self.clipboardCallback)
        self.split_data.clicked.connect(self.split_dataset)

        # initialize defaults
        self.extensions =['dat', 'hdf5']
        self.dataset = None
        self.datatag = None
        self.datadirlist = []
        self.datadirindex = 0
        self.color_list = [pg.mkColor(cl) for cl in qcodes.plots.pyqtgraph.color_cycle]
        self.current_params = []
        self.subpl_ind = dict()
        self.splitted = False

        # add default directory
        self.addDirectory(datadir)
        self.updateLogs()

        # Launch app
        self.show()
        if instance_ready == False:
            self.app.exec()
Example #2
0
    def test_simple_plot(self):
        main_QtPlot = QtPlot(
            window_title='Main plotmon of TestQtPlot',
            figsize=(600, 400))

        x = np.arange(0, 10e-6, 1e-9)
        f = 2e6
        y = np.cos(2*np.pi*f*x)

        for j in range(4):
            main_QtPlot.add(x=x, y=y,
                            xlabel='Time', xunit='s',
                            ylabel='Amplitude', yunit='V',
                            subplot=j+1,
                            symbol='o', symbolSize=5)
Example #3
0
def plot_wave_raw(wave_raw, samplerate=None, station=None):
    ''' Plot the raw wave

    Arguments:
        wave_raw (array): raw data which represents the waveform

    Returns:
        plot (QtPlot): the plot showing the data
    '''
    if samplerate is None:
        if station is None:
            raise Exception('There is no station')
        samplerate = 1 / station.awg.getattr('AWG_clock')
    else:
        samplerate = samplerate
    horz_var = np.arange(0, len(wave_raw) * samplerate, samplerate)
    x = DataArray(name='time(s)',
                  label='time (s)',
                  preset_data=horz_var,
                  is_setpoint=True)
    y = DataArray(label='sweep value (mV)',
                  preset_data=wave_raw,
                  set_arrays=(x, ))
    plot = QtPlot(x, y)

    return plot
    def render_wave(self, wave_name, show=True, QtPlot_win=None):
        '''
        Plots the specified wave.

        Args:
        '''
        x = (np.arange(len(self._wave_dict[wave_name])))*1e-9
        y = self._wave_dict[wave_name]

        if QtPlot_win is None:
            QtPlot_win = QtPlot(window_title=wave_name,
                                figsize=(600, 400))
        QtPlot_win.add(
            x=x, y=y, name=wave_name,
            symbol='o', symbolSize=5,
            xlabel='Time', xunit='s', ylabel='Amplitude', yunit='V')

        return QtPlot_win
    def __init__(self, name,
                 plotting_interval=0.25,
                 live_plot_enabled=True, verbose=True):
        super().__init__(name=name, server_name=None)
        # Soft average is currently only available for "hard"
        # measurements. It does not work with adaptive measurements.
        self.add_parameter('soft_avg',
                           label='Number of soft averages',
                           parameter_class=ManualParameter,
                           vals=vals.Ints(1, int(1e8)),
                           initial_value=1)
        self.add_parameter('verbose',
                           parameter_class=ManualParameter,
                           vals=vals.Bool(),
                           initial_value=verbose)
        self.add_parameter('live_plot_enabled',
                           parameter_class=ManualParameter,
                           vals=vals.Bool(),
                           initial_value=live_plot_enabled)
        self.add_parameter('plotting_interval',
                           units='s',
                           vals=vals.Numbers(min_value=0.001),
                           set_cmd=self._set_plotting_interval,
                           get_cmd=self._get_plotting_interval)
        self.add_parameter('persist_mode',
                           vals=vals.Bool(),
                           parameter_class=ManualParameter,
                           initial_value=True)

        # pyqtgraph plotting process is reused for different measurements.
        if self.live_plot_enabled():
            self.main_QtPlot = QtPlot(
                windowTitle='Main plotmon of {}'.format(self.name),
                figsize=(600, 400))
            self.secondary_QtPlot = QtPlot(
                windowTitle='Secondary plotmon of {}'.format(self.name),
                figsize=(600, 400))

        self.soft_iteration = 0  # used as a counter for soft_avg
        self._persist_dat = None
        self._persist_xlabs = None
        self._persist_ylabs = None
Example #6
0
    def __init__(self, datadir=None, window_title='Data browser', default_parameter='amlitude'):
        super(DataViewer, self).__init__()

        self.default_parameter = default_parameter

        if datadir is None:
            datadir = qcodes.DataSet.default_io.base_location
        self.datadir = datadir

        qcodes.DataSet.default_io = qcodes.DiskIO(datadir)
        logging.info('DataViewer: data directory %s' % datadir)

        # setup GUI
        self.text = QtWidgets.QLabel()
        self.text.setText('Log files at %s' %
                          self.datadir)
        self.logtree = QtWidgets.QTreeView()  # QTreeWidget
        self.logtree.setSelectionBehavior(
            QtWidgets.QAbstractItemView.SelectRows)
        self._treemodel = QtGui.QStandardItemModel()
        self.logtree.setModel(self._treemodel)
        self.__debug = dict()
        self.qplot = QtPlot() # remote=False, interval=0)
        self.plotwindow = self.qplot

        vertLayout = QtWidgets.QVBoxLayout()
        vertLayout.addWidget(self.text)
        vertLayout.addWidget(self.logtree)
        vertLayout.addWidget(self.plotwindow)
        self.setLayout(vertLayout)

        self._treemodel.setHorizontalHeaderLabels(['Log', 'Comments'])
        self.setWindowTitle(window_title)
        self.logtree.header().resizeSection(0, 240)

        # disable edit
        self.logtree.setEditTriggers(
            QtWidgets.QAbstractItemView.NoEditTriggers)
        self.logtree.doubleClicked.connect(self.logCallback)

        # get logs from disk
        self.updateLogs()
def show_analysis_values_pyqt(a_obj, QtPlot_win):

    if QtPlot_win is None:
        QtPlot_win = QtPlot(window_title='Analysis viewer', figsize=(600, 400))
    for i, y in enumerate(a_obj.measured_values):
        if i + 1 > len(QtPlot_win.subplots):
            QtPlot_win.win.nextRow()
        QtPlot_win.add(x=a_obj.sweep_points,
                       y=y,
                       name=a_obj.ylabels[i],
                       subplot=i + 1,
                       symbol='o',
                       symbolSize=5,
                       xlabel=a_obj.xlabel,
                       ylabel=a_obj.ylabels[i])
    p0 = QtPlot_win.subplots[0]
    for j, p in enumerate(QtPlot_win.subplots):
        if j > 0:
            p.setXLink(p0)
    return QtPlot_win
    def render_wave(
        self,
        wave_name,
        time_units="s",
        reload_pulses: bool = True,
        render_distorted_wave: bool = True,
        QtPlot_win=None,
    ):
        """
        Renders a waveform
        """
        if reload_pulses:
            self.generate_standard_waveforms()

        x = np.arange(len(self._wave_dict[wave_name]))
        y = self._wave_dict[wave_name]

        if time_units == "lut_index":
            xlab = ("Lookuptable index", "i")
        elif time_units == "s":
            x = x / self.sampling_rate()
            xlab = ("Time", "s")

        if QtPlot_win is None:
            QtPlot_win = QtPlot(window_title=wave_name, figsize=(600, 400))

        if render_distorted_wave:
            if wave_name in self._wave_dict_dist.keys():
                x2 = np.arange(len(self._wave_dict_dist[wave_name]))
                if time_units == "s":
                    x2 = x2 / self.sampling_rate()

                y2 = self._wave_dict_dist[wave_name]
                QtPlot_win.add(
                    x=x2,
                    y=y2,
                    name=wave_name + " distorted",
                    symbol="o",
                    symbolSize=5,
                    xlabel=xlab[0],
                    xunit=xlab[1],
                    ylabel="Amplitude",
                    yunit="dac val.",
                )
            else:
                log.warning("Wave not in distorted wave dict")
        # Plotting the normal one second ensures it is on top.
        QtPlot_win.add(
            x=x,
            y=y,
            name=wave_name,
            symbol="o",
            symbolSize=5,
            xlabel=xlab[0],
            xunit=xlab[1],
            ylabel="Amplitude",
            yunit="V",
        )

        return QtPlot_win
    def set_sweep_with_calibration(self,
                                   repetition=False,
                                   plot_average=False,
                                   **kw):

        self.plot_average = plot_average

        for i in range(self.qubit_number):
            d = i if i == 1 else 2
            self.plot.append(
                QtPlot(window_title='raw plotting qubit_%d' % d, remote=False))
            #            self.plot.append(MatPlot(figsize = (8,5)))
            if plot_average:
                #                self.average_plot.append(QtPlot(window_title = 'average data qubit_%d'%d, remote = False))
                self.average_plot.append(MatPlot(figsize=(8, 5)))

        for seq in range(len(self.sequencer)):
            self.sequencer[seq].set_sweep()
            self.make_all_segment_list(seq)

        if repetition is True:
            count = kw.pop('count', 1)

            Count_Calibration = StandardParameter(name='Count',
                                                  set_cmd=self.function)
            Sweep_Count_Calibration = Count_Calibration[1:2:1]

            Count = StandardParameter(name='Count', set_cmd=self.delete)
            Sweep_Count = Count[1:count + 1:1]

            self.digitizer, self.dig = set_digitizer(self.digitizer,
                                                     len(self.X_sweep_array),
                                                     self.qubit_number,
                                                     self.seq_repetition)

            loop1 = Loop(sweep_values=Sweep_Count_Calibration).each(self.dig)

            Loop_Calibration = loop1.with_bg_task(
                bg_final_task=self.update_calibration, )
            calibrated_parameter = update_calibration
            if self.Loop is None:
                self.Loop = Loop(sweep_values=Sweep_Count).each(
                    calibrated_parameter, self.dig)
            else:
                raise TypeError('calibration not set')

        return True
Example #10
0
def show_element_pyqt(element,
                      QtPlot_win=None,
                      color_idx=None,
                      channels=['ch1', 'ch2', 'ch3', 'ch4']):
    if QtPlot_win is None:
        QtPlot_win = QtPlot(windowTitle='Seq_plot', figsize=(600, 400))
    # FIXME: Add a legend
    t_vals, outputs_dict = element.waveforms()
    if type(channels) == str:
        channels = [channels]

    t_vals = t_vals
    xlabel = 'Time'
    xunit = 's'
    yunit = 'V'
    for i, ch in enumerate(channels):
        ylabel = 'Output ch {}'.format(ch)

        if color_idx == None:
            color = color_cycle[i % len(color_cycle)]
        else:
            color = color_cycle[color_idx]
        if i + 1 > len(QtPlot_win.subplots):
            QtPlot_win.win.nextRow()
            QtPlot_win.add(x=t_vals[ch],
                           y=outputs_dict[ch],
                           name=ch,
                           color=color,
                           subplot=i + 1,
                           symbol='o',
                           symbolSize=5,
                           xlabel=xlabel,
                           xunit=xunit,
                           ylabel=ylabel,
                           yunit=yunit)
        else:
            QtPlot_win.add(x=t_vals[ch],
                           y=outputs_dict[ch],
                           name=ch,
                           color=color,
                           subplot=i + 1,
                           symbol='o',
                           symbolSize=5,
                           xlabel=xlabel,
                           xunit=xunit,
                           ylabel=ylabel,
                           yunit=yunit)
    # links all the x-axes
    p0 = QtPlot_win.subplots[0]
    for j, p in enumerate(QtPlot_win.subplots):
        if j > 0:
            p.setXLink(p0)
    return QtPlot_win
Example #11
0
def show_num(id, useQT=False, **kwargs):
    """
    Show  and return plot and data for id in current instrument.
    Args:
        id(number): id of instrument
        useQT: Use pyqtgraph as an alternative to Matplotlib
        **kwargs: Are passed to plot function

    Returns:
        plot, data : returns the plot and the dataset

    """
    if not getattr(CURRENT_EXPERIMENT, "init", True):
        raise RuntimeError("Experiment not initalized. "
                           "use qc.Init(mainfolder, samplename)")

    str_id = '{0:03d}'.format(id)

    t = qc.DataSet.location_provider.fmt.format(counter=str_id)
    data = qc.load_data(t)

    plots = []
    for value in data.arrays.keys():
        if "set" not in value:
            if useQT:
                plot = QtPlot(
                    getattr(data, value),
                    fig_x_position=CURRENT_EXPERIMENT['plot_x_position'],
                    **kwargs)
                title = "{} #{}".format(CURRENT_EXPERIMENT["sample_name"],
                                        str_id)
                plot.subplots[0].setTitle(title)
                plot.subplots[0].showGrid(True, True)
            else:
                plot = MatPlot(getattr(data, value), **kwargs)
                title = "{} #{}".format(CURRENT_EXPERIMENT["sample_name"],
                                        str_id)
                plot.subplots[0].set_title(title)
                plot.subplots[0].grid()
            plots.append(plot)
    return data, plots
Example #12
0
def show_num(id, useQT=False, do_plots=True):
    """
    Show  and return plot and data for id in current instrument.
    Args:
        id(number): id of instrument
        do_plots: Default False: if false no plots are produced.
    Returns:
        plot, data : returns the plot and the dataset
    """
    if not getattr(CURRENT_EXPERIMENT, "init", True):
        raise RuntimeError("Experiment not initalized. "
                           "use qc.Init(mainfolder, samplename)")

    str_id = '{0:03d}'.format(id)

    t = qc.DataSet.location_provider.fmt.format(counter=str_id)
    data = qc.load_data(t)

    if do_plots:
        plots = []
        for value in data.arrays.keys():
            if "set" not in value:
                if useQT:
                    plot = QtPlot(
                        getattr(data, value),
                        fig_x_position=CURRENT_EXPERIMENT['plot_x_position'])
                    title = "{} #{}".format(CURRENT_EXPERIMENT["sample_name"],
                                            str_id)
                    plot.subplots[0].setTitle(title)
                    plot.subplots[0].showGrid(True, True)
                else:
                    plot = MatPlot(getattr(data, value))
                    title = "{} #{}".format(CURRENT_EXPERIMENT["sample_name"],
                                            str_id)
                    plot.subplots[0].set_title(title)
                    plot.subplots[0].grid()
                plots.append(plot)
    else:
        plots = None
    return data, plots
Example #13
0
    def render_wave(self, wave_name, show=True, time_units='s',
                    reload_pulses: bool=True, render_distorted_wave: bool=True,
                    QtPlot_win=None):
        """
        Renders a waveform
        """
        if reload_pulses:
            self.generate_standard_waveforms()

        x = np.arange(len(self._wave_dict[wave_name]))
        y = self._wave_dict[wave_name]

        if time_units == 'lut_index':
            xlab = ('Lookuptable index', 'i')
        elif time_units == 's':
            x = x/self.sampling_rate()
            xlab = ('Time', 's')

        if QtPlot_win is None:
            QtPlot_win = QtPlot(window_title=wave_name,
                                figsize=(600, 400))

        if render_distorted_wave:
            if wave_name in self._wave_dict_dist.keys():
                x2 = np.arange(len(self._wave_dict_dist[wave_name]))
                if time_units == 's':
                    x2 = x2/self.sampling_rate()

                y2 = self._wave_dict_dist[wave_name]
                QtPlot_win.add(
                    x=x2, y=y2, name=wave_name+' distorted',
                    symbol='o', symbolSize=5,
                    xlabel=xlab[0], xunit=xlab[1], ylabel='Amplitude',
                    yunit='dac val.')
            else:
                logging.warning('Wave not in distorted wave dict')
        # Plotting the normal one second ensures it is on top.
        QtPlot_win.add(
            x=x, y=y, name=wave_name,
            symbol='o', symbolSize=5,
            xlabel=xlab[0], xunit=xlab[1], ylabel='Amplitude', yunit='V')

        return QtPlot_win
Example #14
0
    def setupMeasurementWindows(station=None,
                                create_parameter_widget=True,
                                ilist=None,
                                qtplot_remote=True):
        """
        Create liveplot window and parameter widget (optional)

        Args:
            station (QCoDeS station): station with instruments
            create_parameter_widget (bool): if True create ParameterWidget
            ilist (None or list): list of instruments to add to ParameterWidget
            qtplot_remote (bool): If True, then use remote plotting
        Returns:
            dict: created gui objects
        """
        windows = {}

        ms = monitorSizes()
        vv = ms[-1]

        if create_parameter_widget and any([station, ilist]):
            if ilist is None:
                ilist = [station.gates]
            w = qtt.createParameterWidget(ilist)
            w.setGeometry(vv[0] + vv[2] - 400 - 300, vv[1], 300, 600)
            windows['parameterviewer'] = w

        plotQ = QtPlot(window_title='Live plot',
                       interval=.5,
                       remote=qtplot_remote)
        plotQ.setGeometry(vv[0] + vv[2] - 600, vv[1] + vv[3] - 400, 600, 400)
        plotQ.update()

        qtt.gui.live_plotting.liveplotwindow = plotQ

        windows['plotwindow'] = plotQ

        app = QtWidgets.QApplication.instance()
        app.processEvents()

        return windows
Example #15
0
step_end = 110
baseline = 0
output_dict = kf.get_all_sampled_vector(convolution,
                                        step_width_ns,
                                        points_per_ns,
                                        step_params={
                                            'baseline': baseline,
                                            'step_start': step_start,
                                            'step_end': step_end
                                        })

try:
    vw.clear()
except Exception:
    from qcodes.plots.pyqtgraph import QtPlot
    vw = QtPlot(windowTitle='Seq_plot', figsize=(600, 400))


def contains_nan(array):
    return np.isnan(array).any()


def load_exp_model():
    """
    Prepares a triple exponential model and its parameters
    Output:
            model
            parameters
    """
    triple_pole_mod = fit_mods.TripleExpDecayModel
    triple_pole_mod.set_param_hint('amp1', value=0.05, vary=True)
Example #16
0
 def test_return_handle(self):
     plotQ = QtPlot(remote=False)
     return_handle = plotQ.add([1, 2, 3])
     self.assertIs(return_handle, plotQ.subplots[0].items[0])
Example #17
0
class DataViewer(QtWidgets.QWidget):

    ''' Simple viewer for Qcodes data

    Arugments
    ---------

        datadir (string or None): directory to scan for experiments
        default_parameter (string): name of default parameter to plot
    '''

    def __init__(self, datadir=None, window_title='Data browser', default_parameter='amlitude'):
        super(DataViewer, self).__init__()

        self.default_parameter = default_parameter

        if datadir is None:
            datadir = qcodes.DataSet.default_io.base_location
        self.datadir = datadir

        qcodes.DataSet.default_io = qcodes.DiskIO(datadir)
        logging.info('DataViewer: data directory %s' % datadir)

        # setup GUI
        self.text = QtWidgets.QLabel()
        self.text.setText('Log files at %s' %
                          self.datadir)
        self.logtree = QtWidgets.QTreeView()  # QTreeWidget
        self.logtree.setSelectionBehavior(
            QtWidgets.QAbstractItemView.SelectRows)
        self._treemodel = QtGui.QStandardItemModel()
        self.logtree.setModel(self._treemodel)
        self.__debug = dict()
        self.qplot = QtPlot() # remote=False, interval=0)
        self.plotwindow = self.qplot

        vertLayout = QtWidgets.QVBoxLayout()
        vertLayout.addWidget(self.text)
        vertLayout.addWidget(self.logtree)
        vertLayout.addWidget(self.plotwindow)
        self.setLayout(vertLayout)

        self._treemodel.setHorizontalHeaderLabels(['Log', 'Comments'])
        self.setWindowTitle(window_title)
        self.logtree.header().resizeSection(0, 240)

        # disable edit
        self.logtree.setEditTriggers(
            QtWidgets.QAbstractItemView.NoEditTriggers)
        self.logtree.doubleClicked.connect(self.logCallback)

        # get logs from disk
        self.updateLogs()

    def updateLogs(self):
        ''' Update the list of measurements '''
        model = self._treemodel
        dd = findfilesR(self.datadir, '.*dat')
        print('found %d files' % (len(dd)))
        # print(dd)

        logs = dict()
        for i, d in enumerate(dd):
            try:
                datetag, logtag = d.split(os.sep)[-3:-1]
                if not datetag in logs:
                    logs[datetag] = dict()
                logs[datetag][logtag] = d
            except Exception:
                pass
        self.logs = logs

        for i, datetag in enumerate(sorted(logs.keys())[::-1]):
            parent1 = QtGui.QStandardItem(datetag)
            for j, logtag in enumerate(sorted(logs[datetag])):
                child1 = QtGui.QStandardItem(logtag)
                child2 = QtGui.QStandardItem('info about plot')
                child3 = QtGui.QStandardItem(os.path.join(datetag, logtag))
                parent1.appendRow([child1, child2, child3])
            model.appendRow(parent1)
            # span container columns
            self.logtree.setFirstColumnSpanned(
                i, self.logtree.rootIndex(), True)

    def plot_parameter(self, data):
        ''' Return parameter to be plotted '''
        arraynames = data.arrays.keys()
        if self.default_parameter in arraynames:
            return self.default_parameter
        vv = [v for v in arraynames if v.endswith('default_parameter')]
        if (len(vv) > 0):
            return vv[0]
        vv = [v for v in arraynames if v.endswith('amplitude')]
        if (len(vv) > 0):
            return vv[0]

        if 'amplitude' in data.arrays.keys():
            return 'amplitude'

        try:
            key = next(iter(data.arrays.keys()))
            return key
        except Exception:
            return None

    def logCallback(self, index):
        ''' Function called when a log entry is selected '''
        logging.info('logCallback!')
        logging.debug('logCallback: index %s' % str(index))
        self.__debug['last'] = index
        pp = index.parent()
        row = index.row()

        tag = pp.child(row, 2).data()

        # load data
        if tag is not None:
            print('logCallback! tag %s' % tag)
            try:
                logging.debug('load tag %s' % tag)
                data = qcodes.load_data(tag)

                self.qplot.clear()

                infotxt = 'arrays: ' + ', '.join(list(data.arrays.keys()))
                q = pp.child(row, 1).model()
                q.setData(pp.child(row, 1), infotxt)

                param_name = self.plot_parameter(data)

                if param_name is not None:
                    logging.info(
                        'using parameter %s for plotting' % param_name)
                    self.qplot.add(getattr(data, param_name))
                else:
                    logging.info('could not find parameter for DataSet')
            except Exception as e:
                print('logCallback! error ...')
                print(e)
                logging.warning(e)
        pass
Example #18
0
    def __init__(self,
                 data_directory=None,
                 window_title='Data browser',
                 default_parameter='amplitude',
                 extensions=None,
                 verbose=1):
        """ Contstructs a simple viewer for Qcodes data.

        Args:
            data_directory (string or None): The directory to scan for experiments.
            default_parameter (string): A name of default parameter to plot.
            extensions (list): A list with the data file extensions to filter.
            verbose (int): The logging verbosity level.
        """
        super(DataViewer, self).__init__()
        if extensions is None:
            extensions = ['dat', 'hdf5']

        self.verbose = verbose
        self.default_parameter = default_parameter
        self.data_directories = [None] * 2
        self.directory_index = 0
        if data_directory is None:
            data_directory = qcodes.DataSet.default_io.base_location

        self.extensions = extensions

        # setup GUI
        self.dataset = None
        self.text = QtWidgets.QLabel()

        # logtree
        self.logtree = QtWidgets.QTreeView()
        self.logtree.setSelectionBehavior(
            QtWidgets.QAbstractItemView.SelectRows)
        self._treemodel = QtGui.QStandardItemModel()
        self.logtree.setModel(self._treemodel)

        # metatabs
        self.meta_tabs = QtWidgets.QTabWidget()
        self.meta_tabs.addTab(QtWidgets.QWidget(), 'metadata')

        self.__debug = dict()
        if isinstance(QtPlot, QWidget):
            self.qplot = QtPlot()
        else:
            self.qplot = QtPlot(remote=False)
        if isinstance(self.qplot, QWidget):
            self.plotwindow = self.qplot
        else:
            self.plotwindow = self.qplot.win

        topLayout = QtWidgets.QHBoxLayout()

        self.filterbutton = QtWidgets.QPushButton()
        self.filterbutton.setText('Filter data')
        self.filtertext = QtWidgets.QLineEdit()
        self.outCombo = QtWidgets.QComboBox()

        topLayout.addWidget(self.text)
        topLayout.addWidget(self.filterbutton)
        topLayout.addWidget(self.filtertext)

        treesLayout = QtWidgets.QHBoxLayout()
        treesLayout.addWidget(self.logtree)
        treesLayout.addWidget(self.meta_tabs)

        vertLayout = QtWidgets.QVBoxLayout()

        vertLayout.addItem(topLayout)
        vertLayout.addItem(treesLayout)
        vertLayout.addWidget(self.plotwindow)

        self.pptbutton = QtWidgets.QPushButton()
        self.pptbutton.setText('Send data to powerpoint')
        self.clipboardbutton = QtWidgets.QPushButton()
        self.clipboardbutton.setText('Copy image to clipboard')

        bLayout = QtWidgets.QHBoxLayout()
        bLayout.addWidget(self.outCombo)
        bLayout.addWidget(self.pptbutton)
        bLayout.addWidget(self.clipboardbutton)

        vertLayout.addItem(bLayout)
        widget = QtWidgets.QWidget()
        widget.setLayout(vertLayout)
        self.setCentralWidget(widget)

        self.setWindowTitle(window_title)
        self.logtree.header().resizeSection(0, 280)

        # disable edit
        self.logtree.setEditTriggers(
            QtWidgets.QAbstractItemView.NoEditTriggers)

        self.set_data_directory(data_directory)
        self.logtree.doubleClicked.connect(self.log_callback)
        self.outCombo.currentIndexChanged.connect(self.combobox_callback)
        self.filterbutton.clicked.connect(
            lambda: self.update_logs(filter_str=self.filtertext.text()))
        self.pptbutton.clicked.connect(self.ppt_callback)
        self.clipboardbutton.clicked.connect(self.clipboard_callback)

        menuBar = self.menuBar()

        menuDict = {
            '&Data': {
                '&Reload Data': self.update_logs,
                '&Preload all Info': self.load_info,
                '&Quit': self.close
            },
            '&Folder': {
                '&Select Dir1': lambda: self.select_directory(index=0),
                'Select &Dir2': lambda: self.select_directory(index=1),
                '&Toggle Dirs': self.toggle_data_directory
            },
            '&Help': {
                '&Info': self.show_help
            }
        }
        for (k, menu) in menuDict.items():
            mb = menuBar.addMenu(k)
            for (kk, action) in menu.items():
                act = QtWidgets.QAction(kk, self)
                mb.addAction(act)
                act.triggered.connect(action)

        if self.verbose >= 2:
            print('created gui...')

        # get logs from disk
        self.update_logs()
        self.datatag = None

        self.logtree.setColumnHidden(2, True)
        self.logtree.setColumnHidden(3, True)
        self.show()
Example #19
0
def show_num(ids, samplefolder=None,useQT=False,avg_sub='',do_plots=True,savepng=True,
            fig_size=[6,4],clim=None,dataname=None,xlim=None,ylim=None,**kwargs):
    """
    Show and return plot and data.
    Args:
        ids (number, list): id or list of ids of dataset(s)
        samplefolder (str): Sample folder if loading data from different sample than the initialized. 
        useQT (boolean): If true plots with QTplot instead of matplotlib
        avg_sub (str: 'col' or 'row'): Subtracts average from either each collumn ('col') or each row ('row')
        do_plots: (boolean): if false no plots are produced.
        dataname (str): If given only plots dataset with that name
        savepng (boolean): If true saves matplotlib figure as png
        fig_size [6,4]: Figure size in inches
        clim [cmin,cmax]: Set min and max of colorbar to cmin and cmax respectrively
        xlim [xmin,xmax]: Set limits on x axis
        ylim [ymin,ymax]: set limits on y axis
        **kwargs: Are passed to plot function

    Returns:
        data, plots : returns the plots and the datasets

    """
    
    if not isinstance(ids, collections.Iterable):
        ids = (ids,)

    data_list = []
    keys_list = []

    # Define samplefolder
    if samplefolder==None:
        check_experiment_is_initialized()
        samplefolder = qc.DataSet.location_provider.fmt.format(counter='')

    # Load all datasets into list
    for id in ids:
        path = samplefolder + '{0:03d}'.format(id)
        data = qc.load_data(path)
        data_list.append(data)

        # find datanames to be plotted
        if do_plots:
            if useQT and len(ids) is not 1:
                raise ValueError('qcodes.QtPlot does not support multigraph plotting. Set useQT=False to plot multiple datasets.')
            if dataname is not None:
                if dataname not in [key for key in data.arrays.keys() if "_set" not in key]:
                    raise RuntimeError('Dataname not in dataset. Input dataname was: \'{}\'', \
                        'while dataname(s) in dataset are: {}.'.format(dataname,', '.join(data.arrays.keys())))
                keys = [dataname]
            else:
                keys = [key for key in data.arrays.keys() if "_set" not in key]
            keys_list.append(keys)


    if do_plots:
        unique_keys = list(set([item for sublist in keys_list for item in sublist]))
        plots = []
        num = ''
        l = len(unique_keys)

        for j, key in enumerate(unique_keys):
            array_list = []
            xlims = [[],[]]
            ylims = [[],[]]
            clims = [[],[]]
            # Find datasets containing data with dataname == key
            for data, keys in zip(data_list,keys_list):
                if key in keys:
                    arrays = getattr(data, key)
                    if avg_sub == 'row':
                        for i in range(np.shape(arrays)[0]):
                            arrays[i,:] -= np.nanmean(arrays[i,:])
                    if avg_sub == 'col':
                        for i in range(np.shape(arrays)[1]):
                            arrays[:,i] -= np.nanmean(arrays[:,i])
                    array_list.append(arrays)

                    # Find axis limits for dataset
                    if len(arrays.set_arrays)==2:
                        xlims[0].append(np.nanmin(arrays.set_arrays[1]))
                        xlims[1].append(np.nanmax(arrays.set_arrays[1]))
                        ylims[0].append(np.nanmin(arrays.set_arrays[0]))
                        ylims[1].append(np.nanmax(arrays.set_arrays[0]))
                        clims[0].append(np.nanmin(arrays.ndarray))
                        clims[1].append(np.nanmax(arrays.ndarray))
                    else:
                        xlims[0].append(np.nanmin(arrays.set_arrays[0]))
                        xlims[1].append(np.nanmax(arrays.set_arrays[0]))
                        ylims[0].append(np.nanmin(arrays.ndarray))
                        ylims[1].append(np.nanmax(arrays.ndarray))

            if useQT:
                plot = QtPlot(array_list[0],
                    fig_x_position=CURRENT_EXPERIMENT['plot_x_position'],
                    **kwargs)
                title = "{} #{}".format(CURRENT_EXPERIMENT["sample_name"],
                                        '{}'.format(ids[0]))
                plot.subplots[0].setTitle(title)
                plot.subplots[0].showGrid(True, True)
                if savepng:
                    print('Save plot only working for matplotlib figure.', \
                        'Set useQT=False to save png.')
            else:
                plot = MatPlot(array_list, **kwargs)
                plot.rescale_axis()
                plot.fig.tight_layout(pad=3)
                plot.fig.set_size_inches(fig_size)
                # Set axis limits
                if xlim is None:
                    plot[0].axes.set_xlim([np.nanmin(xlims[0]),np.nanmax(xlims[1])])
                else:
                    plot[0].axes.set_xlim(xlim)
                if ylim is None:
                    plot[0].axes.set_ylim([np.nanmin(ylims[0]),np.nanmax(ylims[1])])
                else:
                    plot[0].axes.set_ylim(ylim)
                if len(arrays.set_arrays)==2:
                    for i in range(len(array_list)):
                        if clim is None:
                            plot[0].get_children()[i].set_clim(np.nanmin(clims[0]),np.nanmax(clims[1]))
                        else:
                            plot[0].get_children()[i].set_clim(clim)

                # Set figure titles
                plot.fig.suptitle(samplefolder)
                if len(ids)<6:
                    plot.subplots[0].set_title(', '.join(map(str,ids)))
                else:
                    plot.subplots[0].set_title(' - '.join(map(str,[ids[0],ids[-1]])))
                plt.draw()

                # Save figure
                if savepng:
                    if len(ids) == 1:
                        title_png = samplefolder+CURRENT_EXPERIMENT['png_subfolder']+sep+'{}'.format(ids[0])
                    else:
                        title_png = samplefolder+CURRENT_EXPERIMENT['png_subfolder']+sep+'{}-{}'.format(ids[0],ids[-1])
                    if l>1:
                        num = '{}'.format(j+1)
                    plt.savefig(title_png+'_{}_{}.png'.format(num,avg_sub),dpi=500)
            plots.append(plot)
    else:
        plots = None
    return data_list, plots
Example #20
0
    def addPPT_dataset(dataset,
                       title=None,
                       notes=None,
                       show=False,
                       verbose=1,
                       paramname='measured',
                       printformat='fancy',
                       customfig=None,
                       extranotes=None,
                       **kwargs):
        """ Add slide based on dataset to current active Powerpoint presentation

        Args:
            dataset (DataSet): data and metadata from DataSet added to slide
            customfig (QtPlot): custom QtPlot object to be added to
                                slide (for dataviewer)
            notes (string): notes added to slide
            show (boolean): shows the powerpoint application
            verbose (int): print additional information
            paramname (None or str): passed to dataset.default_parameter_array
            printformat (string): 'fancy' for nice formatting or 'dict'
                                  for easy copy to python
        Returns:
            ppt: PowerPoint presentation
            slide: PowerPoint slide

        Example
        -------
        >>> notes = 'some additional information'
        >>> addPPT_dataset(dataset,notes)
        """
        if len(dataset.arrays) < 2:
            raise IndexError('The dataset contains less than two data arrays')

        if customfig is None:

            if isinstance(paramname, str):
                if title is None:
                    parameter_name = dataset.default_parameter_name(
                        paramname=paramname)
                    title = 'Parameter: %s' % parameter_name
                temp_fig = QtPlot(
                    dataset.default_parameter_array(paramname=paramname),
                    show_window=False)
            else:
                if title is None:
                    title = 'Parameter: %s' % (str(paramname), )
                for idx, parameter_name in enumerate(paramname):
                    if idx == 0:
                        temp_fig = QtPlot(dataset.default_parameter_array(
                            paramname=parameter_name),
                                          show_window=False)
                    else:
                        temp_fig.add(
                            dataset.default_parameter_array(
                                paramname=parameter_name))

        else:
            temp_fig = customfig

        text = 'Dataset location: %s' % dataset.location
        if notes is None:
            try:
                metastring = reshape_metadata(dataset, printformat=printformat)
            except Exception as ex:
                metastring = 'Could not read metadata: %s' % str(ex)
            notes = 'Dataset %s metadata:\n\n%s' % (dataset.location,
                                                    metastring)
            scanjob = dataset.metadata.get('scanjob', None)
            if scanjob is not None:
                s = pprint.pformat(scanjob)
                notes = 'scanjob: ' + str(s) + '\n\n' + notes

            gatevalues = dataset.metadata.get('allgatevalues', None)
            if gatevalues is not None:
                notes = 'gates: ' + str(gatevalues) + '\n\n' + notes

        ppt, slide = addPPTslide(title=title,
                                 fig=temp_fig,
                                 subtitle=text,
                                 notes=notes,
                                 show=show,
                                 verbose=verbose,
                                 extranotes=extranotes,
                                 **kwargs)
        return ppt, slide
Example #21
0
formatter = HDF5FormatMetadata()
try_location = '2017-09-04/17-23-05Finding_ResonanceRabi_Sweep'

DS = load_data(location = try_location, io = NewIO,)

DS_P = convert_to_probability(DS, 0.025)

DS_new = new_data(location = try_location, io = NewIO,)


x_data = np.linspace(1,10,10)
y_data = np.linspace(11,20,10)
#z_data = np.linspace(101,201,101)

Mplot = MatPlot(x_data,y_data)
Qplot = QtPlot(x_data,y_data)

Mplot = MatPlot()

config = {
        'x': np.linspace(1,20,20),
        'y': np.linspace(11,30,20)
        }
#Mplot.traces[0]['config'] = config

data = np.array([1,2,3])
data1 = np.array([[1,2,33,5],[5,232,7,3],[1,2,3,4]])

data_array1 = DataArray(preset_data = data, name = 'digitizer', is_setpoint = True)

data_array2 = DataArray(preset_data = data, name = 'digitizer2')
Example #22
0
def _plot_setup(data,
                inst_meas,
                useQT=True,
                startranges=None,
                auto_color_scale=None,
                cutoff_percentile=None):
    title = "{} #{:03d}".format(CURRENT_EXPERIMENT["sample_name"],
                                data.location_provider.counter)
    rasterized_note = " rasterized plot"
    num_subplots = 0
    counter_two = 0
    for j, i in enumerate(inst_meas):
        if getattr(i, "names", False):
            num_subplots += len(i.names)
        else:
            num_subplots += 1
    if useQT:
        plot = QtPlot(fig_x_position=CURRENT_EXPERIMENT['plot_x_position'])
    else:
        plot = MatPlot(subplots=(1, num_subplots))

    def _create_plot(plot, i, name, data, counter_two, j, k):
        """
        Args:
            plot: The plot object, either QtPlot() or MatPlot()
            i: The parameter to measure
            name: -
            data: The DataSet of the current measurement
            counter_two: The sub-measurement counter. Each measurement has a
                number and each sub-measurement has a counter.
            j: The current sub-measurement
            k: -
        """

        color = 'C' + str(counter_two)
        if issubclass(
                i.__class__,
                MultiChannelInstrumentParameter) or i._instrument is None:
            inst_meas_name = name
        else:
            parent_instr_name = (i._instrument.name +
                                 '_') if i._instrument else ''
            inst_meas_name = "{}{}".format(parent_instr_name, name)
        try:
            inst_meas_data = getattr(data, inst_meas_name)
        except AttributeError:
            inst_meas_name = "{}{}_0_0".format(parent_instr_name, name)
            inst_meas_data = getattr(data, inst_meas_name)

        inst_meta_data = __get_plot_type(inst_meas_data, plot)
        if useQT:
            plot.add(inst_meas_data, subplot=j + k + 1)
            plot.subplots[j + k].showGrid(True, True)
            if j == 0:
                plot.subplots[0].setTitle(title)
            else:
                plot.subplots[j + k].setTitle("")

            plot.fixUnitScaling(startranges)
            QtPlot.qc_helpers.foreground_qt_window(plot.win)

        else:
            if 'z' in inst_meta_data:
                xlen, ylen = inst_meta_data['z'].shape
                rasterized = xlen * ylen > 5000
                po = plot.add(inst_meas_data,
                              subplot=j + k + 1,
                              rasterized=rasterized)

                auto_color_scale_from_config(po.colorbar, auto_color_scale,
                                             inst_meta_data['z'],
                                             cutoff_percentile)
            else:
                rasterized = False
                plot.add(inst_meas_data, subplot=j + k + 1, color=color)
                plot.subplots[j + k].grid()
            if j == 0:
                if rasterized:
                    fulltitle = title + rasterized_note
                else:
                    fulltitle = title
                plot.subplots[0].set_title(fulltitle)
            else:
                if rasterized:
                    fulltitle = rasterized_note
                else:
                    fulltitle = ""
                plot.subplots[j + k].set_title(fulltitle)

    subplot_index = 0
    for measurement in inst_meas:
        if getattr(measurement, "names", False):
            # deal with multidimensional parameter
            for name in measurement.names:
                _create_plot(plot, measurement, name, data, counter_two,
                             subplot_index, 0)
                subplot_index += 1
                counter_two += 1
        else:
            # simple_parameters
            _create_plot(plot, measurement, measurement.name, data,
                         counter_two, subplot_index, 0)
            subplot_index += 1
            counter_two += 1
    return plot, num_subplots
Example #23
0
class DataViewer(QtWidgets.QWidget):
    """ Simple data browser for Qcodes datasets """
    def __init__(self,
                 datadir=None,
                 window_title='Data browser',
                 default_parameter='amplitude',
                 extensions=['dat', 'hdf5'],
                 verbose=1):
        """ Simple viewer for Qcodes data

        Args:

            datadir (string or None): directory to scan for experiments
            default_parameter (string): name of default parameter to plot
        """
        super(DataViewer, self).__init__()
        self.verbose = verbose
        self.default_parameter = default_parameter
        if datadir is None:
            datadir = qcodes.DataSet.default_io.base_location

        self.extensions = extensions

        # setup GUI
        self.dataset = None
        self.text = QtWidgets.QLabel()

        # logtree
        self.logtree = QtWidgets.QTreeView()
        self.logtree.setSelectionBehavior(
            QtWidgets.QAbstractItemView.SelectRows)
        self._treemodel = QtGui.QStandardItemModel()
        self.logtree.setModel(self._treemodel)

        # metatabs
        self.meta_tabs = QtWidgets.QTabWidget()
        self.meta_tabs.addTab(QtWidgets.QWidget(), 'metadata')

        self.__debug = dict()
        if isinstance(QtPlot, QWidget):
            self.qplot = QtPlot()  # remote=False, interval=0)
        else:
            self.qplot = QtPlot(remote=False)  # remote=False, interval=0)
        if isinstance(self.qplot, QWidget):
            self.plotwindow = self.qplot
        else:
            self.plotwindow = self.qplot.win

        topLayout = QtWidgets.QHBoxLayout()
        self.select_dir = QtWidgets.QPushButton()
        self.select_dir.setText('Select directory')

        self.reloadbutton = QtWidgets.QPushButton()
        self.reloadbutton.setText('Reload data')

        self.loadinfobutton = QtWidgets.QPushButton()
        self.loadinfobutton.setText('Preload info')

        self.outCombo = QtWidgets.QComboBox()

        topLayout.addWidget(self.text)
        topLayout.addWidget(self.select_dir)
        topLayout.addWidget(self.reloadbutton)
        topLayout.addWidget(self.loadinfobutton)

        treesLayout = QtWidgets.QHBoxLayout()
        treesLayout.addWidget(self.logtree)
        treesLayout.addWidget(self.meta_tabs)

        vertLayout = QtWidgets.QVBoxLayout()

        vertLayout.addItem(topLayout)
        vertLayout.addItem(treesLayout)
        vertLayout.addWidget(self.plotwindow)

        self.pptbutton = QtWidgets.QPushButton()
        self.pptbutton.setText('Send data to powerpoint')
        self.clipboardbutton = QtWidgets.QPushButton()
        self.clipboardbutton.setText('Copy image to clipboard')

        bLayout = QtWidgets.QHBoxLayout()
        bLayout.addWidget(self.outCombo)
        bLayout.addWidget(self.pptbutton)
        bLayout.addWidget(self.clipboardbutton)

        vertLayout.addItem(bLayout)

        self.setLayout(vertLayout)

        self.setWindowTitle(window_title)
        self.logtree.header().resizeSection(0, 280)

        # disable edit
        self.logtree.setEditTriggers(
            QtWidgets.QAbstractItemView.NoEditTriggers)

        self.setDatadir(datadir)

        self.logtree.doubleClicked.connect(self.logCallback)
        self.outCombo.currentIndexChanged.connect(self.comboCallback)
        self.select_dir.clicked.connect(self.selectDirectory)
        self.reloadbutton.clicked.connect(self.updateLogs)
        self.loadinfobutton.clicked.connect(self.loadInfo)
        self.pptbutton.clicked.connect(self.pptCallback)
        self.clipboardbutton.clicked.connect(self.clipboardCallback)
        if self.verbose >= 2:
            print('created gui...')
        # get logs from disk
        self.updateLogs()
        self.datatag = None

        self.logtree.setColumnHidden(2, True)
        self.logtree.setColumnHidden(3, True)

        self.show()

    def setDatadir(self, datadir):
        self.datadir = datadir
        self.io = qcodes.DiskIO(datadir)
        logging.info('DataViewer: data directory %s' % datadir)
        self.text.setText('Log files at %s' % self.datadir)

    def pptCallback(self):
        if self.dataset is None:
            print('no data selected')
            return
        qtt.utilities.tools.addPPT_dataset(self.dataset, customfig=self.qplot)

    def clipboardCallback(self):
        self.qplot.copyToClipboard()

    def getArrayStr(self, metadata):
        params = []
        try:
            if 'loop' in metadata.keys():
                sv = metadata['loop']['sweep_values']
                params.append(
                    '%s [%.2f to %.2f %s]' %
                    (sv['parameter']['label'], sv['values'][0]['first'],
                     sv['values'][0]['last'], sv['parameter']['unit']))

                for act in metadata['loop']['actions']:
                    if 'sweep_values' in act.keys():
                        sv = act['sweep_values']
                        params.append(
                            '%s [%.2f - %.2f %s]' %
                            (sv['parameter']['label'],
                             sv['values'][0]['first'], sv['values'][0]['last'],
                             sv['parameter']['unit']))
                infotxt = ' ,'.join(params)
                infotxt = infotxt + '  |  ' + ', '.join(
                    [('%s' % (v['label']))
                     for (k, v) in metadata['arrays'].items()
                     if not v['is_setpoint']])

            elif 'scanjob' in metadata.keys():
                sd = metadata['scanjob']['sweepdata']
                params.append('%s [%.2f to %.2f]' %
                              (sd['param'], sd['start'], sd['end']))
                if 'stepdata' in metadata['scanjob']:
                    sd = metadata['scanjob']['stepdata']
                    params.append('%s [%.2f to %.2f]' %
                                  (sd['param'], sd['start'], sd['end']))
                infotxt = ' ,'.join(params)
                infotxt = infotxt + '  |  ' + \
                    ', '.join(metadata['scanjob']['minstrument'])
            else:
                infotxt = 'info about plot'

        except BaseException:
            infotxt = 'info about plot'

        return infotxt

    def loadInfo(self):
        try:
            for row in range(self._treemodel.rowCount()):
                index = self._treemodel.index(row, 0)
                i = 0
                while (index.child(i, 0).data() is not None):
                    filename = index.child(i, 3).data()
                    loc = '\\'.join(filename.split('\\')[:-1])
                    tempdata = qcodes.DataSet(loc)
                    tempdata.read_metadata()
                    infotxt = self.getArrayStr(tempdata.metadata)
                    self._treemodel.setData(index.child(i, 1), infotxt)
                    if 'comment' in tempdata.metadata.keys():
                        self._treemodel.setData(index.child(i, 4),
                                                tempdata.metadata['comment'])
                    i = i + 1
        except Exception as e:
            print(e)

    def selectDirectory(self):
        from qtpy.QtWidgets import QFileDialog
        d = QtWidgets.QFileDialog(caption='Select data directory')
        d.setFileMode(QFileDialog.Directory)
        if d.exec():
            datadir = d.selectedFiles()[0]
            self.setDatadir(datadir)
            print('update logs')
            self.updateLogs()

    @staticmethod
    def find_datafiles(datadir,
                       extensions=['dat', 'hdf5'],
                       show_progress=True):
        """ Find all datasets in a directory with a given extension """
        dd = []
        for e in extensions:
            dd += qtt.pgeometry.findfilesR(datadir,
                                           '.*%s' % e,
                                           show_progress=show_progress)

        datafiles = sorted(dd)
        #datafiles = [os.path.join(datadir, d) for d in datafiles]
        return datafiles

    def updateLogs(self):
        ''' Update the list of measurements '''
        model = self._treemodel

        self.datafiles = self.find_datafiles(self.datadir, self.extensions)
        dd = self.datafiles

        if self.verbose:
            print('DataViewer: found %d files' % (len(dd)))

        model.clear()
        model.setHorizontalHeaderLabels(
            ['Log', 'Arrays', 'location', 'filename', 'Comments'])

        logs = dict()
        for i, d in enumerate(dd):
            try:
                datetag, logtag = d.split(os.sep)[-3:-1]
                if datetag not in logs:
                    logs[datetag] = dict()
                logs[datetag][logtag] = d
            except Exception:
                pass
        self.logs = logs

        if self.verbose >= 2:
            print('DataViewer: create gui elements')
        for i, datetag in enumerate(sorted(logs.keys())[::-1]):
            if self.verbose >= 2:
                print('DataViewer: datetag %s ' % datetag)

            parent1 = QtGui.QStandardItem(datetag)
            for j, logtag in enumerate(sorted(logs[datetag])):
                filename = logs[datetag][logtag]
                child1 = QtGui.QStandardItem(logtag)
                child2 = QtGui.QStandardItem('info about plot')
                if self.verbose >= 2:
                    print('datetag %s, logtag %s' % (datetag, logtag))
                child3 = QtGui.QStandardItem(os.path.join(datetag, logtag))
                child4 = QtGui.QStandardItem(filename)
                parent1.appendRow([child1, child2, child3, child4])
            model.appendRow(parent1)
            # span container columns
            #            self.logtree.setFirstColumnSpanned(
            #                i, self.logtree.rootIndex(), True)
            self.logtree.setColumnWidth(0, 240)
            self.logtree.setColumnHidden(2, True)
            self.logtree.setColumnHidden(3, True)

        if self.verbose >= 2:
            print('DataViewer: updateLogs done')

    def _create_meta_tree(self, meta_dict):
        metatree = QtWidgets.QTreeView()
        _metamodel = QtGui.QStandardItemModel()
        metatree.setModel(_metamodel)
        metatree.setEditTriggers(QtWidgets.QAbstractItemView.NoEditTriggers)

        _metamodel.setHorizontalHeaderLabels(['metadata', 'value'])

        try:
            self.fill_item(_metamodel, meta_dict)
            return metatree

        except Exception as ex:
            print(ex)

    def updateMetaTabs(self):
        ''' Update metadata tree '''
        meta = self.dataset.metadata

        self.meta_tabs.clear()
        if 'gates' in meta.keys():
            self.meta_tabs.addTab(self._create_meta_tree(meta['gates']),
                                  'gates')
        elif meta.get('station', dict()).get('instruments',
                                             dict()).get('gates',
                                                         None) is not None:
            self.meta_tabs.addTab(
                self._create_meta_tree(
                    meta['station']['instruments']['gates']), 'gates')
        if meta.get('station', dict()).get('instruments', None) is not None:
            if 'instruments' in meta['station'].keys():
                self.meta_tabs.addTab(
                    self._create_meta_tree(meta['station']['instruments']),
                    'instruments')

        self.meta_tabs.addTab(self._create_meta_tree(meta), 'metadata')

    def fill_item(self, item, value):
        ''' recursive population of tree structure with a dict '''
        def new_item(parent, text, val=None):
            child = QtGui.QStandardItem(text)
            self.fill_item(child, val)
            parent.appendRow(child)

        if value is None:
            return
        elif isinstance(value, dict):
            for key, val in sorted(value.items()):
                if type(val) in [str, float, int]:
                    child = [
                        QtGui.QStandardItem(str(key)),
                        QtGui.QStandardItem(str(val))
                    ]
                    item.appendRow(child)
                else:
                    new_item(item, str(key), val)
        else:
            new_item(item, str(value))

    def getPlotParameter(self):
        ''' Return parameter to be plotted '''
        param_name = self.outCombo.currentText()
        if param_name is not '':
            return param_name
        parameters = self.dataset.arrays.keys()
        if self.default_parameter in parameters:
            return self.default_parameter
        return self.dataset.default_parameter_name()

    def selectedDatafile(self):
        """ Return currently selected data file """
        return self.datatag

    def comboCallback(self, index):
        if not self._update_plot_:
            return
        param_name = self.getPlotParameter()
        if self.dataset is not None:
            self.updatePlot(param_name)

    def logCallback(self, index):
        """ Function called when. a log entry is selected """
        logging.info('logCallback: index %s' % str(index))
        self.__debug['last'] = index
        pp = index.parent()
        row = index.row()
        tag = pp.child(row, 2).data()
        filename = pp.child(row, 3).data()
        self.filename = filename
        self.datatag = tag
        if tag is None:
            return
        if self.verbose >= 2:
            print('DataViewer logCallback: tag %s, filename %s' %
                  (tag, filename))
        try:
            logging.debug('DataViewer: load tag %s' % tag)
            data = self.loadData(filename, tag)
            if not data:
                raise ValueError('File invalid (%s) ...' % filename)
            self.dataset = data
            self.updateMetaTabs()

            data_keys = data.arrays.keys()
            infotxt = self.getArrayStr(data.metadata)
            q = pp.child(row, 1).model()
            q.setData(pp.child(row, 1), infotxt)
            if 'comment' in data.metadata.keys():
                q.setData(pp.child(row, 2), data.metadata['comment'])
            self.resetComboItems(data, data_keys)
            param_name = self.getPlotParameter()
            self.updatePlot(param_name)
        except Exception as e:
            print('logCallback! error: %s' % str(e))
            logging.exception(e)
        return

    def resetComboItems(self, data, keys):
        old_key = self.outCombo.currentText()
        self._update_plot_ = False
        self.outCombo.clear()
        for key in keys:
            if not getattr(data, key).is_setpoint:
                self.outCombo.addItem(key)
        if old_key in keys:
            self.outCombo.setCurrentIndex(self.outCombo.findText(old_key))

        self._update_plot_ = True
        return

    def loadData(self, filename, tag):
        location = os.path.split(filename)[0]
        data = qtt.data.load_dataset(location)
        return data

    def updatePlot(self, parameter):
        self.qplot.clear()
        if parameter is None:
            logging.info('could not find parameter for DataSet')
            return
        else:
            logging.info('using plotting parameter %s' % parameter)
            self.qplot.add(getattr(self.dataset, parameter))
Example #24
0
def _plot_setup(data, inst_meas, useQT=True, startranges=None):
    title = "{} #{:03d}".format(CURRENT_EXPERIMENT["sample_name"],
                                data.location_provider.counter)
    rasterized_note = " rasterized plot"
    num_subplots = 0
    counter_two = 0
    for j, i in enumerate(inst_meas):
        if getattr(i, "names", False):
            num_subplots += len(i.names)
        else:
            num_subplots += 1
    if useQT:
        plot = QtPlot(fig_x_position=CURRENT_EXPERIMENT['plot_x_position'])
    else:
        plot = MatPlot(subplots=(1, num_subplots))

    def _create_plot(plot, i, name, data, counter_two, j, k):
        """
        Args:
            plot: The plot object, either QtPlot() or MatPlot()
            i: The parameter to measure
            name: -
            data: The DataSet of the current measurement
            counter_two: The sub-measurement counter. Each measurement has a
                number and each sub-measurement has a counter.
            j: The current sub-measurement
            k: -
        """
        color = 'C' + str(counter_two)
        counter_two += 1
        inst_meas_name = "{}_{}".format(i._instrument.name, name)
        inst_meas_data = getattr(data, inst_meas_name)
        inst_meta_data = __get_plot_type(inst_meas_data, plot)
        if useQT:
            plot.add(inst_meas_data, subplot=j + k + 1)
            plot.subplots[j + k].showGrid(True, True)
            if j == 0:
                plot.subplots[0].setTitle(title)
            else:
                plot.subplots[j + k].setTitle("")

            plot.fixUnitScaling(startranges)
            QtPlot.qc_helpers.foreground_qt_window(plot.win)

        else:
            if 'z' in inst_meta_data:
                xlen, ylen = inst_meta_data['z'].shape
                rasterized = xlen * ylen > 5000
                plot.add(inst_meas_data,
                         subplot=j + k + 1,
                         rasterized=rasterized)
            else:
                rasterized = False
                plot.add(inst_meas_data, subplot=j + k + 1, color=color)
                plot.subplots[j + k].grid()
            if j == 0:
                if rasterized:
                    fulltitle = title + rasterized_note
                else:
                    fulltitle = title
                plot.subplots[0].set_title(fulltitle)
            else:
                if rasterized:
                    fulltitle = rasterized_note
                else:
                    fulltitle = ""
                plot.subplots[j + k].set_title(fulltitle)

    for j, i in enumerate(inst_meas):
        if getattr(i, "names", False):
            # deal with multidimensional parameter
            for k, name in enumerate(i.names):
                _create_plot(plot, i, name, data, counter_two, j, k)
                counter_two += 1
        else:
            # simple_parameters
            _create_plot(plot, i, i.name, data, counter_two, j, 0)
            counter_two += 1
    return plot, num_subplots
    def __init__(self,
                 name,
                 nr_plot_points: int = 1000,
                 sampling_rate: float = 2.4e9,
                 auto_save_plots: bool = True,
                 **kw):
        '''
        Instantiates an object.

        Args:
            kernel_object (Instrument):
                    kernel object instrument that handles applying kernels to
                    flux pulses.
            square_amp (float):
                    Amplitude of the square pulse that is applied. This is
                    needed for correct normalization of the step response.
            nr_plot_points (int):
                    Number of points of the waveform that are plotted. Can be
                    changed in self.cfg_nr_plot_points().
        '''
        super().__init__(name, **kw)
        # Initialize instance variables
        # Plotting
        self._y_min = 0
        self._y_max = 1
        self._stop_idx = -1
        self._start_idx = 0
        self._t_start_loop = 0  # sets x range for plotting during loop
        self._t_stop_loop = 30e-6
        self.add_parameter('cfg_nr_plot_points',
                           initial_value=nr_plot_points,
                           parameter_class=ManualParameter)
        self.sampling_rate = sampling_rate
        self.add_parameter('cfg_sampling_rate',
                           initial_value=sampling_rate,
                           parameter_class=ManualParameter)
        self.add_parameter('instr_dist_kern',
                           parameter_class=InstrumentRefParameter)

        # Files
        self.filename = ''
        # where traces and plots are saved
        # self.data_dir = self.kernel_object.kernel_dir()
        self._iteration = 0
        self.auto_save_plots = auto_save_plots

        # Data
        self.waveform = []
        self.time_pts = []
        self.new_step = []

        # Fitting
        self.known_fit_models = ['exponential', 'high-pass', 'spline']
        self.fit_model = None
        self.edge_idx = None
        self.fit_res = None
        self.predicted_waveform = None

        # Default fit model used in the interactive loop
        self._fit_model_loop = 'exponential'

        self._loop_helpstring = str(
            'h:      Print this help.\n'
            'q:      Quit the loop.\n'
            'm:      Remeasures the trace. \n'
            'p <pars>:\n'
            '        Print the parameters of the last fit if pars not given.\n'
            '        If pars are given in the form of JSON string, \n'
            '        e.g., {"parA": a, "parB": b} the parameters of the last\n'
            '        fit are updated with those provided.'
            's <filename>:\n'
            '        Save the current plot to "filename.png".\n'
            'model <name>:\n'
            '        Choose the fit model that is used.\n'
            '        Available models:\n'
            '           ' + str(self.known_fit_models) + '\n'
            'xrange <min> <max>:\n'
            '        Set the x-range of the plot to (min, max). The points\n'
            '        outside this range are not plotted. The number of\n'
            '        points plotted in the given interval is fixed to\n'
            '        self.cfg_nr_plot_points() (default=1000).\n'
            'square_amp <amp> \n'
            '        Set the square_amp used to normalize measured waveforms.\n'
            '        If amp = "?" the current square_amp is printed.')

        # Make window for plots
        self.vw = QtPlot(window_title=name, figsize=(600, 400))
Example #26
0
def show_num(ids,
             samplefolder=None,
             useQT=False,
             avg_sub='',
             do_plots=True,
             savepng=True,
             fig_size=[6, 4],
             clim=None,
             dataname=None,
             xlim=None,
             ylim=None,
             transpose=False,
             auto_color_scale: Optional[bool] = None,
             cutoff_percentile: Optional[Union[Tuple[Number, Number],
                                               Number]] = None,
             **kwargs):
    """
    Show and return plot and data.
    Args:
        ids (number, list): id or list of ids of dataset(s)
        samplefolder (str): Sample folder if loading data from different sample than the initialized. 
        useQT (boolean): If true plots with QTplot instead of matplotlib
        avg_sub (str: 'col' or 'row'): Subtracts average from either each collumn ('col') or each row ('row')
        do_plots: (boolean): if false no plots are produced.
        dataname (str): If given only plots dataset with that name
        savepng (boolean): If true saves matplotlib figure as png
        fig_size [6,4]: Figure size in inches
        clim [cmin,cmax]: Set min and max of colorbar to cmin and cmax respectrively
        xlim [xmin,xmax]: Set limits on x axis
        ylim [ymin,ymax]: set limits on y axis
        transpose (boolean): Transpose data to be plotted (only works for 2D scans and qc.MatPlot)
        auto_color_scale: if True, the colorscale of heatmap plots will be
            automatically adjusted to disregard outliers.
        cutoff_percentile: percentile of data that may maximally be clipped
            on both sides of the distribution.
            If given a tuple (a,b) the percentile limits will be a and 100-b.
            See also the plotting tuorial notebook.
        **kwargs: Are passed to plot function

    Returns:
        data, plots : returns the plots and the datasets

    """
    # default values
    if auto_color_scale is None:
        auto_color_scale = qcodes.config.plotting.auto_color_scale.enabled
    if cutoff_percentile is None:
        cutoff_percentile = cast(
            Tuple[Number, Number],
            tuple(qcodes.config.plotting.auto_color_scale.cutoff_percentile))

    if not isinstance(ids, collections.Iterable):
        ids = (ids, )

    data_list = []
    keys_list = []

    # Define samplefolder
    if samplefolder == None:
        check_experiment_is_initialized()
        samplefolder = qc.DataSet.location_provider.fmt.format(counter='')

    # Load all datasets into list
    for id in ids:
        path = samplefolder + '{0:03d}'.format(id)
        data = qc.load_data(path)
        data_list.append(data)

        # find datanames to be plotted
        if do_plots:
            if useQT and len(ids) is not 1:
                raise ValueError(
                    'qcodes.QtPlot does not support multigraph plotting. Set useQT=False to plot multiple datasets.'
                )
            if dataname is not None:
                if dataname not in [
                        key for key in data.arrays.keys() if "_set" not in key
                ]:
                    raise RuntimeError('Dataname not in dataset. Input dataname was: \'{}\''.format(dataname), \
                        'while dataname(s) in dataset are: \'{}\'.'.format('\', \''.join(data.arrays.keys())))
                keys = [dataname]
            else:
                keys = [key for key in data.arrays.keys() if "_set" not in key]
            keys_list.append(keys)

    if do_plots:
        unique_keys = list(
            set([item for sublist in keys_list for item in sublist]))
        plots = []
        num = ''
        l = len(unique_keys)
        for j, key in enumerate(unique_keys):
            array_list = []
            xlims = [[], []]
            ylims = [[], []]
            clims = [[], []]
            # Find datasets containing data with dataname == key
            for data, keys in zip(data_list, keys_list):
                if key in keys:
                    arrays = getattr(data, key)
                    if transpose and len(arrays.set_arrays) == 2:
                        if useQT:
                            raise AttributeError(
                                'Transpose only works for qc.MatPlot.')
                        if dataname is None and l != 1:
                            raise ValueError(
                                'Dataname has to be provided to plot data transposed for dataset with more '
                                'than 1 measurement. Datanames in dataset are: \'{}\'.'
                                .format('\', \''.join(unique_keys)))
                        arrays.ndarray = arrays.ndarray.T
                        set0_temp = arrays.set_arrays[0]
                        set1_temp = arrays.set_arrays[1]
                        set0_temp.ndarray = set0_temp.ndarray.T
                        set1_temp.ndarray = set1_temp.ndarray.T
                        arrays.set_arrays = (
                            set1_temp,
                            set0_temp,
                        )
                    if avg_sub == 'row':
                        for i in range(np.shape(arrays.ndarray)[0]):
                            arrays.ndarray[i, :] -= np.nanmean(
                                arrays.ndarray[i, :])
                    if avg_sub == 'col':
                        for i in range(np.shape(arrays.ndarray)[1]):
                            arrays.ndarray[:,
                                           i] -= np.nanmean(arrays.ndarray[:,
                                                                           i])
                    array_list.append(arrays)

                    # Find axis limits for dataset
                    if len(arrays.set_arrays) == 2:
                        xlims[0].append(np.nanmin(arrays.set_arrays[1]))
                        xlims[1].append(np.nanmax(arrays.set_arrays[1]))
                        ylims[0].append(np.nanmin(arrays.set_arrays[0]))
                        ylims[1].append(np.nanmax(arrays.set_arrays[0]))
                        if auto_color_scale:
                            vlims = auto_range_iqr(arrays.ndarray)
                        else:
                            vlims = (np.nanmin(arrays.ndarray),
                                     np.nanmax(arrays.ndarray))
                        clims[0].append(vlims[0])
                        clims[1].append(vlims[1])
                    else:
                        xlims[0].append(np.nanmin(arrays.set_arrays[0]))
                        xlims[1].append(np.nanmax(arrays.set_arrays[0]))
                        ylims[0].append(np.nanmin(arrays.ndarray))
                        ylims[1].append(np.nanmax(arrays.ndarray))

            if useQT:
                plot = QtPlot(
                    array_list[0],
                    fig_x_position=CURRENT_EXPERIMENT['plot_x_position'],
                    **kwargs)
                title = "{} #{}".format(CURRENT_EXPERIMENT["sample_name"],
                                        '{}'.format(ids[0]))
                plot.subplots[0].setTitle(title)
                plot.subplots[0].showGrid(True, True)
                if savepng:
                    print('Save plot only working for matplotlib figure.', \
                        'Set useQT=False to save png.')
            else:
                plot = MatPlot(array_list, **kwargs)
                plot.rescale_axis()
                plot.fig.tight_layout(pad=3)
                plot.fig.set_size_inches(fig_size)
                # Set axis limits
                if xlim is None:
                    plot[0].axes.set_xlim(
                        [np.nanmin(xlims[0]),
                         np.nanmax(xlims[1])])
                else:
                    plot[0].axes.set_xlim(xlim)
                if ylim is None:
                    plot[0].axes.set_ylim(
                        [np.nanmin(ylims[0]),
                         np.nanmax(ylims[1])])
                else:
                    plot[0].axes.set_ylim(ylim)
                if len(arrays.set_arrays) == 2:
                    total_clim = [None, None]
                    for i in range(len(array_list)):
                        # TODO(DV): get colorbar from plot children (should be ax.qcodes_colorbar)
                        if clim is None:
                            internal_clim = np.nanmin(clims[0]), np.nanmax(
                                clims[1])
                        else:
                            internal_clim = clim
                        if total_clim[0] is None or internal_clim[
                                0] < total_clim[0]:
                            total_clim[0] = internal_clim[0]
                        if total_clim[1] is None or internal_clim[
                                1] > total_clim[1]:
                            total_clim[1] = internal_clim[1]
                    colorbar = plot[0].qcodes_colorbar
                    apply_color_scale_limits(colorbar,
                                             new_lim=tuple(total_clim))

                # Set figure titles
                plot.fig.suptitle(samplefolder)
                if len(ids) < 6:
                    plot.subplots[0].set_title(', '.join(map(str, ids)))
                else:
                    plot.subplots[0].set_title(' - '.join(
                        map(str, [ids[0], ids[-1]])))
                plt.draw()

                # Save figure
                if savepng:
                    if len(ids) == 1:
                        title_png = samplefolder + CURRENT_EXPERIMENT[
                            'png_subfolder'] + sep + '{}'.format(ids[0])
                    else:
                        title_png = samplefolder + CURRENT_EXPERIMENT[
                            'png_subfolder'] + sep + '{}-{}'.format(
                                ids[0], ids[-1])
                    if l > 1:
                        num = '{}'.format(j + 1)
                    plt.savefig(title_png + '_{}_{}.png'.format(num, avg_sub),
                                dpi=500)
            plots.append(plot)
    else:
        plots = None
    return data_list, plots
Example #27
0
class DataViewer(QtWidgets.QMainWindow):
    def __init__(self,
                 data_directory=None,
                 window_title='Data browser',
                 default_parameter='amplitude',
                 extensions=None,
                 verbose=1):
        """ Contstructs a simple viewer for Qcodes data.

        Args:
            data_directory (string or None): The directory to scan for experiments.
            default_parameter (string): A name of default parameter to plot.
            extensions (list): A list with the data file extensions to filter.
            verbose (int): The logging verbosity level.
        """
        super(DataViewer, self).__init__()
        if extensions is None:
            extensions = ['dat', 'hdf5']

        self.verbose = verbose
        self.default_parameter = default_parameter
        self.data_directories = [None] * 2
        self.directory_index = 0
        if data_directory is None:
            data_directory = qcodes.DataSet.default_io.base_location

        self.extensions = extensions

        # setup GUI
        self.dataset = None
        self.text = QtWidgets.QLabel()

        # logtree
        self.logtree = QtWidgets.QTreeView()
        self.logtree.setSelectionBehavior(
            QtWidgets.QAbstractItemView.SelectRows)
        self._treemodel = QtGui.QStandardItemModel()
        self.logtree.setModel(self._treemodel)

        # metatabs
        self.meta_tabs = QtWidgets.QTabWidget()
        self.meta_tabs.addTab(QtWidgets.QWidget(), 'metadata')

        self.__debug = dict()
        if isinstance(QtPlot, QWidget):
            self.qplot = QtPlot()
        else:
            self.qplot = QtPlot(remote=False)
        if isinstance(self.qplot, QWidget):
            self.plotwindow = self.qplot
        else:
            self.plotwindow = self.qplot.win

        topLayout = QtWidgets.QHBoxLayout()

        self.filterbutton = QtWidgets.QPushButton()
        self.filterbutton.setText('Filter data')
        self.filtertext = QtWidgets.QLineEdit()
        self.outCombo = QtWidgets.QComboBox()

        topLayout.addWidget(self.text)
        topLayout.addWidget(self.filterbutton)
        topLayout.addWidget(self.filtertext)

        treesLayout = QtWidgets.QHBoxLayout()
        treesLayout.addWidget(self.logtree)
        treesLayout.addWidget(self.meta_tabs)

        vertLayout = QtWidgets.QVBoxLayout()

        vertLayout.addItem(topLayout)
        vertLayout.addItem(treesLayout)
        vertLayout.addWidget(self.plotwindow)

        self.pptbutton = QtWidgets.QPushButton()
        self.pptbutton.setText('Send data to powerpoint')
        self.clipboardbutton = QtWidgets.QPushButton()
        self.clipboardbutton.setText('Copy image to clipboard')

        bLayout = QtWidgets.QHBoxLayout()
        bLayout.addWidget(self.outCombo)
        bLayout.addWidget(self.pptbutton)
        bLayout.addWidget(self.clipboardbutton)

        vertLayout.addItem(bLayout)
        widget = QtWidgets.QWidget()
        widget.setLayout(vertLayout)
        self.setCentralWidget(widget)

        self.setWindowTitle(window_title)
        self.logtree.header().resizeSection(0, 280)

        # disable edit
        self.logtree.setEditTriggers(
            QtWidgets.QAbstractItemView.NoEditTriggers)

        self.set_data_directory(data_directory)
        self.logtree.doubleClicked.connect(self.log_callback)
        self.outCombo.currentIndexChanged.connect(self.combobox_callback)
        self.filterbutton.clicked.connect(
            lambda: self.update_logs(filter_str=self.filtertext.text()))
        self.pptbutton.clicked.connect(self.ppt_callback)
        self.clipboardbutton.clicked.connect(self.clipboard_callback)

        menuBar = self.menuBar()

        menuDict = {
            '&Data': {
                '&Reload Data': self.update_logs,
                '&Preload all Info': self.load_info,
                '&Quit': self.close
            },
            '&Folder': {
                '&Select Dir1': lambda: self.select_directory(index=0),
                'Select &Dir2': lambda: self.select_directory(index=1),
                '&Toggle Dirs': self.toggle_data_directory
            },
            '&Help': {
                '&Info': self.show_help
            }
        }
        for (k, menu) in menuDict.items():
            mb = menuBar.addMenu(k)
            for (kk, action) in menu.items():
                act = QtWidgets.QAction(kk, self)
                mb.addAction(act)
                act.triggered.connect(action)

        if self.verbose >= 2:
            print('created gui...')

        # get logs from disk
        self.update_logs()
        self.datatag = None

        self.logtree.setColumnHidden(2, True)
        self.logtree.setColumnHidden(3, True)
        self.show()

    def set_data_directory(self, data_directory, index=0):
        self.data_directories[index] = data_directory
        self.data_directory = data_directory
        self.disk_io = qcodes.DiskIO(data_directory)
        logging.info('DataViewer: data directory %s' % data_directory)
        self.text.setText('Log files at %s' % self.data_directory)

    def show_help(self):
        """ Show help dialog """
        self.infotext = "Dataviewer for qcodes datasets"
        QtWidgets.QMessageBox.information(self, 'qtt dataviwer control info',
                                          self.infotext)

    def toggle_data_directory(self):
        index = (self.directory_index + 1) % len(self.data_directories)
        self.directory_index = index
        self.data_directory = self.data_directories[index]
        self.disk_io = qcodes.DiskIO(self.data_directory)
        logging.info('DataViewer: data directory %s' % self.data_directory)
        self.text.setText('Log files at %s' % self.data_directory)
        self.update_logs()

    def ppt_callback(self):
        if self.dataset is None:
            print('no data selected')
            return
        qtt.utilities.tools.addPPT_dataset(self.dataset, customfig=self.qplot)

    def clipboard_callback(self):
        self.qplot.copyToClipboard()

    @staticmethod
    def get_data_info(metadata):
        params = []
        try:
            if 'loop' in metadata.keys():
                sv = metadata['loop']['sweep_values']
                params.append(
                    '%s [%.2f to %.2f %s]' %
                    (sv['parameter']['label'], sv['values'][0]['first'],
                     sv['values'][0]['last'], sv['parameter']['unit']))

                for act in metadata['loop']['actions']:
                    if 'sweep_values' in act.keys():
                        sv = act['sweep_values']
                        params.append(
                            '%s [%.2f - %.2f %s]' %
                            (sv['parameter']['label'],
                             sv['values'][0]['first'], sv['values'][0]['last'],
                             sv['parameter']['unit']))
                infotxt = ' ,'.join(params)
                infotxt = infotxt + '  |  ' + ', '.join(
                    [('%s' % (v['label']))
                     for (k, v) in metadata['arrays'].items()
                     if not v['is_setpoint']])

            elif 'scanjob' in metadata.keys():
                sd = metadata['scanjob']['sweepdata']
                params.append('%s [%.2f to %.2f]' %
                              (sd['param'], sd['start'], sd['end']))
                if 'stepdata' in metadata['scanjob']:
                    sd = metadata['scanjob']['stepdata']
                    params.append('%s [%.2f to %.2f]' %
                                  (sd['param'], sd['start'], sd['end']))
                infotxt = ' ,'.join(params)
                infotxt = infotxt + '  |  ' + \
                          ', '.join(metadata['scanjob']['minstrument'])
            else:
                infotxt = 'info about plot'

        except BaseException:
            infotxt = 'info about plot'

        return infotxt

    def load_info(self):
        try:
            for row in range(self._treemodel.rowCount()):
                index = self._treemodel.index(row, 0)
                i = 0
                while (index.child(i, 0).data() is not None):
                    filename = index.child(i, 3).data()
                    loc = '\\'.join(filename.split('\\')[:-1])
                    tempdata = qcodes.DataSet(loc)
                    tempdata.read_metadata()
                    infotxt = DataViewer.get_data_info(tempdata.metadata)
                    self._treemodel.setData(index.child(i, 1), infotxt)
                    if 'comment' in tempdata.metadata.keys():
                        self._treemodel.setData(index.child(i, 4),
                                                tempdata.metadata['comment'])
                    i = i + 1
        except Exception as e:
            print(e)

    def select_directory(self, index=0):
        d = QtWidgets.QFileDialog(caption='Select data directory')
        d.setFileMode(QFileDialog.Directory)
        if d.exec():
            datadir = d.selectedFiles()[0]
            self.set_data_directory(datadir, index)
            print('update logs')
            self.update_logs()

    @staticmethod
    def find_datafiles(datadir, extensions=None, show_progress=True):
        """ Find all datasets in a directory with a given extension """
        if extensions is None:
            extensions = ['dat', 'hdf5']
        dd = []
        for e in extensions:
            dd += qtt.pgeometry.findfilesR(datadir,
                                           '.*%s' % e,
                                           show_progress=show_progress)

        datafiles = sorted(dd)
        return datafiles

    @staticmethod
    def _filename2datetag(filename):
        """ Parse a filename to a date tag and base filename """
        if filename.endswith('.json'):
            datetag = filename.split(os.sep)[-1].split('_')[0]
            logtag = filename.split(os.sep)[-1][:-5]
        else:
            # other formats, assumed to be in normal form
            datetag, logtag = filename.split(os.sep)[-3:-1]
        return datetag, logtag

    def update_logs(self, filter_str=None):
        ''' Update the list of measurements '''
        model = self._treemodel

        self.datafiles = self.find_datafiles(self.data_directory,
                                             self.extensions)
        dd = self.datafiles

        if filter_str:
            dd = [s for s in dd if filter_str in s]

        if self.verbose:
            print('DataViewer: found %d files' % (len(dd)))

        model.clear()
        model.setHorizontalHeaderLabels(
            ['Log', 'Arrays', 'location', 'filename', 'Comments'])

        logs = dict()
        for _, filename in enumerate(dd):
            try:
                datetag, logtag = self._filename2datetag(filename)
                if datetag not in logs:
                    logs[datetag] = dict()
                logs[datetag][logtag] = filename
            except Exception:
                pass
        self.logs = logs

        if self.verbose >= 2:
            print('DataViewer: create gui elements')
        for i, datetag in enumerate(sorted(logs.keys())[::-1]):
            if self.verbose >= 2:
                print('DataViewer: datetag %s ' % datetag)

            parent1 = QtGui.QStandardItem(datetag)
            for j, logtag in enumerate(sorted(logs[datetag])):
                filename = logs[datetag][logtag]
                child1 = QtGui.QStandardItem(logtag)
                child2 = QtGui.QStandardItem('info about plot')
                if self.verbose >= 2:
                    print('datetag %s, logtag %s' % (datetag, logtag))
                child3 = QtGui.QStandardItem(os.path.join(datetag, logtag))
                child4 = QtGui.QStandardItem(filename)
                parent1.appendRow([child1, child2, child3, child4])
            model.appendRow(parent1)
            self.logtree.setColumnWidth(0, 240)
            self.logtree.setColumnHidden(2, True)
            self.logtree.setColumnHidden(3, True)

        if self.verbose >= 2:
            print('DataViewer: update_logs done')

    def _create_meta_tree(self, meta_dict):
        metatree = QtWidgets.QTreeView()
        _metamodel = QtGui.QStandardItemModel()
        metatree.setModel(_metamodel)
        metatree.setEditTriggers(QtWidgets.QAbstractItemView.NoEditTriggers)

        _metamodel.setHorizontalHeaderLabels(['metadata', 'value'])

        try:
            self.fill_item(_metamodel, meta_dict)
            return metatree

        except Exception as ex:
            print(ex)

    def update_meta_tabs(self):
        ''' Update metadata tree '''
        meta = self.dataset.metadata

        self.meta_tabs.clear()
        if 'gates' in meta.keys():
            self.meta_tabs.addTab(self._create_meta_tree(meta['gates']),
                                  'gates')
        elif meta.get('station', dict()).get('instruments',
                                             dict()).get('gates',
                                                         None) is not None:
            self.meta_tabs.addTab(
                self._create_meta_tree(
                    meta['station']['instruments']['gates']), 'gates')
        if meta.get('station', dict()).get('instruments', None) is not None:
            if 'instruments' in meta['station'].keys():
                self.meta_tabs.addTab(
                    self._create_meta_tree(meta['station']['instruments']),
                    'instruments')

        self.meta_tabs.addTab(self._create_meta_tree(meta), 'metadata')

    def fill_item(self, item, value):
        ''' recursive population of tree structure with a dict '''
        def new_item(parent, text, val=None):
            child = QtGui.QStandardItem(text)
            self.fill_item(child, val)
            parent.appendRow(child)

        if value is None:
            return
        elif isinstance(value, dict):
            for key, val in sorted(value.items()):
                if type(val) in [str, float, int]:
                    child = [
                        QtGui.QStandardItem(str(key)),
                        QtGui.QStandardItem(str(val))
                    ]
                    item.appendRow(child)
                else:
                    new_item(item, str(key), val)
        else:
            new_item(item, str(value))

    def get_plot_parameter(self):
        ''' Return parameter to be plotted '''
        param_name = self.outCombo.currentText()
        if param_name is not '':
            return param_name
        parameters = self.dataset.arrays.keys()
        if self.default_parameter in parameters:
            return self.default_parameter
        return self.dataset.default_parameter_name()

    def selected_data_file(self):
        """ Return currently selected data file """
        return self.datatag

    def combobox_callback(self, index):
        if not self._update_plot_:
            return
        param_name = self.get_plot_parameter()
        if self.dataset is not None:
            self.update_plot(param_name)

    def log_callback(self, index):
        """ Function called when. a log entry is selected """
        logging.info('logCallback: index %s' % str(index))
        self.__debug['last'] = index
        pp = index.parent()
        row = index.row()
        tag = pp.child(row, 2).data()
        filename = pp.child(row, 3).data()
        self.filename = filename
        self.datatag = tag
        if tag is None:
            return
        if self.verbose >= 2:
            print('DataViewer logCallback: tag %s, filename %s' %
                  (tag, filename))
        try:
            logging.debug('DataViewer: load tag %s' % tag)
            data = DataViewer.load_data(filename, tag)
            if not data:
                raise ValueError('File invalid (%s) ...' % filename)
            self.dataset = data
            self.update_meta_tabs()

            data_keys = data.arrays.keys()
            infotxt = DataViewer.get_data_info(data.metadata)
            q = pp.child(row, 1).model()
            q.setData(pp.child(row, 1), infotxt)
            if 'comment' in data.metadata.keys():
                q.setData(pp.child(row, 2), data.metadata['comment'])
            self.reset_combo_items(data, data_keys)
            param_name = self.get_plot_parameter()
            self.update_plot(param_name)
        except Exception as e:
            print('logCallback! error: %s' % str(e))
            logging.exception(e)
        return

    def reset_combo_items(self, data, keys):
        old_key = self.outCombo.currentText()
        self._update_plot_ = False
        self.outCombo.clear()
        for key in keys:
            if not getattr(data, key).is_setpoint:
                self.outCombo.addItem(key)
        if old_key in keys:
            self.outCombo.setCurrentIndex(self.outCombo.findText(old_key))

        self._update_plot_ = True
        return

    @staticmethod
    def load_data(filename, tag):
        if filename.endswith('.json'):
            location = filename
        else:
            # qcodes datasets are found by filename, but should be loaded by directory...
            location = os.path.split(filename)[0]
        data = qtt.data.load_dataset(location)
        return data

    def update_plot(self, parameter):
        self.qplot.clear()
        if parameter is None:
            logging.info('could not find parameter for DataSet')
            return
        else:
            logging.info('using plotting parameter %s' % parameter)
            self.qplot.add(getattr(self.dataset, parameter))
class MeasurementControl(Instrument):

    '''
    New version of Measurement Control that allows for adaptively determining
    data points.
    '''

    def __init__(self, name,
                 plotting_interval=0.25,
                 live_plot_enabled=True, verbose=True):
        super().__init__(name=name, server_name=None)
        # Soft average is currently only available for "hard"
        # measurements. It does not work with adaptive measurements.
        self.add_parameter('soft_avg',
                           label='Number of soft averages',
                           parameter_class=ManualParameter,
                           vals=vals.Ints(1, int(1e8)),
                           initial_value=1)
        self.add_parameter('verbose',
                           parameter_class=ManualParameter,
                           vals=vals.Bool(),
                           initial_value=verbose)
        self.add_parameter('live_plot_enabled',
                           parameter_class=ManualParameter,
                           vals=vals.Bool(),
                           initial_value=live_plot_enabled)
        self.add_parameter('plotting_interval',
                           units='s',
                           vals=vals.Numbers(min_value=0.001),
                           set_cmd=self._set_plotting_interval,
                           get_cmd=self._get_plotting_interval)
        self.add_parameter('persist_mode',
                           vals=vals.Bool(),
                           parameter_class=ManualParameter,
                           initial_value=True)

        # pyqtgraph plotting process is reused for different measurements.
        if self.live_plot_enabled():
            self.main_QtPlot = QtPlot(
                windowTitle='Main plotmon of {}'.format(self.name),
                figsize=(600, 400))
            self.secondary_QtPlot = QtPlot(
                windowTitle='Secondary plotmon of {}'.format(self.name),
                figsize=(600, 400))

        self.soft_iteration = 0  # used as a counter for soft_avg
        self._persist_dat = None
        self._persist_xlabs = None
        self._persist_ylabs = None

    ##############################################
    # Functions used to control the measurements #
    ##############################################

    def run(self, name=None, mode='1D', **kw):
        '''
        Core of the Measurement control.
        '''
        # Setting to zero at the start of every run, used in soft avg
        self.soft_iteration = 0
        self.set_measurement_name(name)
        self.print_measurement_start_msg()
        self.mode = mode
        self.iteration = 0  # used in determining data writing indices
        with h5d.Data(name=self.get_measurement_name()) as self.data_object:
            self.get_measurement_begintime()
            # Commented out because requires git shell interaction from python
            # self.get_git_hash()
            # Such that it is also saved if the measurement fails
            # (might want to overwrite again at the end)
            self.save_instrument_settings(self.data_object)
            self.create_experimentaldata_dataset()
            if mode is not 'adaptive':
                self.xlen = len(self.get_sweep_points())
            if self.mode == '1D':
                self.measure()
            elif self.mode == '2D':
                self.measure_2D()
            elif self.mode == 'adaptive':
                self.measure_soft_adaptive()
            else:
                raise ValueError('mode %s not recognized' % self.mode)
            result = self.dset[()]
            self.save_MC_metadata(self.data_object)  # timing labels etc
        self.finish(result)
        return result

    def measure(self, *kw):
        if self.live_plot_enabled():
            self.initialize_plot_monitor()

        for sweep_function in self.sweep_functions:
            sweep_function.prepare()

        if (self.sweep_functions[0].sweep_control == 'soft' and
                self.detector_function.detector_control == 'soft'):
            self.detector_function.prepare()
            self.get_measurement_preparetime()
            self.measure_soft_static()

        elif self.detector_function.detector_control == 'hard':
            sweep_points = self.get_sweep_points()
            if len(self.sweep_functions) == 1:
                self.get_measurement_preparetime()
                self.detector_function.prepare(
                    sweep_points=self.get_sweep_points())
                self.measure_hard()
            else:
                # Do one iteration to see how many points per data point we get
                self.get_measurement_preparetime()
                for i, sweep_function in enumerate(self.sweep_functions):
                    swf_sweep_points = sweep_points[:, i]
                    val = swf_sweep_points[0]
                    sweep_function.set_parameter(val)
                self.detector_function.prepare(
                    sweep_points=sweep_points[:self.xlen, 0])
                self.measure_hard()

            # will not be complet if it is a 2D loop, soft avg or many shots
            if not self.is_complete():
                pts_per_iter = self.dset.shape[0]
                swp_len = np.shape(sweep_points)[0]
                req_nr_iterations = int(swp_len/pts_per_iter)
                total_iterations = req_nr_iterations * self.soft_avg()

                for i in range(total_iterations-1):
                    start_idx, stop_idx = self.get_datawriting_indices(
                        pts_per_iter=pts_per_iter)
                    if start_idx == 0:
                        self.soft_iteration += 1
                    for i, sweep_function in enumerate(self.sweep_functions):
                        if len(self.sweep_functions) != 1:
                            swf_sweep_points = sweep_points[:, i]
                            sweep_points_0 = sweep_points[:, 0]
                        else:
                            swf_sweep_points = sweep_points
                            sweep_points_0 = sweep_points
                        val = swf_sweep_points[start_idx]

                        if sweep_function.sweep_control is 'soft':
                            sweep_function.set_parameter(val)
                    self.detector_function.prepare(
                        sweep_points=sweep_points_0[start_idx:stop_idx])
                    self.measure_hard()
        else:
            raise Exception('Sweep and Detector functions not '
                            + 'of the same type. \nAborting measurement')
            print(self.sweep_function.sweep_control)
            print(self.detector_function.detector_control)

        self.update_plotmon(force_update=True)
        for sweep_function in self.sweep_functions:
            sweep_function.finish()
        self.detector_function.finish()
        self.get_measurement_endtime()

        return

    def measure_soft_static(self):
        for self.soft_iteration in range(self.soft_avg()):
            for i, sweep_point in enumerate(self.sweep_points):
                self.measurement_function(sweep_point)

    def measure_soft_adaptive(self, method=None):
        '''
        Uses the adaptive function and keywords for that function as
        specified in self.af_pars()
        '''
        self.save_optimization_settings()
        adaptive_function = self.af_pars.pop('adaptive_function')
        if self.live_plot_enabled():
            self.initialize_plot_monitor()
            self.initialize_plot_monitor_adaptive()
        for sweep_function in self.sweep_functions:
            sweep_function.prepare()
        self.detector_function.prepare()
        self.get_measurement_preparetime()

        if adaptive_function == 'Powell':
            adaptive_function = fmin_powell
        if (isinstance(adaptive_function, types.FunctionType) or
                isinstance(adaptive_function, np.ufunc)):
            try:
                adaptive_function(self.optimization_function, **self.af_pars)
            except StopIteration:
                print('Reached f_termination: %s' % (self.f_termination))
        else:
            raise Exception('optimization function: "%s" not recognized'
                            % adaptive_function)

        for sweep_function in self.sweep_functions:
            sweep_function.finish()
        self.detector_function.finish()
        self.update_plotmon(force_update=True)
        self.update_plotmon_adaptive(force_update=True)
        self.get_measurement_endtime()
        if self.verbose():
            print('Optimization completed in {:.4g}s'.format(
                self.endtime-self.begintime))
        return

    def measure_hard(self):
        new_data = np.array(self.detector_function.get_values()).T

        ###########################
        # Shape determining block #
        ###########################

        datasetshape = self.dset.shape
        start_idx, stop_idx = self.get_datawriting_indices(new_data)

        new_datasetshape = (np.max([datasetshape[0], stop_idx]),
                            datasetshape[1])
        self.dset.resize(new_datasetshape)
        len_new_data = stop_idx-start_idx
        if len(np.shape(new_data)) == 1:
            old_vals = self.dset[start_idx:stop_idx,
                                 len(self.sweep_functions)]
            new_vals = ((new_data + old_vals*self.soft_iteration) /
                        (1+self.soft_iteration))

            self.dset[start_idx:stop_idx,
                      len(self.sweep_functions)] = new_vals
        else:
            old_vals = self.dset[start_idx:stop_idx,
                                 len(self.sweep_functions):]
            new_vals = ((new_data + old_vals*self.soft_iteration) /
                        (1+self.soft_iteration))

            self.dset[start_idx:stop_idx,
                      len(self.sweep_functions):] = new_vals
        sweep_len = len(self.get_sweep_points().T)

        ######################
        # DATA STORING BLOCK #
        ######################
        if sweep_len == len_new_data:  # 1D sweep
            self.dset[:, 0] = self.get_sweep_points().T
        else:
            try:
                if len(self.sweep_functions) != 1:
                    relevant_swp_points = self.get_sweep_points()[
                        start_idx:start_idx+len_new_data:]
                    self.dset[start_idx:, 0:len(self.sweep_functions)] = \
                        relevant_swp_points
                else:
                    self.dset[start_idx:, 0] = self.get_sweep_points()[
                        start_idx:start_idx+len_new_data:].T
            except Exception:
                # There are some cases where the sweep points are not
                # specified that you don't want to crash (e.g. on -off seq)
                pass

        self.update_plotmon()
        if self.mode == '2D':
            self.update_plotmon_2D_hard()
        self.print_progress(stop_idx)
        self.iteration += 1
        return new_data

    def measurement_function(self, x):
        '''
        Core measurement function used for soft sweeps
        '''

        if np.size(x) == 1:
            x = [x]
        if np.size(x) != len(self.sweep_functions):
            raise ValueError(
                'size of x "%s" not equal to # sweep functions' % x)
        for i, sweep_function in enumerate(self.sweep_functions[::-1]):
            sweep_function.set_parameter(x[::-1][i])
            # x[::-1] changes the order in which the parameters are set, so
            # it is first the outer sweep point and then the inner.This
            # is generally not important except for specifics: f.i. the phase
            # of an agilent generator is reset to 0 when the frequency is set.

        datasetshape = self.dset.shape
        # self.iteration = datasetshape[0] + 1
        start_idx, stop_idx = self.get_datawriting_indices(pts_per_iter=1)
        vals = self.detector_function.acquire_data_point()
        # Resizing dataset and saving
        new_datasetshape = (np.max([datasetshape[0], stop_idx]),
                            datasetshape[1])
        self.dset.resize(new_datasetshape)
        new_data = np.append(x, vals)
        old_vals = self.dset[start_idx:stop_idx, :]
        new_vals = ((new_data + old_vals*self.soft_iteration) /
                    (1+self.soft_iteration))
        self.dset[start_idx:stop_idx, :] = new_vals
        # update plotmon
        self.update_plotmon()
        if self.mode == '2D':
            self.update_plotmon_2D()
        elif self.mode == 'adaptive':
            self.update_plotmon_adaptive()
        self.iteration += 1
        if self.mode != 'adaptive':
            self.print_progress(stop_idx)
        return vals

    def optimization_function(self, x):
        '''
        A wrapper around the measurement function.
        It takes the following actions based on parameters specified
        in self.af_pars:
        - Rescales the function using the "x_scale" parameter, default is 1
        - Inverts the measured values if "minimize"==False
        - Compares measurement value with "f_termination" and raises an
        exception, that gets caught outside of the optimization loop, if
        the measured value is smaller than this f_termination.

        Measurement function with scaling to correct physical value
        '''
        if hasattr(self.x_scale, '__iter__'):  # to check if
            for i in range(len(x)):
                x[i] = float(x[i])/float(self.x_scale[i])
        elif self.x_scale != 1:  # only rescale if needed
            for i in range(len(x)):
                x[i] = float(x[i])/float(self.x_scale[i])
        if self.minimize_optimization:
            vals = self.measurement_function(x)
            if (self.f_termination is not None):
                if (vals < self.f_termination):
                    raise StopIteration()
        else:
            vals = self.measurement_function(x)
            # when maximizing interrupt when larger than condition before
            # inverting
            if (self.f_termination is not None):
                if (vals > self.f_termination):
                    raise StopIteration()
            vals = np.multiply(-1, vals)

        # to check if vals is an array with multiple values
        if hasattr(vals, '__iter__'):
            if len(vals) > 1:
                vals = vals[self.par_idx]
        return vals

    def finish(self, result):
        '''
        Deletes arrays to clean up memory and avoid memory related mistakes
        '''
        # this data can be plotted by enabling persist_mode
        self._persist_dat = result
        self._persist_xlabs = self.column_names[
            0:len(self.sweep_function_names)]
        self._persist_ylabs = self.column_names[
            len(self.sweep_function_names):]

        for attr in ['TwoD_array',
                     'dset',
                     'sweep_points',
                     'sweep_points_2D',
                     'sweep_functions',
                     'xlen',
                     'ylen',
                     'iteration',
                     'soft_iteration']:
            try:
                delattr(self, attr)
            except AttributeError:
                pass

    ###################
    # 2D-measurements #
    ###################

    def run_2D(self, name=None, **kw):
        self.run(name=name, mode='2D', **kw)

    def tile_sweep_pts_for_2D(self):
        self.xlen = len(self.get_sweep_points())
        self.ylen = len(self.sweep_points_2D)
        if np.size(self.get_sweep_points()[0]) == 1:
            # create inner loop pts
            self.sweep_pts_x = self.get_sweep_points()
            x_tiled = np.tile(self.sweep_pts_x, self.ylen)
            # create outer loop
            self.sweep_pts_y = self.sweep_points_2D
            y_rep = np.repeat(self.sweep_pts_y, self.xlen)
            c = np.column_stack((x_tiled, y_rep))
            self.set_sweep_points(c)
            self.initialize_plot_monitor_2D()
        return

    def measure_2D(self, **kw):
        '''
        Sweeps over two parameters set by sweep_function and sweep_function_2D.
        The outer loop is set by sweep_function_2D, the inner loop by the
        sweep_function.

        Soft(ware) controlled sweep functions require soft detectors.
        Hard(ware) controlled sweep functions require hard detectors.
        '''

        self.tile_sweep_pts_for_2D()
        self.measure(**kw)
        return

    def set_sweep_function_2D(self, sweep_function):
        # If it is not a sweep function, assume it is a qc.parameter
        # and try to auto convert it it
        if not isinstance(sweep_function, swf.Sweep_function):
            sweep_function = wrap_par_to_swf(sweep_function)

        if len(self.sweep_functions) != 1:
            raise KeyError(
                'Specify sweepfunction 1D before specifying sweep_function 2D')
        else:
            self.sweep_functions.append(sweep_function)
            self.sweep_function_names.append(
                str(sweep_function.__class__.__name__))

    def set_sweep_points_2D(self, sweep_points_2D):
        self.sweep_functions[1].sweep_points = sweep_points_2D
        self.sweep_points_2D = sweep_points_2D

    ###########
    # Plotmon #
    ###########
    '''
    There are (will be) three kinds of plotmons, the regular plotmon,
    the 2D plotmon (which does a heatmap) and the adaptive plotmon.
    '''

    def initialize_plot_monitor(self):
        # new code
        if self.main_QtPlot.traces != []:
            self.main_QtPlot.clear()
        self.curves = []
        xlabels = self.column_names[0:len(self.sweep_function_names)]
        ylabels = self.column_names[len(self.sweep_function_names):]
        j = 0
        if (self._persist_ylabs == ylabels and
                self._persist_xlabs == xlabels) and self.persist_mode():
            persist = True
        else:
            persist = False
        for yi, ylab in enumerate(ylabels):
            for xi, xlab in enumerate(xlabels):
                if persist:  # plotting persist first so new data on top
                    yp = self._persist_dat[
                        :, yi+len(self.sweep_function_names)]
                    xp = self._persist_dat[:, xi]
                    self.main_QtPlot.add(x=xp, y=yp,
                                         subplot=j+1,
                                         color=0.75,  # a grayscale value
                                         symbol='o', symbolSize=5)
                self.main_QtPlot.add(x=[0], y=[0],
                                     xlabel=xlab, ylabel=ylab,
                                     subplot=j+1,
                                     color=color_cycle[j%len(color_cycle)],
                                     symbol='o', symbolSize=5)
                self.curves.append(self.main_QtPlot.traces[-1])
                j += 1
            self.main_QtPlot.win.nextRow()

    def update_plotmon(self, force_update=False):
        if self.live_plot_enabled():
            i = 0
            try:
                time_since_last_mon_update = time.time() - self._mon_upd_time
            except:
                self._mon_upd_time = time.time()
                time_since_last_mon_update = 1e9
            # Update always if there are very few points
            if (self.dset.shape[0] < 20 or time_since_last_mon_update >
                    self.plotting_interval() or force_update):
                nr_sweep_funcs = len(self.sweep_function_names)
                for y_ind in range(len(self.detector_function.value_names)):
                    for x_ind in range(nr_sweep_funcs):
                        x = self.dset[:, x_ind]
                        y = self.dset[:, nr_sweep_funcs+y_ind]

                        self.curves[i]['config']['x'] = x
                        self.curves[i]['config']['y'] = y
                        i += 1
                self._mon_upd_time = time.time()
                self.main_QtPlot.update_plot()

    def initialize_plot_monitor_2D(self):
        '''
        Preallocates a data array to be used for the update_plotmon_2D command.

        Made to work with at most 2 2D arrays (as this is how the labview code
        works). It should be easy to extend this function for more vals.
        '''
        if self.live_plot_enabled():
            self.time_last_2Dplot_update = time.time()
            n = len(self.sweep_pts_y)
            m = len(self.sweep_pts_x)
            self.TwoD_array = np.empty(
                [n, m, len(self.detector_function.value_names)])
            self.TwoD_array[:] = np.NAN
            self.secondary_QtPlot.clear()
            for j in range(len(self.detector_function.value_names)):
                self.secondary_QtPlot.add(x=self.sweep_pts_x,
                                   y=self.sweep_pts_y,
                                   z=self.TwoD_array[:, :, j],
                                   xlabel=self.column_names[0],
                                   ylabel=self.column_names[1],
                                   zlabel=self.column_names[2+j],
                                   subplot=j+1,
                                   cmap='viridis')

    def update_plotmon_2D(self, force_update=False):
        '''
        Adds latest measured value to the TwoD_array and sends it
        to the QC_QtPlot.
        '''
        if self.live_plot_enabled():
            i = int((self.iteration) % (self.xlen*self.ylen))
            x_ind = int(i % self.xlen)
            y_ind = int(i / self.xlen)
            for j in range(len(self.detector_function.value_names)):
                z_ind = len(self.sweep_functions) + j
                self.TwoD_array[y_ind, x_ind, j] = self.dset[i, z_ind]
            self.secondary_QtPlot.traces[j]['config']['z'] = self.TwoD_array[:, :, j]
            if (time.time() - self.time_last_2Dplot_update >
                    self.plotting_interval()
                    or self.iteration == len(self.sweep_points)):
                self.time_last_2Dplot_update = time.time()
                self.secondary_QtPlot.update_plot()

    def initialize_plot_monitor_adaptive(self):
        '''
        Uses the Qcodes plotting windows for plotting adaptive plot updates
        '''
        self.time_last_ad_plot_update = time.time()
        self.secondary_QtPlot.clear()
        for j in range(len(self.detector_function.value_names)):
            self.secondary_QtPlot.add(x=[0],
                               y=[0],
                               xlabel='iteration',
                               ylabel=self.detector_function.value_names[j],
                               subplot=j+1,
                               symbol='o', symbolSize=5)

    def update_plotmon_adaptive(self, force_update=False):
        if self.live_plot_enabled():
            if (time.time() - self.time_last_ad_plot_update >
                    self.plotting_interval() or force_update):
                for j in range(len(self.detector_function.value_names)):
                    y_ind = len(self.sweep_functions) + j
                    y = self.dset[:, y_ind]
                    x = range(len(y))
                    self.secondary_QtPlot.traces[j]['config']['x'] = x
                    self.secondary_QtPlot.traces[j]['config']['y'] = y
                    self.time_last_ad_plot_update = time.time()
                    self.secondary_QtPlot.update_plot()

    def update_plotmon_2D_hard(self):
        '''
        Adds latest datarow to the TwoD_array and send it
        to the QC_QtPlot.
        Note that the plotmon only supports evenly spaced lattices.
        '''
        if self.live_plot_enabled():
            i = int((self.iteration) % self.ylen)
            y_ind = i
            for j in range(len(self.detector_function.value_names)):
                z_ind = len(self.sweep_functions) + j
                self.TwoD_array[y_ind, :, j] = self.dset[
                    i*self.xlen:(i+1)*self.xlen, z_ind]
                self.secondary_QtPlot.traces[j]['config']['z'] = \
                    self.TwoD_array[:, :, j]

            if (time.time() - self.time_last_2Dplot_update >
                    self.plotting_interval()
                    or self.iteration == len(self.sweep_points)/self.xlen):
                self.time_last_2Dplot_update = time.time()
                self.secondary_QtPlot.update_plot()

    def _set_plotting_interval(self, plotting_interval):
        self.main_QtPlot.interval = plotting_interval
        self.secondary_QtPlot.interval = plotting_interval

    def _get_plotting_interval(self):
        return self.main_QtPlot.interval

    def clear_persitent_plot(self):
        self._persist_dat = None
        self._persist_xlabs = None
        self._persist_ylabs = None

    ##################################
    # Small helper/utility functions #
    ##################################

    def get_data_object(self):
        '''
        Used for external functions to write to a datafile.
        This is used in time_domain_measurement as a hack and is not
        recommended.
        '''
        return self.data_object

    def get_column_names(self):
        self.column_names = []
        self.sweep_par_names = []
        self.sweep_par_units = []

        for sweep_function in self.sweep_functions:
            self.column_names.append(sweep_function.parameter_name+' (' +
                                     sweep_function.unit+')')
            self.sweep_par_names.append(sweep_function.parameter_name)
            self.sweep_par_units.append(sweep_function.unit)

        for i, val_name in enumerate(self.detector_function.value_names):
            self.column_names.append(val_name+' (' +
                                     self.detector_function.value_units[i] + ')')
        return self.column_names

    def create_experimentaldata_dataset(self):
        data_group = self.data_object.create_group('Experimental Data')
        self.dset = data_group.create_dataset(
            'Data', (0, len(self.sweep_functions) +
                     len(self.detector_function.value_names)),
            maxshape=(None, len(self.sweep_functions) +
                      len(self.detector_function.value_names)))
        self.get_column_names()
        self.dset.attrs['column_names'] = h5d.encode_to_utf8(self.column_names)
        # Added to tell analysis how to extract the data
        data_group.attrs['datasaving_format'] = h5d.encode_to_utf8('Version 2')
        data_group.attrs['sweep_parameter_names'] = h5d.encode_to_utf8(
            self.sweep_par_names)
        data_group.attrs['sweep_parameter_units'] = h5d.encode_to_utf8(
            self.sweep_par_units)

        data_group.attrs['value_names'] = h5d.encode_to_utf8(
            self.detector_function.value_names)
        data_group.attrs['value_units'] = h5d.encode_to_utf8(
            self.detector_function.value_units)

    def save_optimization_settings(self):
        '''
        Saves the parameters used for optimization
        '''
        opt_sets_grp = self.data_object.create_group('Optimization settings')
        param_list = dict_to_ordered_tuples(self.af_pars)
        for (param, val) in param_list:
            opt_sets_grp.attrs[param] = str(val)

    def save_instrument_settings(self, data_object=None, *args):
        '''
        uses QCodes station snapshot to save the last known value of any
        parameter. Only saves the value and not the update time (which is
        known in the snapshot)
        '''
        if data_object is None:
            data_object = self.data_object
        if not hasattr(self, 'station'):
            logging.warning('No station object specified, could not save',
                            ' instrument settings')
        else:
            set_grp = data_object.create_group('Instrument settings')
            inslist = dict_to_ordered_tuples(self.station.components)
            for (iname, ins) in inslist:
                instrument_grp = set_grp.create_group(iname)
                par_snap = ins.snapshot()['parameters']
                parameter_list = dict_to_ordered_tuples(par_snap)
                for (p_name, p) in parameter_list:
                    try:
                        val = str(p['value'])
                    except KeyError:
                        val = ''
                    instrument_grp.attrs[p_name] = str(val)

    def save_MC_metadata(self, data_object=None, *args):
        '''
        Saves metadata on the MC (such as timings)
        '''
        set_grp = data_object.create_group('MC settings')

        bt = set_grp.create_dataset('begintime', (9, 1))
        bt[:, 0] = np.array(time.localtime(self.begintime))
        pt = set_grp.create_dataset('preparetime', (9, 1))
        pt[:, 0] = np.array(time.localtime(self.preparetime))
        et = set_grp.create_dataset('endtime', (9, 1))
        et[:, 0] = np.array(time.localtime(self.endtime))

        set_grp.attrs['mode'] = self.mode
        set_grp.attrs['measurement_name'] = self.measurement_name
        set_grp.attrs['live_plot_enabled'] = self.live_plot_enabled()

    def print_progress(self, stop_idx=None):
        if self.verbose():
            acquired_points = self.dset.shape[0]
            total_nr_pts = len(self.get_sweep_points())
            if self.soft_avg() != 1:
                progr = 1 if stop_idx == None else stop_idx/total_nr_pts
                percdone = (self.soft_iteration+progr)/self.soft_avg()*100
            else:
                percdone = acquired_points*1./total_nr_pts*100
            elapsed_time = time.time() - self.begintime
            progress_message = "\r {percdone}% completed \telapsed time: "\
                "{t_elapsed}s \ttime left: {t_left}s".format(
                    percdone=int(percdone),
                    t_elapsed=round(elapsed_time, 1),
                    t_left=round((100.-percdone)/(percdone) *
                                 elapsed_time, 1) if
                    percdone != 0 else '')

            if percdone != 100:
                end_char = ''
            else:
                end_char = '\n'
            print('\r', progress_message, end=end_char)

    def is_complete(self):
        """
        Returns True if enough data has been acquired.
        """
        acquired_points = self.dset.shape[0]
        total_nr_pts = np.shape(self.get_sweep_points())[0]
        if acquired_points < total_nr_pts:
            return False
        elif acquired_points >= total_nr_pts:
            if self.soft_avg() != 1 and self.soft_iteration == 0:
                return False
            else:
                return True

    def print_measurement_start_msg(self):
        if self.verbose():
            if len(self.sweep_functions) == 1:
                print('Starting measurement: %s' % self.get_measurement_name())
                print('Sweep function: %s' %
                      self.get_sweep_function_names()[0])
                print('Detector function: %s'
                      % self.get_detector_function_name())
            else:
                print('Starting measurement: %s' % self.get_measurement_name())
                for i, sweep_function in enumerate(self.sweep_functions):
                    print('Sweep function %d: %s' % (
                        i, self.sweep_function_names[i]))
                print('Detector function: %s'
                      % self.get_detector_function_name())

    def get_datetimestamp(self):
        return time.strftime('%Y%m%d_%H%M%S', time.localtime())

    def get_datawriting_indices(self, new_data=None, pts_per_iter=None):
        """
        Calculates the start and stop indices required for
        storing a hard measurement.
        """
        if new_data is None and pts_per_iter is None:
            raise(ValueError())
        elif new_data is not None:
            if len(np.shape(new_data)) == 1:
                shape_new_data = (len(new_data), 1)
            else:
                shape_new_data = np.shape(new_data)
            shape_new_data = (shape_new_data[0], shape_new_data[1]+1)
            xlen = shape_new_data[0]
        else:
            xlen = pts_per_iter
        if self.mode == 'adaptive':
            max_sweep_points = np.inf
        else:
            max_sweep_points = np.shape(self.get_sweep_points())[0]
        start_idx = int(
            (xlen*(self.iteration)) % max_sweep_points)

        stop_idx = start_idx + xlen

        return start_idx, stop_idx

    ####################################
    # Non-parameter get/set functions  #
    ####################################

    def set_sweep_function(self, sweep_function):
        '''
        Used if only 1 sweep function is set.
        '''
        # If it is not a sweep function, assume it is a qc.parameter
        # and try to auto convert it it
        if not isinstance(sweep_function, swf.Sweep_function):
            sweep_function = wrap_par_to_swf(sweep_function)
        self.sweep_functions = [sweep_function]
        self.set_sweep_function_names(
            [str(sweep_function.name)])

    def get_sweep_function(self):
        return self.sweep_functions[0]

    def set_sweep_functions(self, sweep_functions):
        '''
        Used to set an arbitrary number of sweep functions.
        '''
        sweep_function_names = []
        for i, sweep_func in enumerate(sweep_functions):
            # If it is not a sweep function, assume it is a qc.parameter
            # and try to auto convert it it
            if not hasattr(sweep_func, 'sweep_control'):
                sweep_func = wrap_par_to_swf(sweep_func)
                sweep_functions[i] = sweep_func
            sweep_function_names.append(str(swf.__class__.__name__))
        self.sweep_functions = sweep_functions
        self.set_sweep_function_names(sweep_function_names)

    def get_sweep_functions(self):
        return self.sweep_functions

    def set_sweep_function_names(self, swfname):
        self.sweep_function_names = swfname

    def get_sweep_function_names(self):
        return self.sweep_function_names

    def set_detector_function(self, detector_function,
                              wrapped_det_control='soft'):
        """
        Sets the detector function. If a parameter is passed instead it
        will attempt to wrap it to a detector function.
        """
        if not hasattr(detector_function, 'detector_control'):
            detector_function = wrap_par_to_det(detector_function,
                                                wrapped_det_control)
        self.detector_function = detector_function
        self.set_detector_function_name(detector_function.name)

    def get_detector_function(self):
        return self.detector_function

    def set_detector_function_name(self, dfname):
        self._dfname = dfname

    def get_detector_function_name(self):
        return self._dfname

    ################################
    # Parameter get/set functions  #
    ################################

    def get_git_hash(self):
        self.git_hash = general.get_git_revision_hash()
        return self.git_hash

    def get_measurement_begintime(self):
        self.begintime = time.time()
        return time.strftime('%Y-%m-%d %H:%M:%S')

    def get_measurement_endtime(self):
        self.endtime = time.time()
        return time.strftime('%Y-%m-%d %H:%M:%S')

    def get_measurement_preparetime(self):
        self.preparetime = time.time()
        return time.strftime('%Y-%m-%d %H:%M:%S')

    def set_sweep_points(self, sweep_points):
        self.sweep_points = np.array(sweep_points)
        # line below is because some sweep funcs have their own sweep points
        # attached
        # This is a mighty bad line! Should be adding sweep points to the
        # individual sweep funcs
        if len(np.shape(sweep_points)) == 1:
            self.sweep_functions[0].sweep_points = np.array(sweep_points)

    def get_sweep_points(self):
        if hasattr(self, 'sweep_points'):
            return self.sweep_points
        else:
            return self.sweep_functions[0].sweep_points

    def set_adaptive_function_parameters(self, adaptive_function_parameters):
        """
        adaptive_function_parameters: Dictionary containing options for
            running adaptive mode.

        The following arguments are reserved keywords. All other entries in
        the dictionary get passed to the adaptive function in the measurement
        loop.

        Reserved keywords:
            "adaptive_function":    function
            "x_scale": 1            float rescales values for adaptive function
            "minimize": False       Bool, inverts value to allow minimizing
                                    or maximizing
            "f_termination" None    terminates the loop if the measured value
                                    is smaller than this value
            "par_idx": 0            If a parameter returns multiple values,
                                    specifies which one to use.
        Common keywords (used in python nelder_mead implementation):
            "x0":                   list of initial values
            "initial_step"
            "no_improv_break"
            "maxiter"
        """
        self.af_pars = adaptive_function_parameters

        # scaling should not be used if a "direc" argument is available
        # in the adaptive function itself, if not specified equals 1
        self.x_scale = self.af_pars.pop('x_scale', 1)
        self.par_idx = self.af_pars.pop('par_idx', 0)
        # Determines if the optimization will minimize or maximize
        self.minimize_optimization = self.af_pars.pop('minimize', True)
        self.f_termination = self.af_pars.pop('f_termination', None)

    def get_adaptive_function_parameters(self):
        return self.af_pars

    def set_measurement_name(self, measurement_name):
        if measurement_name is None:
            self.measurement_name = 'Measurement'
        else:
            self.measurement_name = measurement_name

    def get_measurement_name(self):
        return self.measurement_name

    def set_optimization_method(self, optimization_method):
        self.optimization_method = optimization_method

    def get_optimization_method(self):
        return self.optimization_method

    ################################
    # Actual parameters            #
    ################################

    def get_idn(self):
        """
        Required as a standard interface for QCoDeS instruments.
        """
        return {'vendor': 'PycQED', 'model': 'MeasurementControl',
                'serial': '', 'firmware': '2.0'}
Example #29
0
 def test_creation(self):
     ''' Simple test function which created a QtPlot window '''
     plotQ = QtPlot(remote=False, show_window=False, interval=0)
     _ = plotQ.add_subplot()
Example #30
0
    def __init__(self,
                 datadir=None,
                 window_title='Data browser',
                 default_parameter='amplitude',
                 extensions=['dat', 'hdf5'],
                 verbose=1):
        """ Simple viewer for Qcodes data

        Args:

            datadir (string or None): directory to scan for experiments
            default_parameter (string): name of default parameter to plot
        """
        super(DataViewer, self).__init__()
        self.verbose = verbose
        self.default_parameter = default_parameter
        if datadir is None:
            datadir = qcodes.DataSet.default_io.base_location

        self.extensions = extensions

        # setup GUI
        self.dataset = None
        self.text = QtWidgets.QLabel()

        # logtree
        self.logtree = QtWidgets.QTreeView()
        self.logtree.setSelectionBehavior(
            QtWidgets.QAbstractItemView.SelectRows)
        self._treemodel = QtGui.QStandardItemModel()
        self.logtree.setModel(self._treemodel)

        # metatabs
        self.meta_tabs = QtWidgets.QTabWidget()
        self.meta_tabs.addTab(QtWidgets.QWidget(), 'metadata')

        self.__debug = dict()
        if isinstance(QtPlot, QWidget):
            self.qplot = QtPlot()  # remote=False, interval=0)
        else:
            self.qplot = QtPlot(remote=False)  # remote=False, interval=0)
        if isinstance(self.qplot, QWidget):
            self.plotwindow = self.qplot
        else:
            self.plotwindow = self.qplot.win

        topLayout = QtWidgets.QHBoxLayout()
        self.select_dir = QtWidgets.QPushButton()
        self.select_dir.setText('Select directory')

        self.reloadbutton = QtWidgets.QPushButton()
        self.reloadbutton.setText('Reload data')

        self.loadinfobutton = QtWidgets.QPushButton()
        self.loadinfobutton.setText('Preload info')

        self.outCombo = QtWidgets.QComboBox()

        topLayout.addWidget(self.text)
        topLayout.addWidget(self.select_dir)
        topLayout.addWidget(self.reloadbutton)
        topLayout.addWidget(self.loadinfobutton)

        treesLayout = QtWidgets.QHBoxLayout()
        treesLayout.addWidget(self.logtree)
        treesLayout.addWidget(self.meta_tabs)

        vertLayout = QtWidgets.QVBoxLayout()

        vertLayout.addItem(topLayout)
        vertLayout.addItem(treesLayout)
        vertLayout.addWidget(self.plotwindow)

        self.pptbutton = QtWidgets.QPushButton()
        self.pptbutton.setText('Send data to powerpoint')
        self.clipboardbutton = QtWidgets.QPushButton()
        self.clipboardbutton.setText('Copy image to clipboard')

        bLayout = QtWidgets.QHBoxLayout()
        bLayout.addWidget(self.outCombo)
        bLayout.addWidget(self.pptbutton)
        bLayout.addWidget(self.clipboardbutton)

        vertLayout.addItem(bLayout)

        self.setLayout(vertLayout)

        self.setWindowTitle(window_title)
        self.logtree.header().resizeSection(0, 280)

        # disable edit
        self.logtree.setEditTriggers(
            QtWidgets.QAbstractItemView.NoEditTriggers)

        self.setDatadir(datadir)

        self.logtree.doubleClicked.connect(self.logCallback)
        self.outCombo.currentIndexChanged.connect(self.comboCallback)
        self.select_dir.clicked.connect(self.selectDirectory)
        self.reloadbutton.clicked.connect(self.updateLogs)
        self.loadinfobutton.clicked.connect(self.loadInfo)
        self.pptbutton.clicked.connect(self.pptCallback)
        self.clipboardbutton.clicked.connect(self.clipboardCallback)
        if self.verbose >= 2:
            print('created gui...')
        # get logs from disk
        self.updateLogs()
        self.datatag = None

        self.logtree.setColumnHidden(2, True)
        self.logtree.setColumnHidden(3, True)

        self.show()
Example #31
0
class data_viewer(QtWidgets.QMainWindow, Ui_dataviewer):
    """docstring for virt_gate_matrix_GUI"""
    def __init__(self, datadir=None, window_title='Data browser'):
        # set graphical user interface
        instance_ready = True
        self.app = QtCore.QCoreApplication.instance()
        if self.app is None:
            instance_ready = False
            self.app = QtWidgets.QApplication([])

        super(QtWidgets.QMainWindow, self).__init__()
        self.app.setStyleSheet(qdarkstyle.load_stylesheet_pyqt5())
        self.setupUi(self)

        if datadir is None:
            datadir = DataSet.default_io.base_location

        # set-up tree view for data
        self._treemodel = QtGui.QStandardItemModel()
        self.data_view.setModel(self._treemodel)
        self.tabWidget.addTab(QtWidgets.QWidget(), 'metadata')

        self.qplot = QtPlot(remote=False)
        self.plotwindow = self.qplot.win
        self.qplot.max_len = 10

        self.horizontalLayout_4.addWidget(self.plotwindow)

        # Fix some initializations in the window
        self.splitter_2.setSizes([int(self.height() / 2)] * 2)
        self.splitter.setSizes(
            [int(self.width() / 3),
             int(2 * self.width() / 3)])

        # connect callbacks
        self.data_view.doubleClicked.connect(self.logCallback)
        self.actionReload_data.triggered.connect(self.updateLogs)
        self.actionPreload_all_info.triggered.connect(self.loadInfo)
        self.actionAdd_directory.triggered.connect(self.selectDirectory)
        self.filter_button.clicked.connect(
            lambda: self.updateLogs(filter_str=self.filter_input.text()))
        self.send_ppt.clicked.connect(self.pptCallback)
        self.copy_im.clicked.connect(self.clipboardCallback)
        self.split_data.clicked.connect(self.split_dataset)

        # initialize defaults
        self.extensions = ['dat', 'hdf5']
        self.dataset = None
        self.datatag = None
        self.datadirlist = []
        self.datadirindex = 0
        self.color_list = [
            pg.mkColor(cl) for cl in qcodes.plots.pyqtgraph.color_cycle
        ]
        self.current_params = []
        self.subpl_ind = dict()
        self.splitted = False

        # add default directory
        self.addDirectory(datadir)
        self.updateLogs()

        # Launch app
        self.show()
        if instance_ready == False:
            self.app.exec()

    def find_datafiles(self,
                       datadir,
                       extensions=['dat', 'hdf5'],
                       show_progress=True):
        """ Find all datasets in a directory with a given extension """
        dd = []
        for e in extensions:
            dd += self.findfiles(datadir, e)
        dd.sort()
        datafiles = sorted(dd)
        return datafiles

    def loadInfo(self):
        logging.debug('loading info')
        try:
            for row in range(self._treemodel.rowCount()):
                index = self._treemodel.index(row, 0)
                i = 0
                while (index.child(i, 0).data() is not None):
                    filename = index.child(i, 3).data()
                    loc = os.path.dirname(filename)
                    tempdata = qcodes.DataSet(loc)
                    tempdata.read_metadata()
                    infotxt = self.getArrayStr(tempdata.metadata)
                    self._treemodel.setData(index.child(i, 1), infotxt)
                    i = i + 1
        except Exception as e:
            logging.warning(e)

    def setDatadir(self, newindex):
        logging.info(f'Setting datadir with index: {newindex}')
        oldindex = self.datadirindex
        self.datadirindex = newindex
        datadir = self.datadirlist[newindex]

        self.io = DiskIO(datadir)
        logging.info('DataViewer: data directory %s' % datadir)
        self.logfile.setText('Log files at %s' % datadir)

        self.menuFolder.actions()[oldindex + 1].setText(
            self.menuFolder.actions()[oldindex + 1].text()[2:])
        self.menuFolder.actions()[newindex + 1].setText(
            '>>' + self.menuFolder.actions()[newindex + 1].text())
        self.updateLogs()

    def selectDirectory(self):
        from qtpy.QtWidgets import QFileDialog
        d = QtWidgets.QFileDialog(caption='Select data directory')
        d.setFileMode(QFileDialog.Directory)
        if d.exec():
            datadir = d.selectedFiles()[0]
            self.addDirectory(datadir)

    def addDirectory(self, datadir):
        newindex = len(self.datadirlist)
        self.datadirlist.append(datadir)
        if len(self.datadirlist) == 1:
            datadir = '>>' + datadir
        new_act = QtWidgets.QAction(datadir, self)
        new_act.triggered.connect(partial(self.setDatadir, newindex))
        self.menuFolder.addAction(new_act)
        self.setDatadir(newindex)

    def updateLogs(self, filter_str=None):
        ''' Update the list of measurements '''
        logging.info('updating logs')
        model = self._treemodel

        self.datafiles = self.find_datafiles(
            self.datadirlist[self.datadirindex], self.extensions)
        dd = self.datafiles

        if filter_str:
            dd = [s for s in dd if filter_str in s]

        logging.info(f'DataViewer: found {len(dd)} files')

        model.clear()
        model.setHorizontalHeaderLabels(
            ['Log', 'Arrays', 'location', 'filename'])

        logs = dict()
        for i, d in enumerate(dd):
            try:
                datetag, logtag = d.split(os.sep)[-3:-1]
                if datetag not in logs:
                    logs[datetag] = dict()
                logs[datetag][logtag] = d
            except Exception as e:
                print(e)
                pass
        self.logs = logs

        logging.debug('DataViewer: create gui elements')
        for i, datetag in enumerate(sorted(logs.keys())[::-1]):
            logging.debug(f'DataViewer: datetag {datetag}')

            parent1 = QtGui.QStandardItem(datetag)
            for j, logtag in enumerate(sorted(logs[datetag])):
                filename = logs[datetag][logtag]
                child1 = QtGui.QStandardItem(logtag)
                child2 = QtGui.QStandardItem('info about plot')
                logging.debug(f'datetag: {datetag}, logtag: {logtag}')
                child3 = QtGui.QStandardItem(os.path.join(datetag, logtag))
                child4 = QtGui.QStandardItem(filename)
                parent1.appendRow([child1, child2, child3, child4])
            model.appendRow(parent1)
            self.data_view.setColumnWidth(0, 240)
            self.data_view.setColumnHidden(2, True)
            self.data_view.setColumnHidden(3, True)

            logging.debug('DataViewer: updateLogs done')

    def logCallback(self, index):
        """ Function called when. a log entry is selected """
        logging.info('logCallback: index %s' % str(index))
        oldtab_index = self.tabWidget.currentIndex()
        pp = index.parent()
        row = index.row()
        tag = pp.child(row, 2).data()
        filename = pp.child(row, 3).data()
        self.filename = filename
        self.datatag = tag
        if tag is None:
            return
        logging.debug(
            f'DataViewer logCallback: tag {tag}, filename {filename}')

        try:
            logging.debug('DataViewer: load tag %s' % tag)
            data = self.loadData(filename, tag)
            if not data:
                raise ValueError('File invalid (%s) ...' % filename)
            self.dataset = data
            self.updateMetaTabs()
            try:
                self.tabWidget.setCurrentIndex(oldtab_index)
            except:
                pass
            data_keys = data.arrays.keys()
            infotxt = self.getArrayStr(data.metadata)
            q = pp.child(row, 1).model()
            q.setData(pp.child(row, 1), infotxt)
            self.resetComboItems(data, data_keys)
        except Exception as e:
            print('logCallback! error: %s' % str(e))
            logging.exception(e)

    def resetComboItems(self, data, keys):
        # Clearing old stuff
        self.clearLayout(self.data_select_lay)
        self.qplot.clear()
        self.boxes = dict()
        self.box_labels = dict()
        self.param_keys = list()
        to_plot = list()

        # Loop through keys and add graphics items
        for key in keys:
            if not getattr(data, key).is_setpoint:
                box = QtWidgets.QCheckBox()
                box.clicked.connect(self.checkbox_callback)
                box.setText(key)
                label = QtWidgets.QLabel()
                self.data_select_lay.addRow(box, label)
                self.boxes[key] = box
                self.box_labels[key] = label
                self.param_keys.append(key)
                if key in self.subpl_ind.keys():
                    self.boxes[key].setChecked(True)
                    to_plot.append(key)
        self.data_select_lay.setLabelAlignment(QtCore.Qt.AlignLeft)

        # If no old parameters can be plotted, defined first one
        if not to_plot:
            def_key = list(self.boxes.values())[0].text()
            to_plot.append(self.boxes[def_key].text())
            self.boxes[def_key].setChecked(True)

        # Update the parameter plots
        self.subpl_ind = dict()
        self.current_params = list()
        self.updatePlots(to_plot)
        if self.splitted:
            self.split_dataset()

    def clearPlots(self):
        self.qplot.clear()
        self.current_params = list(
            set(self.current_params) - set(self.subpl_ind.keys()))
        self.subpl_ind = dict()

    def clearLayout(self, layout):
        if layout is not None:
            while layout.count():
                item = layout.takeAt(0)
                widget = item.widget()
                if widget is not None:
                    widget.deleteLater()
                else:
                    self.clearLayout(item.layout())

    def checkbox_callback(self, state):
        if self.splitted:
            self.clearPlots()
            self.splitted = False
        to_plot = []
        for param in self.param_keys:
            box = self.boxes[param]
            if box.isChecked():
                to_plot.append(box.text())
        self.updatePlots(to_plot)

    def updatePlots(self, param_names):
        for param_name in set(param_names + self.current_params):
            param = getattr(self.dataset, param_name)
            if param_name not in param_names:
                self.current_params.remove(param_name)
                if param.shape == (1, ):
                    self.removeValue(param_name)
                elif len(param.shape) < 3:
                    self.removePlot(param_name)
            elif param_name not in self.current_params:
                self.current_params.append(param_name)
                if param.shape == (1, ):
                    self.addValue(param_name)
                elif len(param.shape) < 3:
                    self.addPlot(param_name)

    def addPlot(self, plot):
        logging.info(f'adding param {plot}')
        self.subpl_ind[plot] = len(self.subpl_ind)
        self.qplot.add(getattr(self.dataset, plot),
                       subplot=len(self.subpl_ind),
                       color=self.color_list[0])

    def removePlot(self, plot):
        logging.info(f'removing param {plot}')
        # Deleting graphics items
        plot_index = self.subpl_ind[plot]
        subplot = self.qplot.subplots[plot_index]
        subplot.clear()
        self.qplot.win.removeItem(subplot)
        subplot.deleteLater()

        try:
            hist = self.qplot.traces[plot_index]['plot_object']['hist']
            self.qplot.win.removeItem(hist)
            hist.deleteLater()
        except:
            pass

        # Own bookkeeping
        self.subpl_ind.pop(plot)

        # Removing from qcodes qplot (does not have proper function for this)
        self.qplot.traces.pop(plot_index)
        self.qplot.subplots.remove(subplot)
        for (key, val) in self.subpl_ind.items():
            if val > plot_index:
                self.subpl_ind[key] = val - 1

    def addValue(self, plot):
        val = getattr(self.dataset, plot).ndarray[0]
        self.box_labels[plot].setText(str(val))

    def removeValue(self, plot):
        self.box_labels[plot].setText('')

    def loadData(self, filename, tag):
        location = os.path.split(filename)[0]
        data = qcodes.data.data_set.load_data(location)
        return data

    def _create_meta_tree(self, meta_dict):
        metatree = QtWidgets.QTreeView()
        _metamodel = QtGui.QStandardItemModel()
        metatree.setModel(_metamodel)
        metatree.setEditTriggers(QtWidgets.QAbstractItemView.NoEditTriggers)

        _metamodel.setHorizontalHeaderLabels(['metadata', 'value'])

        try:
            self.fill_item(_metamodel, meta_dict)
            return metatree

        except Exception as ex:
            print(ex)

    def fill_item(self, item, value):
        ''' recursive population of tree structure with a dict '''
        def new_item(parent, text, val=None):
            child = QtGui.QStandardItem(text)
            self.fill_item(child, val)
            parent.appendRow(child)

        if value is None:
            return
        elif isinstance(value, dict):
            for key, val in sorted(value.items()):
                if type(val) in [
                        str, float, int
                ] or (type(val) is list
                      and not any(isinstance(el, list) for el in val)):
                    child = [
                        QtGui.QStandardItem(str(key)),
                        QtGui.QStandardItem(str(val))
                    ]
                    item.appendRow(child)
                else:
                    new_item(item, str(key), val)
        else:
            new_item(item, str(value))

    def parse_gates(self, gates_obj):
        gate_dict = dict()

        for (gate, val) in gates_obj['parameters'].items():
            if gate != 'IDN':
                gate_dict[gate] = val['value']

        return gate_dict

    def updateMetaTabs(self):
        ''' Update metadata tree '''
        meta = self.dataset.metadata
        self.tabWidget.clear()

        try:
            gate_tree = self.parse_gates(
                meta['station']['instruments']['gates'])
            self.tabWidget.addTab(self._create_meta_tree(gate_tree), 'gates')
        except:
            pass

        self.tabWidget.addTab(self._create_meta_tree(meta), 'metadata')

        if 'pc0' in meta.keys():
            self.pulse_plot = pg.PlotWidget()
            self.pulse_plot.addLegend()

            try:
                baseband_freqs = meta['LOs']
            except:
                pass
            end_time = 0
            for name, pdict in meta['pc0'].items():
                if 'baseband' in name:
                    end_time = max([end_time] +
                                   [pulse['stop'] for pulse in pdict.values()])
            for (j, (name, pdict)) in enumerate(meta['pc0'].items()):
                legend_name = name.replace('_baseband',
                                           '').replace('_pulses', '')
                x_plot = list()
                y_plot = list()
                if 'baseband' in name:
                    timepoints = set([
                        x[key] for x in meta['pc0'][name].values()
                        for key in ['start', 'stop']
                    ])
                    timepoints.add(end_time)
                    for tp in sorted(timepoints):
                        point1 = 0
                        point2 = 0
                        for (seg_name, seg_dict) in meta['pc0'][name].items():
                            if seg_dict['start'] < tp and seg_dict[
                                    'stop'] > tp:  # active segement
                                point1 += tp / (
                                    seg_dict['stop'] - seg_dict['start']
                                ) * (seg_dict['v_stop'] - seg_dict['v_stop'])
                                point2 += tp / (
                                    seg_dict['stop'] - seg_dict['start']
                                ) * (seg_dict['v_stop'] - seg_dict['v_stop'])
                            elif seg_dict['start'] == tp:
                                point2 += seg_dict['v_start']
                            elif seg_dict['stop'] == tp:
                                point1 += seg_dict['v_stop']
                        x_plot += [tp, tp]
                        y_plot += [point1, point2]

                elif 'pulses' in name:
                    try:
                        baseband = baseband_freqs[name.replace('_pulses', '')]
                    except:
                        logging.warning(
                            'No baseband frequency found, assuming 0')
                        baseband = 0

                    x = list()
                    y = list()
                    for (seg_name, seg_dict) in meta['pc0'][name].items():
                        x_ar = np.arange(seg_dict['start'], seg_dict['stop'])
                        xx_ar = x_ar - seg_dict['start']
                        f_rl = (seg_dict['frequency'] - baseband) / 1e9
                        y_ar = np.sin(
                            2 * np.pi * f_rl * xx_ar +
                            seg_dict['start_phase']) * seg_dict['amplitude']
                        x = x + list(x_ar) + [seg_dict['stop']]
                        y = y + list(y_ar) + [0]
                        x_plot = x
                        y_plot = y
                self.pulse_plot.setLabel('left', 'Voltage', 'mV')
                self.pulse_plot.setLabel('bottom', 'Time', 'ns')
                self.pulse_plot.plot(x_plot,
                                     y_plot,
                                     pen=self.color_list[j %
                                                         len(self.color_list)],
                                     name=legend_name)

            self.tabWidget.addTab(self.pulse_plot, 'AWG Pulses')

    def pptCallback(self):
        if self.dataset is None:
            print('no data selected')
            return
        addPPT_dataset(self.dataset, customfig=self.qplot)

    def clipboardCallback(self):
        self.qplot.copyToClipboard()

    def getArrayStr(self, metadata):
        params = []
        infotxt = ''
        try:
            if 'loop' in metadata.keys():
                sv = metadata['loop']['sweep_values']
                params.append(
                    '%s [%.2f to %.2f %s]' %
                    (sv['parameter']['label'], sv['values'][0]['first'],
                     sv['values'][0]['last'], sv['parameter']['unit']))

                for act in metadata['loop']['actions']:
                    if 'sweep_values' in act.keys():
                        sv = act['sweep_values']
                        params.append(
                            '%s [%.2f - %.2f %s]' %
                            (sv['parameter']['label'],
                             sv['values'][0]['first'], sv['values'][0]['last'],
                             sv['parameter']['unit']))
                infotxt = ' ,'.join(params) + ' | '
            infotxt = infotxt + ', '.join([('%s' % (v['label'])) for (
                k, v) in metadata['arrays'].items() if not v['is_setpoint']])

        except BaseException:
            infotxt = 'info about plot'

        return infotxt

    def split_dataset(self):
        to_split = []
        for bp in self.subpl_ind.keys():
            plot_shape = getattr(self.dataset, bp).shape
            if len(plot_shape) == 2:
                to_split.append(bp)
        self.clearPlots()
        for (i, zname) in enumerate(to_split):
            tmp = getattr(self.dataset, zname)
            yname = tmp.set_arrays[0].array_id
            xname = tmp.set_arrays[1].array_id

            try:
                ii = np.where(
                    np.isnan(self.dataset.arrays[zname][:, -1]) == True)[0][0]
            except:
                ii = len(self.dataset.arrays[yname][:])
            even = list(range(0, ii, 2))
            odd = list(range(1, ii, 2))

            self.qplot.add(self.dataset.arrays[xname][0],
                           self.dataset.arrays[yname][odd],
                           self.dataset.arrays[zname][odd],
                           ylabel=self.dataset.arrays[yname].label,
                           xlabel=self.dataset.arrays[xname].label,
                           zlabel=self.dataset.arrays[zname].label + '_odd',
                           yunit=self.dataset.arrays[yname].unit,
                           zunit=self.dataset.arrays[zname].unit,
                           subplot=2 * i + 1)
            self.qplot.add(self.dataset.arrays[xname][0],
                           self.dataset.arrays[yname][even],
                           self.dataset.arrays[zname][even],
                           ylabel=self.dataset.arrays[yname].label,
                           xlabel=self.dataset.arrays[xname].label,
                           zlabel=self.dataset.arrays[zname].label + '_even',
                           yunit=self.dataset.arrays[yname].unit,
                           zunit=self.dataset.arrays[zname].unit,
                           subplot=2 * i + 2)
            self.splitted = True

    def findfiles(self, path, extension):
        filelist = list()
        for dirname, dirnames, filenames in os.walk(path):
            # print path to all filenames.
            for filename in filenames:
                fullfile = os.path.join(dirname, filename)
                if fullfile.split('.')[-1] == extension:
                    filelist.append(fullfile)
        return filelist
class Distortion_corrector(Instrument):
    def __init__(self,
                 name,
                 nr_plot_points: int = 1000,
                 sampling_rate: float = 2.4e9,
                 auto_save_plots: bool = True,
                 **kw):
        '''
        Instantiates an object.

        Args:
            kernel_object (Instrument):
                    kernel object instrument that handles applying kernels to
                    flux pulses.
            square_amp (float):
                    Amplitude of the square pulse that is applied. This is
                    needed for correct normalization of the step response.
            nr_plot_points (int):
                    Number of points of the waveform that are plotted. Can be
                    changed in self.cfg_nr_plot_points().
        '''
        super().__init__(name, **kw)
        # Initialize instance variables
        # Plotting
        self._y_min = 0
        self._y_max = 1
        self._stop_idx = -1
        self._start_idx = 0
        self._t_start_loop = 0  # sets x range for plotting during loop
        self._t_stop_loop = 30e-6
        self.add_parameter('cfg_nr_plot_points',
                           initial_value=nr_plot_points,
                           parameter_class=ManualParameter)
        self.sampling_rate = sampling_rate
        self.add_parameter('cfg_sampling_rate',
                           initial_value=sampling_rate,
                           parameter_class=ManualParameter)
        self.add_parameter('instr_dist_kern',
                           parameter_class=InstrumentRefParameter)

        # Files
        self.filename = ''
        # where traces and plots are saved
        # self.data_dir = self.kernel_object.kernel_dir()
        self._iteration = 0
        self.auto_save_plots = auto_save_plots

        # Data
        self.waveform = []
        self.time_pts = []
        self.new_step = []

        # Fitting
        self.known_fit_models = ['exponential', 'high-pass', 'spline']
        self.fit_model = None
        self.edge_idx = None
        self.fit_res = None
        self.predicted_waveform = None

        # Default fit model used in the interactive loop
        self._fit_model_loop = 'exponential'

        self._loop_helpstring = str(
            'h:      Print this help.\n'
            'q:      Quit the loop.\n'
            'm:      Remeasures the trace. \n'
            'p <pars>:\n'
            '        Print the parameters of the last fit if pars not given.\n'
            '        If pars are given in the form of JSON string, \n'
            '        e.g., {"parA": a, "parB": b} the parameters of the last\n'
            '        fit are updated with those provided.'
            's <filename>:\n'
            '        Save the current plot to "filename.png".\n'
            'model <name>:\n'
            '        Choose the fit model that is used.\n'
            '        Available models:\n'
            '           ' + str(self.known_fit_models) + '\n'
            'xrange <min> <max>:\n'
            '        Set the x-range of the plot to (min, max). The points\n'
            '        outside this range are not plotted. The number of\n'
            '        points plotted in the given interval is fixed to\n'
            '        self.cfg_nr_plot_points() (default=1000).\n'
            'square_amp <amp> \n'
            '        Set the square_amp used to normalize measured waveforms.\n'
            '        If amp = "?" the current square_amp is printed.')

        # Make window for plots
        self.vw = QtPlot(window_title=name, figsize=(600, 400))

    # def load_kernel_file(self, filename):
    #     '''
    #     Loads kernel dictionary (containing kernel and metadata) from a JSON
    #     file. This function looks only in the directory
    #     self.kernel_object.kernel_dir() for the file.
    #     Returns a dictionary of the kernel and metadata.
    #     '''
    #     with open(os.path.join(self.kernel_object.kernel_dir(),
    #                            filename)) as infile:
    #         data = json.load(infile)
    #     return data

    # def save_kernel_file(self, kernel_dict, filename):
    #     '''
    #     Saves kernel dictionary (containing kernel and metadata) to a JSON
    #     file in the directory self.kernel_object.kernel_dir().
    #     '''
    #     directory = self.kernel_object.kernel_dir()
    #     if not os.path.exists(directory):
    #         os.makedirs(directory)
    #     with open(os.path.join(directory, filename),
    #               'w') as outfile:
    #         json.dump(kernel_dict, outfile, indent=True, sort_keys=True)

    def save_plot(self, filename):
        try:
            directory = self.kernel_object.kernel_dir()
            if not os.path.exists(directory):
                os.makedirs(directory)
            # FIXME: saving disabled as it is currently broken.
            # self.vw.save(os.path.join(self.kernel_object.kernel_dir(),
            #                           filename))
        except Exception as e:
            logging.warning('Could not save plot.')

    # def open_new_correction(self, kernel_length, AWG_sampling_rate, name):
    #     '''
    #     Opens a new correction with name 'filename', i.e. initializes the
    #     combined kernel to a Dirac delta and empties kernel_list of the
    #     kernel object associated with self.

    #     Args:
    #         kernel_length (float):
    #                 Length of the corrections kernel in s.
    #         AWG_sampling_rate (float):
    #                 Sampling rate of the AWG generating the flux pulses in Hz.
    #         name (string):
    #                 Name for the new kernel. The files will be named after
    #                 this, but with different suffixes (e.g. '_combined.json').
    #     '''
    #     self.kernel_length = int(kernel_length * AWG_sampling_rate)
    #     self.filename = name
    #     self._iteration = 0

    #     # Initialize kernel to Dirac delta
    #     init_ker = np.zeros(self.kernel_length)
    #     init_ker[0] = 1
    #     self.kernel_combined_dict = {
    #         'metadata': {},  # dictionary of kernel dictionaries
    #         'kernel': list(init_ker),
    #         'iteration': 0
    #     }
    #     self.save_kernel_file(self.kernel_combined_dict,
    #                           '{}_combined.json'.format(self.filename))

    #     # Configure kernel object
    #     self.kernel_object.add_kernel_to_kernel_list(
    #         '{}_combined.json'.format(self.filename))

    # def resume_correction(self, filename):
    #     '''
    #     Loads combined kernel from the specified file and prepares for adding
    #     new corrections to that kernel.
    #     '''
    #     # Remove '_combined.json' from filename
    #     self.filename = '_'.join(filename.split('_')[:-1])
    #     self.kernel_combined_dict = self.load_kernel_file(filename)
    #     self._iteration = self.kernel_combined_dict['iteration']
    #     self.kernel_length = len(self.kernel_combined_dict['kernel'])

    #     # Configure kernel object
    #     self.kernel_object.kernel_list([])
    #     self.kernel_object.add_kernel_to_kernel_list(filename)

    # def empty_kernel_list(self):
    #     self.kernel_object.kernel_list([])

    def measure_trace(self, verbose=True):
        raise NotImplementedError(
            'Base class is not attached to physical instruments and does not '
            'implement measurements.')

    def fit_exp_model(self, start_time_fit, end_time_fit):
        '''
        Fits an exponential of the form
            A * exp(-t/tau) + offset
        to the last trace that was measured (self.waveform).
        The fit model and result are saved in self.fit_model and self.fit_res,
        respectively. The new predistortion kernel and information about the
        fit is stored in self.new_kernel_dict.

        Args:
            start_time_fit (float): start of the fitted interval
            end_time_fit (float):   end of the fitted interval
        '''
        self._start_idx = np.argmin(np.abs(self.time_pts - start_time_fit))
        self._stop_idx = np.argmin(np.abs(self.time_pts - end_time_fit))

        # Prepare the fit model
        self.fit_model = lmfit.Model(fm.gain_corr_ExpDecayFunc)
        self.fit_model.set_param_hint('gc',
                                      value=self.waveform[self._stop_idx],
                                      vary=True)
        self.fit_model.set_param_hint('amp',
                                      value=(self.waveform[self._start_idx] -
                                             self.waveform[self._stop_idx]),
                                      vary=True)
        self.fit_model.set_param_hint('tau',
                                      value=end_time_fit - start_time_fit,
                                      vary=True)
        params = self.fit_model.make_params()

        # Do the fit
        fit_res = self.fit_model.fit(
            data=self.waveform[self._start_idx:self._stop_idx],
            t=self.time_pts[self._start_idx:self._stop_idx],
            params=params)

        self.fitted_waveform = fit_res.eval(
            t=self.time_pts[self._start_idx:self._stop_idx])

        # Analytic form of the predistorted square pulse (input that creates a
        # square pulse at the output)
        amp = fit_res.best_values['amp']
        tau = fit_res.best_values['tau']

        # Check if parameters are physical and print warnings if not
        if tau < 0:
            print('Warning: unphysical tau = {} (expect tau > 0).'.format(tau))

        # Save the results
        self.fit_res = fit_res
        self.predicted_waveform = kf.exponential_decay_correction(
            self.waveform,
            tau=tau,
            amp=amp,
            sampling_rate=self.scope_sampling_rate)

    def fit_high_pass(self, start_time_fit, end_time_fit):
        '''
        Fits a model for a simple RC high-pass
            exp(-t/tau), tau = RC
        to the last trace that was measured (self.waveform).
        The fit model and result are saved in self.fit_model and self.fit_res,
        respectively. The new predistortion kernel and information about the
        fit is stored in self.new_kernel_dict.

        Args:
            start_time_fit (float): start of the fitted interval
            end_time_fit (float):   end of the fitted interval
        '''
        self._start_idx = np.argmin(np.abs(self.time_pts - start_time_fit))
        self._stop_idx = np.argmin(np.abs(self.time_pts - end_time_fit))

        # Prepare the fit model: exponential, where only tau is varied
        self.fit_model = lmfit.Model(fm.ExpDecayFunc)
        self.fit_model.set_param_hint('tau',
                                      value=end_time_fit - start_time_fit,
                                      vary=True)
        self.fit_model.set_param_hint('offset', value=0, vary=False)
        self.fit_model.set_param_hint('amplitude', value=1, vary=True)
        self.fit_model.set_param_hint('n', value=1, vary=False)
        params = self.fit_model.make_params()

        # Do the fit
        fit_res = self.fit_model.fit(
            data=self.waveform[self._start_idx:self._stop_idx],
            t=self.time_pts[self._start_idx:self._stop_idx],
            params=params)
        self.fitted_waveform = fit_res.eval(
            t=self.time_pts[self._start_idx:self._stop_idx])

        tau = fit_res.best_values['tau']

        # Check if parameters are physical and print warnings if not
        if tau < 0:
            print('Warning: unphysical tau = {} (expect tau > 0).'.format(tau))
        # Save the fit results and predicted correction
        self.fit_res = fit_res

        self.predicted_waveform = kf.bias_tee_correction(
            self.waveform, tau=tau, sampling_rate=self.scope_sampling_rate)

    def fit_spline(self,
                   start_time_fit,
                   end_time_fit,
                   s=0.001,
                   weight_tau='inf'):
        '''
        Fit the data using a spline interpolation.
        The fit model and result are saved in self.fit_model and self.fit_res,
        respectively. The new predistortion kernel and information about the
        fit is stored in self.new_kernel_dict.

        Args:
            start_time_fit (float):
                    Start of the fitted interval.
            end_time_fit (float):
                    End of the fitted interval.
            s (float):
                    Smoothing condition for the spline. See documentation on
                    scipy.interpolate.splrep for more information.
            weight_tau (float or 'auto'):
                    The points are weighted by a decaying exponential with
                    time constant weight_tau.
                    If this is 'auto' the time constant is chosen to be
                    end_time_fit.
                    If this is 'inf' all weights are set to 1.
                    Smaller weight means the spline can have a larger
                    distance from this point. See documentation on
                    scipy.interpolate.splrep for more information.
        '''
        self._start_idx = np.argmin(np.abs(self.time_pts - start_time_fit))
        self._stop_idx = np.argmin(np.abs(self.time_pts - end_time_fit))

        if weight_tau == 'auto':
            weight_tau = end_time_fit

        if weight_tau == 'inf':
            splWeights = np.ones(self._stop_idx - self._start_idx)
        else:
            splWeights = np.exp(
                -self.time_pts[self._start_idx:self._stop_idx] / weight_tau)

        splTuple = sc_intpl.splrep(
            x=self.time_pts[self._start_idx:self._stop_idx],
            y=self.waveform[self._start_idx:self._stop_idx],
            w=splWeights,
            s=s)
        splStep = sc_intpl.splev(self.time_pts[self._start_idx:self._stop_idx],
                                 splTuple,
                                 ext=3)

        # Pad step response with avg of last 10 points (assuming the user has
        # chosen the range such that the response has become flat)
        splStep = np.concatenate(
            (splStep, np.ones(self.kernel_length - len(splStep)) *
             np.mean(splStep[-10:])))

        self.fit_res = None
        self.fit_model = None
        self.fitted_waveform = splStep[:self._stop_idx - self._start_idx]

        # Calculate the kernel and invert it.
        h = np.empty_like(splStep)
        h[0] = splStep[0]
        h[1:] = splStep[1:] - splStep[:-1]

        filterMatrix = np.zeros((len(h), len(h)))
        for n in range(len(h)):
            for m in range(n + 1):
                filterMatrix[n, m] = h[n - m]

        new_ker = scipy.linalg.inv(filterMatrix)[:, 0]
        self.new_step = np.convolve(new_ker,
                                    np.ones(len(splStep)))[:len(splStep)]
        self.new_kernel_dict = {
            'name': self.filename + '_' + str(self._iteration),
            'filter_params': {},
            'fit': {
                'model': 'spline',
                's': s,
                'weight_tau': weight_tau
            },
            'kernel': list(new_ker)
        }

    def plot_trace(self,
                   start_time=-.5e-6,
                   stop_time=10e-6,
                   nr_plot_pts=4000,
                   save_y_range=True):
        '''
        Plot last trace that was measured (self.waveform).
        Args:
            start_time (float): Start of the plotted interval.
            stop_time (float):  End of the plotted interval.
            save_y_range (bool):
                                Keep the current y-range of the plot.
        '''
        start_idx = np.argmin(np.abs(self.time_pts - start_time))
        stop_idx = np.argmin(np.abs(self.time_pts - stop_time))
        step = max(int(len(self.time_pts[start_idx:stop_idx]) // nr_plot_pts),
                   1)

        # Save the y-range of the plot if a window is open.
        err = False
        try:
            x_range, y_range = self.vw.subplots[0].getViewBox().viewRange()
        except Exception as e:
            print(e)
            err = True

        plot_t_pts = self.time_pts[:len(self.waveform)]

        # Plot
        self.vw.clear()
        self.vw.add(x=plot_t_pts[start_idx:stop_idx:step],
                    y=self.waveform[start_idx:stop_idx:step],
                    symbol='o',
                    symbolSize=5,
                    name='Measured waveform')

        if self.predicted_waveform is not None:
            start_idx = np.argmin(np.abs(self.time_pts - start_time))
            stop_idx = np.argmin(np.abs(self.time_pts - stop_time))
            step = max(
                int(len(self.time_pts[start_idx:stop_idx]) // nr_plot_pts), 1)
            self.vw.add(x=self.time_pts[start_idx:stop_idx:step],
                        y=self.predicted_waveform[start_idx:stop_idx:step],
                        name='Predicted waveform')

        self.vw.add(x=[start_time, stop_time],
                    y=[self.waveform[stop_idx]] * 2,
                    color=(150, 150, 150))
        self.vw.add(x=[start_time, stop_time],
                    y=[0] * 2,
                    color=(150, 150, 150))
        self.vw.add(x=[start_time, stop_time],
                    y=[-self.waveform[stop_idx]] * 2,
                    color=(150, 150, 150))

        # Set the y-range to previous value
        if save_y_range and not err:
            self.vw.subplots[0].setYRange(y_range[0], y_range[1])

        # Labels need to be set in the end, else they don't show sometimes
        self.vw.subplots[0].getAxis('bottom').setLabel('t', 's')
        self.vw.subplots[0].getAxis('left').setLabel('Amplitude', 'V')

    def plot_fit(self,
                 start_time=0,
                 stop_time=10e-6,
                 save_y_range=True,
                 nr_plot_pts=4000):
        '''
        Plot last trace that was measured (self.waveform) and the latest fit.
        Args:
            start_time (float): Start of the plotted interval.
            stop_time (float):  End of the plotted interval.
            save_y_range (bool):
                                Keep the current y-range of the plot.
        '''
        self.plot_trace(start_time=start_time,
                        stop_time=stop_time,
                        save_y_range=save_y_range,
                        nr_plot_pts=nr_plot_pts)

        self.vw.add(x=self.time_pts[self._start_idx:self._stop_idx],
                    y=self.fitted_waveform,
                    color='#2ca02c',
                    name='Fit')

        # Labels need to be set in the end, else they don't show sometimes
        self.vw.subplots[0].getAxis('bottom').setLabel('t', 's')
        self.vw.subplots[0].getAxis('left').setLabel('amp', 'V')

    def test_new_kernel(self):
        '''
        Save the new kernel self.new_kernel_dict to its own file and add it to
        the kernel list of the kernel object.
        '''

        self._iteration
        dist_kern = self.instr_dist_kern.get_instr()

        if self._fit_model_loop == 'high-pass':
            tau = self.fit_res.best_values['tau']
            model = {'model': 'high-pass', 'params': {'tau': tau}}
            dist_kern.set('filter_model_{:02}'.format(self._iteration), model)
        elif self._fit_model_loop == 'exponential':
            tau = self.fit_res.best_values['tau']
            amp = self.fit_res.best_values['amp']
            model = {
                'model': 'exponential',
                'params': {
                    'tau': tau,
                    'amp': amp
                }
            }
            dist_kern.set('filter_model_{:02}'.format(self._iteration), model)
        else:
            raise NotImplementedError

    def apply_new_kernel(self):
        '''
        The correction number (self._iteration) is incremented, such that
        the kernel file for the latest distortion is not overwritten anymore.
        '''
        self._iteration += 1  # This correction is considered completed.

    def discard_new_kernel(self):
        '''
        Removes a the last kernel that was added from the distortions.
        '''
        dist_kern = self.instr_dist_kern.get_instr()
        dist_kern.set('filter_model_{:02}'.format(self._iteration), {})

    def interactive_loop(self):
        '''
        Starts interactive loop to iteratively add corrections.
        '''
        # Loop:
        # 1. Measure trace and plot
        # 2. Fit and plot
        # 3. Test correction and plot
        #   -> discard: back to 2.
        #   -> approve: continue with 4.
        # 4. Apply correction
        #   -> quit?
        #   -> back to 2.
        print('********\n'
              'Interactive room-temperature distortion corrections\n'
              '********\n'
              'At any prompts you may use these commands:\n' +
              self._loop_helpstring)

        while True:
            inp = input('New kernel? ([y]/n) ')
            if inp in ['y', 'n', '']:
                break
        if inp == 'y':
            print('Resetting all kernels in kernel object')
            self.instr_dist_kern.get_instr().reset_kernels()
            self._iteration = 0
        else:
            # Continue working with current kernel; determine how many filters
            # already exist.
            self._iteration = self.instr_dist_kern.get_instr(
            ).get_first_empty_filter()
            print('Starting from iteration {}'.format(self._iteration))

        # 1. Measure trace and plot
        self.measure_trace()

        # Set up initial plot range
        self._t_start_loop = 0
        self._t_stop_loop = self.time_pts[-1]

        self.plot_trace(self._t_start_loop,
                        self._t_stop_loop,
                        save_y_range=False,
                        nr_plot_pts=self.cfg_nr_plot_points())

        # LOOP STARTS HERE
        # Default fit model used, high-pass is typically the first model
        self._fit_model_loop = 'high-pass'
        while True:
            print('\n-- Correction number {} --'.format(self._iteration))
            print('Current fit model: {}'.format(self._fit_model_loop))
            # 2. Fit and plot
            repeat = True
            while repeat:
                inp = input('Fit range: ')
                repeat, quit = self._handle_interactive_input(inp, 'any')
                if not quit and not repeat:
                    try:
                        inp = inp.split(' ')
                        fit_start = float(inp[0])
                        fit_stop = float(inp[1])
                    except Exception as e:
                        print('input format: "t_start t_stop"')
                        repeat = True
            if quit:
                # Exit loop
                break

            if self._fit_model_loop == 'exponential':
                self.fit_exp_model(fit_start, fit_stop)
            elif self._fit_model_loop == 'high-pass':
                self.fit_high_pass(fit_start, fit_stop)
            elif self._fit_model_loop == 'spline':
                self.fit_spline(fit_start, fit_stop)

            self.plot_fit(self._t_start_loop,
                          self._t_stop_loop,
                          nr_plot_pts=self.cfg_nr_plot_points())

            repeat = True
            while repeat:
                inp = input('Accept? ([y]/n) ').strip()
                repeat, quit = self._handle_interactive_input(
                    inp, ['y', 'n', ''])
            if quit:
                # Exit loop
                break
            elif inp != 'y' and inp != '':
                # Go back to 2.
                continue

            # Fit was accepted -> save plot
            if self.auto_save_plots:
                self.save_plot('fit_{}.png'.format(self._iteration))

            # 3. Test correction and plot
            # Save last data, in case new distortion is rejected.
            previous_t = self.time_pts
            previous_wave = self.waveform

            print('Testing new correction.')
            self.test_new_kernel()
            self.measure_trace()
            self.plot_trace(self._t_start_loop,
                            self._t_stop_loop,
                            nr_plot_pts=self.cfg_nr_plot_points())

            repeat = True
            while repeat:
                inp = input('Accept? ([y]/n) ').strip()
                repeat, quit = self._handle_interactive_input(
                    inp, ['y', 'n', ''])
            if quit:
                # Exit loop
                break
            elif inp != 'y' and inp != '':
                print('Discarding new correction.')
                self.discard_new_kernel()
                self.time_pts = previous_t
                self.waveform = previous_wave
                self.plot_trace(self._t_start_loop,
                                self._t_stop_loop,
                                nr_plot_pts=self.cfg_nr_plot_points())
                # Go back to 2.
                continue

            # Correction was accepted -> save plot
            if self.auto_save_plots:
                self.save_plot('trace_{}.png'.format(self._iteration))

            # 4. Apply correction
            print('Applying new correction.')
            self.apply_new_kernel()

    def _handle_interactive_input(self, inp, valid_inputs):
        '''
        Handles input from user in an interactive loop session. Takes
        action in special cases.
        Args:
            inp (string):   Input given by the user.
            valid_inputs (list of strings or 'any'):
                            List of inputs that are accepted. Any input is
                            accepted if this is 'any'.

        Returns:
            repeat (bool):  Should the input prompt be repeated.
            quit (bool):    Should the loop be exited.
        '''
        repeat = True
        quit = False

        inp_elements = inp.split(' ')
        if (inp_elements[0].lower() == 'xrange' and len(inp_elements) == 3):
            self._t_start_loop = float(inp_elements[1])
            self._t_stop_loop = float(inp_elements[2])
            if len(self.vw.traces) == 4:  # 3 grey lines + 1 data trace
                # Only data plotted
                self.plot_trace(self._t_start_loop,
                                self._t_stop_loop,
                                nr_plot_pts=self.cfg_nr_plot_points())
            else:
                # Fit also plotted
                self.plot_fit(self._t_start_loop,
                              self._t_stop_loop,
                              nr_plot_pts=self.cfg_nr_plot_points())

        elif inp_elements[0] == 'm':
            # Remeasures the trace
            print('Remeasuring trace')
            self.measure_trace()
            self.plot_trace(self._t_start_loop,
                            self._t_stop_loop,
                            save_y_range=False,
                            nr_plot_pts=self.cfg_nr_plot_points())

        elif inp_elements[0] == 'h':
            print(self._loop_helpstring)

        elif inp_elements[0] == 'q':
            self.print_summary()
            quit = True
            repeat = False

        elif inp_elements[0] == 'p':
            if len(inp_elements) == 1:
                try:
                    # for param, val in self.new_kernel_dict['fit'].items():
                    #     print('{} = {}'.format(param, val))
                    print(self.fit_res.best_values)
                except KeyError:
                    print('No fit has been done yet!')
            else:
                self._update_latest_params(json_string=inp[1:])

        elif (inp_elements[0] == 's' and len(inp_elements == 2)):
            self.save_plot('{}.png'.format(inp_elements[1]))
            print('Current plot saved.')

        elif (inp_elements[0] == 'model' and len(inp_elements) == 2):
            if inp_elements[1] in self.known_fit_models:
                self._fit_model_loop = str(inp_elements[1])
                print('Using fit model "{}".'.format(self._fit_model_loop))
            else:
                print('Model "{}" unknown. Please choose from {}.'.format(
                    inp_elements[1], self.known_fit_models))

        elif valid_inputs != 'any':
            if inp not in valid_inputs:
                print('Valid inputs: {}'.format(valid_inputs))
            else:
                repeat = False

        else:
            # Any input ok
            repeat = False

        return repeat, quit

    def _update_latest_params(self, json_string):
        """
        Uses a JSON formatted string to update the parameters of the
        latest fit.

        For each model does the following
            1. update the 'fit' dict
            4. calculate the new "fit"
            5. Plot the new "fit"

        Currently only supported for the high-pass and exponential model.
        """
        try:
            par_dict = json.loads(json_string)
        except Exception as e:
            print(e)
            return

        # 1. update the 'fit' dict
        self.fit_res.best_values.update(par_dict)
        self.fitted_waveform = self.fit_res.eval(
            t=self.time_pts[self._start_idx:self._stop_idx],
            tau=self.fit_res.best_values['tau'])

        if self._fit_model_loop == 'high-pass':
            self.predicted_waveform = kf.bias_tee_correction(
                self.waveform,
                tau=self.fit_res.best_values['tau'],
                sampling_rate=self.scope_sampling_rate)

        elif self._fit_model_loop == 'exponential':
            self.predicted_waveform = kf.exponential_decay_correction(
                self.waveform,
                tau=self.fit_res.best_values['tau'],
                amp=self.fit_res.best_values['amp'],
                sampling_rate=self.scope_sampling_rate)

        # The fit results still have to be updated
        self.plot_fit(self._t_start_loop,
                      self._t_stop_loop,
                      nr_plot_pts=self.cfg_nr_plot_points())

    def print_summary(self):
        '''
        Prints a summary of all corrections that have been applied.
        '''
        self.instr_dist_kern.get_instr().print_overview()

    def _set_square_amp(self, square_amp: float):
        old_square_amp = self.square_amp
        self.square_amp = square_amp
        if len(self.waveform) > 0:
            self.waveform = self.waveform * old_square_amp / self.square_amp
        self.plot_trace(self._t_start_loop,
                        self._t_stop_loop,
                        nr_plot_pts=self.cfg_nr_plot_points())
        print('Updated square amp from {} to {}'.format(
            old_square_amp, square_amp))