示例#1
0
def test_map(engine):
    plotter.set_engine(engine)
    syst = syst_2d()

    with tempfile.NamedTemporaryFile(
            'w+b', suffix=plotter_file_suffix(engine)) as out:
        out_filename = out.name
        plotter.map(syst,
                    lambda site: site.tag[0],
                    pos_transform=good_transform,
                    file=out_filename,
                    method='linear',
                    a=4,
                    oversampling=4,
                    cmap='flag',
                    show=False)
        pytest.raises(ValueError,
                      plotter.map,
                      syst,
                      lambda site: site.tag[0],
                      pos_transform=bad_transform,
                      file=out_filename)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            plotter.map(syst.finalized(),
                        range(len(syst.sites())),
                        file=out_filename,
                        show=False)
        pytest.raises(ValueError,
                      plotter.map,
                      syst,
                      range(len(syst.sites())),
                      file=out_filename)
示例#2
0
def test_plotly_plot():
    plotter.set_engine('plotly')
    plot = plotter.plot
    syst2d = syst_2d()
    syst3d = syst_3d()
    color_opts = ['black', (lambda site: site.tag[0]),
                  lambda site: (abs(site.tag[0] / 100),
                                abs(site.tag[1] / 100), 0)]
    engine = plotter.get_engine()
    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
        out_filename = out.name
        for color in color_opts:
            for syst in (syst2d, syst3d):
                plot(syst, site_color=color, cmap='binary', file=out_filename, show=False)

        color_opts = ['black', (lambda site, site2: site.tag[0]),
                      lambda site, site2: (abs(site.tag[0] / 100),
                                           abs(site.tag[1] / 100), 0)]

        syst2d.leads = []
        plot(syst2d, file=out_filename, show=False)
        del syst2d[list(syst2d.hoppings())]
        plot(syst2d, file=out_filename, show=False)

        plot(syst3d, file=out_filename, show=False)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            plot(syst2d.finalized(), file=out_filename, show=False)

        # test 2D projections of 3D systems
        plot(syst3d, file=out_filename, pos_transform=lambda pos: pos[:2], show=False)
示例#3
0
def test_matplotlib_plot():

    plotter.set_engine('matplotlib')
    plot = plotter.plot
    syst2d = syst_2d()
    syst3d = syst_3d()
    color_opts = [
        'k', (lambda site: site.tag[0]), lambda site:
        (abs(site.tag[0] / 100), abs(site.tag[1] / 100), 0)
    ]
    engine = plotter.get_engine()
    with tempfile.NamedTemporaryFile(
            'w+b', suffix=plotter_file_suffix(engine)) as out:
        out_filename = out.name
        for color in color_opts:
            for syst in (syst2d, syst3d):
                fig = plot(syst,
                           site_color=color,
                           cmap='binary',
                           file=out_filename)
                if (color != 'k' and isinstance(
                        color(next(iter(syst2d.sites()))), float)):
                    assert fig.axes[0].collections[0].get_array() is not None
                assert len(fig.axes[0].collections) == 6
        color_opts = [
            'k', (lambda site, site2: site.tag[0]), lambda site, site2:
            (abs(site.tag[0] / 100), abs(site.tag[1] / 100), 0)
        ]
        for color in color_opts:
            for syst in (syst2d, syst3d):
                fig = plot(syst2d,
                           hop_color=color,
                           cmap='binary',
                           file=out_filename,
                           fig_size=(2, 10),
                           dpi=30)
                if color != 'k' and isinstance(
                        color(next(iter(syst2d.sites())), None), float):
                    assert fig.axes[0].collections[1].get_array() is not None

        assert isinstance(
            plot(syst3d, file=out_filename).axes[0], mplot3d.axes3d.Axes3D)

        syst2d.leads = []
        plot(syst2d, file=out_filename)
        del syst2d[list(syst2d.hoppings())]
        plot(syst2d, file=out_filename)

        plot(syst3d, file=out_filename)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            plot(syst2d.finalized(), file=out_filename)

        # test 2D projections of 3D systems
        plot(syst3d, file=out_filename, pos_transform=lambda pos: pos[:2])
示例#4
0
def test_spectrum(engine):

    plotter.set_engine(engine)

    def ham_1d(a, b, c):
        return a**2 + b**2 + c**2

    def ham_2d(a, b, c):
        return np.eye(2) * (a**2 + b**2 + c**2)

    lat = kwant.lattice.chain(norbs=1)
    syst = kwant.Builder()
    syst[(lat(i) for i in range(3))] = lambda site, a, b: a + b
    syst[lat.neighbors()] = lambda site1, site2, c: c
    fsyst = syst.finalized()

    vals = np.linspace(0, 1, 3)

    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
        out_filename = out.name

        for ham in (ham_1d, ham_2d, fsyst):
            plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), file=out_filename, show=False)
            if engine == 'matplotlib':
                # test with explicit figsize
                plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1),
                                 fig_size=(10, 10), file=out_filename, show=False)

        for ham in (ham_1d, ham_2d, fsyst):
            plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
                             params=dict(c=1), file=out_filename, show=False)
            if engine == 'matplotlib':
                # test with explicit figsize
                plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
                                 params=dict(c=1), fig_size=(10, 10), file=out_filename, show=False)

        if engine == 'matplotlib':
            # test 2D plot and explicitly passing axis
            fig = pyplot.figure()
            ax = fig.add_subplot(1, 1, 1, projection='3d')
            plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
                             params=dict(c=1), ax=ax, file=out_filename, show=False)
            # explicitly pass axis without 3D support
            ax = fig.add_subplot(1, 1, 1)
            with pytest.raises(TypeError):
                plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
                                 params=dict(c=1), ax=ax, file=out_filename, show=False)

    def mask(a, b):
        return a > 0.5

    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
        out_filename = out.name
        plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), params=dict(c=1),
                         mask=mask, file=out_filename, show=False)
示例#5
0
def test_plot_raises_on_bad_site_spec(engine):
    syst = kwant.Builder()
    lat = kwant.lattice.square(norbs=1)
    syst[(lat(i, j) for i in range(5) for j in range(5))] = None

    # Cannot provide site_size as an array when syst is a Builder
    plotter.set_engine(engine)
    with pytest.raises(TypeError):
        plotter.plot(syst, site_size=[1] * 25)

    # Cannot provide site_size as an array when syst is a Builder
    with pytest.raises(TypeError):
        plotter.plot(syst, site_symbol=['o'] * 25)
示例#6
0
def test_current():
    plotter.set_engine('matplotlib')
    syst = syst_2d().finalized()
    J = kwant.operator.Current(syst)
    current = J(kwant.wave_function(syst, energy=1)(1)[0])

    # Test good codepath
    with tempfile.NamedTemporaryFile('w+b') as out:
        plotter.current(syst, current, file=out)

        fig = pyplot.Figure()
        ax = fig.add_subplot(1, 1, 1)
        plotter.current(syst, current, ax=ax, file=out)
示例#7
0
def test_plot_more_site_families_than_colors(engine):
    # test against regression reported in
    # https://gitlab.kwant-project.org/kwant/kwant/issues/257
    ncolors = len(pyplot.rcParams['axes.prop_cycle'])
    syst = kwant.Builder()
    lattices = [kwant.lattice.square(name=i, norbs=1)
                for i in range(ncolors + 1)]
    for i, lat in enumerate(lattices):
        syst[lat(i, 0)] = None

    plotter.set_engine(engine)
    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
        out_filename = out.name
        plotter.plot(syst, file=out_filename, show=False)
示例#8
0
def test_bands(engine):

    plotter.set_engine(engine)

    syst = syst_2d().finalized().leads[0]

    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
        out_filename = out.name
        plotter.bands(syst, show=False, file=out_filename)
        plotter.bands(syst, show=False, momenta=np.linspace(0, 2 * np.pi), file=out_filename)

        if engine == 'matplotlib':
            plotter.bands(syst, show=False, fig_size=(10, 10), file=out_filename)

            fig = pyplot.Figure()
            ax = fig.add_subplot(1, 1, 1)
            plotter.bands(syst, show=False, ax=ax, file=out_filename)