def __init__(self,
                 main_figure,
                 spec_figure,
                 fit_results_div,
                 source='salt2-extended'):
        """Assign callbacks to plot widgets

        Args:
            main_figure  (Figure): Bokeh figure to render plots on
            fit_results_div (Div): Used to display fit results as text
            source  (Str, Source): SNCosmo source of the desired model
        """

        # Widgets for plotting / fit results
        self.main_figure = main_figure
        self.spec_figure = spec_figure
        self.fit_results_div = fit_results_div
        self.sn_model_without_pwv = SNModel(source)
        self.sn_model_with_pwv = SNModel(source,
                                         effects=[StaticPWVTrans()],
                                         effect_frames=['obs'],
                                         effect_names=[''])

        # Assign callbacks
        self.plot_button.on_click(self.plot_simulated_flux)
        self.fit_button.on_click(self.fit_light_curve)
        self.plot_model_button.on_click(self.plot_current_model)
Beispiel #2
0
 def runTest(self) -> None:
     with NamedTemporaryFile() as temp_file:
         path = Path(temp_file.name)
         path.touch()
         with self.assertRaises(FileExistsError):
             FittingPipeline(cadence='alt_sched',
                             sim_model=SNModel('salt2'),
                             fit_model=SNModel('salt2'),
                             vparams=['x0'],
                             out_path=path)
Beispiel #3
0
    def setUpClass(cls) -> None:
        """Initialize an instance of the pipeline"""

        cls.temp_dir = TemporaryDirectory()
        cls.pipeline = FittingPipeline(cadence='alt_sched',
                                       sim_model=SNModel('salt2'),
                                       fit_model=SNModel('salt2'),
                                       vparams=['x0'],
                                       out_path=Path(cls.temp_dir.name) /
                                       'foo.h5')
Beispiel #4
0
    def generic_setup(cls, vparams: List[str]) -> None:
        """Set up nodes for feeding/accumulating a mock fitting pipeline

        Args:
            vparams: Parameters to vary in the fit
        """

        cls.packet = create_mock_pipeline_packet(include_fit=False)

        # Create a mock pipeline for fitting the packet's light-curve
        source = MockSource([cls.packet])
        cls.node = FitLightCurves(SNModel('salt2-extended'),
                                  vparams=vparams,
                                  num_processes=0)
        cls.success_target = MockTarget()
        cls.failure_target = MockTarget()

        source.output.connect(cls.node.input)
        cls.node.success_output.connect(cls.success_target.input)
        cls.node.failure_output.connect(cls.failure_target.input)

        # Run the mock pipeline
        for mock_node in (source, cls.node, cls.success_target,
                          cls.failure_target):
            mock_node.execute()
            sleep(2)
    def setUp(self) -> None:
        """Set up mock nodes for feeding/accumulating a ``SimulateLightCurves`` instance"""

        # Set up separate target node for each of the ``SimulateLightCurves`` output connectors
        self.source = MockSource()
        self.node = SimulateLightCurves(SNModel('salt2-extended'),
                                        num_processes=0)
        self.success_target = MockTarget()
        self.failure_target = MockTarget()

        self.source.output.connect(self.node.input)
        self.node.success_output.connect(self.success_target.input)
        self.node.failure_output.connect(self.failure_target.input)
Beispiel #6
0
def build_sn_model(variability, pwv_model, source='salt2-extended'):
    if variability.isnumeric():
        transmission_effect = StaticPWVTrans()
        transmission_effect.set(pwv=float(variability))
        effect = transmission_effect

    elif variability == 'epoch':
        effect = VariablePWVTrans(pwv_model)

    elif variability == 'seasonal':
        effect = SeasonalPWVTrans.from_pwv_model(pwv_model)

    else:
        raise NotImplementedError(f'Unknown variability: {variability}')

    return SNModel(
        source,
        effects=[effect],
        effect_names=[''],
        effect_frames=['obs'])
    def test_simulation_includes_reference_catalog() -> None:
        """Test simulated light-curves are calibrated if a reference catalog is specified"""

        model = SNModel('salt2-extended')
        catalog = create_mock_variable_catalog('G2', 'M5', 'K2')
        packet = create_mock_pipeline_packet(include_lc=False)

        node_without_catalog = SimulateLightCurves(model, add_scatter=False)
        uncalibrated_lc = node_without_catalog.simulate_lc(
            packet.sim_params, packet.cadence)

        node_with_catalog = SimulateLightCurves(model,
                                                add_scatter=False,
                                                catalog=catalog)
        calibrated_lc = node_with_catalog.simulate_lc(packet.sim_params,
                                                      packet.cadence)

        ra, dec = packet.sim_params['ra'], packet.sim_params['dec']
        pd.testing.assert_frame_equal(
            catalog.calibrate_lc(uncalibrated_lc, ra, dec).to_pandas(),
            calibrated_lc.to_pandas())
Beispiel #8
0
def create_mock_pipeline_packet(snid: int = 123456,
                                include_lc: bool = True,
                                include_fit: bool = True) -> PipelinePacket:
    """Create a ``PipelinePacket`` instance with mock data

    Args:
        snid: The unique id value for the pipeline packet
        include_lc: Include a simulated light_curve in the packet
        include_fit: Include fit results for the simulated light_curve

    Returns:
        A ``PipelinePacket`` instance
    """

    sim_params = {
        'ra': 10,
        'dec': -5,
        't0': 0,
        'x1': .1,
        'c': .2,
        'z': .5,
        'x0': 1
    }
    time_values = np.arange(-20, 52)
    cadence = ObservedCadence(
        obs_times=np.arange(-20, 52),
        bands=[f'lsst_hardware_{b}'
               for b in 'ugrizy'] * (len(time_values) // 6),
        skynoise=np.full_like(time_values, 0),
        zp=np.full_like(time_values, 30),
        zpsys=np.full_like(time_values, 'ab', dtype='U2'),
        gain=np.full_like(time_values, 1),
    )
    packet = PipelinePacket(snid, cadence=cadence, sim_params=sim_params)

    if include_lc:
        model = SNModel('salt2-extended')
        model.update(
            {p: v
             for p, v in sim_params.items() if p in model.param_names})
        packet.light_curve = model.simulate_lc(cadence)

        if include_fit:
            packet.fit_result, packet.fitted_model = model.fit_lc(
                packet.light_curve, ['x0', 'x1', 'c'])
            packet.covariance = packet.fit_result.salt_covariance_linear()

    return packet
class Callbacks(SimulatedParamWidgets, FittedParamWidgets):
    """Assigns callbacks and establishes interactive behavior"""

    plotted_fits = []
    plotted_data = []
    sim_data = None

    def __init__(self,
                 main_figure,
                 spec_figure,
                 fit_results_div,
                 source='salt2-extended'):
        """Assign callbacks to plot widgets

        Args:
            main_figure  (Figure): Bokeh figure to render plots on
            fit_results_div (Div): Used to display fit results as text
            source  (Str, Source): SNCosmo source of the desired model
        """

        # Widgets for plotting / fit results
        self.main_figure = main_figure
        self.spec_figure = spec_figure
        self.fit_results_div = fit_results_div
        self.sn_model_without_pwv = SNModel(source)
        self.sn_model_with_pwv = SNModel(source,
                                         effects=[StaticPWVTrans()],
                                         effect_frames=['obs'],
                                         effect_names=[''])

        # Assign callbacks
        self.plot_button.on_click(self.plot_simulated_flux)
        self.fit_button.on_click(self.fit_light_curve)
        self.plot_model_button.on_click(self.plot_current_model)

    def _clear_fitted_lines(self):
        """Remove model fits from the plot"""

        while self.plotted_fits:
            line = self.plotted_fits.pop()
            try:
                self.main_figure.renderers.remove(line)

            except ValueError:
                self.spec_figure.renderers.remove(line)

    def _clear_plotted_object_data(self):
        """Remove simulated light-curve data points from the plot"""

        while self.plotted_data:
            line = self.plotted_data.pop()
            try:
                self.main_figure.renderers.remove(line)

            except ValueError:
                self.spec_figure.renderers.remove(line)

    def plot_simulated_flux(self, event=None):
        """Simulate and plot a light-curve"""

        # Clear the plot
        self._clear_plotted_object_data()
        self._clear_fitted_lines()
        params = dict(z=self.sim_z_slider.value,
                      t0=self.sim_t0_slider.value,
                      x0=self.sim_x0_slider.value,
                      x1=self.sim_x1_slider.value,
                      c=self.sim_c_slider.value,
                      pwv=self.sim_pwv_slider.value)

        # Simulate a light-curve
        obs = mock.create_mock_cadence(
            np.arange(-10, 51, float(self.sampling_input.value)), BANDS)
        self.sn_model_with_pwv.update(params)
        self.sim_data = self.sn_model_with_pwv.simulate_lc(
            obs, fixed_snr=float(self.snr_input.value), scatter=False)

        # Scale flux by reference star
        if 1 in self.checkbox.active:
            self.sim_data = REFERENCE_CATALOG.calibrate_lc(
                self.sim_data, self.sim_pwv_slider.value)

        # Update the main plot with simulated flux data
        sim_as_astropy = self.sim_data.to_astropy()
        for band, color in zip(BANDS, palette):
            band_data = sim_as_astropy[sim_as_astropy['band'] == band]
            x = band_data['time']
            y = band_data['flux']
            yerr = band_data['fluxerr']

            circ = self.main_figure.circle(x=x, y=y, color=color)
            self.plotted_data.append(circ)

            if 0 in self.checkbox.active:
                err_bar = self.main_figure.multi_line(
                    np.transpose([x, x]).tolist(),
                    np.transpose([y - yerr, y + yerr]).tolist(),
                    color=color)

                self.plotted_data.append(err_bar)

        # Update plot of simulated spectrum
        wave = np.arange(self.sn_model_with_pwv.minwave(),
                         self.sn_model_with_pwv.maxwave())
        spec_line = self.spec_figure.line(x=wave,
                                          y=self.sn_model_with_pwv.flux(
                                              0, wave))
        self.plotted_data.append(spec_line)

        # Match fitted param sliders to sim param sliders
        self.fit_t0_slider.update(value=self.sim_t0_slider.value)
        self.fit_x0_slider.update(value=self.sim_x0_slider.value)
        self.fit_x1_slider.update(value=self.sim_x1_slider.value)
        self.fit_c_slider.update(value=self.sim_c_slider.value)

    def fit_light_curve(self, event=None):
        """Fit the simulated light-curve and plot the results"""

        self.sn_model_without_pwv.set(z=self.sim_z_slider.value,
                                      t0=self.fit_t0_slider.value,
                                      x0=self.fit_x0_slider.value,
                                      x1=self.fit_x1_slider.value,
                                      c=self.fit_c_slider.value)

        try:
            result, fitted_model = self.sn_model_without_pwv.fit_lc(
                self.sim_data, vparam_names=['t0', 'x0', 'x1', 'c'])

        except Exception as e:
            self.fit_results_div.update(text=str(e))
            return

        # Set fitted param sliders to reflect the fitted parameters
        self.fit_t0_slider.update(value=fitted_model['t0'])
        self.fit_x0_slider.update(value=fitted_model['x0'])
        self.fit_x1_slider.update(value=fitted_model['x1'])
        self.fit_c_slider.update(value=fitted_model['c'])
        self.plot_current_model()

        # Update results div
        keys = 'message', 'ncall', 'chisq', 'ndof', 'vparam_names', 'param_names', 'parameters'
        text = '<h4>Fit Results</h4>'
        text += '<br>'.join(f'{k}: {result[k]}' for k in keys)

        text += '<br><h4>Sim Mag</h4>'
        text += f'standard::b (AB): {self.sn_model_with_pwv.source_peakmag("standard::b", "AB")}'
        text += f'<br>peak standard::b (AB): {self.sn_model_with_pwv.source_peakabsmag("standard::b", "AB")}'

        text += '<br><h4>Fitted Mag</h4>'
        text += f'standard::b (AB): {fitted_model.source_peakmag("standard::b", "AB")}'
        text += f'<br>peak standard::b (AB): {fitted_model.source_peakabsmag("standard::b", "AB")}'
        self.fit_results_div.update(text=text)

    def plot_current_model(self, event=None):
        """Plot the model (without PWV) using the initial guess parameters"""

        self.sn_model_without_pwv.set(z=self.sim_z_slider.value,
                                      t0=self.fit_t0_slider.value,
                                      x0=self.fit_x0_slider.value,
                                      x1=self.fit_x1_slider.value,
                                      c=self.fit_c_slider.value)

        self._clear_fitted_lines()
        self.fit_results_div.update(text='')
        time_arr = np.arange(-25, 55, .5)
        for band, color in zip(BANDS, palette):
            line = self.main_figure.line(
                x=time_arr,
                y=self.sn_model_without_pwv.bandflux(band,
                                                     time_arr,
                                                     zp=30,
                                                     zpsys='ab'),
                color=color)

            self.plotted_fits.append(line)

        wave = np.arange(self.sn_model_without_pwv.minwave(),
                         min(self.sn_model_without_pwv.maxwave(), 12000))
        spec_line = self.spec_figure.line(x=wave,
                                          y=self.sn_model_without_pwv.flux(
                                              0, wave),
                                          color='red')
        self.plotted_fits.append(spec_line)