示例#1
0
def test_dataset_2ddata_loading(nt_dataset_doubledot, tmp_path):
    ds = Dataset(1, db_name="temp.db", db_folder=str(tmp_path))

    assert ds.exp_id == 1
    assert ds.dimensions['dc_current'] == 2
    assert ds.dimensions['dc_sensor'] == 2
    assert len(ds.raw_data) == 2
    assert len(ds.data) == 2

    ds_vx = ds.raw_data['current']['v_x'].values
    ds_vy = ds.raw_data['current']['v_y'].values
    ds_sig = ds.raw_data['current'].values
    ds_sens = ds.raw_data['sensor'].values

    xv, yv, ddot, sensor = generate_doubledot_data()
    x = np.unique(xv)
    y = np.unique(yv)

    assert np.allclose(ds_vx, x)
    assert np.allclose(ds_vy, y)
    assert np.allclose(ds_sig, ddot.T)
    assert np.allclose(ds_sens, sensor.T)

    assert ds.get_plot_label('dc_current', 0) == "voltage x [V]"
    assert ds.get_plot_label('dc_sensor', 0) == "voltage x [V]"
    assert ds.get_plot_label('dc_current', 1) == "voltage y [V]"
    assert ds.get_plot_label('dc_sensor', 1) == "voltage y [V]"
    assert ds.get_plot_label('dc_current', 2) == "dc current [A]"
    assert ds.get_plot_label('dc_sensor', 2) == "dc sensor [A]"
示例#2
0
def test_dataset_1ddata_loading(nt_dataset_pinchoff, tmp_path):
    ds = Dataset(1, db_name="temp.db", db_folder=str(tmp_path))

    assert ds.exp_id == 1
    assert ds.dimensions['dc_current'] == 1
    assert ds.dimensions['dc_sensor'] == 1

    assert len(ds.raw_data) == 2
    assert len(ds.data) == 2

    vx = np.linspace(-0.1, 0, 120)
    ds_vx = ds.raw_data['current']['voltage'].values
    assert np.allclose(ds_vx, vx)

    ds_sig = ds.raw_data['current'].values
    sig = 0.6 * (1 + np.tanh(1000 * vx + 50))
    assert np.allclose(ds_sig, sig)

    assert ds.get_plot_label('dc_current', 0) == "voltage x [V]"
    assert ds.get_plot_label('dc_sensor', 0) == "voltage x [V]"
    assert ds.get_plot_label('dc_current', 1) == "dc current [A]"
    assert ds.get_plot_label('dc_sensor', 1) == "dc sensor [A]"
    with pytest.raises(AssertionError):
        ds.get_plot_label('dc_sensor', 2)
示例#3
0
def plot_dataset(
    qc_run_id: int,
    db_name: str,
    save_figures: bool = True,
    db_folder: Optional[str] = None,
    plot_filtered_data: bool = False,
    plot_params: Optional[plot_params_type] = None,
    ax: Optional[matplotlib.axes.Axes] = None,
    colorbar: Optional[matplotlib.colorbar.Colorbar] = None,
    filename: Optional[str] = None,
    file_location: Optional[str] = None,
) -> AxesTuple:
    """
    If to be saved and no file location specified, the figure will be saved at
    os.path.join(nt.config['db_folder'], 'tuning_results', dataset.device_name)
    in both eps and png
    """

    if plot_params is None:
        plot_params = default_plot_params
    matplotlib.rcParams.update(plot_params)
    if db_folder is None:
        _, db_folder = nt.get_database()

    dataset = Dataset(qc_run_id, db_name, db_folder=db_folder)

    if plot_filtered_data:
        data = dataset.filtered_data
    else:
        data = dataset.data

    if ax is None:
        fig_size = copy.deepcopy(plot_params["figure.figsize"])
        fig_size[1] *= len(dataset.data) * 0.8  # type: ignore
        fig, ax = plt.subplots(
            len(dataset.data),
            1,
            squeeze=False,
            figsize=fig_size,
        )

        colorbars: List[matplotlib.colorbar.Colorbar] = []

    fig_title = dataset.guid

    for r_i, read_meth in enumerate(dataset.readout_methods):
        c_name = default_coord_names['voltage'][0]
        voltage_x = data[read_meth][c_name].values
        signal = data[read_meth].values.T

        if dataset.dimensions[read_meth] == 1:
            colorbar = None
            ax[r_i, 0].plot(
                voltage_x,
                signal,
                zorder=6,
            )
            ax[r_i, 0].set_xlabel(dataset.get_plot_label(read_meth, 0))
            ax[r_i, 0].set_ylabel(dataset.get_plot_label(read_meth, 1))
            ax[r_i, 0].set_title(str(fig_title))

            divider = make_axes_locatable(ax[r_i, 0])
            cbar_ax = divider.append_axes("right", size="5%", pad=0.06)
            cbar_ax.set_facecolor("none")
            for caxis in ["top", "bottom", "left", "right"]:
                cbar_ax.spines[caxis].set_linewidth(0)
            cbar_ax.set_xticks([])
            cbar_ax.set_yticks([])
            colorbars.append(colorbars)

            ax[r_i, 0].figure.tight_layout()

        elif dataset.dimensions[read_meth] == 2:
            c_name = default_coord_names['voltage'][1]
            voltage_y = data[read_meth][c_name].values
            colormesh = ax[r_i, 0].pcolormesh(
                voltage_x,
                voltage_y,
                signal,
                shading="auto",
            )

            if colorbar is not None:
                colorbars.append(ax[r_i, 0].figure.colorbar(colormesh,
                                                            ax=ax[r_i, 0],
                                                            cax=colorbar.ax))
            else:
                # colorbar = fig.colorbar(colormesh, ax=ax[r_i, 0])
                divider = make_axes_locatable(ax[r_i, 0])
                cbar_ax = divider.append_axes("right", size="5%", pad=0.06)
                colorbars.append(
                    fig.colorbar(
                        colormesh,
                        ax=ax[r_i, 0],
                        cax=cbar_ax,
                    ))
            colorbars[-1].set_label(
                dataset.get_plot_label(read_meth, 2),
                rotation=-270,
            )

            ax[r_i, 0].set_xlabel(dataset.get_plot_label(read_meth, 0))
            ax[r_i, 0].set_ylabel(dataset.get_plot_label(read_meth, 1))
            ax[r_i, 0].set_title(str(fig_title))

            ax[r_i, 0].figure.tight_layout()

        else:
            raise NotImplementedError

    if save_figures:
        if file_location is None:
            file_location = os.path.join(nt.config["db_folder"],
                                         "tuning_results", dataset.device_name)
        if not os.path.exists(file_location):
            os.makedirs(file_location)

        if filename is None:
            filename = "dataset_" + str(dataset.guid)
        else:
            filename = os.path.splitext(filename)[0]

        path = os.path.join(file_location, filename + ".png")
        plt.savefig(path, format="png", dpi=600, bbox_inches="tight")
    return ax, colorbars