Пример #1
0
import os, sys
import pickle
from dynesty import utils as dyutils
from allesfitter import config
from astropy.time import Time
import matplotlib.pyplot as plt

# base data path for the TESS known planet project
pathdata = os.environ['KNWN_DATA_PATH'] + '/'
pathdatapost = os.environ['KNWN_DATA_PATH'] + '/postproc/'

# name of the exoplanet
strgplan = sys.argv[1]

pathdataplan = pathdata + '%s/allesfit_global/allesfit_onlytess_full/' % strgplan
config.init(pathdataplan)
fileobjt = open(pathdataplan + 'results/save_ns.pickle', 'rb')
objtrest = pickle.load(fileobjt)
weig = np.exp(objtrest['logwt'] - objtrest['logz'][-1])
chan = dyutils.resample_equal(objtrest.samples, weig)

listkeys = config.BASEMENT.fitkeys
listlabl = config.BASEMENT.fitlabels

# get period and epoch posteriors
for k, labl in enumerate(listlabl):
    if listkeys[k] == 'b_epoch':
        postepoc = chan[:, k]
    if listkeys[k] == 'b_period':
        postperi = chan[:, k]
Пример #2
0
def plot_viol(pathdataoutp, pvalthrs=1e-3, boolonlytess=False):

    strgtimestmp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    print('allesfitter postprocessing violin plot started at %s...' %
          strgtimestmp)

    liststrgruns = ['woutTESS', 'alldata']
    if boolonlytess:
        liststrgruns.append(['TESS'])

    numbruns = len(liststrgruns)
    indxruns = np.arange(numbruns)

    liststrgpara = [[] for i in indxruns]
    listlablpara = [[] for i in indxruns]
    listobjtalle = [[] for i in indxruns]
    for i, strgruns in enumerate(liststrgruns):
        pathdata = pathdataoutp + 'allesfits/allesfit_%s' % strgruns
        print('Reading from %s...' % pathdata)
        config.init(pathdata)
        liststrgpara[i] = np.array(config.BASEMENT.fitkeys)
        listlablpara[i] = np.array(config.BASEMENT.fitlabels)
        # read the chain
        listobjtalle[i] = allesfitter.allesclass(pathdata)

    liststrgparaconc = np.concatenate(liststrgpara)
    liststrgparaconc = np.unique(liststrgparaconc)
    listlablparaconc = np.copy(liststrgparaconc)
    for k, strgparaconc in enumerate(liststrgparaconc):
        for i, strgruns in enumerate(liststrgruns):
            if strgparaconc in liststrgpara[i]:
                listlablparaconc[k] = listlablpara[i][np.where(
                    liststrgpara[i] == strgparaconc)[0][0]]

    ticklabl = ['w/o TESS', 'w/ TESS']
    if boolonlytess:
        ticklabl.append(['only TESS'])

    xpos = 0.6 * (np.arange(numbruns) + 1.)
    for k, strgpara in enumerate(liststrgparaconc):
        booltemp = True
        for i in indxruns:
            if not strgpara in liststrgpara[i]:
                booltemp = False
        if not booltemp:
            continue

        figr, axis = plt.subplots(figsize=(5, 4))
        chanlist = []
        for i in indxruns:
            chanlist.append(
                (listobjtalle[i].posterior_params[strgpara] -
                 np.mean(listobjtalle[i].posterior_params[strgpara])) * 24. *
                60.)
        axis.violinplot(chanlist, xpos, showmedians=True, showextrema=False)
        axis.set_xticks(xpos)
        axis.set_xticklabels(ticklabl)
        if strgpara == 'b_period':
            axis.set_ylabel('P [min]')
        else:
            axis.set_ylabel(listlablparaconc[k])
        plt.tight_layout()

        path = pathdataoutp + 'viol_%04d.svg' % (k)
        print('Writing to %s...' % path)
        figr.savefig(path)
        plt.close()

    listyear = [2021, 2023, 2025]
    numbyear = len(listyear)
    indxyear = np.arange(numbyear)
    timejwst = [[[] for i in indxruns] for k in indxyear]
    for k, year in enumerate(listyear):
        epocjwst = astropy.time.Time('%d-01-01 00:00:00' % year,
                                     format='iso').jd
        for i in indxruns:
            epoc = listobjtalle[i].posterior_params['b_epoch']
            peri = listobjtalle[i].posterior_params['b_period']
            indxtran = (epocjwst - epoc) / peri
            indxtran = np.mean(np.rint(indxtran))
            if indxtran.size != np.unique(indxtran).size:
                raise Exception('')

            timejwst[k][i] = epoc + peri * indxtran

            timejwst[k][i] -= np.mean(timejwst[k][i])
            timejwst[k][i] *= 24. * 60.

    listfigr = []
    listaxis = []

    ## temporal evolution
    figr, axis = plt.subplots(figsize=(5, 4))
    listfigr.append(figr)
    listaxis.append(axis)
    axis.violinplot([timejwst[k][1] for k in indxyear], listyear)
    axis.set_xlabel('Year')
    axis.set_ylabel('Transit time residual [min]')
    plt.tight_layout()
    path = pathdataoutp + 'jwsttime.svg'
    print('Writing to %s...' % path)
    plt.savefig(path)
    plt.close()

    ## without/with/only TESS prediction comparison
    figr, axis = plt.subplots(figsize=(5, 4))
    listfigr.append(figr)
    listaxis.append(axis)
    axis.violinplot(timejwst[1], xpos, points=2000)
    axis.set_xticks(xpos)
    axis.set_xticklabels(ticklabl)
    axis.set_ylabel('Transit time residual in 2023 [min]')
    #axis.set_ylim([-300, 300])
    plt.tight_layout()
    path = pathdataoutp + 'jwstcomp.svg'
    print('Writing to %s...' % path)
    plt.savefig(path)
    plt.close()

    return listfigr, listaxis

    # all parameter summary
    figr, axis = plt.subplots(figsize=(4, 3))
    chanlist = []
    axis.violinplot(chanlist, xpos, showmedians=True, showextrema=False)
    axis.set_xticks(valutick)
    axis.set_xticklabels(labltick)
    axis.set_ylabel(lablparatemp)
    plt.tight_layout()
    path = pathdataoutp + 'para_%s.pdf'
    print('Writing to %s...' % path)
    figr.savefig(path)
    plt.close()

    # plot p values
    ## threshold p value to conclude significant difference between posteriors with and without TESS
    if pvalthrs is None:
        pvalthrs = 1e-6

    lablparacomp = [[] for u in indxruns]
    for u in indxruns:

        lablparacomp[u] = list(
            set(lablpara[indxrunsfrst[u]]).intersection(
                lablpara[indxrunsseco[u]]))

        # post-processing
        ## calculate the KS test statistic between the posteriors
        numbparacomp = len(lablparacomp[u])
        pval = np.empty(numbparacomp)
        for j in range(numbparacomp):
            kosm, pval[j] = scipy.stats.ks_2samp([indxrunsfrst[u]][:, j],
                                                 chan[indxrunsseco[u]][:, j])
            kosm, pval[j] = scipy.stats.ks_2samp(chan[indxrunsfrst[u]][:, j],
                                                 chan[indxrunsseco[u]][:, j])

        ## find the list of parameters whose posterior with and without TESS are unlikely to be drawn from the same distribution
        indxparagood = np.where(pval < pvalthrs)[0]
        if indxparagood.size > 0:

            figr, axis = plt.subplots(figsize=(12, 5))
            indxparacomp = np.arange(numbparacomp)
            axis.plot(indxparacomp, pval, ls='', marker='o')
            axis.plot(indxparacomp[indxparagood],
                      pval[indxparagood],
                      ls='',
                      marker='o',
                      color='r')
            axis.set_yscale('log')
            axis.set_xticks(indxparacomp)
            axis.set_xticklabels(lablparacomp[u])
            if u == 0:
                axis.set_title('Posteriors with TESS vs. without TESS')
            if u == 1:
                axis.set_title('Posteriors without TESS vs. only TESS')
            if u == 2:
                axis.set_title('Posteriors with TESS vs. only TESS')

            axis.axhline(pvalthrs, ls='--', color='black', alpha=0.3)
            plt.tight_layout()
            path = pathdataoutp + 'kosm_com%d.pdf' % u
            print('Writing to %s...' % path)
            figr.savefig(path)
            plt.close()
Пример #3
0
chan = [[] for i in indxrtyp]
lablpara = [[] for i in indxrtyp]
numbwalk = [[] for i in indxrtyp]
numbswep = [[] for i in indxrtyp]
numbsamp = [[] for i in indxrtyp]
numbburn = [[] for i in indxrtyp]
factthin = [[] for i in indxrtyp]
numbpara = [[] for i in indxrtyp]

pathdataplan = pathdata + strgplan + '/'

for i, strgrtyp in enumerate(liststrgrtyp):
    pathdata = pathdataplan + 'allesfit_%s/allesfit_%stess_ns/' % (strgkind,
                                                                   strgrtyp)
    print('Reading from %s...' % pathdata)
    config.init(pathdata)
    lablpara[i] = config.BASEMENT.fitlabels

    pathsave = pathdata + 'results/mcmc_save.h5'
    if False and not os.path.exists(pathsave):
        # sample from the posterior excluding the TESS data
        print('Calling allesfitter to fit the data...')
        allesfitter.mcmc_fit(pathdata)

    # read the chain
    ## MCMC
    #emceobjt = emcee.backends.HDFBackend(pathsave, read_only=True)
    #chan[i] = emceobjt.get_chain()
    #lpos[i] = emceobjt.get_log_prob()

    ## Nested sampling
Пример #4
0
def plot_publication_spots_from_posteriors(datadir,
                                           Nsamples=20,
                                           command='save',
                                           mode='default'):
    '''
    command : str
        'show', 'save', 'return', 'show and return', 'save and return'
    mode: str
        default : 5000 points, phase (-0.25,0.75), errorbars
        zhan2019 : 100 points, phase (0,2), no errorbars
    '''

    fig, ax1, ax2, ax3 = setup_grid()

    config.init(datadir)
    posterior_samples = allesfitter.get_ns_posterior_samples(
        datadir, Nsamples=Nsamples, as_type='2d_array')

    for inst in config.BASEMENT.settings['inst_all']:
        if config.BASEMENT.settings['host_N_spots_' + inst] > 0:

            if mode == 'default':
                xx = np.linspace(config.BASEMENT.data[inst]['time'][0],
                                 config.BASEMENT.data[inst]['time'][-1], 5000)
            elif mode == 'zhan2019':
                xx = np.linspace(0, 2, 10000)

            for i_sample, sample in tqdm(enumerate(posterior_samples)):

                params = allesfitter.computer.update_params(sample)

                spots = [[
                    params['host_spot_' + str(i) + '_long_' + inst]
                    for i in range(
                        1, config.BASEMENT.settings['host_N_spots_' + inst] +
                        1)
                ],
                         [
                             params['host_spot_' + str(i) + '_lat_' + inst]
                             for i in range(
                                 1, config.BASEMENT.settings['host_N_spots_' +
                                                             inst] + 1)
                         ],
                         [
                             params['host_spot_' + str(i) + '_size_' + inst]
                             for i in range(
                                 1, config.BASEMENT.settings['host_N_spots_' +
                                                             inst] + 1)
                         ],
                         [
                             params['host_spot_' + str(i) + '_brightness_' +
                                    inst]
                             for i in range(
                                 1, config.BASEMENT.settings['host_N_spots_' +
                                                             inst] + 1)
                         ]]

                model = allesfitter.computer.calculate_model(
                    params, inst, 'flux')
                baseline = allesfitter.computer.calculate_baseline(
                    params, inst, 'flux')

                model_xx = allesfitter.computer.calculate_model(
                    params, inst, 'flux', xx=xx)  #evaluated on xx (!)
                baseline_xx = allesfitter.computer.calculate_baseline(
                    params, inst, 'flux', xx=xx)  #evaluated on xx (!)

                if i_sample == 0:
                    if mode == 'default':
                        ax1 = axplot_data(ax1,
                                          config.BASEMENT.data[inst]['time'],
                                          config.BASEMENT.data[inst]['flux'],
                                          flux_err=np.exp(
                                              params['log_err_flux_' + inst]))
                        ax2 = axplot_residuals(
                            ax2,
                            config.BASEMENT.data[inst]['time'],
                            config.BASEMENT.data[inst]['flux'] - model -
                            baseline,
                            res_err=np.exp(params['log_err_flux_' + inst]))
                        ax3 = axplot_spots_2d(ax3, spots)
                    elif mode == 'zhan2019':
                        ax1 = axplot_data(
                            ax1,
                            np.concatenate(
                                (config.BASEMENT.data[inst]['time'],
                                 config.BASEMENT.data[inst]['time'] + 1,
                                 config.BASEMENT.data[inst]['time'] + 2)),
                            np.concatenate(
                                (config.BASEMENT.data[inst]['flux'],
                                 config.BASEMENT.data[inst]['flux'],
                                 config.BASEMENT.data[inst]['flux'])),
                            flux_err=None)
                        ax2 = axplot_residuals(
                            ax2,
                            np.concatenate(
                                (config.BASEMENT.data[inst]['time'],
                                 config.BASEMENT.data[inst]['time'] + 1,
                                 config.BASEMENT.data[inst]['time'] + 2)),
                            np.concatenate(
                                (config.BASEMENT.data[inst]['flux'] - model -
                                 baseline, config.BASEMENT.data[inst]['flux'] -
                                 model - baseline,
                                 config.BASEMENT.data[inst]['flux'] - model -
                                 baseline)),
                            res_err=None)
                        ax3 = axplot_spots_2d(ax3, spots)

                ax1 = axplot_model(ax1, xx, model_xx + baseline_xx)

            ax1.locator_params(axis='y', nbins=5)

            if mode == 'zhan2019':
                ax1.set(xlim=[0, 2])
                ax2.set(xlim=[0, 2])

            if 'save' in command:
                pubdir = os.path.join(config.BASEMENT.outdir, 'pub')
                if not os.path.exists(pubdir): os.makedirs(pubdir)
                if mode == 'default':
                    fig.savefig(os.path.join(pubdir,
                                             'host_spots_' + inst + '.pdf'),
                                bbox_inches='tight')
                elif mode == 'zhan2019':
                    fig.savefig(os.path.join(pubdir,
                                             'host_spots_' + inst + '_zz.pdf'),
                                bbox_inches='tight')
                plt.close(fig)

            if 'show' in command:
                plt.show()

            if 'return' in command:
                return fig, ax1, ax2, ax3
Пример #5
0
def plot_spots_from_posteriors(datadir, Nsamples=10, command='return'):

    if command == 'show':
        Nsamples = 1  #overwrite user input and only show 1 sample if command=='show'

    config.init(datadir)
    posterior_samples_dic = allesfitter.get_ns_posterior_samples(
        datadir, Nsamples=Nsamples)

    for sample in tqdm(range(Nsamples)):

        params = {}
        for key in posterior_samples_dic:
            params[key] = posterior_samples_dic[key][sample]

        for inst in config.BASEMENT.settings['inst_all']:

            if config.BASEMENT.settings['host_N_spots_' + inst] > 0:
                spots = [[
                    params['host_spot_' + str(i) + '_long_' + inst]
                    for i in range(
                        1, config.BASEMENT.settings['host_N_spots_' + inst] +
                        1)
                ],
                         [
                             params['host_spot_' + str(i) + '_lat_' + inst]
                             for i in range(
                                 1, config.BASEMENT.settings['host_N_spots_' +
                                                             inst] + 1)
                         ],
                         [
                             params['host_spot_' + str(i) + '_size_' + inst]
                             for i in range(
                                 1, config.BASEMENT.settings['host_N_spots_' +
                                                             inst] + 1)
                         ],
                         [
                             params['host_spot_' + str(i) + '_brightness_' +
                                    inst]
                             for i in range(
                                 1, config.BASEMENT.settings['host_N_spots_' +
                                                             inst] + 1)
                         ]]

                if command == 'return':
                    fig, ax, ax2 = plot_spots(spots, command='return')
                    plt.suptitle('sample ' + str(sample))
                    spotsdir = os.path.join(config.BASEMENT.outdir, 'spotmaps')
                    if not os.path.exists(spotsdir): os.makedirs(spotsdir)
                    fig.savefig(
                        os.path.join(
                            spotsdir, 'host_spots_' + inst +
                            '_posterior_sample_' + str(sample)))
                    plt.close(fig)

                elif command == 'show':
                    plot_spots(spots, command='show')

        for companion in config.BASEMENT.settings['companions_all']:
            for inst in config.BASEMENT.settings['inst_all']:
                if config.BASEMENT.settings[companion + '_N_spots_' +
                                            inst] > 0:
                    spots = [[
                        params[companion + '_spot_' + str(i) + '_long_' + inst]
                        for i in range(
                            1, config.BASEMENT.settings[companion +
                                                        '_N_spots_' + inst] +
                            1)
                    ],
                             [
                                 params[companion + '_spot_' + str(i) +
                                        '_lat_' + inst]
                                 for i in range(
                                     1, config.BASEMENT.settings[companion +
                                                                 '_N_spots_' +
                                                                 inst] + 1)
                             ],
                             [
                                 params[companion + '_spot_' + str(i) +
                                        '_size_' + inst]
                                 for i in range(
                                     1, config.BASEMENT.settings[companion +
                                                                 '_N_spots_' +
                                                                 inst] + 1)
                             ],
                             [
                                 params[companion + '_spot_' + str(i) +
                                        '_brightness_' + inst]
                                 for i in range(
                                     1, config.BASEMENT.settings[companion +
                                                                 '_N_spots_' +
                                                                 inst] + 1)
                             ]]

                    if command == 'return':
                        fig, ax, ax2 = plot_spots(spots, command='return')
                        plt.suptitle('sample ' + str(sample))
                        spotsdir = os.path.join(config.BASEMENT.outdir,
                                                'spotmaps')
                        if not os.path.exists(spotsdir): os.makedirs(spotsdir)
                        fig.savefig(
                            os.path.join(
                                spotsdir, companion + '_spots_' + inst +
                                '_posterior_sample_' + str(sample)))
                        plt.close(fig)

                    elif command == 'show':
                        plot_spots(spots, command='show')
Пример #6
0
def plot_viol(pathbase, liststrgstar, liststrgruns, lablstrgruns, pathimag, pvalthrs=1e-3):

    strgtimestmp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    print('allesfitter postprocessing violin plot started at %s...' % strgtimestmp)
    
    # construct global object
    gdat = gdatstrt()
    
    # copy unnamed inputs to the global object
    #for attr, valu in locals().iter():
    for attr, valu in locals().items():
        if '__' not in attr and attr != 'gdat':
            setattr(gdat, attr, valu)

    # runs to be compared for each star
    gdat.numbruns = len(liststrgruns)
    gdat.indxruns = np.arange(gdat.numbruns)
    
    gdat.pathimag = pathimag
    gdat.liststrgstar = liststrgstar

    # stars
    numbstar = len(liststrgstar)
    gdat.indxstar = np.arange(numbstar)

    # plotting
    gdat.strgplotextn = 'png'

    # read parameter keys, labels and posterior from allesfitter output
    liststrgpara = [[] for i in gdat.indxruns]
    listlablpara = [[] for i in gdat.indxruns]
    gdat.listobjtalle = [[[] for m in gdat.indxstar] for i in gdat.indxruns]
    for i in gdat.indxruns:
        for m in gdat.indxstar:
            pathalle = pathbase + '%s/allesfits/allesfit_%s/' % (gdat.liststrgstar[m], gdat.liststrgruns[i])
            print('Reading from %s...' % pathalle)
            config.init(pathalle)
            liststrgpara[i] = np.array(config.BASEMENT.fitkeys)
            listlablpara[i] = np.array(config.BASEMENT.fitlabels)
            # read the chain
            print('pathalle')
            print(pathalle)
            gdat.listobjtalle[i][m] = allesfitter.allesclass(pathalle)
    
    # concatenate the keys, labels from different runs
    gdat.liststrgparaconc = np.concatenate(liststrgpara)
    gdat.liststrgparaconc = np.unique(gdat.liststrgparaconc)
    gdat.listlablparaconc = np.copy(gdat.liststrgparaconc)
    for k, strgparaconc in enumerate(gdat.liststrgparaconc):
        for i, strgruns in enumerate(liststrgruns):
            if strgparaconc in liststrgpara[i]:
                gdat.listlablparaconc[k] = listlablpara[i][np.where(liststrgpara[i] == strgparaconc)[0][0]]
    
    gdat.numbparaconc = len(gdat.liststrgparaconc)
    gdat.indxparaconc = np.arange(gdat.numbparaconc)
    for k, strgpara in enumerate(gdat.liststrgparaconc):
        booltemp = True
        for i in gdat.indxruns:
            if not strgpara in liststrgpara[i]:
                booltemp = False
        if not booltemp:
            continue
        
        ## violin plot
        ## mid-transit time prediction
        plot(gdat, gdat.indxstar, indxpara=np.array([k]), strgtype='paracomp')
        ## per-star 
        #for m in gdat.indxstar:
        #    plot(gdat, indxstar=np.array([m]), indxpara=k, strgtype='paracomp')
        
    # calculate the future evolution of epoch
    gdat.listyear = [2021, 2023, 2025]
    numbyear = len(gdat.listyear)
    gdat.indxyear = np.arange(numbyear)
    gdat.timejwst = [[[[] for m in gdat.indxstar] for i in gdat.indxruns] for k in gdat.indxyear]
    for k, year in enumerate(gdat.listyear):
        epocjwst = astropy.time.Time('%d-01-01 00:00:00' % year, format='iso').jd
        for i in gdat.indxruns:
            for m in gdat.indxstar:
                epoc = gdat.listobjtalle[i][m].posterior_params['b_epoch']
                peri = gdat.listobjtalle[i][m].posterior_params['b_period']
                indxtran = (epocjwst - epoc) / peri
                indxtran = np.mean(np.rint(indxtran))
                if indxtran.size != np.unique(indxtran).size:
                    raise Exception('')
                gdat.timejwst[k][i][m] = epoc + peri * indxtran
                gdat.timejwst[k][i][m] -= np.mean(gdat.timejwst[k][i][m])
                gdat.timejwst[k][i][m] *= 24. * 60.
    
    listfigr = []
    listaxis = []

    # temporal evolution of mid-transit time prediction
    plot(gdat, gdat.indxstar, strgtype='epocevol')
    ## per-star 
    #for m in gdat.indxstar:
    #    plot(gdat, indxstar=np.array([m]), strgtype='epocevol')
    
    ## mid-transit time prediction
    plot(gdat, gdat.indxstar, strgtype='jwstcomp')
    ## per-star 
    #for m in gdat.indxstar:
    #    plot(gdat, indxstar=np.array([m]), strgtype='jwstcomp')
    
    return listfigr, listaxis