コード例 #1
0
def calc_bands(glat, kxy=None, density=100, verbose=True):
    """Compute the band eigenvalues at an unstructured array of kxy points.

    Parameters
    ----------
    glat : GyroLattice class instance
        the gyro lattice on which to compute the bands
    kxy : N x 2 float array or None
        the wavevectors (each row is a wavevector) at which to compute the spectrum. If None, will generate a random
        collection in the BZ
    density : int
        the number of evaluation points per unit area of the BZ
    verbose : bool
        print information to command line

    Returns
    -------
    omegas : #bands x len(kxy) float array
        the eigenvalues of the dynamical matrix in descending(?) order
    kx, ky :
    """
    matk = lambda_matrix_kspace(glat, eps=1e-10)
    bzvtcs = glat.lattice.get_bz(attribute=True)
    bzarea = dh.polygon_area(bzvtcs)

    # Make the kx, ky points to add to current results
    if kxy is None:
        kxy = dh.generate_random_xy_in_polygon(density * bzarea,
                                               bzvtcs,
                                               sorted=True)

    start_time = time.time()
    bands = []
    for ii in range(len(kxy)):
        # time for evaluating point index ii
        tpi = []
        # Display how much time is left
        if ii % 4000 == 1999 and verbose:
            end_time = time.time()
            tpi.append(abs((start_time - end_time) / (ii + 1)))
            total_time = np.mean(tpi) * len(kxy)
            printstr = 'Estimated time remaining: ' + '%0.2f s' % (
                total_time - (end_time - start_time))
            printstr += ', ii = ' + str(ii + 1) + '/' + str(len(kxy))
            print printstr

        matii = matk([kxy[ii, 0], kxy[ii, 1]])
        eigval, eigvect = le.eig_vals_vects(matii)
        bands.append(np.imag(eigval))

    omegas = np.array(bands)
    kx, ky = kxy[:, 0], kxy[:, 1]
    return omegas, kx, ky
コード例 #2
0
def c_3D_plot(x, y, z, vertex_points, ax, v_min, v_max):
    """

    Parameters
    ----------
    x
    y
    z
    vertex_points
    ax
    v_min
    v_max

    Returns
    -------

    """
    # Create the Triangulation; no triangles so Delaunay triangulation created.
    x = np.array(x)
    y = np.array(y)
    z = np.array(z)

    triang = mtri.Triangulation(x, y)

    min_radius = 1
    # Mask off unwanted triangles.
    xmid = x[triang.triangles]
    ymid = y[triang.triangles]

    vp = np.array([np.array([xmid[i], ymid[i]]).T for i in range(len(xmid))])
    areas = np.array([dh.polygon_area(vp[i]) for i in range(len(vp))])

    poly_path = Path(vertex_points)

    xmid = x[triang.triangles].mean(axis=1)
    ymid = y[triang.triangles].mean(axis=1)

    mask = np.array([
        not poly_path.contains_point([xmid[i], ymid[i]])
        for i in range(len(xmid))
    ])
    triang.set_mask(mask)
    triang.set_mask(mask)
    ax.plot_trisurf(triang,
                    z,
                    cmap='cool',
                    vmin=v_min,
                    vmax=v_max,
                    linewidth=0.0,
                    alpha=1)
コード例 #3
0
def calc_chern(kchern, verbose=True):
    """Compute the chern number using the wedge product of the projector

    Parameters
    ----------
    kchern : KChernGyro class instance
        kspace Chern calculation class

    Returns
    -------

    """
    # Unpack kchern a bit
    cp = kchern.cp
    # this is really the only function you need to change to do a different kind of lattice.
    # mM, vertex_points, od, fn, ar = lf.honeycomb_sheared(tvals, delta, phi, ons, base_dir)
    matk = glatkfns.lambda_matrix_kspace(kchern.gyro_lattice, eps=1e-10)
    bzvtcs = kchern.gyro_lattice.lattice.get_bz(attribute=True)

    # set the directory for saving images and data.  The function above actually creates these directories for you.
    #  I know this isn't exaclty ideal.
    bzarea = dh.polygon_area(bzvtcs)

    # Check if a similar results file exists --> new data is appended to the old data, if it exists
    savedir = cp['cpmeshfn']
    dio.ensure_dir(savedir)
    savefn = savedir + 'kspacechern_density{0:07d}.pkl'.format(cp['density'])
    globs = sorted(glob.glob(savedir + 'kspacechern_density*.pkl'))
    # Obtain the filename with the density that is smaller than the requested one
    if globs:
        densities = np.array([
            int(globfn.split('density')[-1].split('.pkl')[0])
            for globfn in globs
        ])
        if (densities < cp['density']).any():
            biggestsmall = densities[densities < cp['density']][-1]
            smallerfn = savedir + 'kspacechern_density{0:07d}.pkl'.format(
                biggestsmall)
        else:
            smallerfn = False
    else:
        smallerfn = False

    if verbose:
        print 'kchern_gyro_fns: looking for saved file ' + savefn
    if os.path.isfile(savefn):
        if verbose:
            print 'kchern_gyro_fns: chern file exists, overwriting: ', savefn
        # initialize empty lists to be filled up
        kkx, kky = [], []
        bands, traces = [], []
        npts = int(cp['density'] * bzarea)
    elif smallerfn:
        if verbose:
            print 'kchern_gyro_fns: chern file with smaller density exists, loading to append...'
        with open(smallerfn, 'rb') as of:
            data = pkl.load(of)
            # Convert contents to lists so that we can append to them
            bands = list(data['bands'])
            kkx = list(data['kx'])
            kky = list(data['ky'])
            traces = list(data['traces'])

        # figure out how many points to append to reach desired density
        npts = int(cp['density'] * bzarea) - len(kkx)
    else:
        # initialize empty lists to be filled up
        kkx, kky = [], []
        bands, traces = [], []
        npts = int(cp['density'] * bzarea)

    # Make the kx, ky points to add to current results
    kxy = dh.generate_random_xy_in_polygon(npts, bzvtcs, sorted=True)

    start_time = time.time()
    for ii in range(len(kxy)):
        # time for evaluating point index ii
        tpi = []
        # Display how much time is left
        if ii % 4000 == 1999 and verbose:
            end_time = time.time()
            tpi.append(abs((start_time - end_time) / (ii + 1)))
            total_time = np.mean(tpi) * len(kxy)
            printstr = 'Estimated time remaining: ' + '%0.2f s' % (
                total_time - (end_time - start_time))
            printstr += ', ii = ' + str(ii + 1) + '/' + str(len(kxy))
            print printstr

        eigval, tr = calc_bands(matk, kxy[ii, 0], kxy[ii, 1])
        kkx.append(kxy[ii, 0])
        kky.append(kxy[ii, 1])
        traces.append(tr)
        bands.append(eigval)

    bands = np.array(bands)
    traces = np.array(traces)
    kkx = np.array(kkx)
    kky = np.array(kky)

    bv = []
    if verbose:
        print 'kchern_gyro_fns: np.shape(bands) = ', np.shape(bands)
        print 'kchern_gyro_fns: np.shape(traces) = ', np.shape(traces)
    for ii in range(len(bands[0])):
        chern = ((1j / (2. * np.pi)) * np.mean(traces[:, ii]) * bzarea
                 )  # chern numbers for bands
        bv.append(chern)

    dat_dict = {
        'kx': kkx,
        'ky': kky,
        'chern': bv,
        'bands': bands,
        'traces': traces,
        'bzvtcs': bzvtcs
    }
    kchern.chern = dat_dict

    return kchern.chern
コード例 #4
0
def calc_chern(matk,
               cp,
               lattice,
               appendstr=None,
               verbose=True,
               bz_cutoff=1e14,
               kxy=None,
               overwrite=False,
               signed_norm=False):
    """Compute the chern number using the wedge product of the projector

    Parameters
    ----------
    kchern : KChernGyro class instance
        kspace Chern calculation class
    bz_cutoff : float
        if BZ has a dimension larger than bz_cutoff, then we return zero chern number and no bands

    Returns
    -------
    kchern.chern
    """
    # this is really the only function you need to change to do a different kind of lattice.
    # mM, vertex_points, od, fn, ar = lf.honeycomb_sheared(tvals, delta, phi, ons, base_dir)
    # matk = glatkfns.lambda_matrix_kspace(kchern.gyro_lattice, eps=1e-10)
    bzvtcs = lattice.get_bz(attribute=True)

    # Define the brillouin zone polygon
    bzarea = dh.polygon_area(bzvtcs)

    # Check if a similar results file exists --> new data is appended to the old data, if it exists
    savedir = cp['cpmeshfn']
    dio.ensure_dir(savedir)
    savefn, cp, search_density_fn = prepare_chern_filename(cp,
                                                           appendstr=appendstr)
    print 'established savefn: ', cp['savefn']

    if kxy is None:
        globs = sorted(glob.glob(search_density_fn))
        # Obtain the filename with the density that is smaller than the requested one
        if globs and not overwrite:
            densities = np.array([
                int(
                    globfn.split('density')[-1].split('.pkl')[0].split('_')[0])
                for globfn in globs
            ])
            if (densities < cp['density']).any():
                biggestsmall = densities[densities < cp['density']][-1]
                smallerfn = savedir + 'kspacechern_density{0:07d}'.format(
                    biggestsmall)
                if cp['ortho']:
                    smallerfn += '_ortho'
                if 'deriv_res' in cp:
                    if cp['deriv_res'] != 1e-5:
                        smallerfn += '_dres{0:0.3e}'.format(
                            cp['deriv_res']).replace('-',
                                                     'n').replace('.', 'p')

                smallerfn += appendstr
                smallerfn += '.pkl'
            else:
                smallerfn = False
        else:
            smallerfn = False

        if verbose:
            print 'kchern_gyro_fns: looking for saved file ' + savefn
        if os.path.isfile(savefn) and overwrite:
            if verbose:
                print 'kchern_gyro_fns: chern file exists, overwriting: ', savefn
            # initialize empty lists to be filled up
            kkx, kky = [], []
            bands, traces = [], []
            npts = int(cp['density'] * bzarea)
        elif os.path.isfile(savefn):
            # File of same size exists
            chern = load_chern(cp, appendstr=appendstr, verbose=True)
            print 'generalized_kchern_fns: loaded chern, returning chern'
            return chern
        elif smallerfn:
            if verbose:
                print 'kchern_gyro_fns: chern file with smaller density exists, loading to append...'
            with open(smallerfn, 'rb') as of:
                data = pkl.load(of)
                # Convert contents to lists so that we can append to them
                bands = list(data['bands'])
                kkx = list(data['kx'])
                kky = list(data['ky'])
                traces = list(data['traces'])

            # figure out how many points to append to reach desired density
            npts = int(cp['density'] * bzarea) - len(kkx)
        else:
            if verbose:
                print 'generalized_kchern_fns: no chern file with smaller density exists, computing from scratch...'
            # initialize empty lists to be filled up
            kkx, kky = [], []
            bands, traces = [], []
            npts = int(cp['density'] * bzarea)

        # Make the kx, ky points to add to current results
        # Handle cases where the BZ is too elongated to populate in reasonable time
        try:
            kxy = dh.generate_random_xy_in_polygon(npts, bzvtcs, sorted=True)
            if len(kxy) == 0:
                if (bzvtcs > bz_cutoff).any():
                    print 'The BZ is too large to fill with kvec points'
                    kkx, kky = np.zeros(10), np.zeros(10)
                    bv = np.zeros(2 * len(lattice.xy[:, 0]))
                    bands = np.nan * np.ones(2 * len(lattice.xy[:, 0]))
                    traces = np.nan * np.ones((2 * len(lattice.xy[:, 0]), 10))
                    dat_dict = {
                        'kx': kkx,
                        'ky': kky,
                        'chern': bv,
                        'bands': bands,
                        'traces': traces,
                        'bzvtcs': bzvtcs
                    }
                    chern = dat_dict
                    return chern
                else:
                    print 'generalized_kchern_fns: bzvtcs = ', bzvtcs
                    print 'generalized_kchern_fns: kxy = ', kxy
                    raise RuntimeError(
                        'Could not generate random xy in polygon, but BZ is not larger than cutoff in any dim.'
                    )
        except ValueError:
            print 'The BZ is too large to fill with kvec points'
            if (np.abs(bzvtcs) > bz_cutoff).any():
                print 'The BZ is too large to fill with kvec points'
                kkx, kky = np.zeros(10), np.zeros(10)
                bv = np.zeros(2 * len(lattice.xy[:, 0]))
                bands = np.nan * np.ones(2 * len(lattice.xy[:, 0]))
                traces = np.nan * np.ones((2 * len(lattice.xy[:, 0]), 10))
                dat_dict = {
                    'kx': kkx,
                    'ky': kky,
                    'chern': bv,
                    'bands': bands,
                    'traces': traces,
                    'bzvtcs': bzvtcs
                }
                chern = dat_dict
                return chern
            else:
                print 'kchern_gyro_fns: bzvtcs = ', bzvtcs
                print 'bz_cutoff = ', bz_cutoff
                # print 'generalized_kchern_fns: kxy = ', kxy
                raise RuntimeError(
                    'Could not generate random xy in polygon, but BZ is not larger than cutoff in any dim.'
                )
                dat_dict = {
                    'kx': kkx,
                    'ky': kky,
                    'chern': bv,
                    'bands': bands,
                    'traces': traces,
                    'bzvtcs': bzvtcs
                }
    else:
        # kxy is supplied -- compute at those locations
        print 'kxy is supplied, computing berry at supplied points'
        kkx, kky = [], []
        bands, traces = [], []

    start_time = time.time()
    for ii in range(len(kxy)):
        # time for evaluating point index ii
        tpi = []
        # Display how much time is left
        if ii % 4000 == 1999 and verbose:
            end_time = time.time()
            tpi.append(abs((start_time - end_time) / (ii + 1)))
            total_time = np.mean(tpi) * len(kxy)
            printstr = 'Estimated time remaining: ' + '%0.2f s' % (
                total_time - (end_time - start_time))
            printstr += ', ii = ' + str(ii + 1) + '/' + str(len(kxy))
            print printstr

        eigval, tr = calc_bands(matk,
                                kxy[ii, 0],
                                kxy[ii, 1],
                                h=cp['deriv_res'],
                                ortho=cp['ortho'],
                                signed_norm=signed_norm)
        kkx.append(kxy[ii, 0])
        kky.append(kxy[ii, 1])
        traces.append(tr)
        bands.append(eigval)

    bands = np.array(bands)
    traces = np.array(traces)
    kkx = np.array(kkx)
    kky = np.array(kky)

    bv = []
    if verbose:
        print 'kchern_gyro_fns: np.shape(bands) = ', np.shape(bands)
        print 'kchern_gyro_fns: np.shape(traces) = ', np.shape(traces)
    for ii in range(len(bands[0])):
        chern = ((1j / (2. * np.pi)) * np.mean(traces[:, ii]) * bzarea
                 )  # chern numbers for bands
        bv.append(chern)

    dat_dict = {
        'kx': kkx,
        'ky': kky,
        'chern': bv,
        'bands': bands,
        'traces': traces,
        'bzvtcs': bzvtcs
    }
    chern = dat_dict

    return chern
コード例 #5
0
def save_plot(kx, ky, bands, traces, tvals, ons, vertex_points, od):
    """

    Parameters
    ----------
    kx
    ky
    bands
    traces
    tvals
    ons
    vertex_points
    od

    Returns
    -------

    """
    bz_area = dh.polygon_area(vertex_points)

    w = 180
    h = 140
    label_params = dict(size=12, fontweight='normal')

    s = 2
    h1 = (h - 3 * s) / 2.
    w1 = 0.75 * w

    fig = sps.figure_in_mm(w, h)
    ax1 = sps.axes_in_mm(0,
                         h - h1,
                         w1,
                         h1,
                         label=None,
                         label_params=label_params,
                         projection='3d')
    ax2 = sps.axes_in_mm(0,
                         h - 2 * h1,
                         w1,
                         h1,
                         label=None,
                         label_params=label_params,
                         projection='3d')
    ax3 = sps.axes_in_mm(1 * s + w1,
                         h - 2 * h1,
                         w - w1 - 3 * s,
                         2 * h1,
                         label=None,
                         label_params=label_params)
    ax_lat = sps.axes_in_mm(1 * s + w1,
                            h - 0.4 * h1,
                            w - w1 - 3 * s,
                            0.4 * h1,
                            label=None,
                            label_params=label_params)

    c1 = '#000000'
    c2 = '#B9B9B9'
    c3 = '#E51B1B'
    c4 = '#18CFCF'
    c5 = 'violet'
    c6 = 'k'

    cc1 = '#96adb4'
    cc2 = '#b16566'
    cc3 = '#665656'
    cc4 = '#9d96b4'

    pin = 1
    cols = [c1, c2, c3, c4, c5, c6]
    ccols = [cc1, cc2, cc3, cc4, c5]
    l_cols = [cc1, cc2, cc3, cc4, c5]
    ax3.axis([0, 1, 0, 1.])
    print len(ons)

    xp = [0.025, 0.525, 0.025, .525]
    yp = [0.65, 0.65, 0.6, 0.6]
    min_l = min([len(xp), len(ons)])
    for i in range(len(tvals)):
        ax3.text(xp[i],
                 yp[i] + 0.1,
                 '$k_%1d$' % (i + 1) + '= $%.2f$' % tvals[i],
                 style='italic',
                 color=ccols[i],
                 fontsize=12)
    for i in range(min_l):
        ax3.text(xp[i],
                 yp[i],
                 '$M_%1d$' % (i + 1) + '$= %.2f$' % ons[i],
                 style='italic',
                 color=cols[i],
                 fontsize=12)
    ax3.axis('off')
    ax_lat.axis('off')

    nb = len(bands[0])
    b = []
    band_gaps = np.zeros(nb - 1)

    ax3.text(0.025,
             .55,
             '$\mathrm{Chern\/ numbers}$',
             fontweight='bold',
             color='k',
             fontsize=12)
    ax3.text(0.025,
             .35,
             '$\mathrm{Band\/ boundaries}$',
             fontweight='bold',
             color='k',
             fontsize=12)
    ax3.text(0.025,
             .15,
             '$\mathrm{Min.\/differences}$',
             fontweight='bold',
             color='k',
             fontsize=12)

    for i in range(int(nb)):
        j = i
        bv = ((1j / (2 * np.pi)) * np.mean(traces[:, i]) * bz_area)

        if j >= nb / 2:
            ax1.plot_trisurf(kx,
                             ky,
                             bands[:, j],
                             cmap='cool',
                             vmin=min(abs(bands.flatten())),
                             vmax=max(abs(bands.flatten())),
                             linewidth=0.0,
                             alpha=1)
            ax2.plot_trisurf(kx,
                             ky,
                             bands[:, j],
                             cmap='cool',
                             vmin=min(abs(bands.flatten())),
                             vmax=max(abs(bands.flatten())),
                             linewidth=0.0,
                             alpha=1)

            v_min = min(abs(bands.flatten()))
            v_max = max(abs(bands.flatten()))
            # c_3D_plot(kx, ky, bands[:,j], vertex_points, ax1, v_min, v_max)
            # c_3D_plot(kx, ky, bands[:,j], vertex_points, ax2, v_min, v_max)

            ax1.set_ylabel('ky')
            ax1.axes.get_xaxis().set_ticks([])
            ax2.set_xlabel('kx')
            ax2.axes.get_yaxis().set_ticks([])

            ax1.set_zlabel('$\Omega$')
            ax2.set_zlabel('$\Omega$')

            bv = ((1j / (2 * np.pi)) * np.mean(traces[:, i]) * bz_area)
            if abs(bv) > 0.7:
                cc = 'r'
            else:
                cc = 'k'
            ax3.text(0.025,
                     .55 - (j - nb / 2 + 1) * 0.04,
                     '$%0.2f$' % bv,
                     color=cc,
                     fontsize=10)

            maxtb = max(bands[:, j])
            mintb = min(bands[:, j])
            ax3.text(0.025,
                     .35 - (j - nb / 2 + 1) * 0.04,
                     '$%0.2f$' % mintb,
                     color='k',
                     fontsize=10)
            ax3.text(0.525,
                     .35 - (j - nb / 2 + 1) * 0.04,
                     '$%0.2f$' % maxtb,
                     color='k',
                     fontsize=10)

            if j > nb / 2:
                min_diff = min(bands[:, j] - bands[:, j - 1])
                ax3.text(0.025,
                         .15 - (j - nb / 2) * 0.04,
                         '$%0.2f$' % min_diff,
                         color='k',
                         fontsize=10)

    ax1.view_init(elev=0, azim=0.)
    ax2.view_init(elev=0, azim=90.)

    minx = min(kx)
    maxx = max(kx)
    miny = min(ky)
    maxy = max(ky)

    return ax_lat, cols, l_cols, fig
コード例 #6
0
def make_plot(kkx, kky, bands, traces, tvals, ons, pin, num='all'):
    """One of Lisa's functions for plotting chern results with lattice structure"""
    w = 180
    h = 140
    label_params = dict(size=12, fontweight='normal')

    s = 2
    h1 = (h - 3 * s) / 2.
    w1 = 0.75 * w

    fig = sps.figure_in_mm(w, h)
    ax1 = sps.axes_in_mm(0,
                         h - h1,
                         w1,
                         h1,
                         label=None,
                         label_params=label_params,
                         projection='3d')
    ax2 = sps.axes_in_mm(0,
                         h - 2 * h1,
                         w1,
                         h1,
                         label=None,
                         label_params=label_params,
                         projection='3d')
    ax3 = sps.axes_in_mm(1 * s + w1,
                         h - 2 * h1,
                         w - w1 - 3 * s,
                         2 * h1,
                         label=None,
                         label_params=label_params)
    ax_lat = sps.axes_in_mm(1 * s + w1,
                            h - 0.4 * h1,
                            w - w1 - 3 * s,
                            0.4 * h1,
                            label=None,
                            label_params=label_params)

    c1 = '#000000'
    c2 = '#FFFFFF'
    c3 = '#E51B1B'
    c4 = '#18CFCF'

    cc1 = '#96adb4'
    cc2 = '#b16566'
    cc3 = '#665656'
    cc4 = '#9d96b4'

    pin = 1
    gor = 'random'  # grid or random, lines
    cols = [c1, c2, c3, c4]
    l_cols = [cc1, cc2, cc3, cc4]
    mM, vertex_points, dump_dir = lf.alpha_lattice(tvals,
                                                   pin,
                                                   ons,
                                                   ax=ax_lat,
                                                   col=cols,
                                                   lincol=l_cols)
    ax3.axis([0, 1, 0, 1.])
    ax3.text(0.025,
             .75,
             '$k_1 = %.2f$' % tvals[0],
             style='italic',
             color=cc1,
             fontsize=12)
    ax3.text(0.525,
             .75,
             '$k_2 = %.2f$' % tvals[1],
             style='italic',
             color=cc2,
             fontsize=12)
    ax3.text(0.025,
             .7,
             '$k_3 = %.2f$' % tvals[2],
             style='italic',
             color=cc3,
             fontsize=12)
    ax3.text(0.525,
             .7,
             '$k_4 = %.2f$' % tvals[3],
             style='italic',
             color=cc4,
             fontsize=12)
    # bbox={'facecolor':'red', 'alpha':0.5, 'pad':10}
    ax3.text(0.025,
             .65,
             '$M_1 = %.2f$' % ons[0],
             style='italic',
             color=c1,
             fontsize=12)
    ax3.text(0.525,
             .65,
             '$M_2 = %.2f$' % ons[1],
             style='italic',
             color='#B9B9B9',
             fontsize=12)
    ax3.text(0.025,
             .6,
             '$M_3 = %.2f$' % ons[2],
             style='italic',
             color=c3,
             fontsize=12)
    ax3.text(0.525,
             .6,
             '$M_4 = %.2f$' % ons[3],
             style='italic',
             color=c4,
             fontsize=12)
    # bbox={'facecolor':'red', 'alpha':0.5, 'pad':10}
    ax3.axis('off')
    ax_lat.axis('off')

    nb = len(bands[0])
    b = []
    bz_area = dh.polygon_area(vertex_points)
    band_gaps = np.zeros(nb - 1)
    ax3.text(0.025,
             .55,
             '$\mathrm{Chern\/ numbers}$',
             fontweight='bold',
             color='k',
             fontsize=12)
    ax3.text(0.025,
             .35,
             '$\mathrm{Band\/ boundaries}$',
             fontweight='bold',
             color='k',
             fontsize=12)
    ax3.text(0.025,
             .15,
             '$\mathrm{Min.\/differences}$',
             fontweight='bold',
             color='k',
             fontsize=12)
    for i in range(int(nb)):
        j = i

        # ax.scatter(kkx, kky, 1j*traces[:,j])
        if j >= nb / 2:
            if num == 'all':
                ax1.plot_trisurf(kx[:1000],
                                 ky[:1000],
                                 bands[:, j][:1000],
                                 cmap='cool',
                                 vmin=min(abs(bands.flatten())),
                                 vmax=max(abs(bands.flatten())),
                                 linewidth=0.0)
                # ax2.plot_trisurf(kx, ky, bands[:,j], cmap = 'cool', vmin = min(abs(bands.flatten())),
                # vmax=max(abs(bands.flatten())), linewidth = 0.0)
            elif j == num:
                ax2.plot_trisurf(kx[:1000],
                                 ky[:1000],
                                 1j * traces[:, j][:1000],
                                 cmap='cool',
                                 vmin=min(abs(traces[:, j].flatten())),
                                 vmax=max(abs(traces[:, j].flatten())),
                                 linewidth=0.0)

            ax1.set_ylabel('ky')
            ax1.axes.get_xaxis().set_ticks([])
            ax2.set_xlabel('kx')
            ax2.axes.get_yaxis().set_ticks([])

            ax1.set_zlabel('$\Omega$')
            ax2.set_zlabel('$\Omega$')
            bv = ((1j / (2 * np.pi)) * np.mean(traces[:, i]) * bz_area)
            if abs(bv) > 0.7:
                cc = 'r'
            else:
                cc = 'k'
            if j == num:
                ax3.text(0.025,
                         .55 - (j - nb / 2 + 1) * 0.04,
                         '$%0.2f$' % bv,
                         color=cc,
                         fontsize=10)
            if j >= nb / 2:
                maxtb = max(bands[:, j])
                mintb = min(bands[:, j])
                ax3.text(0.025,
                         .35 - (j - nb / 2 + 1) * 0.04,
                         '$%0.2f$' % mintb,
                         color='k',
                         fontsize=10)
                ax3.text(0.525,
                         .35 - (j - nb / 2 + 1) * 0.04,
                         '$%0.2f$' % maxtb,
                         color='k',
                         fontsize=10)
            if j > nb / 2:
                min_diff = min(bands[:, j] - bands[:, j - 1])
                ax3.text(0.025,
                         .15 - (j - nb / 2) * 0.04,
                         '$%0.2f$' % min_diff,
                         color='k',
                         fontsize=10)

        b.append((1j / (2 * np.pi)) * np.mean(traces[:, i]) * bz_area)

    ax1.view_init(elev=0, azim=0.)
    ax2.view_init(elev=-0, azim=90.)

    minx = min(kx)
    maxx = max(kx)
    miny = min(ky)
    maxy = max(ky)

    # plt.show()
    # plt.savefig(save_dir+'/images/test.png')

    s1 = '%0.2f_' % tvals[0]
    s2 = '%0.2f_' % tvals[1]
    s3 = '%0.2f_' % tvals[2]
    s4 = '%0.2f_' % tvals[3]

    s5 = '%0.2f_' % ons[0]
    s6 = '%0.2f_' % ons[1]
    s7 = '%0.2f_' % ons[2]
    s8 = '%0.2f_' % ons[3]

    plt.show()