import os, sys import pickle from dynesty import utils as dyutils from allesfitter import config from astropy.time import Time import matplotlib.pyplot as plt # base data path for the TESS known planet project pathdata = os.environ['KNWN_DATA_PATH'] + '/' pathdatapost = os.environ['KNWN_DATA_PATH'] + '/postproc/' # name of the exoplanet strgplan = sys.argv[1] pathdataplan = pathdata + '%s/allesfit_global/allesfit_onlytess_full/' % strgplan config.init(pathdataplan) fileobjt = open(pathdataplan + 'results/save_ns.pickle', 'rb') objtrest = pickle.load(fileobjt) weig = np.exp(objtrest['logwt'] - objtrest['logz'][-1]) chan = dyutils.resample_equal(objtrest.samples, weig) listkeys = config.BASEMENT.fitkeys listlabl = config.BASEMENT.fitlabels # get period and epoch posteriors for k, labl in enumerate(listlabl): if listkeys[k] == 'b_epoch': postepoc = chan[:, k] if listkeys[k] == 'b_period': postperi = chan[:, k]
def plot_viol(pathdataoutp, pvalthrs=1e-3, boolonlytess=False): strgtimestmp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') print('allesfitter postprocessing violin plot started at %s...' % strgtimestmp) liststrgruns = ['woutTESS', 'alldata'] if boolonlytess: liststrgruns.append(['TESS']) numbruns = len(liststrgruns) indxruns = np.arange(numbruns) liststrgpara = [[] for i in indxruns] listlablpara = [[] for i in indxruns] listobjtalle = [[] for i in indxruns] for i, strgruns in enumerate(liststrgruns): pathdata = pathdataoutp + 'allesfits/allesfit_%s' % strgruns print('Reading from %s...' % pathdata) config.init(pathdata) liststrgpara[i] = np.array(config.BASEMENT.fitkeys) listlablpara[i] = np.array(config.BASEMENT.fitlabels) # read the chain listobjtalle[i] = allesfitter.allesclass(pathdata) liststrgparaconc = np.concatenate(liststrgpara) liststrgparaconc = np.unique(liststrgparaconc) listlablparaconc = np.copy(liststrgparaconc) for k, strgparaconc in enumerate(liststrgparaconc): for i, strgruns in enumerate(liststrgruns): if strgparaconc in liststrgpara[i]: listlablparaconc[k] = listlablpara[i][np.where( liststrgpara[i] == strgparaconc)[0][0]] ticklabl = ['w/o TESS', 'w/ TESS'] if boolonlytess: ticklabl.append(['only TESS']) xpos = 0.6 * (np.arange(numbruns) + 1.) for k, strgpara in enumerate(liststrgparaconc): booltemp = True for i in indxruns: if not strgpara in liststrgpara[i]: booltemp = False if not booltemp: continue figr, axis = plt.subplots(figsize=(5, 4)) chanlist = [] for i in indxruns: chanlist.append( (listobjtalle[i].posterior_params[strgpara] - np.mean(listobjtalle[i].posterior_params[strgpara])) * 24. * 60.) axis.violinplot(chanlist, xpos, showmedians=True, showextrema=False) axis.set_xticks(xpos) axis.set_xticklabels(ticklabl) if strgpara == 'b_period': axis.set_ylabel('P [min]') else: axis.set_ylabel(listlablparaconc[k]) plt.tight_layout() path = pathdataoutp + 'viol_%04d.svg' % (k) print('Writing to %s...' % path) figr.savefig(path) plt.close() listyear = [2021, 2023, 2025] numbyear = len(listyear) indxyear = np.arange(numbyear) timejwst = [[[] for i in indxruns] for k in indxyear] for k, year in enumerate(listyear): epocjwst = astropy.time.Time('%d-01-01 00:00:00' % year, format='iso').jd for i in indxruns: epoc = listobjtalle[i].posterior_params['b_epoch'] peri = listobjtalle[i].posterior_params['b_period'] indxtran = (epocjwst - epoc) / peri indxtran = np.mean(np.rint(indxtran)) if indxtran.size != np.unique(indxtran).size: raise Exception('') timejwst[k][i] = epoc + peri * indxtran timejwst[k][i] -= np.mean(timejwst[k][i]) timejwst[k][i] *= 24. * 60. listfigr = [] listaxis = [] ## temporal evolution figr, axis = plt.subplots(figsize=(5, 4)) listfigr.append(figr) listaxis.append(axis) axis.violinplot([timejwst[k][1] for k in indxyear], listyear) axis.set_xlabel('Year') axis.set_ylabel('Transit time residual [min]') plt.tight_layout() path = pathdataoutp + 'jwsttime.svg' print('Writing to %s...' % path) plt.savefig(path) plt.close() ## without/with/only TESS prediction comparison figr, axis = plt.subplots(figsize=(5, 4)) listfigr.append(figr) listaxis.append(axis) axis.violinplot(timejwst[1], xpos, points=2000) axis.set_xticks(xpos) axis.set_xticklabels(ticklabl) axis.set_ylabel('Transit time residual in 2023 [min]') #axis.set_ylim([-300, 300]) plt.tight_layout() path = pathdataoutp + 'jwstcomp.svg' print('Writing to %s...' % path) plt.savefig(path) plt.close() return listfigr, listaxis # all parameter summary figr, axis = plt.subplots(figsize=(4, 3)) chanlist = [] axis.violinplot(chanlist, xpos, showmedians=True, showextrema=False) axis.set_xticks(valutick) axis.set_xticklabels(labltick) axis.set_ylabel(lablparatemp) plt.tight_layout() path = pathdataoutp + 'para_%s.pdf' print('Writing to %s...' % path) figr.savefig(path) plt.close() # plot p values ## threshold p value to conclude significant difference between posteriors with and without TESS if pvalthrs is None: pvalthrs = 1e-6 lablparacomp = [[] for u in indxruns] for u in indxruns: lablparacomp[u] = list( set(lablpara[indxrunsfrst[u]]).intersection( lablpara[indxrunsseco[u]])) # post-processing ## calculate the KS test statistic between the posteriors numbparacomp = len(lablparacomp[u]) pval = np.empty(numbparacomp) for j in range(numbparacomp): kosm, pval[j] = scipy.stats.ks_2samp([indxrunsfrst[u]][:, j], chan[indxrunsseco[u]][:, j]) kosm, pval[j] = scipy.stats.ks_2samp(chan[indxrunsfrst[u]][:, j], chan[indxrunsseco[u]][:, j]) ## find the list of parameters whose posterior with and without TESS are unlikely to be drawn from the same distribution indxparagood = np.where(pval < pvalthrs)[0] if indxparagood.size > 0: figr, axis = plt.subplots(figsize=(12, 5)) indxparacomp = np.arange(numbparacomp) axis.plot(indxparacomp, pval, ls='', marker='o') axis.plot(indxparacomp[indxparagood], pval[indxparagood], ls='', marker='o', color='r') axis.set_yscale('log') axis.set_xticks(indxparacomp) axis.set_xticklabels(lablparacomp[u]) if u == 0: axis.set_title('Posteriors with TESS vs. without TESS') if u == 1: axis.set_title('Posteriors without TESS vs. only TESS') if u == 2: axis.set_title('Posteriors with TESS vs. only TESS') axis.axhline(pvalthrs, ls='--', color='black', alpha=0.3) plt.tight_layout() path = pathdataoutp + 'kosm_com%d.pdf' % u print('Writing to %s...' % path) figr.savefig(path) plt.close()
chan = [[] for i in indxrtyp] lablpara = [[] for i in indxrtyp] numbwalk = [[] for i in indxrtyp] numbswep = [[] for i in indxrtyp] numbsamp = [[] for i in indxrtyp] numbburn = [[] for i in indxrtyp] factthin = [[] for i in indxrtyp] numbpara = [[] for i in indxrtyp] pathdataplan = pathdata + strgplan + '/' for i, strgrtyp in enumerate(liststrgrtyp): pathdata = pathdataplan + 'allesfit_%s/allesfit_%stess_ns/' % (strgkind, strgrtyp) print('Reading from %s...' % pathdata) config.init(pathdata) lablpara[i] = config.BASEMENT.fitlabels pathsave = pathdata + 'results/mcmc_save.h5' if False and not os.path.exists(pathsave): # sample from the posterior excluding the TESS data print('Calling allesfitter to fit the data...') allesfitter.mcmc_fit(pathdata) # read the chain ## MCMC #emceobjt = emcee.backends.HDFBackend(pathsave, read_only=True) #chan[i] = emceobjt.get_chain() #lpos[i] = emceobjt.get_log_prob() ## Nested sampling
def plot_publication_spots_from_posteriors(datadir, Nsamples=20, command='save', mode='default'): ''' command : str 'show', 'save', 'return', 'show and return', 'save and return' mode: str default : 5000 points, phase (-0.25,0.75), errorbars zhan2019 : 100 points, phase (0,2), no errorbars ''' fig, ax1, ax2, ax3 = setup_grid() config.init(datadir) posterior_samples = allesfitter.get_ns_posterior_samples( datadir, Nsamples=Nsamples, as_type='2d_array') for inst in config.BASEMENT.settings['inst_all']: if config.BASEMENT.settings['host_N_spots_' + inst] > 0: if mode == 'default': xx = np.linspace(config.BASEMENT.data[inst]['time'][0], config.BASEMENT.data[inst]['time'][-1], 5000) elif mode == 'zhan2019': xx = np.linspace(0, 2, 10000) for i_sample, sample in tqdm(enumerate(posterior_samples)): params = allesfitter.computer.update_params(sample) spots = [[ params['host_spot_' + str(i) + '_long_' + inst] for i in range( 1, config.BASEMENT.settings['host_N_spots_' + inst] + 1) ], [ params['host_spot_' + str(i) + '_lat_' + inst] for i in range( 1, config.BASEMENT.settings['host_N_spots_' + inst] + 1) ], [ params['host_spot_' + str(i) + '_size_' + inst] for i in range( 1, config.BASEMENT.settings['host_N_spots_' + inst] + 1) ], [ params['host_spot_' + str(i) + '_brightness_' + inst] for i in range( 1, config.BASEMENT.settings['host_N_spots_' + inst] + 1) ]] model = allesfitter.computer.calculate_model( params, inst, 'flux') baseline = allesfitter.computer.calculate_baseline( params, inst, 'flux') model_xx = allesfitter.computer.calculate_model( params, inst, 'flux', xx=xx) #evaluated on xx (!) baseline_xx = allesfitter.computer.calculate_baseline( params, inst, 'flux', xx=xx) #evaluated on xx (!) if i_sample == 0: if mode == 'default': ax1 = axplot_data(ax1, config.BASEMENT.data[inst]['time'], config.BASEMENT.data[inst]['flux'], flux_err=np.exp( params['log_err_flux_' + inst])) ax2 = axplot_residuals( ax2, config.BASEMENT.data[inst]['time'], config.BASEMENT.data[inst]['flux'] - model - baseline, res_err=np.exp(params['log_err_flux_' + inst])) ax3 = axplot_spots_2d(ax3, spots) elif mode == 'zhan2019': ax1 = axplot_data( ax1, np.concatenate( (config.BASEMENT.data[inst]['time'], config.BASEMENT.data[inst]['time'] + 1, config.BASEMENT.data[inst]['time'] + 2)), np.concatenate( (config.BASEMENT.data[inst]['flux'], config.BASEMENT.data[inst]['flux'], config.BASEMENT.data[inst]['flux'])), flux_err=None) ax2 = axplot_residuals( ax2, np.concatenate( (config.BASEMENT.data[inst]['time'], config.BASEMENT.data[inst]['time'] + 1, config.BASEMENT.data[inst]['time'] + 2)), np.concatenate( (config.BASEMENT.data[inst]['flux'] - model - baseline, config.BASEMENT.data[inst]['flux'] - model - baseline, config.BASEMENT.data[inst]['flux'] - model - baseline)), res_err=None) ax3 = axplot_spots_2d(ax3, spots) ax1 = axplot_model(ax1, xx, model_xx + baseline_xx) ax1.locator_params(axis='y', nbins=5) if mode == 'zhan2019': ax1.set(xlim=[0, 2]) ax2.set(xlim=[0, 2]) if 'save' in command: pubdir = os.path.join(config.BASEMENT.outdir, 'pub') if not os.path.exists(pubdir): os.makedirs(pubdir) if mode == 'default': fig.savefig(os.path.join(pubdir, 'host_spots_' + inst + '.pdf'), bbox_inches='tight') elif mode == 'zhan2019': fig.savefig(os.path.join(pubdir, 'host_spots_' + inst + '_zz.pdf'), bbox_inches='tight') plt.close(fig) if 'show' in command: plt.show() if 'return' in command: return fig, ax1, ax2, ax3
def plot_spots_from_posteriors(datadir, Nsamples=10, command='return'): if command == 'show': Nsamples = 1 #overwrite user input and only show 1 sample if command=='show' config.init(datadir) posterior_samples_dic = allesfitter.get_ns_posterior_samples( datadir, Nsamples=Nsamples) for sample in tqdm(range(Nsamples)): params = {} for key in posterior_samples_dic: params[key] = posterior_samples_dic[key][sample] for inst in config.BASEMENT.settings['inst_all']: if config.BASEMENT.settings['host_N_spots_' + inst] > 0: spots = [[ params['host_spot_' + str(i) + '_long_' + inst] for i in range( 1, config.BASEMENT.settings['host_N_spots_' + inst] + 1) ], [ params['host_spot_' + str(i) + '_lat_' + inst] for i in range( 1, config.BASEMENT.settings['host_N_spots_' + inst] + 1) ], [ params['host_spot_' + str(i) + '_size_' + inst] for i in range( 1, config.BASEMENT.settings['host_N_spots_' + inst] + 1) ], [ params['host_spot_' + str(i) + '_brightness_' + inst] for i in range( 1, config.BASEMENT.settings['host_N_spots_' + inst] + 1) ]] if command == 'return': fig, ax, ax2 = plot_spots(spots, command='return') plt.suptitle('sample ' + str(sample)) spotsdir = os.path.join(config.BASEMENT.outdir, 'spotmaps') if not os.path.exists(spotsdir): os.makedirs(spotsdir) fig.savefig( os.path.join( spotsdir, 'host_spots_' + inst + '_posterior_sample_' + str(sample))) plt.close(fig) elif command == 'show': plot_spots(spots, command='show') for companion in config.BASEMENT.settings['companions_all']: for inst in config.BASEMENT.settings['inst_all']: if config.BASEMENT.settings[companion + '_N_spots_' + inst] > 0: spots = [[ params[companion + '_spot_' + str(i) + '_long_' + inst] for i in range( 1, config.BASEMENT.settings[companion + '_N_spots_' + inst] + 1) ], [ params[companion + '_spot_' + str(i) + '_lat_' + inst] for i in range( 1, config.BASEMENT.settings[companion + '_N_spots_' + inst] + 1) ], [ params[companion + '_spot_' + str(i) + '_size_' + inst] for i in range( 1, config.BASEMENT.settings[companion + '_N_spots_' + inst] + 1) ], [ params[companion + '_spot_' + str(i) + '_brightness_' + inst] for i in range( 1, config.BASEMENT.settings[companion + '_N_spots_' + inst] + 1) ]] if command == 'return': fig, ax, ax2 = plot_spots(spots, command='return') plt.suptitle('sample ' + str(sample)) spotsdir = os.path.join(config.BASEMENT.outdir, 'spotmaps') if not os.path.exists(spotsdir): os.makedirs(spotsdir) fig.savefig( os.path.join( spotsdir, companion + '_spots_' + inst + '_posterior_sample_' + str(sample))) plt.close(fig) elif command == 'show': plot_spots(spots, command='show')
def plot_viol(pathbase, liststrgstar, liststrgruns, lablstrgruns, pathimag, pvalthrs=1e-3): strgtimestmp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') print('allesfitter postprocessing violin plot started at %s...' % strgtimestmp) # construct global object gdat = gdatstrt() # copy unnamed inputs to the global object #for attr, valu in locals().iter(): for attr, valu in locals().items(): if '__' not in attr and attr != 'gdat': setattr(gdat, attr, valu) # runs to be compared for each star gdat.numbruns = len(liststrgruns) gdat.indxruns = np.arange(gdat.numbruns) gdat.pathimag = pathimag gdat.liststrgstar = liststrgstar # stars numbstar = len(liststrgstar) gdat.indxstar = np.arange(numbstar) # plotting gdat.strgplotextn = 'png' # read parameter keys, labels and posterior from allesfitter output liststrgpara = [[] for i in gdat.indxruns] listlablpara = [[] for i in gdat.indxruns] gdat.listobjtalle = [[[] for m in gdat.indxstar] for i in gdat.indxruns] for i in gdat.indxruns: for m in gdat.indxstar: pathalle = pathbase + '%s/allesfits/allesfit_%s/' % (gdat.liststrgstar[m], gdat.liststrgruns[i]) print('Reading from %s...' % pathalle) config.init(pathalle) liststrgpara[i] = np.array(config.BASEMENT.fitkeys) listlablpara[i] = np.array(config.BASEMENT.fitlabels) # read the chain print('pathalle') print(pathalle) gdat.listobjtalle[i][m] = allesfitter.allesclass(pathalle) # concatenate the keys, labels from different runs gdat.liststrgparaconc = np.concatenate(liststrgpara) gdat.liststrgparaconc = np.unique(gdat.liststrgparaconc) gdat.listlablparaconc = np.copy(gdat.liststrgparaconc) for k, strgparaconc in enumerate(gdat.liststrgparaconc): for i, strgruns in enumerate(liststrgruns): if strgparaconc in liststrgpara[i]: gdat.listlablparaconc[k] = listlablpara[i][np.where(liststrgpara[i] == strgparaconc)[0][0]] gdat.numbparaconc = len(gdat.liststrgparaconc) gdat.indxparaconc = np.arange(gdat.numbparaconc) for k, strgpara in enumerate(gdat.liststrgparaconc): booltemp = True for i in gdat.indxruns: if not strgpara in liststrgpara[i]: booltemp = False if not booltemp: continue ## violin plot ## mid-transit time prediction plot(gdat, gdat.indxstar, indxpara=np.array([k]), strgtype='paracomp') ## per-star #for m in gdat.indxstar: # plot(gdat, indxstar=np.array([m]), indxpara=k, strgtype='paracomp') # calculate the future evolution of epoch gdat.listyear = [2021, 2023, 2025] numbyear = len(gdat.listyear) gdat.indxyear = np.arange(numbyear) gdat.timejwst = [[[[] for m in gdat.indxstar] for i in gdat.indxruns] for k in gdat.indxyear] for k, year in enumerate(gdat.listyear): epocjwst = astropy.time.Time('%d-01-01 00:00:00' % year, format='iso').jd for i in gdat.indxruns: for m in gdat.indxstar: epoc = gdat.listobjtalle[i][m].posterior_params['b_epoch'] peri = gdat.listobjtalle[i][m].posterior_params['b_period'] indxtran = (epocjwst - epoc) / peri indxtran = np.mean(np.rint(indxtran)) if indxtran.size != np.unique(indxtran).size: raise Exception('') gdat.timejwst[k][i][m] = epoc + peri * indxtran gdat.timejwst[k][i][m] -= np.mean(gdat.timejwst[k][i][m]) gdat.timejwst[k][i][m] *= 24. * 60. listfigr = [] listaxis = [] # temporal evolution of mid-transit time prediction plot(gdat, gdat.indxstar, strgtype='epocevol') ## per-star #for m in gdat.indxstar: # plot(gdat, indxstar=np.array([m]), strgtype='epocevol') ## mid-transit time prediction plot(gdat, gdat.indxstar, strgtype='jwstcomp') ## per-star #for m in gdat.indxstar: # plot(gdat, indxstar=np.array([m]), strgtype='jwstcomp') return listfigr, listaxis