Esempio n. 1
0
def animate_contours(chain, source, version, dt=5, fps=20, ffmpeg=True):
    """Saves frames of contour evolution, to make an animation
    """
    default_plt_options()
    pkeys = mcmc_versions.get_parameter(source, version, 'param_keys')

    n_walkers, n_steps, n_dimensions = chain.shape
    mtarget = os.path.join(GRIDS_PATH, 'sources', source, 'mcmc', 'animation')
    ftarget = os.path.join(mtarget, 'frames')

    cc = chainconsumer.ChainConsumer()

    for i in range(dt, n_steps, dt):
        print('frame  ', i)
        subchain = chain[:, :i, :].reshape((-1, n_dimensions))
        cc.add_chain(subchain, parameters=pkeys)

        fig = cc.plotter.plot()
        fig.set_size_inches(6, 6)
        cnt = round(i / dt)

        filename = f'{cnt:04d}.png'
        filepath = os.path.join(ftarget, filename)
        fig.savefig(filepath)

        plt.close(fig)
        cc.remove_chain()

    if ffmpeg:
        print('Creating movie')
        framefile = os.path.join(ftarget, f'%04d.png')
        savefile = os.path.join(mtarget, f'chain.mp4')
        subprocess.run(['ffmpeg', '-r', str(fps), '-i', framefile, savefile])
Esempio n. 2
0
def plot_qb_mdot(chain, source, version, discard, cap=None, display=True, save=False,
                 figsize=(5, 5), fontsize=16, sigmas=(1, 2)):
    """Plots 2D contours of Qb versus Mdot for each epoch (from multi-epoch chain)
    """
    mv = mcmc_versions.McmcVersion(source=source, version=version)
    chain_flat = mcmc_tools.slice_chain(chain, discard=discard, cap=cap, flatten=True)

    system_table = obs_tools.load_summary(mv.system)
    epochs = list(system_table.epoch)
    cc = chainconsumer.ChainConsumer()

    param_labels = []
    for param in ['mdot', 'qb']:
        param_labels += [plot_tools.full_label(param)]

    for i, epoch in enumerate(epochs):
        mdot_idx = mv.param_keys.index(f'mdot{i + 1}')
        qb_idx = mv.param_keys.index(f'qb{i + 1}')
        param_idxs = [mdot_idx, qb_idx]

        cc.add_chain(chain_flat[:, param_idxs], parameters=param_labels,
                     name=str(epoch))

    cc.configure(kde=False, smooth=0, label_font_size=fontsize,
                 tick_font_size=fontsize-2, sigmas=sigmas)
    fig = cc.plotter.plot(display=False, figsize=figsize)
    fig.subplots_adjust(left=0.2, bottom=0.2)

    save_plot(fig, prefix='qb', save=save, source=source, version=version,
              display=display, chain=chain)
    return fig
Esempio n. 3
0
def setup_chainconsumer(chain,
                        discard,
                        cap=None,
                        param_labels=None,
                        source=None,
                        version=None,
                        smoothing=False):
    """Return ChainConsumer object set up with given chain and pkeys
    """
    if param_labels is None:
        if (source is None) or (version is None):
            raise ValueError(
                'If param_labels not provided, must give source, version')
        param_labels = mcmc_versions.get_parameter(source, version,
                                                   'param_keys')

    chain = mcmc_tools.slice_chain(chain, discard=discard, cap=cap)
    n_dimensions = chain.shape[2]
    cc = chainconsumer.ChainConsumer()
    cc.add_chain(chain[:, :, :].reshape(-1, n_dimensions),
                 parameters=param_labels)

    if not smoothing:
        cc.configure(kde=False, smooth=0)
    return cc
Esempio n. 4
0
def plot_posteriors(chain=None, discard=10000):
    if chain is None:
        chain = mcmc_tools.load_chain('sim_test',
                                      n_walkers=960,
                                      n_steps=20000,
                                      version=5)
    params = [
        r'Accretion rate ($\dot{M} / \dot{M}_\text{Edd}$)', 'Hydrogen',
        r'$Z_{\text{CNO}}$', r'$Q_\text{b}$ (MeV nucleon$^{-1}$)',
        'gravity ($10^{14}$ cm s$^{-2}$)', 'redshift (1+z)', 'distance (kpc)',
        'inclination (degrees)'
    ]

    g = gravity.get_acceleration_newtonian(10, 1.4).value / 1e14
    chain[:, :, 4] *= g

    cc = chainconsumer.ChainConsumer()
    cc.add_chain(chain[:, discard:, :].reshape((-1, 8)))
    cc.configure(kde=False, smooth=0)

    fig = cc.plotter.plot_distributions(display=True)

    for i, p in enumerate(params):
        fig.axes[i].set_title('')
        fig.axes[i].set_xlabel(p)  #, fontsize=10)

    plt.tight_layout()
    return fig
Esempio n. 5
0
def setup_custom_chainconsumer(flat_chain, parameters, cloud=False, unit_labels=True,
                               sigmas=np.linspace(0, 2, 5), summary=False, fontsize=12):
    """Returns ChainConsumer, with derived parameters

        Note: chain must already be flattened and  discarded/capped
    """
    param_labels = plot_tools.convert_mcmc_labels(parameters, unit_labels=unit_labels)

    cc = chainconsumer.ChainConsumer()
    cc.add_chain(flat_chain, parameters=param_labels)
    cc.configure(sigmas=sigmas, cloud=cloud, kde=False, smooth=0, summary=summary,
                 label_font_size=fontsize, tick_font_size=fontsize-2)
    return cc
Esempio n. 6
0
def setup_epochs_chainconsumer(source, versions, n_steps, discard, n_walkers=1000,
                               cap=None, sigmas=None, cloud=None, compressed=False,
                               alt_params=True, unit_labels=True):
    """Setup multiple MCMC chains fit to individual epochs

    chains : [n_epochs]
        list of raw numpy chains
    param_keys : [n_epochs]
        list of parameters for each epoch chain
    discard : int
    cap : int (optional)
    sigmas : [] (optional)
    cloud : bool (optional)
    """
    param_keys = load_multi_param_keys(source, versions=versions)
    chains = load_multi_chains(source, versions=versions, n_steps=n_steps,
                               n_walkers=n_walkers, compressed=compressed)

    #  TODO: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
    #       quick and dirty patch! To fix
    ref_m = 1.4
    ref_g = gravity.get_acceleration_newtonian(r=10, m=ref_m).value / 1e14
    g_idx = 4
    m_idx = 5
    if alt_params:
        for params in param_keys:
            params[g_idx] = 'g'
            params[m_idx] = 'M'
    #  TODO: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx

    chains_flat = []
    for chain in chains:
        sliced_flat = slice_chain(chain, discard=discard, cap=cap, flatten=True)

        # TODO: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
        if alt_params:
            sliced_flat[:, g_idx] *= ref_g / ref_m
        # TODO: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx

        chains_flat += [sliced_flat]

    cc = chainconsumer.ChainConsumer()

    for i, chain_flat in enumerate(chains_flat):
        epoch = mcmc_versions.get_parameter(source, version=versions[i], parameter='epoch')
        param_labels = plot_tools.convert_mcmc_labels(param_keys[i],
                                                      unit_labels=unit_labels)
        cc.add_chain(chain_flat, parameters=param_labels, name=str(epoch))

    cc.configure(sigmas=sigmas, cloud=cloud, kde=False, smooth=0)
    return cc
Esempio n. 7
0
def plot_mass_radius(chain,
                     discard,
                     source,
                     version,
                     cap=None,
                     display=True,
                     save=False,
                     max_lhood=False,
                     verbose=True,
                     smoothing=False):
    """Plots contours of mass versus radius

    See: get_mass_radius()
    """
    default_plt_options()
    mass_radius_chain = get_mass_radius(chain=chain,
                                        discard=discard,
                                        source=source,
                                        version=version,
                                        cap=cap)

    cc = chainconsumer.ChainConsumer()
    cc.add_chain(mass_radius_chain.reshape(-1, 2), parameters=['M', 'R'])
    if not smoothing:
        cc.configure(kde=False, smooth=0)

    if max_lhood:
        n_walkers, n_steps = chain[:, :, 0].shape
        max_params = mcmc_tools.get_max_lhood_params(source,
                                                     version=version,
                                                     n_walkers=n_walkers,
                                                     n_steps=n_steps,
                                                     verbose=verbose)
        mass, radius = get_mass_radius_point(max_params,
                                             source=source,
                                             version=version)
        fig = cc.plotter.plot(display=True,
                              figsize=[6, 6],
                              truth=[mass, radius])
    else:
        fig = cc.plotter.plot(display=True, figsize=[6, 6])

    save_plot(fig,
              prefix='mass-radius',
              chain=chain,
              save=save,
              source=source,
              version=version,
              display=display)
Esempio n. 8
0
def setup_chainconsumer(chain, discard, cap=None, param_labels=None, cloud=False,
                        source=None, version=None, sigmas=np.linspace(0, 2, 5),
                        summary=False, fontsize=14, max_ticks=4):
    """Return ChainConsumer object set up with given chain and pkeys
    """
    if param_labels is None:
        if (source is None) or (version is None):
            raise ValueError('If param_labels not provided, must give source, version')
        param_keys = mcmc_versions.get_parameter(source, version, 'param_keys')
        param_labels = plot_tools.convert_mcmc_labels(param_keys)

    n_walkers = chain.shape[0]
    chain_flat = slice_chain(chain, discard=discard, cap=cap, flatten=True)

    cc = chainconsumer.ChainConsumer()
    cc.add_chain(chain_flat, parameters=param_labels, walkers=n_walkers)
    cc.configure(sigmas=sigmas, cloud=cloud, kde=False, smooth=0, summary=summary,
                 label_font_size=fontsize, tick_font_size=fontsize-3, max_ticks=max_ticks)
    return cc
Esempio n. 9
0
def setup_bprop_chainconsumer(chain, source, version, n=None, discard=None, cap=None,
                              summary=False, max_ticks=4, bp_sample=None,
                              sigmas=np.linspace(0, 2, 5), fontsize=16):
    """Returns ChainConsumer object for bprop sample (posterior predictive distribution)
    """
    if bp_sample is None:
        if (n is None) or (discard is None):
            raise ValueError('If bp_sample not provided, must give discard and n')
        bp_sample = bprop_sample(chain=chain, n=n, source=source, version=version,
                                 discard=discard, cap=cap)

    mv = mcmc_versions.McmcVersion(source=source, version=version)
    cc = chainconsumer.ChainConsumer()

    for i, epoch in enumerate(bp_sample):
        cc.add_chain(epoch.transpose(), name=f"Epoch {i+1}", parameters=mv.bprops)

    cc.configure(sigmas=sigmas, kde=False, smooth=0, summary=summary,
                 label_font_size=fontsize, tick_font_size=fontsize - 3,
                 max_ticks=max_ticks)
    return cc
Esempio n. 10
0
    def plot_corner_cc(self, results, param_names, **kwargs):

        samples = np.array(results['weighted_samples']['points'])
        weights = np.array(results['weighted_samples']['weights'])
        cumsum_weights = np.cumsum(weights)

        mask = cumsum_weights > 1e-4

        if mask.sum() == 1:
            print(
                'Posterior has a really poor spread. Something funny is going on.'
            )

        c = cc.ChainConsumer()
        c.add_chain(samples[mask, :],
                    weights=weights[mask],
                    parameters=param_names)
        c.configure(summary=True)

        fig = c.plotter.plot(**kwargs)

        # fig = corner.corner(samples[mask,:], weights=weights[mask], show_titles=True, labels = param_names, **kwargs)

        return fig
Esempio n. 11
0
    def corner_plot_cc(self,
                       parameters=None,
                       renamed_parameters=None,
                       **cc_kwargs):
        """
        Corner plots using chainconsumer which allows for nicer plotting of
        marginals
        see: https://samreay.github.io/ChainConsumer/chain_api.html#chainconsumer.ChainConsumer.configure
        for all options
        :param parameters: list of parameters to plot
        :param renamed_parameters: a python dictionary of parameters to rename.
             Useful when e.g. spectral indices in models have different names but you wish to compare them. Format is
             {'old label': 'new label'}
        :param **cc_kwargs: chainconsumer general keyword arguments
        :return fig:
        """

        if not has_chainconsumer:
            raise RuntimeError(
                "You must have chainconsumer installed to use this function: pip install chainconsumer"
            )

        # these are the keywords for the plot command

        _default_plot_args = {
            'truth': None,
            'figsize': 'GROW',
            'filename': None,
            'display': False,
            'legend': None
        }
        keys = cc_kwargs.keys()
        for key in keys:

            if key in _default_plot_args:
                _default_plot_args[key] = cc_kwargs.pop(key)

        labels = []
        priors = []

        for i, (parameter_name,
                parameter) in enumerate(self._free_parameters.iteritems()):
            short_name = parameter_name.split(".")[-1]

            labels.append(short_name)

            priors.append(
                self._optimized_model.parameters[parameter_name].prior)

        # Rename the parameters if needed.

        if renamed_parameters is not None:

            for old_label, new_label in renamed_parameters.iteritems():

                for i, _ in enumerate(labels):

                    if labels[i] == old_label:
                        labels[i] = new_label

        # Must remove underscores!

        for i, val, in enumerate(labels):

            if '$' not in labels[i]:
                labels[i] = val.replace('_', '')

        cc = chainconsumer.ChainConsumer()

        cc.add_chain(self._samples_transposed.T, parameters=labels)

        if not cc_kwargs:
            cc_kwargs = threeML_config['bayesian']['chain consumer style']

        cc.configure(**cc_kwargs)
        fig = cc.plotter.plot(parameters=parameters, **_default_plot_args)

        return fig
Esempio n. 12
0
    def comparison_corner_plot(self, *other_fits, **kwargs):
        """
        Create a corner plot from many different fits which allow for co-plotting of parameters marginals.

        :param other_fits: other fitted results
        :param parameters: parameters to plot
        :param renamed_parameters: a python dictionary of parameters to rename.
             Useful when e.g. spectral indices in models have different names but you wish to compare them. Format is
             {'old label': 'new label'}
        :param names: (optional) name for each chain first name is this chain followed by each added chain
        :param kwargs: chain consumer kwargs
        :return:

        Returns:

        """

        if not has_chainconsumer:
            raise RuntimeError(
                "You must have chainconsumer installed to use this function")

        cc = chainconsumer.ChainConsumer()

        # these are the keywords for the plot command

        _default_plot_args = {
            'truth': None,
            'figsize': 'GROW',
            'parameters': None,
            'filename': None,
            'display': False,
            'legend': None
        }

        keys = kwargs.keys()

        for key in keys:

            if key in _default_plot_args:
                _default_plot_args[key] = kwargs.pop(key)

        # allows us to name chains

        if 'names' in kwargs:

            names = kwargs.pop('names')

            assert len(names) == len(
                other_fits) + 1, 'you have %d chains but %d names' % (
                    len(other_fits) + 1, len(names))

        else:

            names = None

        if 'renamed_parameters' in kwargs:

            renamed_parameters = kwargs.pop('renamed_parameters')

        else:

            renamed_parameters = None

        for j, other_fit in enumerate(other_fits):

            if other_fit.samples is not None:
                assert len(other_fit._free_parameters.keys()
                           ) == other_fit.samples.T[0].shape[0], (
                               "Mismatch between sample"
                               " dimensions and number of free"
                               " parameters")

            labels_other = []
            # priors_other = []

            for i, (parameter_name, parameter) in enumerate(
                    other_fit._free_parameters.iteritems()):
                short_name = parameter_name.split(".")[-1]

                labels_other.append(short_name)

                # priors_other.append(other_fit._likelihood_model.parameters[parameter_name].prior)

            # Rename any parameters so that they can be plotted together.
            # A dictionary is passed with keys = old label values = new label.

            if renamed_parameters is not None:

                for old_label, new_label in renamed_parameters.iteritems():

                    for i, _ in enumerate(labels_other):

                        if labels_other[i] == old_label:
                            labels_other[i] = new_label

            # Must remove underscores!

            for i, val, in enumerate(labels_other):

                if '$' not in labels_other[i]:
                    labels_other[i] = val.replace('_', ' ')

            if names is not None:

                cc.add_chain(other_fit.samples.T,
                             parameters=labels_other,
                             name=names[j + 1])

            else:

                cc.add_chain(other_fit.samples.T, parameters=labels_other)

        labels = []
        # priors = []

        for i, (parameter_name,
                parameter) in enumerate(self._free_parameters.iteritems()):
            short_name = parameter_name.split(".")[-1]

            labels.append(short_name)

            # priors.append(self._optimized_model.parameters[parameter_name].prior)

        if renamed_parameters is not None:

            for old_label, new_label in renamed_parameters.iteritems():

                for i, _ in enumerate(labels):

                    if labels[i] == old_label:
                        labels[i] = new_label

        # Must remove underscores!

        for i, val, in enumerate(labels):

            if '$' not in labels[i]:
                labels[i] = val.replace('_', ' ')

        if names is not None:

            cc.add_chain(self._samples_transposed.T,
                         parameters=labels,
                         name=names[0])

        else:

            cc.add_chain(self._samples_transposed.T, parameters=labels)

        # should only be the cc kwargs

        cc.configure(**kwargs)
        fig = cc.plot(**_default_plot_args)

        return fig
Esempio n. 13
0
def animate_walkers(chain,
                    source,
                    version,
                    stepsize=1,
                    n_steps=100,
                    bin=10,
                    burn=100):
    default_plt_options()
    mv = mcmc_versions.McmcVersion(source, version)
    g_idx = mv.param_keys.index('g')
    red_idx = mv.param_keys.index('redshift')
    save_path = os.path.join(GRIDS_PATH, 'sources', source, 'plots', 'misc',
                             'walker2')
    cc = chainconsumer.ChainConsumer()

    # ===== axis setup =====
    fig = plt.figure(1, figsize=(8, 8))

    nullfmt = NullFormatter()
    xlim = (0.6, 2.0)
    ylim = (1.08, 1.2)
    hist_ylim = (0, 1.1)

    left, width = 0.1, 0.65
    bottom, height = 0.1, 0.65
    bottom_h = left_h = left + width + 0.02

    rect_scatter = [left, bottom, width, height]
    rect_histx = [left, bottom_h, width, 0.2]
    rect_histy = [left_h, bottom, 0.2, height]

    axScatter = plt.axes(rect_scatter)
    axHistx = plt.axes(rect_histx)
    axHisty = plt.axes(rect_histy)

    axScatter.set_xlim(xlim)
    axScatter.set_ylim(ylim)
    axScatter.set_xlabel('X', fontsize=20)
    axScatter.set_ylabel('Y', fontsize=20)

    axHistx.set_xlim(xlim)
    axHistx.set_ylim(hist_ylim)

    axHisty.set_ylim(ylim)
    axHisty.set_xlim(hist_ylim)

    axHistx.xaxis.set_major_formatter(nullfmt)
    axHistx.yaxis.set_major_formatter(nullfmt)
    axHisty.yaxis.set_major_formatter(nullfmt)
    axHisty.xaxis.set_major_formatter(nullfmt)

    # ===== Add data to axes =====
    for i in range(stepsize, stepsize * (n_steps + 1), stepsize):
        num = int(i / stepsize)
        sys.stdout.write(f'\r{num}/{n_steps}')

        # ===== walker scatter =====
        lines_scatter = axScatter.plot(chain[:, i, g_idx],
                                       chain[:, i, red_idx],
                                       marker='o',
                                       ls='none',
                                       markersize=2.5,
                                       color='C0')

        # ===== chainconsumer distributions =====
        # width1 = 10
        # burn= 100
        if i < bin:
            sub_chain = mcmc_tools.slice_chain(chain, discard=0, cap=i)
        elif i < burn:
            sub_chain = mcmc_tools.slice_chain(chain, discard=i - bin, cap=i)
        else:
            sub_chain = mcmc_tools.slice_chain(chain,
                                               discard=burn - bin,
                                               cap=i)

        cc.add_chain(sub_chain[:, :, [g_idx, red_idx]].reshape(-1, 2),
                     parameters=['g', 'redshift'])
        cc_fig = cc.plotter.plot_distributions(blind=True)

        x_x = cc_fig.axes[0].lines[0].get_data()[0]
        x_y = cc_fig.axes[0].lines[0].get_data()[1]

        y_x = cc_fig.axes[1].lines[0].get_data()[0]
        y_y = cc_fig.axes[1].lines[0].get_data()[1]

        x_ymax = np.max(x_y)
        y_ymax = np.max(y_y)

        plt.close(cc_fig)

        lines_x = axHistx.plot(x_x, x_y / x_ymax, color='C0')
        lines_y = axHisty.plot(y_y / y_ymax, y_x, color='C0')

        filename = f'walker2_biggrid2_V25_{num:04}.png'
        filepath = os.path.join(save_path, filename)
        fig.savefig(filepath)

        lines_scatter.pop(0).remove()
        lines_x.pop(0).remove()
        lines_y.pop(0).remove()
        cc.remove_chain()

        # fig.show()
        # return

    sys.stdout.write('')
    plt.close('all')