def __init__(self, main_figure, spec_figure, fit_results_div, source='salt2-extended'): """Assign callbacks to plot widgets Args: main_figure (Figure): Bokeh figure to render plots on fit_results_div (Div): Used to display fit results as text source (Str, Source): SNCosmo source of the desired model """ # Widgets for plotting / fit results self.main_figure = main_figure self.spec_figure = spec_figure self.fit_results_div = fit_results_div self.sn_model_without_pwv = SNModel(source) self.sn_model_with_pwv = SNModel(source, effects=[StaticPWVTrans()], effect_frames=['obs'], effect_names=['']) # Assign callbacks self.plot_button.on_click(self.plot_simulated_flux) self.fit_button.on_click(self.fit_light_curve) self.plot_model_button.on_click(self.plot_current_model)
def runTest(self) -> None: with NamedTemporaryFile() as temp_file: path = Path(temp_file.name) path.touch() with self.assertRaises(FileExistsError): FittingPipeline(cadence='alt_sched', sim_model=SNModel('salt2'), fit_model=SNModel('salt2'), vparams=['x0'], out_path=path)
def setUpClass(cls) -> None: """Initialize an instance of the pipeline""" cls.temp_dir = TemporaryDirectory() cls.pipeline = FittingPipeline(cadence='alt_sched', sim_model=SNModel('salt2'), fit_model=SNModel('salt2'), vparams=['x0'], out_path=Path(cls.temp_dir.name) / 'foo.h5')
def generic_setup(cls, vparams: List[str]) -> None: """Set up nodes for feeding/accumulating a mock fitting pipeline Args: vparams: Parameters to vary in the fit """ cls.packet = create_mock_pipeline_packet(include_fit=False) # Create a mock pipeline for fitting the packet's light-curve source = MockSource([cls.packet]) cls.node = FitLightCurves(SNModel('salt2-extended'), vparams=vparams, num_processes=0) cls.success_target = MockTarget() cls.failure_target = MockTarget() source.output.connect(cls.node.input) cls.node.success_output.connect(cls.success_target.input) cls.node.failure_output.connect(cls.failure_target.input) # Run the mock pipeline for mock_node in (source, cls.node, cls.success_target, cls.failure_target): mock_node.execute() sleep(2)
def setUp(self) -> None: """Set up mock nodes for feeding/accumulating a ``SimulateLightCurves`` instance""" # Set up separate target node for each of the ``SimulateLightCurves`` output connectors self.source = MockSource() self.node = SimulateLightCurves(SNModel('salt2-extended'), num_processes=0) self.success_target = MockTarget() self.failure_target = MockTarget() self.source.output.connect(self.node.input) self.node.success_output.connect(self.success_target.input) self.node.failure_output.connect(self.failure_target.input)
def build_sn_model(variability, pwv_model, source='salt2-extended'): if variability.isnumeric(): transmission_effect = StaticPWVTrans() transmission_effect.set(pwv=float(variability)) effect = transmission_effect elif variability == 'epoch': effect = VariablePWVTrans(pwv_model) elif variability == 'seasonal': effect = SeasonalPWVTrans.from_pwv_model(pwv_model) else: raise NotImplementedError(f'Unknown variability: {variability}') return SNModel( source, effects=[effect], effect_names=[''], effect_frames=['obs'])
def test_simulation_includes_reference_catalog() -> None: """Test simulated light-curves are calibrated if a reference catalog is specified""" model = SNModel('salt2-extended') catalog = create_mock_variable_catalog('G2', 'M5', 'K2') packet = create_mock_pipeline_packet(include_lc=False) node_without_catalog = SimulateLightCurves(model, add_scatter=False) uncalibrated_lc = node_without_catalog.simulate_lc( packet.sim_params, packet.cadence) node_with_catalog = SimulateLightCurves(model, add_scatter=False, catalog=catalog) calibrated_lc = node_with_catalog.simulate_lc(packet.sim_params, packet.cadence) ra, dec = packet.sim_params['ra'], packet.sim_params['dec'] pd.testing.assert_frame_equal( catalog.calibrate_lc(uncalibrated_lc, ra, dec).to_pandas(), calibrated_lc.to_pandas())
def create_mock_pipeline_packet(snid: int = 123456, include_lc: bool = True, include_fit: bool = True) -> PipelinePacket: """Create a ``PipelinePacket`` instance with mock data Args: snid: The unique id value for the pipeline packet include_lc: Include a simulated light_curve in the packet include_fit: Include fit results for the simulated light_curve Returns: A ``PipelinePacket`` instance """ sim_params = { 'ra': 10, 'dec': -5, 't0': 0, 'x1': .1, 'c': .2, 'z': .5, 'x0': 1 } time_values = np.arange(-20, 52) cadence = ObservedCadence( obs_times=np.arange(-20, 52), bands=[f'lsst_hardware_{b}' for b in 'ugrizy'] * (len(time_values) // 6), skynoise=np.full_like(time_values, 0), zp=np.full_like(time_values, 30), zpsys=np.full_like(time_values, 'ab', dtype='U2'), gain=np.full_like(time_values, 1), ) packet = PipelinePacket(snid, cadence=cadence, sim_params=sim_params) if include_lc: model = SNModel('salt2-extended') model.update( {p: v for p, v in sim_params.items() if p in model.param_names}) packet.light_curve = model.simulate_lc(cadence) if include_fit: packet.fit_result, packet.fitted_model = model.fit_lc( packet.light_curve, ['x0', 'x1', 'c']) packet.covariance = packet.fit_result.salt_covariance_linear() return packet
class Callbacks(SimulatedParamWidgets, FittedParamWidgets): """Assigns callbacks and establishes interactive behavior""" plotted_fits = [] plotted_data = [] sim_data = None def __init__(self, main_figure, spec_figure, fit_results_div, source='salt2-extended'): """Assign callbacks to plot widgets Args: main_figure (Figure): Bokeh figure to render plots on fit_results_div (Div): Used to display fit results as text source (Str, Source): SNCosmo source of the desired model """ # Widgets for plotting / fit results self.main_figure = main_figure self.spec_figure = spec_figure self.fit_results_div = fit_results_div self.sn_model_without_pwv = SNModel(source) self.sn_model_with_pwv = SNModel(source, effects=[StaticPWVTrans()], effect_frames=['obs'], effect_names=['']) # Assign callbacks self.plot_button.on_click(self.plot_simulated_flux) self.fit_button.on_click(self.fit_light_curve) self.plot_model_button.on_click(self.plot_current_model) def _clear_fitted_lines(self): """Remove model fits from the plot""" while self.plotted_fits: line = self.plotted_fits.pop() try: self.main_figure.renderers.remove(line) except ValueError: self.spec_figure.renderers.remove(line) def _clear_plotted_object_data(self): """Remove simulated light-curve data points from the plot""" while self.plotted_data: line = self.plotted_data.pop() try: self.main_figure.renderers.remove(line) except ValueError: self.spec_figure.renderers.remove(line) def plot_simulated_flux(self, event=None): """Simulate and plot a light-curve""" # Clear the plot self._clear_plotted_object_data() self._clear_fitted_lines() params = dict(z=self.sim_z_slider.value, t0=self.sim_t0_slider.value, x0=self.sim_x0_slider.value, x1=self.sim_x1_slider.value, c=self.sim_c_slider.value, pwv=self.sim_pwv_slider.value) # Simulate a light-curve obs = mock.create_mock_cadence( np.arange(-10, 51, float(self.sampling_input.value)), BANDS) self.sn_model_with_pwv.update(params) self.sim_data = self.sn_model_with_pwv.simulate_lc( obs, fixed_snr=float(self.snr_input.value), scatter=False) # Scale flux by reference star if 1 in self.checkbox.active: self.sim_data = REFERENCE_CATALOG.calibrate_lc( self.sim_data, self.sim_pwv_slider.value) # Update the main plot with simulated flux data sim_as_astropy = self.sim_data.to_astropy() for band, color in zip(BANDS, palette): band_data = sim_as_astropy[sim_as_astropy['band'] == band] x = band_data['time'] y = band_data['flux'] yerr = band_data['fluxerr'] circ = self.main_figure.circle(x=x, y=y, color=color) self.plotted_data.append(circ) if 0 in self.checkbox.active: err_bar = self.main_figure.multi_line( np.transpose([x, x]).tolist(), np.transpose([y - yerr, y + yerr]).tolist(), color=color) self.plotted_data.append(err_bar) # Update plot of simulated spectrum wave = np.arange(self.sn_model_with_pwv.minwave(), self.sn_model_with_pwv.maxwave()) spec_line = self.spec_figure.line(x=wave, y=self.sn_model_with_pwv.flux( 0, wave)) self.plotted_data.append(spec_line) # Match fitted param sliders to sim param sliders self.fit_t0_slider.update(value=self.sim_t0_slider.value) self.fit_x0_slider.update(value=self.sim_x0_slider.value) self.fit_x1_slider.update(value=self.sim_x1_slider.value) self.fit_c_slider.update(value=self.sim_c_slider.value) def fit_light_curve(self, event=None): """Fit the simulated light-curve and plot the results""" self.sn_model_without_pwv.set(z=self.sim_z_slider.value, t0=self.fit_t0_slider.value, x0=self.fit_x0_slider.value, x1=self.fit_x1_slider.value, c=self.fit_c_slider.value) try: result, fitted_model = self.sn_model_without_pwv.fit_lc( self.sim_data, vparam_names=['t0', 'x0', 'x1', 'c']) except Exception as e: self.fit_results_div.update(text=str(e)) return # Set fitted param sliders to reflect the fitted parameters self.fit_t0_slider.update(value=fitted_model['t0']) self.fit_x0_slider.update(value=fitted_model['x0']) self.fit_x1_slider.update(value=fitted_model['x1']) self.fit_c_slider.update(value=fitted_model['c']) self.plot_current_model() # Update results div keys = 'message', 'ncall', 'chisq', 'ndof', 'vparam_names', 'param_names', 'parameters' text = '<h4>Fit Results</h4>' text += '<br>'.join(f'{k}: {result[k]}' for k in keys) text += '<br><h4>Sim Mag</h4>' text += f'standard::b (AB): {self.sn_model_with_pwv.source_peakmag("standard::b", "AB")}' text += f'<br>peak standard::b (AB): {self.sn_model_with_pwv.source_peakabsmag("standard::b", "AB")}' text += '<br><h4>Fitted Mag</h4>' text += f'standard::b (AB): {fitted_model.source_peakmag("standard::b", "AB")}' text += f'<br>peak standard::b (AB): {fitted_model.source_peakabsmag("standard::b", "AB")}' self.fit_results_div.update(text=text) def plot_current_model(self, event=None): """Plot the model (without PWV) using the initial guess parameters""" self.sn_model_without_pwv.set(z=self.sim_z_slider.value, t0=self.fit_t0_slider.value, x0=self.fit_x0_slider.value, x1=self.fit_x1_slider.value, c=self.fit_c_slider.value) self._clear_fitted_lines() self.fit_results_div.update(text='') time_arr = np.arange(-25, 55, .5) for band, color in zip(BANDS, palette): line = self.main_figure.line( x=time_arr, y=self.sn_model_without_pwv.bandflux(band, time_arr, zp=30, zpsys='ab'), color=color) self.plotted_fits.append(line) wave = np.arange(self.sn_model_without_pwv.minwave(), min(self.sn_model_without_pwv.maxwave(), 12000)) spec_line = self.spec_figure.line(x=wave, y=self.sn_model_without_pwv.flux( 0, wave), color='red') self.plotted_fits.append(spec_line)