Ejemplo n.º 1
0
    def show_ratio_box(self, ratio, area, start, end, copy_to_clipboard=True):
        """Displays a text box that shows the ratio and of areas between two
        spectra within start and end.
        """
        if self.anchored_box:
            self.anchored_box.set_visible(False)
            self.anchored_box = None

        child_boxes = []

        if ratio is None:
            text = "Invalid selection, \nno ratio could \nbe calculated"
            child_boxes.append(
                offsetbox.TextArea(text, textprops=dict(color="k", size=12)))

        else:
            ratio_round = 9  # Round decimal number

            text = f"Difference: {round(area, 2)}\n" \
                   f"Ratio: {round(ratio, ratio_round)}\n" \
                   f"Interval: [{round(start, 2)}, {round(end, 2)}]"
            child_boxes.append(
                offsetbox.TextArea(text, textprops=dict(color="k", size=12)))

            if copy_to_clipboard:
                self.clipboard.setText(str(round(ratio, ratio_round)))
                text_2 = "\nRatio copied to clipboard."
                child_boxes.append(
                    offsetbox.TextArea(text_2,
                                       textprops=dict(color="k", size=10)))

        box = offsetbox.VPacker(children=child_boxes,
                                align="center",
                                pad=0,
                                sep=0)

        self.anchored_box = offsetbox.AnchoredOffsetbox(
            loc=2,
            child=box,
            pad=0.5,
            frameon=False,
            bbox_to_anchor=(1.0, 1.0),
            bbox_transform=self.axes.transAxes,
            borderpad=0.,
        )
        self.axes.add_artist(self.anchored_box)
        self.axes.add_artist(self.leg)
        self.canvas.draw_idle()
Ejemplo n.º 2
0
def set_marginal_histogram_title(ax, fmt, color, label=None, rotated=False):
    """ Sets the title of the marginal histograms.

    Parameters
    ----------
    ax : Axes
        The `Axes` instance for the plot.
    fmt : str
        The string to add to the title.
    color : str
        The color of the text to add to the title.
    label : str
        If title does not exist, then include label at beginning of the string.
    rotated : bool
        If `True` then rotate the text 270 degrees for sideways title.
    """

    # get rotation angle of the title
    rotation = 270 if rotated else 0

    # get how much to displace title on axes
    xscale = 1.05 if rotated else 0.0
    if rotated:
        yscale = 1.0
    elif len(ax.get_figure().axes) > 1:
        yscale = 1.15
    else:
        yscale = 1.05

    # get class that packs text boxes vertical or horizonitally
    packer_class = offsetbox.VPacker if rotated else offsetbox.HPacker

    # if no title exists
    if not hasattr(ax, "title_boxes"):

        # create a text box
        title = "{} = {}".format(label, fmt)
        tbox1 = offsetbox.TextArea(title,
                                   textprops=dict(color=color,
                                                  size=15,
                                                  rotation=rotation,
                                                  ha='left',
                                                  va='bottom'))

        # save a list of text boxes as attribute for later
        ax.title_boxes = [tbox1]

        # pack text boxes
        ybox = packer_class(children=ax.title_boxes,
                            align="bottom",
                            pad=0,
                            sep=5)

    # else append existing title
    else:

        # delete old title
        ax.title_anchor.remove()

        # add new text box to list
        tbox1 = offsetbox.TextArea(" {}".format(fmt),
                                   textprops=dict(color=color,
                                                  size=15,
                                                  rotation=rotation,
                                                  ha='left',
                                                  va='bottom'))
        ax.title_boxes = ax.title_boxes + [tbox1]

        # pack text boxes
        ybox = packer_class(children=ax.title_boxes,
                            align="bottom",
                            pad=0,
                            sep=5)

    # add new title and keep reference to instance as an attribute
    anchored_ybox = offsetbox.AnchoredOffsetbox(loc=2,
                                                child=ybox,
                                                pad=0.,
                                                frameon=False,
                                                bbox_to_anchor=(xscale,
                                                                yscale),
                                                bbox_transform=ax.transAxes,
                                                borderpad=0.)
    ax.title_anchor = ax.add_artist(anchored_ybox)
Ejemplo n.º 3
0
def PlotOscProb(iineu,Enumin,Enumax,param, datapath = "../data/SunOscProbabilities/",plotpath = "../plots",plot_survival_probability = False,filename_append = '', fmt = 'eps'):
    """Plots P(neu_ineu -> neu_fneu) as a function of the Energy from an initial flavor state (ineu)
    to all final flavor states (fneu) on the sun
    # iineu         : 0 (electron), 1 (muon), 2 (tau)
    # Enumin        : minimum neutrino energy       [eV]
    # Enumax        : maximum neutrino energy       [eV]
    # param         : physics parameter set list    [param_1,param_2,...,param_n]
    """
    plt.cla()
    plt.clf()
    plt.close()    
    #fig = plt.figure(figsize = (4*3+2,6))
    fig = plt.figure(figsize = (10,7.5))
    ax = plt.subplot(111)
    
    mpl.rcParams['axes.labelsize'] = "xx-large"
    mpl.rcParams['xtick.labelsize'] = "xx-large"
    mpl.rcParams['legend.fontsize'] = "small"
    
    mpl.rcParams['font.size'] = 30
    
    ordmag = np.log10(Enumax)-np.log10(Enumin)
    npoints = 1000.0*ordmag
    # Figuring out best energy scale
    try:
        if(Enumax/param[0].MeV <= 500.0) :
            scale =  param[0].MeV
            scalen = "MeV"
        elif(Enumax/param[0].GeV <= 1000.0) :
            scale =  param[0].GeV
            scalen = "GeV"
        else :
            scale =  param[0].GeV#param[0].TeV
            scalen = "GeV"#"TeV"
    except (TypeError,AttributeError):
        if(Enumax/param.MeV <= 500.0) :
            scale =  param.MeV
            scalen = "MeV"
        elif(Enumax/param.GeV <= 1000.0) :
            scale =  param.GeV
            scalen = "GeV"
        else :
            scale =  param.TeV
            scalen = "TeV"
            
    try : 
        Emin = Enumin/param[0].GeV
        Emax = Enumax/param[0].GeV
    except :
        Emin = Enumin/param.GeV
        Emax = Enumax/param.GeV        
        
    
    neulabel    = {0 : "e",1 : "\\mu", 2 : "\\tau",3 : "{s_1}",4 : "{s_2}",5 : "{s_3}"}
    sneulabel   = {0 : "e",1 : "mu", 2 : "tau",3 : "s1",4 : "s2",5 : "s3"} 
            
    # RK points
    #ERKstep = (np.log10(Enumax)-np.log10(Enumin))/(20.0)    
    #ERK = np.arange(np.log10(Enumin),np.log10(Enumax),ERKstep)
    #ERK = map(lambda E : (10**E)/scale,ERK)
    #ERK.append(1000.0)
    
    ERK = gt.MidPoint(gt.LogSpaceEnergies(Enumin/scale,Enumax/scale, binnum = 200))
    
    ERK = [ERK[i] for i in range(len(ERK)-1)]
    
    ERK = [ERK[i] for i in range(len(ERK)-1)]
    
    Estep = (Enumax-Enumin)/npoints        
    Enu = np.arange(Enumin,Enumax,Estep)/scale
    # Generating plot
    
    #totalPRK = [0.0]*len(ERK)
    
    #colors   = ['orange', 'r', 'k', 'c', 'm', 'y', 'k']
    colors   = ['b', 'r', 'g','c', 'm', 'y', 'k']
    linestyles = ['--','-.',':','-..','-']
    
    for ineu,fneu in [[iineu,0],[iineu,1],[iineu,2],[iineu,5]]:#[[0,0],[1,1],[2,2],[2,1],[2,0],[0,0]]:
        for i,p in enumerate(param):
            p.Refresh()
            plt.xlabel(r"$\mathrm{E}_\nu\mathrm{["+scalen+"]}$")
            plt.ylabel("$ \mathrm{Probability}$")
            
            if (p.name == "STD" or p.name == "STD_XXX" or p.numneu == 3) and fneu > 2:
                pass
            else :
                if fneu == 5:
                    fneu = 3
                    PRK_3 = map(lambda E : float(no.InterOscProb(ineu,fneu,E*scale,p,datapath,Emin,Emax,filename_append = filename_append)),ERK)
                    fneu = 4
                    PRK_4 = map(lambda E : float(no.InterOscProb(ineu,fneu,E*scale,p,datapath,Emin,Emax,filename_append = filename_append)),ERK)
                    if p.numneu > 5 :
                        fneu = 5
                        PRK_5 = map(lambda E : float(no.InterOscProb(ineu,fneu,E*scale,p,datapath,Emin,Emax,filename_append = filename_append)),ERK)
                    else :
                        PRK_5 = map(lambda E : 0.0*E ,ERK)
                    #PRK = [PRK_3[i] + PRK_4[i] for i in range(len(PRK_3))]
                    PRK = [PRK_3[i] + PRK_4[i] + PRK_5[i] for i in range(len(PRK_3))]
                    
                    if p.neutype == "neutrino":
                        plt.plot(ERK,PRK,linestyle = linestyles[-1],label='$ \mathrm{P}(\\nu_'+neulabel[ineu]+'\\rightarrow \\nu_s)$',color = p.colorstyle, lw = 6,solid_joinstyle = 'bevel')
                    elif p.neutype == "antineutrino":
                        plt.plot(ERK,PRK,linestyle = linestyles[-1],label='$ \mathrm{P}(\\bar{\\nu}_'+neulabel[ineu]+'\\rightarrow \\bar{\\nu}_s)$',color = p.colorstyle, lw = 6,solid_joinstyle = 'bevel')
                    else:
                        print "Wrong neutrino type."
                        quit()
                else :
                    PRK = map(lambda E : no.InterOscProb(ineu,fneu,E*scale,p,datapath,Emin,Emax,filename_append = filename_append),ERK)
                    if p.name == "STD" or p.name == "STD_XXX" or p.numneu == 3:
                        plt.plot(ERK,PRK,linestyle = linestyles[fneu],color = p.colorstyle, lw = 4)
                    else :
                        if p.neutype == "neutrino":
                            plt.plot(ERK,PRK,linestyle = linestyles[fneu],label='$ \mathrm{P}(\\nu_'+neulabel[ineu]+'\\rightarrow \\nu_'+neulabel[fneu]+')$', color = p.colorstyle, lw = 4)
                        elif p.neutype == "antineutrino":
                            plt.plot(ERK,PRK,linestyle = linestyles[fneu],label='$ \mathrm{P}(\\bar{\\nu}_'+neulabel[ineu]+'\\rightarrow \\bar{\\nu}_'+neulabel[fneu]+')$', color = p.colorstyle, lw = 4)
                        else:
                            print "Wrong neutrino type."
                            quit()                                
                
    if plot_survival_probability :
        for p in param:
            P_surival = map(lambda E : no.InterOscProb(iineu,p.numneu,E*scale,p,datapath,Emin,Emax,filename_append = filename_append),ERK)
            plt.plot(ERK,P_surival,linestyle = 'solid', lw = 4, color = p.colorstyle,solid_joinstyle = 'bevel')
        
    plt.semilogx()
    #plt.loglog()
    
    plt.axis([Enumin/scale,Enumax/scale,0.0,1.0])
    
    #fig.subplots_adjust(left=0.05, right=0.8,wspace = 0.35, top = 0.85, bottom = 0.15)
    #mpl.rcParams['legend.fontsize'] = "small"
    #plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.,fancybox = True)
    plt.legend(loc='upper right',fancybox = True)
    
    ######################### begin extra legend ##########################
    for i,p in enumerate(param):
        # removing everything after an _ (underscore)
        try:
            paramname = re.search('.*(?=_)',p.name).group(0)
        except AttributeError:
            paramname = p.name
        
        # create text box
        boxt = osb.TextArea(paramname, textprops=dict(color="k"))
        boxd = osb.DrawingArea(60, 20, 0, 0)
        el = ptc.Ellipse((10, 10), width=5, height=5, angle=0, fc=p.colorstyle, edgecolor = 'none')
        boxd.add_artist(el)
        
        box = osb.HPacker(children=[boxt, boxd],
              align="center",
              pad=0, sep=1)
        
        # anchor boxes
        anchored_box = osb.AnchoredOffsetbox(loc=2,
                            child=box, pad=0.25,
                            frameon=False,
                            bbox_to_anchor=(0.0, 1.0-0.06*i),
                            bbox_transform=ax.transAxes,
                            borderpad=0.,
                            )
        
        ax.add_artist(anchored_box)
    ########################## end extra legend ##############################
    
    fig.subplots_adjust(bottom = 0.12, top = 0.95, left = 0.12, right = 0.95)
    
    path = plotpath
    
    mpl.rcParams['font.size'] = 30
    
    try:
        filename = path+"PlotOscProbability_ineu_"+str(iineu)
        for p in param:
            filename = filename + +"_" + p.name+"_"+p.neutype
        filename = filename + "."+fmt
    except TypeError:
        filename = path+"PlotOscProbability_ineu_"+str(iineu)+"_"+param.name+"_"+param.neutype+"."+fmt
        
    plt.savefig(filename, dpi = 1200)
    
    plt.clf()
    plt.close()
Ejemplo n.º 4
0
def PlotSingleNeuCompositionCompare(E,body,param,sparam = PC.PhysicsConstants()):
    """ Plots the composition of a single mass neutrino state.
    
    E        :    neutrino energy [eV]
    body     :    body with the asociated density profile.
    param    :    set of physical parameters used to make the plot. param can be a list.
    sparam   :    standard parameters
    """
    fig = plt.figure()
    ax = plt.subplot(111)
    
    mpl.rcParams['axes.labelsize'] = "x-large"
    mpl.rcParams['xtick.labelsize'] = "x-large"
    mpl.rcParams['legend.fontsize'] = "small"
    
    mpl.rcParams['font.size'] = 18    
    
    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
    linestyles = ['--','-.',':','-','-']
    
    #B Initializing variables
    param.Refresh()
    fM2 = no.flavorM2(param)
    sparam.Refresh()
    fM2STD = no.flavorM2(sparam)
    R = np.arange(1.0,0.01,-0.001)
    Rho = map(lambda r : body.rdensity(r), R)
    #E Initializing variables
    #B Estimating Energy Scale
    if(E/param.MeV <= 500.0) :
        scale =  param.MeV
        scalen = "MeV"
    elif(E/param.GeV <= 500.0) :
        scale =  param.GeV
        scalen = "GeV"
    else :
        scale =  param.TeV
        scalen = "TeV"
    #E Estimating Energy Scale
    
    #B Adding title
    #tit = "Energy : "+str(E/scale)+" "+scalen+ " Parameters :"#+" $\\th_{12}$ = " + str(param.th12) + " $\\th_{23}$ = " + str(param.th23) + " $\\th_{13}$ = "+str(param.th13)
    #atit = []
    #[[ atit.append(" $\\theta_{"+str(j)+str(i)+"}$ = "+format(param.th[j][i],'.4f')) for i in range(1,param.numneu+1) if i>j] for j in range(1,param.numneu+1) ]
    #[[ atit.append(" $\\Delta m^2_{"+str(i)+str(j)+"}$ = "+format(param.dm2[j][i],'.4f')) for i in range(1,param.numneu+1) if i>j and j == 1] for j in range(1,param.numneu+1) ]
    ##[[ atit.append(" $\\Delta m^2_{"+str(j)+str(i)+"}$ = "+format(param.dm2[j][i],'.4f')) for i in range(1,param.numneu+1) if i>j and j == 1] for j in range(1,param.numneu+1) ]
    #for i in range(len(atit)):
    #    tit = tit + atit[i]
    #plt.suptitle(tit,horizontalalignment='center')    
    #E Adding title
    
    ##B PLOTTING MASS BASIS AS FUNCTION OF FLAVOR BASIS
    for i in [1]:
        #fig.add_subplot(2,param.numneu,i+1)
        flavor = False
        NeuComp = map(lambda x : no.NeuComposition(i,E, x, body, fM2, param,flavor),R)
        plt.xlabel(r"$\rho \mathrm{[g/cm^{3}]}$")
        pp = []
        for k in range(param.numneu):
            kNeuComp = map(lambda x: x[k],NeuComp)
            
            # log interpolator
            rholog = gt.LogSpaceEnergies(float(Rho[0]),float(Rho[-1]),100)
            
            #print rholog , float(Rho[0]),float(Rho[-1])
            rholog[-1] = Rho[-1]
            
            inter_neu = interpolate.interp1d(Rho,kNeuComp)
            logkNeuComp = map(inter_neu,rholog)
            
            if k == 3:
                #pp.append(plt.plot(Rho,kNeuComp,'o-',color = 'r',markevery = 10,markeredgewidth = 0.0, ms = 2.0))
                pp.append(plt.plot(rholog,logkNeuComp,'x-',color = 'r',markevery = 10,markeredgewidth = 0.0, ms = 2.0,aa = True,solid_joinstyle = 'bevel'))
            elif k == 4:
                pp.append(plt.plot(Rho,kNeuComp,'o-',color = 'r',markevery = 10,markeredgewidth = 0.0, ms = 2.0))
            else :
                pp.append(plt.plot(Rho,kNeuComp,linestyle = linestyles[k] ,color = 'r'))
        if i<=2 :
            NeuCompSTD = map(lambda x : no.NeuComposition(i,E, x, body, fM2STD, sparam,flavor),R)
            for k in range(sparam.numneu):
                kNeuCompSTD = map(lambda x: x[k],NeuCompSTD)
                plt.plot(Rho,kNeuCompSTD, color = 'k', linestyle = linestyles[k])
        #Solar density
        #ps = plt.vlines(150, 0.0, 1.0, linestyle = "dashed", label = r"$\rho_S$")
        #B plt format
        plt.title(r"Composition of $\nu_"+str(i+1)+"$")
        plt.semilogx()
        plt.ylim(0.0,1.0)
        plt.xlim(1.0,150.0)
        plt.yticks(np.arange(0.0,1.1,0.1))
        xtt = [1.0,5.0,10.0,30.0,100.0]#,150.0]
        #plt.xticks(xtt)
        ax.set_xticks(xtt)
	ax.set_xticklabels(['$1$','$5$','$10$','$30$','$100$'])#,'$\\rho_\\odot = 150$'])
        
        #plt.xscale()
        #B LEGEND
        plots = [] 
        for e in pp :
            plots.append(e[0])
        #plots.append(ps)
        leg = ["$\\nu_e$","$\\nu_\mu$","$\\nu_\\tau$"]
        ss =  ["$\\nu_{s"+str(i)+"}$" for i in np.arange(1,param.numneu-3+1,1)]
        if ss != []:
            leg.extend(ss)
        leg = plt.legend(plots,leg,loc = 6,fancybox=True,bbox_to_anchor = (0.05, 0.75))
        leg.get_frame().set_alpha(0.25)
        #E LEGEND            
        #E plt format
        
        #B EXTRA LEGEND
        box1t = osb.TextArea("STD", textprops=dict(color="k"))
        box1d = osb.DrawingArea(60, 20, 0, 0)
        el1 = ptc.Ellipse((10, 10), width=5, height=5, angle=0, fc="k", edgecolor = 'none')
        box1d.add_artist(el1)
        
        box2t = osb.TextArea("2+3", textprops=dict(color="k"))
        box2d = osb.DrawingArea(60, 20, 0, 0)
        el2 = ptc.Ellipse((10, 10), width=5, height=5, angle=0, fc="r", edgecolor = 'none')
        box2d.add_artist(el2)
        
        box1 = osb.HPacker(children=[box1t, box1d],
                  align="center",
                  pad=0, sep=1)
        
        box2 = osb.HPacker(children=[box2t, box2d],
                  align="center",
                  pad=0, sep=5)
        
        anchored_box1 = osb.AnchoredOffsetbox(loc=9,
                                     child=box1, pad=5.0,
                                     frameon=False,
                                     #bbox_to_anchor=(0., 1.02),
                                     #bbox_transform=ax.transAxes,
                                     borderpad=0.,
                                     )
        
        anchored_box2 = osb.AnchoredOffsetbox(loc=9,
                                     child=box2, pad=6.0,
                                     frameon=False,
                                     #bbox_to_anchor=(0., 1.02),
                                     #bbox_transform=ax.transAxes,
                                     borderpad=0.,
                                     )    
        
        ax.add_artist(anchored_box1)
        ax.add_artist(anchored_box2)
        #E EXTRA LEGEND
        
    ##E PLOTTING MASS BASIS AS FUNCTION OF FLAVOR BASIS

    #plt.suptitle("*Dashed colored lines are 3-flavor standard oscillations.", x = 0.15, y = 0.03)
    path = "../plots/"
    filename = "PlotNeuComposition_E_"+str(E/scale)+"_"+scalen+"_FIG2.eps"
    plt.savefig(path + filename, dpi = 1200)    
Ejemplo n.º 5
0
def plot_model(model,
               save_file=None,
               ax=None,
               show=False,
               fig_title="Demographic Model",
               pop_labels=None,
               nref=None,
               draw_ancestors=True,
               draw_migrations=True,
               draw_scale=True,
               arrow_size=0.01,
               transition_size=0.05,
               gen_time=0,
               gen_time_units="Years",
               reverse_timeline=False,
               fig_bg_color='#ffffff',
               plot_bg_color='#ffffff',
               text_color='#002b36',
               gridline_color='#586e75',
               pop_color='#268bd2',
               arrow_color='#073642',
               label_size=16,
               tick_size=12,
               grid=True):
    """
    Plots a demographic model based on information contained within a _ModelInfo
    object. See the matplotlib docs for valid entries for the color parameters.

    model : A _ModelInfo object created using generate_model().

    save_file : If not None, the figure will be saved to this location. Otherwise
                the figure will be displayed to the screen.

    fig_title : Title of the figure.

    pop_labels : If not None, should be a list of strings of the same length as
                 the total number of final populations in the model. The string
                 at index i should be the name of the population along axis i in
                 the model's SFS.

    nref : If specified, this will update the time and population size labels to
           use units based on an ancestral population size of nref. See the
           documentation for details.

    draw_ancestors : Specify whether the ancestral populations should be drawn
                     in beginning of plot. Will fade off with a gradient.
    
    draw_migrations : Specify whether migration arrows are drawn.

    draw_scale : Specify whether scale bar should be shown in top-left corner.

    arrow_size : Float to control the size of the migration arrows.

    transition_size : Float specifying size of the "transitional periods"
                      between populations.

    gen_time : If greater than 0, and nref given, timeline will be adjusted to
               show absolute time values, using this value as the time elapsed
               per generation.

    gen_time_units : Units used for gen_time (e.g. Years, Thousand Years, etc.).

    reverse_timeline : If True, the labels on the timeline will be reversed, so
                       that "0 time" is the present time period, rather than the
                       time of the original population.

    fig_bg_color : Background color of figure (i.e. border surrounding the
                   drawn model).

    plot_bg_color : Background color of the actual plot area.

    text_color : Color of text in the figure.

    gridline_color : Color of the plot gridlines.

    pop_color : Color of the populations.

    arrow_color : Color of the arrows showing migrations between populations.
    """
    # Set up the plot with a title and axis labels
    fig_kwargs = {
        'figsize': (9.6, 5.4),
        'dpi': 200,
        'facecolor': fig_bg_color,
        'edgecolor': fig_bg_color
    }
    if ax == None:
        fig = plt.figure(**fig_kwargs)
        ax = fig.add_subplot(111)
    ax.set_facecolor(plot_bg_color)
    ax.set_title(fig_title, color=text_color, fontsize=24)
    xlabel = "Time Ago" if reverse_timeline else "Time"
    if nref:
        if gen_time > 0:
            xlabel += " ({})".format(gen_time_units)
        else:
            xlabel += " (Generations)"
        ylabel = "Population Sizes"
    else:
        xlabel += " (Genetic Units)"
        ylabel = "Relative Population Sizes"
    ax.set_xlabel(xlabel, color=text_color, fontsize=label_size)
    ax.set_ylabel(ylabel, color=text_color, fontsize=label_size)

    # Determine various maximum values for proper scaling within the plot
    xmax = model.tp_list[-1].time[-1]
    ymax = sum(model.tp_list[0].framesizes)
    ax.set_xlim([-1 * xmax * 0.1, xmax])
    ax.set_ylim([0, ymax])
    mig_max = 0
    for tp in model.tp_list:
        if tp.migrations is None:
            continue
        mig = np.amax(tp.migrations)
        mig_max = mig_max if mig_max > mig else mig

    # Configure axis border colors
    ax.spines['top'].set_color(text_color)
    ax.spines['right'].set_color(text_color)
    ax.spines['bottom'].set_color(text_color)
    ax.spines['left'].set_color(text_color)

    # Major ticks along x-axis (time) placed at each population split
    xticks = [tp.time[0] for tp in model.tp_list]
    xticks.append(xmax)
    ax.xaxis.set_major_locator(mticker.FixedLocator(xticks))
    ax.xaxis.set_minor_locator(mticker.NullLocator())
    ax.tick_params(which='both',
                   axis='x',
                   labelcolor=text_color,
                   labelsize=tick_size,
                   top=False)
    # Choose correct time labels based on nref, gen_time, and reverse_timeline
    if reverse_timeline:
        xticks = [xmax - x for x in xticks]
    if nref:
        if gen_time > 0:
            xticks = [2 * nref * gen_time * x for x in xticks]
        else:
            xticks = [2 * nref * x for x in xticks]
        ax.set_xticklabels(['{:.0f}'.format(x) for x in xticks])
    else:
        ax.set_xticklabels(['{:.2f}'.format(x) for x in xticks])

    # Gridlines along y-axis (population size) spaced by nref size
    if grid:
        ax.yaxis.set_major_locator(mticker.FixedLocator(np.arange(ymax)))
        ax.yaxis.set_minor_locator(mticker.NullLocator())
        ax.grid(b=True, which='major', axis='y', color=gridline_color)
        ax.tick_params(which='both',
                       axis='y',
                       colors='none',
                       labelsize=tick_size)
    else:
        ax.set_yticks([])

    # Add scale in top-left corner displaying ancestral population size (Nref)
    if draw_scale:
        # Bidirectional arrow of height Nref
        arrow = mbox.AuxTransformBox(ax.transData)
        awidth = xmax * arrow_size * 0.2
        alength = ymax * arrow_size
        arrow_kwargs = {
            'width': awidth,
            'head_width': awidth * 3,
            'head_length': alength,
            'color': text_color,
            'length_includes_head': True
        }
        arrow.add_artist(
            plt.arrow(0, 0.25, 0, 0.75, zorder=100, **arrow_kwargs))
        arrow.add_artist(
            plt.arrow(0, 0.75, 0, -0.75, zorder=100, **arrow_kwargs))
        # Population bar of height Nref
        bar = mbox.AuxTransformBox(ax.transData)
        bar.add_artist(
            mpatches.Rectangle((0, 0), xmax / ymax, 1, color=pop_color))
        # Appropriate label depending on scale
        label = mbox.TextArea(str(nref) if nref else "Nref")
        label.get_children()[0].set_color(text_color)
        bars = mbox.HPacker(children=[label, arrow, bar],
                            pad=0,
                            sep=2,
                            align="center")
        scalebar = mbox.AnchoredOffsetbox(2,
                                          pad=0.25,
                                          borderpad=0.25,
                                          child=bars,
                                          frameon=False)
        ax.add_artist(scalebar)

    # Add ancestral populations using a gradient fill.
    if draw_ancestors:
        time = -1 * xmax * 0.1
        for i, ori in enumerate(model.tp_list[0].origins):
            # Draw ancestor for each initial pop
            xlist = np.linspace(time, 0.0, model.precision)
            dx = xlist[1] - xlist[0]
            low, mid, top = (ori[1], ori[1] + 1.0,
                             ori[1] + model.tp_list[0].popsizes[i][0])
            tsize = int(transition_size * model.precision)
            y1list = np.array([low] * model.precision)
            y2list = np.array([mid] * (model.precision - tsize))
            y2list = np.append(y2list, np.linspace(mid, top, tsize))
            # Custom color map runs from bg color to pop color
            cmap = mcolors.LinearSegmentedColormap.from_list(
                "custom_map", [plot_bg_color, pop_color])
            colors = np.array(
                cmap(np.linspace(0.0, 1.0, model.precision - tsize)))
            # Gradient created by drawing multiple small rectangles
            for x, y1, y2, color in zip(xlist[:-1 * tsize],
                                        y1list[:-1 * tsize],
                                        y2list[:-1 * tsize], colors):
                rect = mpatches.Rectangle((x, y1), dx, y2 - y1, color=color)
                ax.add_patch(rect)
            ax.fill_between(xlist[-1 * tsize:],
                            y1list[-1 * tsize:],
                            y2list[-1 * tsize:],
                            color=pop_color,
                            edgecolor=pop_color)

    # Iterate through time periods and populations to draw everything
    for tp_index, tp in enumerate(model.tp_list):
        # Keep track of migrations to evenly space arrows across time period
        total_migrations = np.count_nonzero(tp.migrations)
        num_migrations = 0

        for pop_index in range(len(tp.popsizes)):
            # Draw current population
            origin = tp.origins[pop_index]
            popsize = tp.popsizes[pop_index]
            direc = tp.direcs[pop_index]
            y1 = origin[1]
            y2 = origin[1] + (direc * popsize)
            ax.fill_between(tp.time,
                            y1,
                            y2,
                            color=pop_color,
                            edgecolor=pop_color)

            # Draw connections to next populations if necessary
            if tp.descendants is not None and tp.descendants[pop_index] != -1:
                desc = tp.descendants[pop_index]
                tp_next = model.tp_list[tp_index + 1]
                # Split population case
                if isinstance(desc, tuple):
                    # Get origins
                    connect_below = tp_next.origins[desc[0]][1]
                    connect_above = tp_next.origins[desc[1]][1]
                    # Get popsizes
                    subpop_below = tp_next.popsizes[desc[0]][0]
                    subpop_above = tp_next.popsizes[desc[1]][0]
                    # Determine correct connection location
                    connect_below -= direc * subpop_below
                    connect_above += direc * subpop_above
                # Single population case
                else:
                    connect_below = tp_next.origins[desc][1]
                    subpop = tp_next.popsizes[desc][0]
                    connect_above = connect_below + direc * subpop
                # Draw the connections
                tsize = int(transition_size * model.precision)
                cx = tp.time[-1 * tsize:]
                cy_below_1 = [origin[1]] * tsize
                cy_above_1 = origin[1] + direc * popsize[-1 * tsize:]
                cy_below_2 = np.linspace(cy_below_1[0], connect_below, tsize)
                cy_above_2 = np.linspace(cy_above_1[0], connect_above, tsize)
                ax.fill_between(cx,
                                cy_below_1,
                                cy_below_2,
                                color=pop_color,
                                edgecolor=pop_color)
                ax.fill_between(cx,
                                cy_above_1,
                                cy_above_2,
                                color=pop_color,
                                edgecolor=pop_color)

            # Draw migrations if necessary
            if draw_migrations and tp.migrations is not None:
                # Iterate through migrations for current population
                for mig_index, mig_val in enumerate(tp.migrations[pop_index]):
                    # If no migration, continue
                    if mig_val == 0:
                        continue
                    # Calculate proper offset for arrow within this period
                    num_migrations += 1
                    offset = int(tp.precision * num_migrations /
                                 (total_migrations + 1.0))
                    x = tp.time[offset]
                    dx = 0
                    # Determine which sides of populations are closest
                    y1 = origin[1]
                    y2 = y1 + direc * popsize[offset]
                    mig_y1 = tp.origins[mig_index][1]
                    mig_y2 = mig_y1 + (tp.direcs[mig_index] *
                                       tp.popsizes[mig_index][offset])
                    y = y1 if abs(mig_y1 - y1) < abs(mig_y1 - y2) else y2
                    dy = mig_y1-y if abs(mig_y1 - y) < abs(mig_y2 - y) \
                                  else mig_y2-y
                    # Scale arrow to proper size
                    mig_scale = max(0.1, mig_val / mig_max)
                    awidth = xmax * arrow_size * mig_scale
                    alength = ymax * arrow_size
                    ax.arrow(x,
                             y,
                             dx,
                             dy,
                             width=awidth,
                             head_width=awidth * 3,
                             head_length=alength,
                             color=arrow_color,
                             length_includes_head=True)

    # Label populations if proper labels are given
    tp_last = model.tp_list[-1]
    if pop_labels and len(pop_labels) == len(tp_last.popsizes):
        ax2 = ax.twinx()
        ax2.set_xlim(ax.get_xlim())
        ax2.set_ylim(ax.get_ylim())
        # Determine placement of ticks
        yticks = [
            tp_last.origins[i][1] +
            0.5 * tp_last.direcs[i] * tp_last.popsizes[i][-1]
            for i in range(len(tp_last.popsizes))
        ]
        ax2.yaxis.set_major_locator(mticker.FixedLocator(yticks))
        ax2.set_yticklabels(pop_labels)
        ax2.tick_params(which='both',
                        color='none',
                        labelcolor=text_color,
                        labelsize=label_size,
                        left=False,
                        top=False,
                        right=False)
        ax2.spines['top'].set_color(text_color)
        ax2.spines['left'].set_color(text_color)
        ax2.spines['right'].set_color(text_color)
        ax2.spines['bottom'].set_color(text_color)

    # Display figure
    if save_file:
        plt.savefig(save_file, **fig_kwargs)
    else:
        if show == True:
            plt.show()
    if ax == None:
        plt.close(fig)