Ejemplo n.º 1
0
Archivo: dict.py Proyecto: k-ship/bilby
def create_default_prior(name, default_priors_file=None):
    """Make a default prior for a parameter with a known name.

    Parameters
    ----------
    name: str
        Parameter name
    default_priors_file: str, optional
        If given, a file containing the default priors.

    Return
    ------
    prior: Prior
        Default prior distribution for that parameter, if unknown None is
        returned.
    """

    if default_priors_file is None:
        logger.debug(
            "No prior file given.")
        prior = None
    else:
        default_priors = PriorDict(filename=default_priors_file)
        if name in default_priors.keys():
            prior = default_priors[name]
        else:
            logger.debug(
                "No default prior found for variable {}.".format(name))
            prior = None
    return prior
Ejemplo n.º 2
0
    def resample_posteriors(self, posteriors, max_samples=1e300):
        """
        Convert list of pandas DataFrame object to dict of arrays.

        Parameters
        ----------
        posteriors: list
            List of pandas DataFrame objects.
        max_samples: int, opt
            Maximum number of samples to take from each posterior,
            default is length of shortest posterior chain.
        Returns
        -------
        data: dict
            Dictionary containing arrays of size (n_posteriors, max_samples)
            There is a key for each shared key in posteriors.
        """
        for posterior in posteriors:
            max_samples = min(len(posterior), max_samples)
        data = {key: [] for key in posteriors[0]}
        logger.debug(
            'Downsampling to {} samples per posterior.'.format(max_samples))
        self.samples_per_posterior = max_samples
        for posterior in posteriors:
            temp = posterior.sample(self.samples_per_posterior)
            for key in data:
                data[key].append(temp[key])
        for key in data:
            data[key] = xp.array(data[key])
        return data
Ejemplo n.º 3
0
Archivo: dict.py Proyecto: k-ship/bilby
    def to_file(self, outdir, label):
        """ Write the prior distribution to file.

        Parameters
        ----------
        outdir: str
            output directory name
        label: str
            Output file naming scheme
        """

        check_directory_exists_and_if_not_mkdir(outdir)
        prior_file = os.path.join(outdir, "{}.prior".format(label))
        logger.debug("Writing priors to {}".format(prior_file))
        joint_dists = []
        with open(prior_file, "w") as outfile:
            for key in self.keys():
                if JointPrior in self[key].__class__.__mro__:
                    distname = '_'.join(self[key].dist.names) + '_{}'.format(self[key].dist.distname)
                    if distname not in joint_dists:
                        joint_dists.append(distname)
                        outfile.write(
                            "{} = {}\n".format(distname, self[key].dist))
                    diststr = repr(self[key].dist)
                    priorstr = repr(self[key])
                    outfile.write(
                        "{} = {}\n".format(key, priorstr.replace(diststr,
                                                                 distname)))
                else:
                    outfile.write(
                        "{} = {}\n".format(key, self[key]))
Ejemplo n.º 4
0
Archivo: dict.py Proyecto: k-ship/bilby
    def __init__(self, dictionary=None, filename=None,
                 conversion_function=None):
        """ A set of priors

        Parameters
        ----------
        dictionary: Union[dict, str, None]
            If given, a dictionary to generate the prior set.
        filename: Union[str, None]
            If given, a file containing the prior to generate the prior set.
        conversion_function: func
            Function to convert between sampled parameters and constraints.
            Default is no conversion.
        """
        super(PriorDict, self).__init__()
        if isinstance(dictionary, dict):
            self.from_dictionary(dictionary)
        elif type(dictionary) is str:
            logger.debug('Argument "dictionary" is a string.' +
                         ' Assuming it is intended as a file name.')
            self.from_file(dictionary)
        elif type(filename) is str:
            self.from_file(filename)
        elif dictionary is not None:
            raise ValueError("PriorDict input dictionary not understood")

        self.convert_floats_to_delta_functions()

        if conversion_function is not None:
            self.conversion_function = conversion_function
        else:
            self.conversion_function = self.default_conversion_function
Ejemplo n.º 5
0
 def _split_repr(cls, string):
     subclass_args = infer_args_from_method(cls.__init__)
     args = string.split(',')
     remove = list()
     for ii, key in enumerate(args):
         if '(' in key:
             jj = ii
             while ')' not in args[jj]:
                 jj += 1
                 args[ii] = ','.join([args[ii], args[jj]]).strip()
                 remove.append(jj)
     remove.reverse()
     for ii in remove:
         del args[ii]
     kwargs = dict()
     for ii, arg in enumerate(args):
         if '=' not in arg:
             logger.debug(
                 'Reading priors with non-keyword arguments is dangerous!')
             key = subclass_args[ii]
             val = arg
         else:
             split_arg = arg.split('=')
             key = split_arg[0]
             val = '='.join(split_arg[1:])
         kwargs[key] = val
     return kwargs
Ejemplo n.º 6
0
Archivo: dict.py Proyecto: k-ship/bilby
 def to_json(self, outdir, label):
     check_directory_exists_and_if_not_mkdir(outdir)
     prior_file = os.path.join(outdir, "{}_prior.json".format(label))
     logger.debug("Writing priors to {}".format(prior_file))
     with open(prior_file, "w") as outfile:
         json.dump(self._get_json_dict(), outfile, cls=BilbyJsonEncoder,
                   indent=2)
Ejemplo n.º 7
0
 def sample_subset(self, keys=iter([]), size=None):
     self.convert_floats_to_delta_functions()
     subset_dict = ConditionalPriorDict({key: self[key] for key in keys})
     if not subset_dict._resolved:
         raise IllegalConditionsException(
             "The current set of priors contains unresolvable conditions.")
     samples = dict()
     for key in subset_dict.sorted_keys:
         if isinstance(self[key], Constraint):
             continue
         elif isinstance(self[key], Prior):
             try:
                 samples[key] = subset_dict[key].sample(
                     size=size, **subset_dict.get_required_variables(key))
             except ValueError:
                 # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw)
                 # If that is the case, we sample each sample individually.
                 required_variables = subset_dict.get_required_variables(
                     key)
                 samples[key] = np.zeros(size)
                 for i in range(size):
                     rvars = {
                         key: value[i]
                         for key, value in required_variables.items()
                     }
                     samples[key][i] = subset_dict[key].sample(**rvars)
         else:
             logger.debug('{} not a known prior.'.format(key))
     return samples
Ejemplo n.º 8
0
Archivo: dict.py Proyecto: k-ship/bilby
 def from_dictionary(self, dictionary):
     eval_dict = dict(inf=np.inf)
     for key, val in iteritems(dictionary):
         if isinstance(val, Prior):
             continue
         elif isinstance(val, (int, float)):
             dictionary[key] = DeltaFunction(peak=val)
         elif isinstance(val, str):
             cls = val.split('(')[0]
             args = '('.join(val.split('(')[1:])[:-1]
             try:
                 dictionary[key] = DeltaFunction(peak=float(cls))
                 logger.debug("{} converted to DeltaFunction prior".format(key))
                 continue
             except ValueError:
                 pass
             if "." in cls:
                 module = '.'.join(cls.split('.')[:-1])
                 cls = cls.split('.')[-1]
             else:
                 module = __name__.replace(
                     '.' + os.path.basename(__file__).replace('.py', ''), ''
                 )
             cls = getattr(import_module(module), cls, cls)
             if key.lower() in ["conversion_function", "condition_func"]:
                 setattr(self, key, cls)
             elif isinstance(cls, str):
                 if "(" in val:
                     raise TypeError("Unable to parse prior class {}".format(cls))
                 else:
                     continue
             elif (cls.__name__ in ['MultivariateGaussianDist',
                                    'MultivariateNormalDist']):
                 if key not in eval_dict:
                     eval_dict[key] = eval(val, None, eval_dict)
             elif (cls.__name__ in ['MultivariateGaussian',
                                    'MultivariateNormal']):
                 dictionary[key] = eval(val, None, eval_dict)
             else:
                 try:
                     dictionary[key] = cls.from_repr(args)
                 except TypeError as e:
                     raise TypeError(
                         "Unable to parse prior, bad entry: {} "
                         "= {}. Error message {}".format(key, val, e)
                     )
         elif isinstance(val, dict):
             logger.warning(
                 'Cannot convert {} into a prior object. '
                 'Leaving as dictionary.'.format(key))
         else:
             raise TypeError(
                 "Unable to parse prior, bad entry: {} "
                 "= {} of type {}".format(key, val, type(val))
             )
     self.update(dictionary)
Ejemplo n.º 9
0
Archivo: dict.py Proyecto: k-ship/bilby
 def convert_floats_to_delta_functions(self):
     """ Convert all float parameters to delta functions """
     for key in self:
         if isinstance(self[key], Prior):
             continue
         elif isinstance(self[key], float) or isinstance(self[key], int):
             self[key] = DeltaFunction(self[key])
             logger.debug(
                 "{} converted to delta function prior.".format(key))
         else:
             logger.debug(
                 "{} cannot be converted to delta function prior."
                 .format(key))
Ejemplo n.º 10
0
 def _get_from_json_dict(cls, prior_dict):
     try:
         cls == getattr(import_module(prior_dict["__module__"]),
                        prior_dict["__name__"])
     except ImportError:
         logger.debug("Cannot import prior module {}.{}".format(
             prior_dict["__module__"], prior_dict["__name__"]))
     except KeyError:
         logger.debug("Cannot find module name to load")
     for key in ["__module__", "__name__", "__prior_dict__"]:
         if key in prior_dict:
             del prior_dict[key]
     obj = cls(dict())
     obj.from_dictionary(prior_dict)
     return obj
Ejemplo n.º 11
0
 def from_dictionary(self, dictionary):
     for key, val in iteritems(dictionary):
         if isinstance(val, str):
             try:
                 prior = eval(val)
                 if isinstance(prior, (Prior, float, int, str)):
                     val = prior
             except (NameError, SyntaxError, TypeError):
                 logger.debug(
                     "Failed to load dictionary value {} correctly".format(
                         key))
                 pass
         elif isinstance(val, dict):
             logger.warning('Cannot convert {} into a prior object. '
                            'Leaving as dictionary.'.format(key))
         self[key] = val
Ejemplo n.º 12
0
 def _initialize_attributes(self):
     if np.trapz(self._yy, self.xx) != 1:
         logger.debug(
             'Supplied PDF for {} is not normalised, normalising.'.format(
                 self.name))
     self._yy /= np.trapz(self._yy, self.xx)
     self.YY = cumtrapz(self._yy, self.xx, initial=0)
     # Need last element of cumulative distribution to be exactly one.
     self.YY[-1] = 1
     self.probability_density = interp1d(x=self.xx,
                                         y=self._yy,
                                         bounds_error=False,
                                         fill_value=0)
     self.cumulative_distribution = interp1d(x=self.xx,
                                             y=self.YY,
                                             bounds_error=False,
                                             fill_value=(0, 1))
     self.inverse_cumulative_distribution = interp1d(x=self.YY,
                                                     y=self.xx,
                                                     bounds_error=True)
Ejemplo n.º 13
0
Archivo: dict.py Proyecto: k-ship/bilby
    def sample_subset(self, keys=iter([]), size=None):
        """Draw samples from the prior set for parameters which are not a DeltaFunction

        Parameters
        ----------
        keys: list
            List of prior keys to draw samples from
        size: int or tuple of ints, optional
            See numpy.random.uniform docs

        Returns
        -------
        dict: Dictionary of the drawn samples
        """
        self.convert_floats_to_delta_functions()
        samples = dict()
        for key in keys:
            if isinstance(self[key], Constraint):
                continue
            elif isinstance(self[key], Prior):
                samples[key] = self[key].sample(size=size)
            else:
                logger.debug('{} not a known prior.'.format(key))
        return samples
def plot_interferometer_waveform_posterior(res,
                                           interferometer,
                                           level=0.9,
                                           n_samples=None,
                                           start_time=None,
                                           end_time=None,
                                           outdir='.',
                                           signals_to_plot={}):
    """
    Plot the posterior for the waveform in the frequency domain and
    whitened time domain.

    If the strain data is passed that will be plotted.

    If injection parameters can be found, the injection will be plotted.

    Parameters
    ==========
    interferometer: (str, bilby.gw.detector.interferometer.Interferometer)
        detector to use, if an Interferometer object is passed the data
        will be overlaid on the posterior
    level: float, optional
        symmetric confidence interval to show, default is 90%
    n_samples: int, optional
        number of samples to use to calculate the median/interval
        default is all
    start_time: float, optional
        the amount of time before merger to begin the time domain plot.
        the merger time is defined as the mean of the geocenter time
        posterior. Default is - 0.4
    end_time: float, optional
        the amount of time before merger to end the time domain plot.
        the merger time is defined as the mean of the geocenter time
        posterior. Default is 0.2

    Returns
    =======
    fig: figure-handle, only is save=False

    Notes
    -----
    To reduce the memory footprint we decimate the frequency domain
    waveforms to have ~4000 entries. This should be sufficient for decent
    resolution.
    """

    DATA_COLOR = "#ff7f0e"
    WAVEFORM_COLOR = "#1f77b4"
    INJECTION_COLOR = "#000000"

    if not isinstance(interferometer, bilby.gw.detector.Interferometer):
        raise TypeError('interferometer type must be Interferometer')

    logger.info("Generating waveform figure for {}".format(
        interferometer.name))

    if n_samples is None:
        samples = res.posterior
    else:
        samples = res.posterior.sample(n_samples, replace=False)

    if start_time is None:
        start_time = -0.4
    start_time = np.mean(samples.geocent_time) + start_time
    if end_time is None:
        end_time = 0.2
    end_time = np.mean(samples.geocent_time) + end_time

    time_idxs = ((interferometer.time_array >= start_time) &
                 (interferometer.time_array <= end_time))
    frequency_idxs = np.where(interferometer.frequency_mask)[0]
    logger.debug("Frequency mask contains {} values".format(
        len(frequency_idxs)))
    frequency_idxs = frequency_idxs[::max(1, len(frequency_idxs) // 4000)]
    logger.debug("Downsampling frequency mask to {} values".format(
        len(frequency_idxs)))
    plot_times = interferometer.time_array[time_idxs]
    plot_times -= interferometer.strain_data.start_time
    start_time -= interferometer.strain_data.start_time
    end_time -= interferometer.strain_data.start_time
    plot_frequencies = interferometer.frequency_array[frequency_idxs]

    waveform_arguments = res.waveform_arguments
    waveform_arguments['waveform_approximant'] = "IMRPhenomPv2"

    waveform_generator = res.waveform_generator_class(
        duration=res.duration,
        sampling_frequency=res.sampling_frequency,
        start_time=res.start_time,
        frequency_domain_source_model=res.frequency_domain_source_model,
        parameter_conversion=res.parameter_conversion,
        waveform_arguments=waveform_arguments)

    old_font_size = rcParams["font.size"]
    rcParams["font.size"] = 20
    fig, axs = plt.subplots(2,
                            1,
                            gridspec_kw=dict(height_ratios=[1.5, 1]),
                            figsize=(16, 12.5))

    axs[0].loglog(plot_frequencies,
                  asd_from_freq_series(
                      interferometer.frequency_domain_strain[frequency_idxs],
                      1 / interferometer.strain_data.duration),
                  color=DATA_COLOR,
                  label='Data',
                  alpha=0.3)
    axs[0].loglog(
        plot_frequencies,
        interferometer.amplitude_spectral_density_array[frequency_idxs],
        color=DATA_COLOR,
        label='ASD')
    axs[1].plot(
        plot_times,
        infft(interferometer.whitened_frequency_domain_strain *
              np.sqrt(2. / interferometer.sampling_frequency),
              sampling_frequency=interferometer.strain_data.sampling_frequency)
        [time_idxs],
        color=DATA_COLOR,
        alpha=0.3)
    logger.debug('Plotted interferometer data.')

    fd_waveforms = list()
    td_waveforms = list()
    for _, params in tqdm(samples.iterrows(),
                          desc="Processing Samples",
                          total=len(samples)):
        try:
            params = dict(params)
            wf_pols = waveform_generator.frequency_domain_strain(params)
            fd_waveform = interferometer.get_detector_response(wf_pols, params)
            fd_waveforms.append(fd_waveform[frequency_idxs])
            td_waveform = infft(
                fd_waveform * np.sqrt(2. / interferometer.sampling_frequency) /
                interferometer.amplitude_spectral_density_array,
                res.sampling_frequency)[time_idxs]
        except Exception as e:
            logger.debug(f"ERROR: {e}\nparams: {params}")
            pass
        td_waveforms.append(td_waveform)
    fd_waveforms = asd_from_freq_series(
        fd_waveforms, 1 / interferometer.strain_data.duration)
    td_waveforms = np.array(td_waveforms)

    delta = (1 + level) / 2
    upper_percentile = delta * 100
    lower_percentile = (1 - delta) * 100
    logger.debug('Plotting posterior between the {} and {} percentiles'.format(
        lower_percentile, upper_percentile))

    lower_limit = np.mean(fd_waveforms, axis=0)[0] / 1e3
    axs[0].loglog(plot_frequencies,
                  np.mean(fd_waveforms, axis=0),
                  color=WAVEFORM_COLOR,
                  label='Mean reconstructed')
    axs[0].fill_between(plot_frequencies,
                        np.percentile(fd_waveforms, lower_percentile, axis=0),
                        np.percentile(fd_waveforms, upper_percentile, axis=0),
                        color=WAVEFORM_COLOR,
                        label='{}\% credible interval'.format(
                            int(upper_percentile - lower_percentile)),
                        alpha=0.3)
    axs[1].plot(plot_times,
                np.mean(td_waveforms, axis=0),
                color=WAVEFORM_COLOR)
    axs[1].fill_between(plot_times,
                        np.percentile(td_waveforms, lower_percentile, axis=0),
                        np.percentile(td_waveforms, upper_percentile, axis=0),
                        color=WAVEFORM_COLOR,
                        alpha=0.3)

    if len(signals_to_plot) > 0:
        for d in signals_to_plot:
            params = d['params']
            label = d['label']
            col = d['color']
            try:
                hf_inj = waveform_generator.frequency_domain_strain(params)
                hf_inj_det = interferometer.get_detector_response(
                    hf_inj, params)
                ht_inj_det = infft(
                    hf_inj_det *
                    np.sqrt(2. / interferometer.sampling_frequency) /
                    interferometer.amplitude_spectral_density_array,
                    res.sampling_frequency)[time_idxs]

                axs[0].loglog(plot_frequencies,
                              asd_from_freq_series(
                                  hf_inj_det[frequency_idxs],
                                  1 / interferometer.strain_data.duration),
                              label=label,
                              linestyle=':',
                              color=col)
                axs[1].plot(plot_times, ht_inj_det, linestyle=':', color=col)
                logger.debug('Plotted injection.')
            except IndexError as e:
                logger.info(
                    'Failed to plot injection with message {}.'.format(e))

    f_domain_x_label = "$f [\\mathrm{Hz}]$"
    f_domain_y_label = "$\\mathrm{ASD} \\left[\\mathrm{Hz}^{-1/2}\\right]$"
    t_domain_x_label = "$t - {} [s]$".format(
        interferometer.strain_data.start_time)
    t_domain_y_label = "Whitened Strain"

    axs[0].set_xlim(interferometer.minimum_frequency,
                    interferometer.maximum_frequency)
    axs[1].set_xlim(start_time, end_time)
    axs[0].set_ylim(lower_limit)
    axs[0].set_xlabel(f_domain_x_label)
    axs[0].set_ylabel(f_domain_y_label)
    axs[1].set_xlabel(t_domain_x_label)
    axs[1].set_ylabel(t_domain_y_label)
    axs[0].legend(loc='lower left', ncol=2)

    filename = f"{outdir}/{res.label}_{interferometer.name}_waveform.png"

    plt.tight_layout()
    fig.savefig(fname=filename, dpi=600)
    plt.close()
    logger.info("Waveform figure saved to {}".format(filename))
    rcParams["font.size"] = old_font_size
Ejemplo n.º 15
0
    def from_file(self, filename):
        """ Reads in a prior from a file specification

        Parameters
        ----------
        filename: str
            Name of the file to be read in

        Notes
        -----
        Lines beginning with '#' or empty lines will be ignored.
        Priors can be loaded from:
            bilby.core.prior as, e.g.,    foo = Uniform(minimum=0, maximum=1)
            floats, e.g.,                 foo = 1
            bilby.gw.prior as, e.g.,      foo = bilby.gw.prior.AlignedSpin()
            other external modules, e.g., foo = my.module.CustomPrior(...)
        """

        comments = ['#', '\n']
        prior = dict()
        mvgdict = dict(inf=np.inf)  # evaluate inf as np.inf
        with ioopen(filename, 'r', encoding='unicode_escape') as f:
            for line in f:
                if line[0] in comments:
                    continue
                line.replace(' ', '')
                elements = line.split('=')
                key = elements[0].replace(' ', '')
                val = '='.join(elements[1:]).strip()
                cls = val.split('(')[0]
                args = '('.join(val.split('(')[1:])[:-1]
                try:
                    prior[key] = DeltaFunction(peak=float(cls))
                    logger.debug(
                        "{} converted to DeltaFunction prior".format(key))
                    continue
                except ValueError:
                    pass
                if "." in cls:
                    module = '.'.join(cls.split('.')[:-1])
                    cls = cls.split('.')[-1]
                else:
                    module = __name__.replace(
                        '.' + os.path.basename(__file__).replace('.py', ''),
                        '')
                cls = getattr(import_module(module), cls, cls)
                if key.lower() in ["conversion_function", "condition_func"]:
                    setattr(self, key, cls)
                elif (cls.__name__ in [
                        'MultivariateGaussianDist', 'MultivariateNormalDist'
                ]):
                    if key not in mvgdict:
                        mvgdict[key] = eval(val, None, mvgdict)
                elif (cls.__name__
                      in ['MultivariateGaussian', 'MultivariateNormal']):
                    prior[key] = eval(val, None, mvgdict)
                else:
                    try:
                        prior[key] = cls.from_repr(args)
                    except TypeError as e:
                        raise TypeError(
                            "Unable to parse dictionary file {}, bad line: {} "
                            "= {}. Error message {}".format(
                                filename, key, val, e))
        self.update(prior)