def create_default_prior(name, default_priors_file=None): """Make a default prior for a parameter with a known name. Parameters ---------- name: str Parameter name default_priors_file: str, optional If given, a file containing the default priors. Return ------ prior: Prior Default prior distribution for that parameter, if unknown None is returned. """ if default_priors_file is None: logger.debug( "No prior file given.") prior = None else: default_priors = PriorDict(filename=default_priors_file) if name in default_priors.keys(): prior = default_priors[name] else: logger.debug( "No default prior found for variable {}.".format(name)) prior = None return prior
def resample_posteriors(self, posteriors, max_samples=1e300): """ Convert list of pandas DataFrame object to dict of arrays. Parameters ---------- posteriors: list List of pandas DataFrame objects. max_samples: int, opt Maximum number of samples to take from each posterior, default is length of shortest posterior chain. Returns ------- data: dict Dictionary containing arrays of size (n_posteriors, max_samples) There is a key for each shared key in posteriors. """ for posterior in posteriors: max_samples = min(len(posterior), max_samples) data = {key: [] for key in posteriors[0]} logger.debug( 'Downsampling to {} samples per posterior.'.format(max_samples)) self.samples_per_posterior = max_samples for posterior in posteriors: temp = posterior.sample(self.samples_per_posterior) for key in data: data[key].append(temp[key]) for key in data: data[key] = xp.array(data[key]) return data
def to_file(self, outdir, label): """ Write the prior distribution to file. Parameters ---------- outdir: str output directory name label: str Output file naming scheme """ check_directory_exists_and_if_not_mkdir(outdir) prior_file = os.path.join(outdir, "{}.prior".format(label)) logger.debug("Writing priors to {}".format(prior_file)) joint_dists = [] with open(prior_file, "w") as outfile: for key in self.keys(): if JointPrior in self[key].__class__.__mro__: distname = '_'.join(self[key].dist.names) + '_{}'.format(self[key].dist.distname) if distname not in joint_dists: joint_dists.append(distname) outfile.write( "{} = {}\n".format(distname, self[key].dist)) diststr = repr(self[key].dist) priorstr = repr(self[key]) outfile.write( "{} = {}\n".format(key, priorstr.replace(diststr, distname))) else: outfile.write( "{} = {}\n".format(key, self[key]))
def __init__(self, dictionary=None, filename=None, conversion_function=None): """ A set of priors Parameters ---------- dictionary: Union[dict, str, None] If given, a dictionary to generate the prior set. filename: Union[str, None] If given, a file containing the prior to generate the prior set. conversion_function: func Function to convert between sampled parameters and constraints. Default is no conversion. """ super(PriorDict, self).__init__() if isinstance(dictionary, dict): self.from_dictionary(dictionary) elif type(dictionary) is str: logger.debug('Argument "dictionary" is a string.' + ' Assuming it is intended as a file name.') self.from_file(dictionary) elif type(filename) is str: self.from_file(filename) elif dictionary is not None: raise ValueError("PriorDict input dictionary not understood") self.convert_floats_to_delta_functions() if conversion_function is not None: self.conversion_function = conversion_function else: self.conversion_function = self.default_conversion_function
def _split_repr(cls, string): subclass_args = infer_args_from_method(cls.__init__) args = string.split(',') remove = list() for ii, key in enumerate(args): if '(' in key: jj = ii while ')' not in args[jj]: jj += 1 args[ii] = ','.join([args[ii], args[jj]]).strip() remove.append(jj) remove.reverse() for ii in remove: del args[ii] kwargs = dict() for ii, arg in enumerate(args): if '=' not in arg: logger.debug( 'Reading priors with non-keyword arguments is dangerous!') key = subclass_args[ii] val = arg else: split_arg = arg.split('=') key = split_arg[0] val = '='.join(split_arg[1:]) kwargs[key] = val return kwargs
def to_json(self, outdir, label): check_directory_exists_and_if_not_mkdir(outdir) prior_file = os.path.join(outdir, "{}_prior.json".format(label)) logger.debug("Writing priors to {}".format(prior_file)) with open(prior_file, "w") as outfile: json.dump(self._get_json_dict(), outfile, cls=BilbyJsonEncoder, indent=2)
def sample_subset(self, keys=iter([]), size=None): self.convert_floats_to_delta_functions() subset_dict = ConditionalPriorDict({key: self[key] for key in keys}) if not subset_dict._resolved: raise IllegalConditionsException( "The current set of priors contains unresolvable conditions.") samples = dict() for key in subset_dict.sorted_keys: if isinstance(self[key], Constraint): continue elif isinstance(self[key], Prior): try: samples[key] = subset_dict[key].sample( size=size, **subset_dict.get_required_variables(key)) except ValueError: # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) # If that is the case, we sample each sample individually. required_variables = subset_dict.get_required_variables( key) samples[key] = np.zeros(size) for i in range(size): rvars = { key: value[i] for key, value in required_variables.items() } samples[key][i] = subset_dict[key].sample(**rvars) else: logger.debug('{} not a known prior.'.format(key)) return samples
def from_dictionary(self, dictionary): eval_dict = dict(inf=np.inf) for key, val in iteritems(dictionary): if isinstance(val, Prior): continue elif isinstance(val, (int, float)): dictionary[key] = DeltaFunction(peak=val) elif isinstance(val, str): cls = val.split('(')[0] args = '('.join(val.split('(')[1:])[:-1] try: dictionary[key] = DeltaFunction(peak=float(cls)) logger.debug("{} converted to DeltaFunction prior".format(key)) continue except ValueError: pass if "." in cls: module = '.'.join(cls.split('.')[:-1]) cls = cls.split('.')[-1] else: module = __name__.replace( '.' + os.path.basename(__file__).replace('.py', ''), '' ) cls = getattr(import_module(module), cls, cls) if key.lower() in ["conversion_function", "condition_func"]: setattr(self, key, cls) elif isinstance(cls, str): if "(" in val: raise TypeError("Unable to parse prior class {}".format(cls)) else: continue elif (cls.__name__ in ['MultivariateGaussianDist', 'MultivariateNormalDist']): if key not in eval_dict: eval_dict[key] = eval(val, None, eval_dict) elif (cls.__name__ in ['MultivariateGaussian', 'MultivariateNormal']): dictionary[key] = eval(val, None, eval_dict) else: try: dictionary[key] = cls.from_repr(args) except TypeError as e: raise TypeError( "Unable to parse prior, bad entry: {} " "= {}. Error message {}".format(key, val, e) ) elif isinstance(val, dict): logger.warning( 'Cannot convert {} into a prior object. ' 'Leaving as dictionary.'.format(key)) else: raise TypeError( "Unable to parse prior, bad entry: {} " "= {} of type {}".format(key, val, type(val)) ) self.update(dictionary)
def convert_floats_to_delta_functions(self): """ Convert all float parameters to delta functions """ for key in self: if isinstance(self[key], Prior): continue elif isinstance(self[key], float) or isinstance(self[key], int): self[key] = DeltaFunction(self[key]) logger.debug( "{} converted to delta function prior.".format(key)) else: logger.debug( "{} cannot be converted to delta function prior." .format(key))
def _get_from_json_dict(cls, prior_dict): try: cls == getattr(import_module(prior_dict["__module__"]), prior_dict["__name__"]) except ImportError: logger.debug("Cannot import prior module {}.{}".format( prior_dict["__module__"], prior_dict["__name__"])) except KeyError: logger.debug("Cannot find module name to load") for key in ["__module__", "__name__", "__prior_dict__"]: if key in prior_dict: del prior_dict[key] obj = cls(dict()) obj.from_dictionary(prior_dict) return obj
def from_dictionary(self, dictionary): for key, val in iteritems(dictionary): if isinstance(val, str): try: prior = eval(val) if isinstance(prior, (Prior, float, int, str)): val = prior except (NameError, SyntaxError, TypeError): logger.debug( "Failed to load dictionary value {} correctly".format( key)) pass elif isinstance(val, dict): logger.warning('Cannot convert {} into a prior object. ' 'Leaving as dictionary.'.format(key)) self[key] = val
def _initialize_attributes(self): if np.trapz(self._yy, self.xx) != 1: logger.debug( 'Supplied PDF for {} is not normalised, normalising.'.format( self.name)) self._yy /= np.trapz(self._yy, self.xx) self.YY = cumtrapz(self._yy, self.xx, initial=0) # Need last element of cumulative distribution to be exactly one. self.YY[-1] = 1 self.probability_density = interp1d(x=self.xx, y=self._yy, bounds_error=False, fill_value=0) self.cumulative_distribution = interp1d(x=self.xx, y=self.YY, bounds_error=False, fill_value=(0, 1)) self.inverse_cumulative_distribution = interp1d(x=self.YY, y=self.xx, bounds_error=True)
def sample_subset(self, keys=iter([]), size=None): """Draw samples from the prior set for parameters which are not a DeltaFunction Parameters ---------- keys: list List of prior keys to draw samples from size: int or tuple of ints, optional See numpy.random.uniform docs Returns ------- dict: Dictionary of the drawn samples """ self.convert_floats_to_delta_functions() samples = dict() for key in keys: if isinstance(self[key], Constraint): continue elif isinstance(self[key], Prior): samples[key] = self[key].sample(size=size) else: logger.debug('{} not a known prior.'.format(key)) return samples
def plot_interferometer_waveform_posterior(res, interferometer, level=0.9, n_samples=None, start_time=None, end_time=None, outdir='.', signals_to_plot={}): """ Plot the posterior for the waveform in the frequency domain and whitened time domain. If the strain data is passed that will be plotted. If injection parameters can be found, the injection will be plotted. Parameters ========== interferometer: (str, bilby.gw.detector.interferometer.Interferometer) detector to use, if an Interferometer object is passed the data will be overlaid on the posterior level: float, optional symmetric confidence interval to show, default is 90% n_samples: int, optional number of samples to use to calculate the median/interval default is all start_time: float, optional the amount of time before merger to begin the time domain plot. the merger time is defined as the mean of the geocenter time posterior. Default is - 0.4 end_time: float, optional the amount of time before merger to end the time domain plot. the merger time is defined as the mean of the geocenter time posterior. Default is 0.2 Returns ======= fig: figure-handle, only is save=False Notes ----- To reduce the memory footprint we decimate the frequency domain waveforms to have ~4000 entries. This should be sufficient for decent resolution. """ DATA_COLOR = "#ff7f0e" WAVEFORM_COLOR = "#1f77b4" INJECTION_COLOR = "#000000" if not isinstance(interferometer, bilby.gw.detector.Interferometer): raise TypeError('interferometer type must be Interferometer') logger.info("Generating waveform figure for {}".format( interferometer.name)) if n_samples is None: samples = res.posterior else: samples = res.posterior.sample(n_samples, replace=False) if start_time is None: start_time = -0.4 start_time = np.mean(samples.geocent_time) + start_time if end_time is None: end_time = 0.2 end_time = np.mean(samples.geocent_time) + end_time time_idxs = ((interferometer.time_array >= start_time) & (interferometer.time_array <= end_time)) frequency_idxs = np.where(interferometer.frequency_mask)[0] logger.debug("Frequency mask contains {} values".format( len(frequency_idxs))) frequency_idxs = frequency_idxs[::max(1, len(frequency_idxs) // 4000)] logger.debug("Downsampling frequency mask to {} values".format( len(frequency_idxs))) plot_times = interferometer.time_array[time_idxs] plot_times -= interferometer.strain_data.start_time start_time -= interferometer.strain_data.start_time end_time -= interferometer.strain_data.start_time plot_frequencies = interferometer.frequency_array[frequency_idxs] waveform_arguments = res.waveform_arguments waveform_arguments['waveform_approximant'] = "IMRPhenomPv2" waveform_generator = res.waveform_generator_class( duration=res.duration, sampling_frequency=res.sampling_frequency, start_time=res.start_time, frequency_domain_source_model=res.frequency_domain_source_model, parameter_conversion=res.parameter_conversion, waveform_arguments=waveform_arguments) old_font_size = rcParams["font.size"] rcParams["font.size"] = 20 fig, axs = plt.subplots(2, 1, gridspec_kw=dict(height_ratios=[1.5, 1]), figsize=(16, 12.5)) axs[0].loglog(plot_frequencies, asd_from_freq_series( interferometer.frequency_domain_strain[frequency_idxs], 1 / interferometer.strain_data.duration), color=DATA_COLOR, label='Data', alpha=0.3) axs[0].loglog( plot_frequencies, interferometer.amplitude_spectral_density_array[frequency_idxs], color=DATA_COLOR, label='ASD') axs[1].plot( plot_times, infft(interferometer.whitened_frequency_domain_strain * np.sqrt(2. / interferometer.sampling_frequency), sampling_frequency=interferometer.strain_data.sampling_frequency) [time_idxs], color=DATA_COLOR, alpha=0.3) logger.debug('Plotted interferometer data.') fd_waveforms = list() td_waveforms = list() for _, params in tqdm(samples.iterrows(), desc="Processing Samples", total=len(samples)): try: params = dict(params) wf_pols = waveform_generator.frequency_domain_strain(params) fd_waveform = interferometer.get_detector_response(wf_pols, params) fd_waveforms.append(fd_waveform[frequency_idxs]) td_waveform = infft( fd_waveform * np.sqrt(2. / interferometer.sampling_frequency) / interferometer.amplitude_spectral_density_array, res.sampling_frequency)[time_idxs] except Exception as e: logger.debug(f"ERROR: {e}\nparams: {params}") pass td_waveforms.append(td_waveform) fd_waveforms = asd_from_freq_series( fd_waveforms, 1 / interferometer.strain_data.duration) td_waveforms = np.array(td_waveforms) delta = (1 + level) / 2 upper_percentile = delta * 100 lower_percentile = (1 - delta) * 100 logger.debug('Plotting posterior between the {} and {} percentiles'.format( lower_percentile, upper_percentile)) lower_limit = np.mean(fd_waveforms, axis=0)[0] / 1e3 axs[0].loglog(plot_frequencies, np.mean(fd_waveforms, axis=0), color=WAVEFORM_COLOR, label='Mean reconstructed') axs[0].fill_between(plot_frequencies, np.percentile(fd_waveforms, lower_percentile, axis=0), np.percentile(fd_waveforms, upper_percentile, axis=0), color=WAVEFORM_COLOR, label='{}\% credible interval'.format( int(upper_percentile - lower_percentile)), alpha=0.3) axs[1].plot(plot_times, np.mean(td_waveforms, axis=0), color=WAVEFORM_COLOR) axs[1].fill_between(plot_times, np.percentile(td_waveforms, lower_percentile, axis=0), np.percentile(td_waveforms, upper_percentile, axis=0), color=WAVEFORM_COLOR, alpha=0.3) if len(signals_to_plot) > 0: for d in signals_to_plot: params = d['params'] label = d['label'] col = d['color'] try: hf_inj = waveform_generator.frequency_domain_strain(params) hf_inj_det = interferometer.get_detector_response( hf_inj, params) ht_inj_det = infft( hf_inj_det * np.sqrt(2. / interferometer.sampling_frequency) / interferometer.amplitude_spectral_density_array, res.sampling_frequency)[time_idxs] axs[0].loglog(plot_frequencies, asd_from_freq_series( hf_inj_det[frequency_idxs], 1 / interferometer.strain_data.duration), label=label, linestyle=':', color=col) axs[1].plot(plot_times, ht_inj_det, linestyle=':', color=col) logger.debug('Plotted injection.') except IndexError as e: logger.info( 'Failed to plot injection with message {}.'.format(e)) f_domain_x_label = "$f [\\mathrm{Hz}]$" f_domain_y_label = "$\\mathrm{ASD} \\left[\\mathrm{Hz}^{-1/2}\\right]$" t_domain_x_label = "$t - {} [s]$".format( interferometer.strain_data.start_time) t_domain_y_label = "Whitened Strain" axs[0].set_xlim(interferometer.minimum_frequency, interferometer.maximum_frequency) axs[1].set_xlim(start_time, end_time) axs[0].set_ylim(lower_limit) axs[0].set_xlabel(f_domain_x_label) axs[0].set_ylabel(f_domain_y_label) axs[1].set_xlabel(t_domain_x_label) axs[1].set_ylabel(t_domain_y_label) axs[0].legend(loc='lower left', ncol=2) filename = f"{outdir}/{res.label}_{interferometer.name}_waveform.png" plt.tight_layout() fig.savefig(fname=filename, dpi=600) plt.close() logger.info("Waveform figure saved to {}".format(filename)) rcParams["font.size"] = old_font_size
def from_file(self, filename): """ Reads in a prior from a file specification Parameters ---------- filename: str Name of the file to be read in Notes ----- Lines beginning with '#' or empty lines will be ignored. Priors can be loaded from: bilby.core.prior as, e.g., foo = Uniform(minimum=0, maximum=1) floats, e.g., foo = 1 bilby.gw.prior as, e.g., foo = bilby.gw.prior.AlignedSpin() other external modules, e.g., foo = my.module.CustomPrior(...) """ comments = ['#', '\n'] prior = dict() mvgdict = dict(inf=np.inf) # evaluate inf as np.inf with ioopen(filename, 'r', encoding='unicode_escape') as f: for line in f: if line[0] in comments: continue line.replace(' ', '') elements = line.split('=') key = elements[0].replace(' ', '') val = '='.join(elements[1:]).strip() cls = val.split('(')[0] args = '('.join(val.split('(')[1:])[:-1] try: prior[key] = DeltaFunction(peak=float(cls)) logger.debug( "{} converted to DeltaFunction prior".format(key)) continue except ValueError: pass if "." in cls: module = '.'.join(cls.split('.')[:-1]) cls = cls.split('.')[-1] else: module = __name__.replace( '.' + os.path.basename(__file__).replace('.py', ''), '') cls = getattr(import_module(module), cls, cls) if key.lower() in ["conversion_function", "condition_func"]: setattr(self, key, cls) elif (cls.__name__ in [ 'MultivariateGaussianDist', 'MultivariateNormalDist' ]): if key not in mvgdict: mvgdict[key] = eval(val, None, mvgdict) elif (cls.__name__ in ['MultivariateGaussian', 'MultivariateNormal']): prior[key] = eval(val, None, mvgdict) else: try: prior[key] = cls.from_repr(args) except TypeError as e: raise TypeError( "Unable to parse dictionary file {}, bad line: {} " "= {}. Error message {}".format( filename, key, val, e)) self.update(prior)