def initialize_1panelcbar_fig(Wfig=90,
                              Hfig=None,
                              x0frac=0.15,
                              y0frac=0.1,
                              wsfrac=0.4,
                              hs=None,
                              vspace=8,
                              hspace=8,
                              tspace=10,
                              fontsize=8):
    """Create new figure with 1 panel and a colorbar, with defaults set for kitaev plots

    Parameters
    ----------
    Wfig : width of the figure in mm
    x0frac : fraction of Wfig to leave blank left of plot
    y0frac : fraction of Wfig to leave blank below plot
    wsfrac : fraction of Wfig to make width of subplot
    hs : height of subplot in mm. If none, uses ws = wsfrac * Wfig
    vspace : vertical space between subplots
    hspace : horizontal space btwn subplots
    tspace : space above top figure
    fontsize : size of text labels, title

    Returns
    -------
    fig
    ax
    """
    # Make figure
    x0 = round(Wfig * x0frac)
    y0 = round(Wfig * y0frac)
    ws = round(Wfig * wsfrac)
    if hs is None:
        hs = ws
    wscbar = ws * 0.3
    hscbar = wscbar * 0.1
    if Hfig is None:
        Hfig = y0 + hs + vspace + hscbar + tspace
    fig = sps.figure_in_mm(Wfig, Hfig)
    label_params = dict(size=fontsize, fontweight='normal')
    ax = [
        sps.axes_in_mm(x0,
                       y0,
                       width,
                       height,
                       label=part,
                       label_params=label_params)
        for x0, y0, width, height, part in (
            [Wfig * 0.5 -
             ws * 0.5, y0, ws, hs, ''],  # Chern vary hlatparam vs ksize
            [Wfig * 0.5 - wscbar, y0 + hs + vspace, wscbar *
             2, hscbar, '']  # cbar for chern
        )
    ]
    return fig, ax
def initialize_spectrum_with_dos_plot(Wfig=90,
                                      Hfig=None,
                                      x0frac=0.15,
                                      y0frac=0.1,
                                      wsfrac=0.4,
                                      hs=None,
                                      wsdosfrac=0.3,
                                      vspace=8,
                                      hspace=8,
                                      tspace=10,
                                      fontsize=8):
    """Initialize a figure with 4 axes, for plotting chern spectrum with accompanying DOS plot on the side

    Returns
    -------
    fig :
    ax : list of 4 matplotlib axis instances
        ax[0] is DOS axis, ax[1] is chern spectrum axis, ax[2] is ipr cbar axis, ax[3] is cbar for chern axis
    """
    x0 = round(Wfig * x0frac)
    y0 = round(Wfig * y0frac)
    ws = round(Wfig * wsfrac)
    hs = ws
    wsDOS = wsdosfrac * ws
    hsDOS = hs
    wscbar = wsDOS
    hscbar = wscbar * 0.1
    if Hfig is None:
        Hfig = y0 + hs + vspace + hscbar + tspace

    fig = sps.figure_in_mm(Wfig, Hfig)
    label_params = dict(size=fontsize, fontweight='normal')
    ax = [
        sps.axes_in_mm(x0,
                       y0,
                       width,
                       height,
                       label=part,
                       label_params=label_params)
        for x0, y0, width, height, part in (
            [x0, y0, wsDOS, hsDOS, ''],  # DOS
            [x0 + wsDOS + hspace, y0, ws, hs, ''],  # Chern omegac vs ksize
            [
                x0 + wsDOS * 0.5 - wscbar * 0.5, y0 + hs +
                vspace, wscbar, hscbar, ''
            ],  # cbar for ipr
            [
                x0 + wsDOS + hspace + ws * 0.5 - wscbar, y0 + hs +
                vspace, wscbar * 2, hscbar, ''
            ]  # cbar for chern
        )
    ]
    return fig, ax
Пример #3
0
def initialize_1p5panelcbar_fig(Wfig=90, Hfig=None, x0frac=0.15, y0frac=0.1, wsfrac=0.4, hs=None,
                                wssfrac=0.25, hssfrac=None, vspace=8, tspace=10,
                                fontsize=8, center0_frac=0.35, center2_frac=0.65):
    """Plot a chern figure with a panel below the plot (as in for showing kitaev regions)

    Parameters
    ----------
    Wfig : width of the figure in mm
    x0frac : fraction of Wfig to leave blank left of plot
    y0frac : fraction of Wfig to leave blank below plot
    wsfrac : fraction of Wfig to make width of subplot
    hs : height of subplot in mm. If none, uses ws = wsfrac * Wfig
    vspace : vertical space between subplots
    tspace : space above top figure
    fontsize : size of text labels, title

    Returns
    -------
    fig
    ax
    """
    # Make figure
    x0 = round(Wfig * x0frac)
    y0 = round(Wfig * y0frac)
    ws = round(Wfig * wsfrac)
    wss = Wfig * wssfrac
    if hssfrac is None:
        hssfrac = wssfrac
    hss = Wfig * hssfrac
    if hs is None:
        hs = ws
    wscbar = ws * 0.3
    hscbar = wscbar * 0.1
    if Hfig is None:
        Hfig = y0 + hs + vspace + hscbar + tspace
    fig = sps.figure_in_mm(Wfig, Hfig)
    label_params = dict(size=fontsize, fontweight='normal')
    ax = [sps.axes_in_mm(x0, y0, width, height, label=part, label_params=label_params)
          for x0, y0, width, height, part in (
            [Wfig * center0_frac - ws * 0.5, y0, ws, hs, ''],  # Chern vary glatparam vs ksize
            [Wfig * center0_frac - wscbar, y0 + hs + vspace, wscbar * 2, hscbar, ''],  # cbar for chern (DOS)
            [Wfig * center2_frac - ws * 0.5, y0, wss, hss, '']  # subplot for showing kitaev region
          )]
    return fig, ax
def movie_honeycomb_single_frame(deltaz, phi, text=True):
    """

    Parameters
    ----------
    deltaz
    phi
    text

    Returns
    -------

    """
    a1, a2 = bz.find_lattice_vecs(deltaz * np.pi / 180., phi * np.pi / 180.)
    b1, b2 = bz.find_bz_lattice_vecs(a1, a2)
    R, Ni, Nk, cols, line_cols = lf.honeycomb_sheared_vis(deltaz, phi)

    fig = sps.figure_in_mm(90, 90)
    ax_lat = sps.axes_in_mm(0, 0, 90, 90)

    lf.lattice_plot(R, Ni, Nk, ax_lat, cols, line_cols)
    plt.ylim(-2, 1)
    plt.xlim(-3, 3)
    ax_lat.axis('off')

    if text:
        ax_lat.text(-2.75,
                    1.5,
                    '$\delta$ = %d$^{\circ}$' % deltaz,
                    fontweight='bold',
                    color='k',
                    fontsize=16)
        ax_lat.text(-2.75,
                    1.3,
                    '$\phi$ = %d$^{\circ}$' % phi,
                    fontweight='bold',
                    color='k',
                    fontsize=16)

    return fig, ax_lat
Пример #5
0
        # Make figure
        Wfig = 180
        x0frac = 0.1
        y0frac = 0.1
        wsfrac = 0.35
        tspace = 12
        vspace = 20
        hspace = 18
        FSFS = 12
        x0 = round(Wfig * x0frac)
        y0 = round(Wfig * y0frac)
        ws = round(Wfig * wsfrac)
        hs = ws
        Hfig = y0 + 2 * hs + vspace + tspace

        fig = sps.figure_in_mm(Wfig, Hfig)
        label_params = dict(size=FSFS, fontweight='normal')
        ax = [
            sps.axes_in_mm(x0,
                           y0,
                           width,
                           height,
                           label=part,
                           label_params=label_params)
            for x0, y0, width, height, part in (
                [x0, y0 + hs + vspace, ws, hs, ''], [x0, y0, ws, hs, ''],
                [x0 + ws + hspace, y0, ws, hs, ''],
                [x0 + ws + hspace, y0 + hs + vspace, ws, hs, ''])
        ]

        ind = 0
Пример #6
0
def make_mode_movie(seriesdir,
                    amp=50,
                    semilog=True,
                    freqmin=None,
                    freqmax=None,
                    overwrite=False,
                    percent_max=0.01):
    """

    Parameters
    ----------
    seriesdir : str
        The full path to the directory with all tracked cines for which to make mode decomposition movies
    amp : float
        The amplification factor for the displacements in the movie output
    semilog : bool
        plot the FFT intensity vs frequency in log-normal scale
    freqmin : float
        The minimum frequency to plot
    freqmax : float
        The maximum frequency to plot
    overwrite : bool
        Overwrite the saved images/movies if they exist

    Returns
    -------
    """
    pathlist = dio.find_subdirs('20*', seriesdir)
    freqstr = ''
    if freqmin is not None:
        freqstr += '_minfreq' + sf.float2pstr(freqmin)
    if freqmax is not None:
        freqstr += '_maxfreq' + sf.float2pstr(freqmax)

    for path in pathlist:
        movname = path + 'modes_amp{0:0.1f}'.format(amp).replace(
            '.', 'p') + freqstr + '.mov'
        movexist = glob.glob(movname)
        print 'mode_movie: movexist = ', movexist

        if not movexist or overwrite:
            # If the movie does not exist yet, make it here
            print 'building modes for ', path
            fn = os.path.join(path, 'com_data.hdf5')
            print 'mode_drawing_functions.mode_movie.make_mode_movie(): loading data from ', fn
            data = new_mf.load_linked_data_and_window(fn)

            tp, fft_x, fft_y, freq = nmf.ffts_and_add(data)
            high_power_inds = nmf.find_peaks(tp, percent_max=percent_max)

            # Check how many mode images have been done
            modespngs = glob.glob(path + 'modes' + freqstr + '/*.png')
            modespngs_traces = glob.glob(path + 'modes_traces' + freqstr +
                                         '/*.png')

            # If we haven't made images of all the modes, do that here
            print 'mode_movie.make_mode_movie(): len(modespngs) = ', len(
                modespngs)
            print 'mode_movie.make_mode_movie(): len(high_power_inds) = ', len(
                high_power_inds)

            if len(modespngs) < len(high_power_inds) or overwrite:
                # Check if mode pickle is saved
                modesfn_traces = path + 'modes_trackes' + freqstr + '.pkl'
                globfn = glob.glob(modesfn_traces)
                if globfn:
                    with open(globfn[0], "rb") as fn:
                        mode_data = cPickle.load(fn)

                    coords = mode_data['xy']
                    mode_data = dh.removekey(mode_data, 'xy')

                    for i in mode_data:
                        if i % 10 == 0:
                            print 'mode_movie.make_mode_movie(): Creating mode image #', i

                        x_traces = mode_data[i][
                            'x_traces']  # sf.float2pcstr(freq[high_power_inds[i]], ndigits=8)]
                        y_traces = mode_data[i]['y_traces']

                        fig = sps.figure_in_mm(120, 155)
                        ax_mode = sps.axes_in_mm(10, 10, 100, 100)
                        ax_freq = sps.axes_in_mm(10, 120, 100, 30)

                        axes = [ax_mode, ax_freq]

                        new_mf.draw_mode(x_traces,
                                         y_traces,
                                         coords, [freq, tp],
                                         axes,
                                         i,
                                         freq[high_power_inds[i]],
                                         output_dir=os.path.join(
                                             path, 'modes'),
                                         amp=amp,
                                         semilog=semilog)
                else:
                    # The mode data is not saved, so we must create it as we go along
                    # data = new_mf.load_linked_data_and_window(fn)
                    coords = np.array([data[2], data[3]]).T

                    for i in xrange(len(high_power_inds)):
                        if i % 10 == 0:
                            print 'mode_movie.make_mode_movie(): Creating mode image #', i

                        x_traces, y_traces, max_mag = new_mf.get_mode_drawing_data(
                            fft_x, fft_y, freq, high_power_inds[i])

                        x_traces = np.array(x_traces)
                        y_traces = np.array(y_traces)

                        fig = sps.figure_in_mm(120, 155)
                        ax_mode = sps.axes_in_mm(10, 10, 100, 100)
                        ax_freq = sps.axes_in_mm(10, 120, 100, 30)

                        axes = [ax_mode, ax_freq]

                        new_mf.draw_mode(x_traces,
                                         y_traces,
                                         coords, [freq, tp],
                                         axes,
                                         i,
                                         freq[high_power_inds[i]],
                                         output_dir=os.path.join(
                                             path, 'modes'),
                                         amp=amp,
                                         semilog=semilog)

            # Obtain imagename and moviename, create movie if it doesn't exist
            modesfn = glob.glob(path + 'modes' + freqstr + '/*.png')
            imagename_split = modesfn[0].split('/')[-1].split('.png')[0]
            try:
                test = int(imagename_split)
                indexsz = len(imagename_split)
                imgname = path + 'modes' + freqstr + '/'
            except:
                print 'mode_movie.make_mode_movie(): imagename_split = ', imagename_split
                print 'mode_movie.make_mode_movie(): indexsz = ', len(
                    imagename_split)
                raise RuntimeError(
                    'Imagename is not just an int -- write code to allow ')

            movies.make_movie(imgname,
                              movname,
                              indexsz=str(indexsz),
                              framerate=10)
def plot_cherns_varyloc(ccoll, title='Chern number calculation for varied positions',
                        filename='chern_varyloc', exten='.pdf', rootdir=None, outdir=None,
                        max_boxfrac=None, max_boxsize=None,
                        xlabel=None, ylabel=None, step=0.5, fracsteps=False,
                        singleksz_frac=None, singleksz=-1.0, maxchern=False,
                        ax=None, cbar_ax=None, save=True, make_cbar=True, colorz=False,
                        dpi=600, colormap='divgmap_blue_red'):
    """Plot the chern as a function of space for each haldane_lattice examined. If save==True, saves figure to


    Parameters
    ----------
    ccoll : ChernCollection instance
    filename : str
        name of the file to output as the plot
    rootdir : str or None
        if specified, dictates the path where the file is stored, with the output directory as
        outdir = kfns.get_cmeshfn(ccoll.cherns[hlat_name][0].haldane_lattice.lp, rootdir=rootdir)
    max_boxfrac : float
        Fraction of spatial extent of the sample to use as maximum bound for kitaev sum
    max_boxsize : float or None
        If None, uses max_boxfrac * spatial extent of the sample asmax_boxsize
    singleksz : float
        if positive, plots the spatially-resolved chern number for a single kitaev summation region size, closest to
        the supplied value in actual size. Otherwise, draws many rectangles of different sizes for different ksizes.
        If positive, IGNORES max_boxfrac and max_boxsize arguments
    maxchern : bool
        if True, plots the EXTREMAL spatially-resolved chern number for each voxel --> ie the maximum (signed absolute)
        value it reaches. Otherwise, draws many rectangles of different sizes for different ksizes.
        If True, the function IGNORES max_boxfrac and max_boxsize arguments
    ax : axis instance or None
    cbar_ax : axis instance or None
    save : bool
    make_cbar : bool
    dpi : int
        dots per inch, if exten is '.png'
    """
    if colormap == 'divgmap_blue_red':
        divgmap = cmaps.diverging_cmap(250, 10, l=30)
    elif colormap == 'divgmap_red_blue':
        divgmap = cmaps.diverging_cmap(10, 250, l=30)

    # plot it
    for hlat_name in ccoll.cherns:
        rectps = []
        colorL = []
        print 'all hlats should have same pointer:'
        print 'ccoll[hlat_name][0].haldane_lattice = ', ccoll.cherns[hlat_name][0].haldane_lattice
        print 'ccoll[hlat_name][1].haldane_lattice = ', ccoll.cherns[hlat_name][1].haldane_lattice

        print 'Opening hlat_name = ', hlat_name

        if outdir is None:
            outdir = hcfns.get_cmeshfn(ccoll.cherns[hlat_name][0].haldane_lattice.lp, rootdir=rootdir)

        print 'when saving, will save to ' + outdir + 'filename'

        if maxchern:
            for chernii in ccoll.cherns[hlat_name]:
                # Grab small, medium, and large circles
                ksize = chernii.chern_finsize[:, 2]

                # Build XYloc_sz_nu from all cherns done on this network
                xx = float(chernii.cp['poly_offset'].split('/')[0])
                yy = float(chernii.cp['poly_offset'].split('/')[1])
                nu = chernii.chern_finsize[:, -1]
                ind = np.argmax(np.abs(nu))
                rad = step
                rect = plt.Rectangle((xx-rad*0.5, yy-rad*0.5), rad, rad, ec="none")
                colorL.append(nu[ind])
                rectps.append(rect)
        elif singleksz > 0:
            for chernii in ccoll.cherns[hlat_name]:
                # Grab small, medium, and large circles
                ksize = chernii.chern_finsize[:, 2]

                # Build XYloc_sz_nu from all cherns done on this network
                xx = float(chernii.cp['poly_offset'].split('/')[0])
                yy = float(chernii.cp['poly_offset'].split('/')[1])
                nu = chernii.chern_finsize[:, -1]
                ind = np.argmin(np.abs(ksize - singleksz))
                # print 'ksize = ', ksize
                # print 'singleksz = ', singleksz
                # print 'ind = ', ind
                rad = step
                rect = plt.Rectangle((xx-rad*0.5, yy-rad*0.5), rad, rad, ec="none")
                colorL.append(nu[ind])
                rectps.append(rect)
        elif singleksz_frac > 0:
            for chernii in ccoll.cherns[hlat_name]:
                # Grab small, medium, and large circles
                ksize_frac = chernii.chern_finsize[:, 1]

                # Build XYloc_sz_nu from all cherns done on this network
                xx = float(chernii.cp['poly_offset'].split('/')[0])
                yy = float(chernii.cp['poly_offset'].split('/')[1])
                nu = chernii.chern_finsize[:, -1]
                ind = np.argmin(np.abs(ksize_frac - singleksz_frac))
                # print 'ksize = ', ksize
                # print 'singleksz = ', singleksz
                # print 'ind = ', ind
                rad = step
                rect = plt.Rectangle((xx - rad * 0.5, yy - rad * 0.5), rad, rad, ec="none")
                colorL.append(nu[ind])
                rectps.append(rect)
        else:
            print 'stacking rectangles in list to add to plot...'
            for chernii in ccoll.cherns[hlat_name]:
                # Grab small, medium, and large circles
                # Note: I used to multiply by 0.5 here: Why did I multiply by 0.5 here? Not sure...
                # perhaps before I measured ksize by its diameter (or width) but used radius as an imput for drawing
                # a kitaev region.
                ksize = chernii.chern_finsize[:, 2]
                if max_boxsize is not None:
                    ksize = ksize[ksize < max_boxsize]
                else:
                    if max_boxfrac is not None:
                        cgll = chernii.haldane_lattice.lattice
                        maxsz = max(np.max(cgll.xy[:, 0]) - np.min(cgll.xy[:, 0]),
                                    np.max(cgll.xy[:, 1]) - np.min(cgll.xy[:, 1]))
                        max_boxsize = max_boxfrac * maxsz
                        ksize = ksize[ksize < max_boxsize]
                    else:
                        ksize = ksize
                        max_boxsize = np.max(ksize)

                # print 'ksize =  ', ksize
                # print 'max_boxsize =  ', max_boxsize

                # Build XYloc_sz_nu from all cherns done on this network
                xx = float(chernii.cp['poly_offset'].split('/')[0])
                yy = float(chernii.cp['poly_offset'].split('/')[1])
                nu = chernii.chern_finsize[:, -1]
                rectsizes = ksize / np.max(ksize) * step
                # Choose which rectangles to draw
                if len(ksize) > 30:
                    # Too many rectangles to add to the plot! Limit the number to keep file size down
                    inds2use = np.arange(0, len(ksize), int(float(len(ksize))*0.05))[::-1]
                else:
                    inds2use = np.arange(0, len(ksize), 1)[::-1]
                # print 'Adding ' + str(len(inds2use)) + ' rectangles...'

                # Make a list of the rectangles
                for ind in inds2use:
                    rad = rectsizes[ind]
                    rect = plt.Rectangle((xx-rad*0.5, yy-rad*0.5), rad, rad, ec="none")
                    colorL.append(nu[ind])
                    rectps.append(rect)

        print 'Adding patches to figure...'
        p = PatchCollection(rectps, cmap=divgmap, alpha=1.0, edgecolors='none')
        p.set_array(np.array(np.array(colorL)))
        p.set_clim([-1., 1.])

        if ax is None:
            # Make figure
            FSFS = 8
            Wfig = 90
            x0 = round(Wfig * 0.15)
            y0 = round(Wfig * 0.1)
            ws = round(Wfig * 0.4)
            hs = ws
            wsDOS = ws * 0.3
            hsDOS = hs
            wscbar = wsDOS
            hscbar = wscbar * 0.1
            vspace = 8  # vertical space btwn subplots
            hspace = 8  # horizonl space btwn subplots
            tspace = 10  # space above top figure
            Hfig = y0 + hs + vspace + hscbar + tspace
            fig = sps.figure_in_mm(Wfig, Hfig)
            label_params = dict(size=FSFS, fontweight='normal')
            ax = [sps.axes_in_mm(x0, y0, width, height, label=part, label_params=label_params)
                  for x0, y0, width, height, part in (
                      [Wfig * 0.5 - ws * 0.5, y0, ws, hs, ''],  # Chern vary hlatparam vs ksize
                      [Wfig * 0.5 - wscbar, y0 + hs + vspace, wscbar * 2, hscbar, '']  # cbar for chern
                  )]

            # Add the patches of nu calculations for each site probed
            ax[0].add_collection(p)
            hlat = ccoll.cherns[hlat_name][0].haldane_lattice
            netvis.movie_plot_2D(hlat.lattice.xy, hlat.lattice.BL, 0*hlat.lattice.BL[:, 0],
                             None, None, ax=ax[0], fig=fig, axcb=None,
                             xlimv='auto', ylimv='auto', climv=0.1, colorz=colorz, ptcolor=None, figsize='auto',
                             colormap='BlueBlackRed', bgcolor='#ffffff', axis_off=True, axis_equal=True,
                             lw=0.2)

            # Add title
            ax[0].annotate(title, xy=(0.5, .95), xycoords='figure fraction',
                           horizontalalignment='center', verticalalignment='center')
            if xlabel is not None:
                ax[0].set_xlabel(xlabel)
            if ylabel is not None:
                ax[0].set_xlabel(ylabel)

            # Position colorbar
            sm = plt.cm.ScalarMappable(cmap=divgmap, norm=plt.Normalize(vmin=-1, vmax=1))
            # fake up the array of the scalar mappable.
            sm._A = []
            cbar = plt.colorbar(sm, cax=ax[1], orientation='horizontal', ticks=[-1, 0, 1])
            ax[1].set_xlabel(r'$\nu$')
            ax[1].xaxis.set_label_position("top")
        else:
            # Add the patches of nu calculations for each site probed
            ax.add_collection(p)
            hlat = ccoll.cherns[hlat_name][0].haldane_lattice
            netvis.movie_plot_2D(hlat.lattice.xy, hlat.lattice.BL, 0*hlat.lattice.BL[:, 0],
                                 None, None, ax=ax, fig=plt.gcf(), axcb=None,
                                 xlimv='auto', ylimv='auto', climv=0.1, colorz=colorz, ptcolor=None, figsize='auto',
                                 colormap='BlueBlackRed', bgcolor='#ffffff', axis_off=True, axis_equal=True,
                                 lw=0.2)

            # Add title
            if title is not None and title != 'none':
                ax.annotate(title, xy=(0.5, .95), xycoords='figure fraction',
                            horizontalalignment='center', verticalalignment='center')
            if xlabel is not None:
                ax.set_xlabel(xlabel)
            if ylabel is not None:
                ax.set_xlabel(ylabel)

            if make_cbar:
                # Position colorbar
                sm = plt.cm.ScalarMappable(cmap=divgmap, norm=plt.Normalize(vmin=-1, vmax=1))
                # fake up the array of the scalar mappable.
                sm._A = []
                if cbar_ax is not None:
                    # figure out if cbar_ax is horizontal or vertical
                    if leplt.cbar_ax_is_vertical(cbar_ax):
                        cbar = plt.colorbar(sm, cax=cbar_ax, orientation='vertical', ticks=[-1, 0, 1],
                                            label='Chern\n' + r'number, $\nu$')
                    else:
                        cbar = plt.colorbar(sm, cax=cbar_ax, orientation='horizontal', ticks=[-1, 0, 1],
                                            label=r'Chern number, $\nu$')
                        cbar_ax.set_xlabel(r'Chern number, $\nu$')
                        cbar_ax.xaxis.set_label_position("top")
                else:
                    cbar = plt.colorbar(sm, cax=cbar_ax, orientation='horizontal', ticks=[-1, 0, 1],
                                        label=r'Chern number, $\nu$')

        if save:
            # Save the plot
            # make outdir the hlat_cmesh
            print 'kcollpfns: saving figure:\n outdir =', outdir, ' filename = ', filename
            # outdir = kfns.get_cmeshfn(ccoll.cherns[hlat_name][0].haldane_lattice.lp, rootdir=rootdir)
            if filename == 'chern_varyloc':
                filename += '_Nks'+'{0:03d}'.format(len(ksize)) + '_step' + sf.float2pstr(step) + '_maxbsz' +\
                            sf.float2pstr(max_boxsize) + exten
            print 'saving figure: ' + outdir + filename
            if exten == '.png':
                plt.savefig(outdir + filename, dpi=dpi)
            else:
                plt.savefig(outdir + filename)
            plt.clf()

    return ax, cbar
def save_plot(kx, ky, bands, traces, tvals, ons, vertex_points, od):
    """

    Parameters
    ----------
    kx
    ky
    bands
    traces
    tvals
    ons
    vertex_points
    od

    Returns
    -------

    """
    bz_area = dh.polygon_area(vertex_points)

    w = 180
    h = 140
    label_params = dict(size=12, fontweight='normal')

    s = 2
    h1 = (h - 3 * s) / 2.
    w1 = 0.75 * w

    fig = sps.figure_in_mm(w, h)
    ax1 = sps.axes_in_mm(0,
                         h - h1,
                         w1,
                         h1,
                         label=None,
                         label_params=label_params,
                         projection='3d')
    ax2 = sps.axes_in_mm(0,
                         h - 2 * h1,
                         w1,
                         h1,
                         label=None,
                         label_params=label_params,
                         projection='3d')
    ax3 = sps.axes_in_mm(1 * s + w1,
                         h - 2 * h1,
                         w - w1 - 3 * s,
                         2 * h1,
                         label=None,
                         label_params=label_params)
    ax_lat = sps.axes_in_mm(1 * s + w1,
                            h - 0.4 * h1,
                            w - w1 - 3 * s,
                            0.4 * h1,
                            label=None,
                            label_params=label_params)

    c1 = '#000000'
    c2 = '#B9B9B9'
    c3 = '#E51B1B'
    c4 = '#18CFCF'
    c5 = 'violet'
    c6 = 'k'

    cc1 = '#96adb4'
    cc2 = '#b16566'
    cc3 = '#665656'
    cc4 = '#9d96b4'

    pin = 1
    cols = [c1, c2, c3, c4, c5, c6]
    ccols = [cc1, cc2, cc3, cc4, c5]
    l_cols = [cc1, cc2, cc3, cc4, c5]
    ax3.axis([0, 1, 0, 1.])
    print len(ons)

    xp = [0.025, 0.525, 0.025, .525]
    yp = [0.65, 0.65, 0.6, 0.6]
    min_l = min([len(xp), len(ons)])
    for i in range(len(tvals)):
        ax3.text(xp[i],
                 yp[i] + 0.1,
                 '$k_%1d$' % (i + 1) + '= $%.2f$' % tvals[i],
                 style='italic',
                 color=ccols[i],
                 fontsize=12)
    for i in range(min_l):
        ax3.text(xp[i],
                 yp[i],
                 '$M_%1d$' % (i + 1) + '$= %.2f$' % ons[i],
                 style='italic',
                 color=cols[i],
                 fontsize=12)
    ax3.axis('off')
    ax_lat.axis('off')

    nb = len(bands[0])
    b = []
    band_gaps = np.zeros(nb - 1)

    ax3.text(0.025,
             .55,
             '$\mathrm{Chern\/ numbers}$',
             fontweight='bold',
             color='k',
             fontsize=12)
    ax3.text(0.025,
             .35,
             '$\mathrm{Band\/ boundaries}$',
             fontweight='bold',
             color='k',
             fontsize=12)
    ax3.text(0.025,
             .15,
             '$\mathrm{Min.\/differences}$',
             fontweight='bold',
             color='k',
             fontsize=12)

    for i in range(int(nb)):
        j = i
        bv = ((1j / (2 * np.pi)) * np.mean(traces[:, i]) * bz_area)

        if j >= nb / 2:
            ax1.plot_trisurf(kx,
                             ky,
                             bands[:, j],
                             cmap='cool',
                             vmin=min(abs(bands.flatten())),
                             vmax=max(abs(bands.flatten())),
                             linewidth=0.0,
                             alpha=1)
            ax2.plot_trisurf(kx,
                             ky,
                             bands[:, j],
                             cmap='cool',
                             vmin=min(abs(bands.flatten())),
                             vmax=max(abs(bands.flatten())),
                             linewidth=0.0,
                             alpha=1)

            v_min = min(abs(bands.flatten()))
            v_max = max(abs(bands.flatten()))
            # c_3D_plot(kx, ky, bands[:,j], vertex_points, ax1, v_min, v_max)
            # c_3D_plot(kx, ky, bands[:,j], vertex_points, ax2, v_min, v_max)

            ax1.set_ylabel('ky')
            ax1.axes.get_xaxis().set_ticks([])
            ax2.set_xlabel('kx')
            ax2.axes.get_yaxis().set_ticks([])

            ax1.set_zlabel('$\Omega$')
            ax2.set_zlabel('$\Omega$')

            bv = ((1j / (2 * np.pi)) * np.mean(traces[:, i]) * bz_area)
            if abs(bv) > 0.7:
                cc = 'r'
            else:
                cc = 'k'
            ax3.text(0.025,
                     .55 - (j - nb / 2 + 1) * 0.04,
                     '$%0.2f$' % bv,
                     color=cc,
                     fontsize=10)

            maxtb = max(bands[:, j])
            mintb = min(bands[:, j])
            ax3.text(0.025,
                     .35 - (j - nb / 2 + 1) * 0.04,
                     '$%0.2f$' % mintb,
                     color='k',
                     fontsize=10)
            ax3.text(0.525,
                     .35 - (j - nb / 2 + 1) * 0.04,
                     '$%0.2f$' % maxtb,
                     color='k',
                     fontsize=10)

            if j > nb / 2:
                min_diff = min(bands[:, j] - bands[:, j - 1])
                ax3.text(0.025,
                         .15 - (j - nb / 2) * 0.04,
                         '$%0.2f$' % min_diff,
                         color='k',
                         fontsize=10)

    ax1.view_init(elev=0, azim=0.)
    ax2.view_init(elev=0, azim=90.)

    minx = min(kx)
    maxx = max(kx)
    miny = min(ky)
    maxy = max(ky)

    return ax_lat, cols, l_cols, fig
def make_plot(kkx, kky, bands, traces, tvals, ons, pin, num='all'):
    """One of Lisa's functions for plotting chern results with lattice structure"""
    w = 180
    h = 140
    label_params = dict(size=12, fontweight='normal')

    s = 2
    h1 = (h - 3 * s) / 2.
    w1 = 0.75 * w

    fig = sps.figure_in_mm(w, h)
    ax1 = sps.axes_in_mm(0,
                         h - h1,
                         w1,
                         h1,
                         label=None,
                         label_params=label_params,
                         projection='3d')
    ax2 = sps.axes_in_mm(0,
                         h - 2 * h1,
                         w1,
                         h1,
                         label=None,
                         label_params=label_params,
                         projection='3d')
    ax3 = sps.axes_in_mm(1 * s + w1,
                         h - 2 * h1,
                         w - w1 - 3 * s,
                         2 * h1,
                         label=None,
                         label_params=label_params)
    ax_lat = sps.axes_in_mm(1 * s + w1,
                            h - 0.4 * h1,
                            w - w1 - 3 * s,
                            0.4 * h1,
                            label=None,
                            label_params=label_params)

    c1 = '#000000'
    c2 = '#FFFFFF'
    c3 = '#E51B1B'
    c4 = '#18CFCF'

    cc1 = '#96adb4'
    cc2 = '#b16566'
    cc3 = '#665656'
    cc4 = '#9d96b4'

    pin = 1
    gor = 'random'  # grid or random, lines
    cols = [c1, c2, c3, c4]
    l_cols = [cc1, cc2, cc3, cc4]
    mM, vertex_points, dump_dir = lf.alpha_lattice(tvals,
                                                   pin,
                                                   ons,
                                                   ax=ax_lat,
                                                   col=cols,
                                                   lincol=l_cols)
    ax3.axis([0, 1, 0, 1.])
    ax3.text(0.025,
             .75,
             '$k_1 = %.2f$' % tvals[0],
             style='italic',
             color=cc1,
             fontsize=12)
    ax3.text(0.525,
             .75,
             '$k_2 = %.2f$' % tvals[1],
             style='italic',
             color=cc2,
             fontsize=12)
    ax3.text(0.025,
             .7,
             '$k_3 = %.2f$' % tvals[2],
             style='italic',
             color=cc3,
             fontsize=12)
    ax3.text(0.525,
             .7,
             '$k_4 = %.2f$' % tvals[3],
             style='italic',
             color=cc4,
             fontsize=12)
    # bbox={'facecolor':'red', 'alpha':0.5, 'pad':10}
    ax3.text(0.025,
             .65,
             '$M_1 = %.2f$' % ons[0],
             style='italic',
             color=c1,
             fontsize=12)
    ax3.text(0.525,
             .65,
             '$M_2 = %.2f$' % ons[1],
             style='italic',
             color='#B9B9B9',
             fontsize=12)
    ax3.text(0.025,
             .6,
             '$M_3 = %.2f$' % ons[2],
             style='italic',
             color=c3,
             fontsize=12)
    ax3.text(0.525,
             .6,
             '$M_4 = %.2f$' % ons[3],
             style='italic',
             color=c4,
             fontsize=12)
    # bbox={'facecolor':'red', 'alpha':0.5, 'pad':10}
    ax3.axis('off')
    ax_lat.axis('off')

    nb = len(bands[0])
    b = []
    bz_area = dh.polygon_area(vertex_points)
    band_gaps = np.zeros(nb - 1)
    ax3.text(0.025,
             .55,
             '$\mathrm{Chern\/ numbers}$',
             fontweight='bold',
             color='k',
             fontsize=12)
    ax3.text(0.025,
             .35,
             '$\mathrm{Band\/ boundaries}$',
             fontweight='bold',
             color='k',
             fontsize=12)
    ax3.text(0.025,
             .15,
             '$\mathrm{Min.\/differences}$',
             fontweight='bold',
             color='k',
             fontsize=12)
    for i in range(int(nb)):
        j = i

        # ax.scatter(kkx, kky, 1j*traces[:,j])
        if j >= nb / 2:
            if num == 'all':
                ax1.plot_trisurf(kx[:1000],
                                 ky[:1000],
                                 bands[:, j][:1000],
                                 cmap='cool',
                                 vmin=min(abs(bands.flatten())),
                                 vmax=max(abs(bands.flatten())),
                                 linewidth=0.0)
                # ax2.plot_trisurf(kx, ky, bands[:,j], cmap = 'cool', vmin = min(abs(bands.flatten())),
                # vmax=max(abs(bands.flatten())), linewidth = 0.0)
            elif j == num:
                ax2.plot_trisurf(kx[:1000],
                                 ky[:1000],
                                 1j * traces[:, j][:1000],
                                 cmap='cool',
                                 vmin=min(abs(traces[:, j].flatten())),
                                 vmax=max(abs(traces[:, j].flatten())),
                                 linewidth=0.0)

            ax1.set_ylabel('ky')
            ax1.axes.get_xaxis().set_ticks([])
            ax2.set_xlabel('kx')
            ax2.axes.get_yaxis().set_ticks([])

            ax1.set_zlabel('$\Omega$')
            ax2.set_zlabel('$\Omega$')
            bv = ((1j / (2 * np.pi)) * np.mean(traces[:, i]) * bz_area)
            if abs(bv) > 0.7:
                cc = 'r'
            else:
                cc = 'k'
            if j == num:
                ax3.text(0.025,
                         .55 - (j - nb / 2 + 1) * 0.04,
                         '$%0.2f$' % bv,
                         color=cc,
                         fontsize=10)
            if j >= nb / 2:
                maxtb = max(bands[:, j])
                mintb = min(bands[:, j])
                ax3.text(0.025,
                         .35 - (j - nb / 2 + 1) * 0.04,
                         '$%0.2f$' % mintb,
                         color='k',
                         fontsize=10)
                ax3.text(0.525,
                         .35 - (j - nb / 2 + 1) * 0.04,
                         '$%0.2f$' % maxtb,
                         color='k',
                         fontsize=10)
            if j > nb / 2:
                min_diff = min(bands[:, j] - bands[:, j - 1])
                ax3.text(0.025,
                         .15 - (j - nb / 2) * 0.04,
                         '$%0.2f$' % min_diff,
                         color='k',
                         fontsize=10)

        b.append((1j / (2 * np.pi)) * np.mean(traces[:, i]) * bz_area)

    ax1.view_init(elev=0, azim=0.)
    ax2.view_init(elev=-0, azim=90.)

    minx = min(kx)
    maxx = max(kx)
    miny = min(ky)
    maxy = max(ky)

    # plt.show()
    # plt.savefig(save_dir+'/images/test.png')

    s1 = '%0.2f_' % tvals[0]
    s2 = '%0.2f_' % tvals[1]
    s3 = '%0.2f_' % tvals[2]
    s4 = '%0.2f_' % tvals[3]

    s5 = '%0.2f_' % ons[0]
    s6 = '%0.2f_' % ons[1]
    s7 = '%0.2f_' % ons[2]
    s8 = '%0.2f_' % ons[3]

    plt.show()