Exemple #1
0
def plots(args):
    """
    Generate plots

    Args:
        args (ArgumentParser): command line arguments
    """

    config_file = args.setupfn
    conf_base = os.path.basename(config_file).split('.')[0]
    statfile = os.path.join(
        args.outputdir, "{}_radvel.stat".format(conf_base)
    )

    status = load_status(statfile)

    P, post = radvel.utils.initialize_posterior(config_file,
                                                decorr=args.decorr)

    assert status.getboolean('fit', 'run'), \
        "Must perform max-liklihood fit before plotting"
    post = radvel.posterior.load(status.get('fit', 'postfile'))

    for ptype in args.type:
        print("Creating {} plot for {}".format(ptype, conf_base))

        if ptype == 'rv':
            args.plotkw['uparams'] = post.uparams
            args.plotkw['status'] = status
            if 'saveplot' not in args.plotkw:
                saveto = os.path.join(
                    args.outputdir, conf_base+'_rv_multipanel.pdf'
                )
            else:
                saveto = args.plotkw['saveplot']
                args.plotkw.pop('saveplot')
            P, _ = radvel.utils.initialize_posterior(config_file)
            if hasattr(P, 'bjd0'):
                args.plotkw['epoch'] = P.bjd0

            if args.gp:
                GPPlot = orbit_plots.GPMultipanelPlot(
                    post, saveplot=saveto, **args.plotkw
                )
                GPPlot.plot_multipanel()
            else:
                RVPlot = orbit_plots.MultipanelPlot(
                    post, saveplot=saveto, **args.plotkw
                )
                RVPlot.plot_multipanel()

                # check to make sure that Posterior is not GP, print warning if it is
                if isinstance(post.likelihood, radvel.likelihood.CompositeLikelihood):
                    like_list = post.likelihood.like_list
                else:
                    like_list = [post.likelihood]
                for like in like_list:
                    if isinstance(like, radvel.likelihood.GPLikelihood):
                        print("WARNING: GP Likelihood(s) detected. \
You may want to use the '--gp' flag when making these plots.")
                        break

        if ptype == 'corner' or ptype == 'auto' or ptype == 'trend':
            assert status.getboolean('mcmc', 'run'), \
                "Must run MCMC before making corner, auto, or trend plots"

            chains = pd.read_csv(status.get('mcmc', 'chainfile'))
            autocorr = pd.read_csv(status.get('mcmc', 'autocorrfile'))

        if ptype == 'auto':
            saveto = os.path.join(args.outputdir, conf_base+'_auto.pdf')
            Auto = mcmc_plots.AutoPlot(autocorr, saveplot=saveto)
            Auto.plot()

        if ptype == 'corner':
            saveto = os.path.join(args.outputdir, conf_base+'_corner.pdf')
            Corner = mcmc_plots.CornerPlot(post, chains, saveplot=saveto)
            Corner.plot()

        if ptype == 'trend':
            nwalkers = status.getint('mcmc', 'nwalkers')
            nensembles = status.getint('mcmc', 'nensembles')

            saveto = os.path.join(args.outputdir, conf_base+'_trends.pdf')
            Trend = mcmc_plots.TrendPlot(post, chains, nwalkers, nensembles, saveto)
            Trend.plot()

        if ptype == 'derived':
            assert status.has_section('derive'), \
                "Must run `radvel derive` before plotting derived parameters"

            P, _ = radvel.utils.initialize_posterior(config_file)
            chains = pd.read_csv(status.get('derive', 'chainfile'))
            saveto = os.path.join(
                args.outputdir, conf_base+'_corner_derived_pars.pdf'
            )

            Derived = mcmc_plots.DerivedPlot(chains, P, saveplot=saveto)
            Derived.plot()

        savestate = {'{}_plot'.format(ptype): os.path.relpath(saveto)}
        save_status(statfile, 'plot', savestate)
Exemple #2
0
    def plot_phasefold(self, pltletter, pnum):
        """
        Plot phased orbit plots for each planet in the fit.

        Args:
            pltletter (int): integer representation of 
                letter to be printed in the corner of the first
                phase plot.
                Ex: ord("a") gives 97, so the input should be 97.
            pnum (int): the number of the planet to be plotted. Must be
                the same as the number used to define a planet's 
                Parameter objects (e.g. 'per1' is for planet #1)

        """

        ax = pl.gca()

        if len(self.post.likelihood.x) < 20:
            self.nobin = True

        bin_fac = 1.75
        bin_markersize = bin_fac * rcParams['lines.markersize']
        bin_markeredgewidth = bin_fac * rcParams['lines.markeredgewidth']

        rvmod2 = self.model(self.rvmodt, planet_num=pnum) - self.slope
        modph = t_to_phase(self.post.params, self.rvmodt, pnum, cat=True) - 1
        rvdat = self.rawresid + self.model(self.rvtimes,
                                           planet_num=pnum) - self.slope_low
        phase = t_to_phase(self.post.params, self.rvtimes, pnum, cat=True) - 1
        rvdatcat = np.concatenate((rvdat, rvdat))
        rverrcat = np.concatenate((self.rverr, self.rverr))
        rvmod2cat = np.concatenate((rvmod2, rvmod2))
        bint, bindat, binerr = fastbin(phase + 1, rvdatcat, nbins=25)
        bint -= 1.0

        ax.axhline(
            0,
            color='0.5',
            linestyle='--',
        )
        ax.plot(sorted(modph),
                rvmod2cat[np.argsort(modph)],
                'b-',
                linewidth=self.fit_linewidth)
        plot.labelfig(pltletter)

        telcat = np.concatenate(
            (self.post.likelihood.telvec, self.post.likelihood.telvec))

        if self.highlight_last:
            ind = np.argmax(self.rvtimes)
            hphase = t_to_phase(self.post.params,
                                self.rvtimes[ind],
                                pnum,
                                cat=False)
            if hphase > 0.5:
                hphase -= 1
            pl.plot(hphase, rvdatcat[ind], **plot.highlight_format)

        plot.mtelplot(phase,
                      rvdatcat,
                      rverrcat,
                      telcat,
                      ax,
                      telfmts=self.telfmts)
        if not self.nobin and len(rvdat) > 10:
            ax.errorbar(bint,
                        bindat,
                        yerr=binerr,
                        fmt='ro',
                        mec='w',
                        ms=bin_markersize,
                        mew=bin_markeredgewidth)

        if self.phase_limits:
            ax.set_xlim(self.phase_limits[0], self.phase_limits[1])
        else:
            ax.set_xlim(-0.5, 0.5)

        if not self.yscale_auto:
            scale = np.std(rvdatcat)
            ax.set_ylim(-self.yscale_sigma * scale, self.yscale_sigma * scale)

        keys = [p + str(pnum) for p in ['per', 'k', 'e']]

        labels = [self.post.params.tex_labels().get(k, k) for k in keys]
        if pnum < self.num_planets:
            ticks = ax.yaxis.get_majorticklocs()
            ax.yaxis.set_ticks(ticks[1:-1])

        ax.set_ylabel('RV [{ms:}]'.format(**plot.latex), weight='bold')
        ax.set_xlabel('Phase', weight='bold')

        print_params = ['per', 'k', 'e']
        units = {'per': 'days', 'k': plot.latex['ms'], 'e': ''}

        anotext = []
        for l, p in enumerate(print_params):
            val = self.post.params["%s%d" % (print_params[l], pnum)].value

            if self.uparams is None:
                _anotext = r'$\mathregular{%s}$ = %4.2f %s' % (
                    labels[l].replace("$", ""), val, units[p])
            else:
                if hasattr(self.post, 'medparams'):
                    val = self.post.medparams["%s%d" % (print_params[l], pnum)]
                else:
                    print("WARNING: medparams attribute not found in " +
                          "posterior object will annotate with " +
                          "max-likelihood values and reported uncertainties " +
                          "may not be appropriate.")
                err = self.uparams["%s%d" % (print_params[l], pnum)]
                if err > 1e-15:
                    val, err, errlow = sigfig(val, err)
                    _anotext = r'$\mathregular{%s}$ = %s $\mathregular{\pm}$ %s %s' \
                               % (labels[l].replace("$", ""), val, err, units[p])
                else:
                    _anotext = r'$\mathregular{%s}$ = %4.2f %s' % (
                        labels[l].replace("$", ""), val, units[p])

            anotext += [_anotext]

        if hasattr(self.post, 'derived'):
            chains = pd.read_csv(self.status['derive']['chainfile'])
            self.post.nplanets = self.num_planets
            dp = mcmc_plots.DerivedPlot(chains, self.post)
            labels = dp.labels
            texlabels = dp.texlabels
            units = dp.units
            derived_params = ['mpsini']
            for l, par in enumerate(derived_params):
                par_label = par + str(pnum)
                if par_label in self.post.derived.columns:
                    index = np.where(np.array(labels) == par_label)[0][0]

                    unit = units[index]
                    if unit == "M$_{\\rm Jup}$":
                        conversion_fac = 0.00315
                    elif unit == "M$_{\\odot}$":
                        conversion_fac = 0.000954265748
                    else:
                        conversion_fac = 1

                    val = self.post.derived["%s%d" %
                                            (derived_params[l],
                                             pnum)].loc[0.500] * conversion_fac
                    low = self.post.derived["%s%d" %
                                            (derived_params[l],
                                             pnum)].loc[0.159] * conversion_fac
                    high = self.post.derived[
                        "%s%d" %
                        (derived_params[l], pnum)].loc[0.841] * conversion_fac
                    err_low = val - low
                    err_high = high - val
                    err = np.mean([err_low, err_high])
                    err = radvel.utils.round_sig(err)
                    if err > 1e-15:
                        val, err, errlow = sigfig(val, err)
                        _anotext = r'$\mathregular{%s}$ = %s $\mathregular{\pm}$ %s %s' \
                                   % (texlabels[index].replace("$", ""), val, err, units[index])
                    else:
                        _anotext = r'$\mathregular{%s}$ = %4.2f %s' % (
                            texlabels[index].replace("$",
                                                     ""), val, units[index])

                    anotext += [_anotext]

        anotext = '\n'.join(anotext)
        plot.add_anchored(anotext,
                          loc=1,
                          frameon=True,
                          prop=dict(size=self.phasetext_size, weight='bold'),
                          bbox=dict(ec='none', fc='w', alpha=0.8))