Пример #1
0
    def evaluate_bazin(self, param: list, time: np.array):
        """Evaluate the Bazin function given parameter values.

        Parameters
        ----------
        param: list
            List of Bazin parameters in order [a, b, t0, tfall, trise] 
            for all filters, concatenated from blue to red
        time: np.array or list
            Time since maximum where to evaluate the Bazin fit.

        Returns
        -------
        np.array
            Value of the Bazin function in each required time
        """
        # store flux values and starting points
        flux = []
        first_obs = []
        tmax_all = []

        for k in range(len(self.filters)):
            # find day of maximum
            x = range(400)
            y = [
                bazin(epoch, param[0 + k * 5], param[1 + k * 5],
                      param[2 + k * 5], param[3 + k * 5], param[4 + k * 5])
                for epoch in x
            ]

            t_max = x[y.index(max(y))]
            tmax_all.append(t_max)

            for item in time:
                epoch = t_max + item
                flux.append(
                    bazin(epoch, param[0 + k * 5], param[1 + k * 5],
                          param[2 + k * 5], param[3 + k * 5],
                          param[4 + k * 5]))

            first_obs.append(t_max + time[0])

        return np.array(flux), first_obs, tmax_all
Пример #2
0
    def plot_bazin_fit(self, save=False, show=True, output_file=' '):
        """
        Plot data and Bazin fitted function.

        Parameters
        ----------
        save: bool (optional)
             Save figure to file. Default is True.
        show: bool (optinal)
             Display plot in windown. Default is False.
        output_file: str
            Name of file to store the plot.
        """

        plt.figure()

        for i in range(len(self.filters)):
            plt.subplot(2,
                        len(self.filters) / 2 + len(self.filters) % 2, i + 1)
            plt.title('Filter: ' + self.filters[i])

            # filter flag
            filter_flag = self.photometry['band'] == self.filters[i]
            x = self.photometry['mjd'][filter_flag].values
            y = self.photometry['flux'][filter_flag].values - np.min(
                self.photometry['flux'].values)
            yerr = self.photometry['fluxerr'][filter_flag].values

            # shift to avoid large numbers in x-axis
            time = x - min(x)
            xaxis = np.linspace(0, max(time), 500)[:, np.newaxis]
            # calculate fitted function
            fitted_flux = np.array([
                bazin(t, self.bazin_features[i * 5],
                      self.bazin_features[i * 5 + 1],
                      self.bazin_features[i * 5 + 2],
                      self.bazin_features[i * 5 + 3],
                      self.bazin_features[i * 5 + 4]) for t in xaxis
            ])

            plt.errorbar(time, y, yerr=yerr, color='blue', fmt='o')
            plt.plot(xaxis, fitted_flux, color='red', lw=1.5)
            plt.xlabel('MJD - ' + str(min(x)))
            plt.ylabel('FLUXCAL')
            plt.tight_layout()

        if save:
            plt.savefig(output_file)
        if show:
            plt.show()
Пример #3
0
    def plot_GP_fit(self,
                    save=False,
                    show=True,
                    output_file=' ',
                    threeDim=False,
                    plot_bazin=True,
                    figsize=(10, 10)):
        """
        Plot data and GP fitted function.

        Parameters
        ----------
        save: bool (optional)
             Save figure to file. Default is True.
        show: bool (optional)
             Display plot in window. Default is False.
        output_file: str
            Name of file to store the plot.
        threeDim: bool (optional)
            Plot the fit in 3D. Default is False.
        """

        bands = np.unique(self.photometry['lambda_cen'])

        X = np.vstack([
            self.photometry['mjd'] - np.min(self.photometry['mjd']),
            self.photometry['lambda_cen']
        ]).T
        y = self.photometry['flux'] - np.min(self.photometry['flux'])
        dy = self.photometry['fluxerr']

        if len(self.GP_kernel_params[:-1]) == 9:
            kn = kernel_RBF(*self.GP_kernel_params[:-1])

        elif len(self.GP_kernel_params[:-1]) == 3:
            kn = kernel_Matern(*self.GP_kernel_params[:-1])

        else:
            raise ValueError(
                'Unexpected number of kernel hyperparameters in self.GP_kernel_params.'
            )

        gp = GaussianProcessRegressor(kernel=kn,
                                      alpha=dy,
                                      n_restarts_optimizer=0)
        gp.fit(X, y)

        time_plot = np.arange(np.floor(np.min(X[:, 0])),
                              np.ceil(np.max(X[:, 0])) + 1.)

        if not threeDim:

            h, axes = plt.subplots(
                2,
                int(len(self.filters) / 2 + len(self.filters) % 2),
                figsize=figsize)
            for k, ax in enumerate(axes.flatten()):

                wavelength_plot = bands[k]
                tgp, wgp = np.meshgrid(time_plot, wavelength_plot)
                x = np.vstack([tgp.flatten(), wgp.flatten()]).T

                y_plot, sigma_plot = gp.predict(x, return_std=True)

                filter_flag = X[:, 1] == bands[k]
                ax.errorbar(X[:, 0][filter_flag],
                            y[filter_flag],
                            dy[filter_flag],
                            fmt='.',
                            color='k')

                ax.plot(tgp.mean(0), y_plot, label='GP', color='C1')
                ax.fill_between(tgp.mean(0),
                                y_plot + sigma_plot,
                                y_plot - sigma_plot,
                                alpha=0.3,
                                color='C1')

                if plot_bazin:
                    bfilter_flag = self.photometry['band'] == self.filters[k]
                    bx = self.photometry['mjd'][filter_flag].values

                    # shift to avoid large numbers in x-axis
                    btime = bx - min(bx)
                    bxaxis = np.linspace(0, max(btime), 500)[:, np.newaxis]
                    # calculate fitted function
                    fitted_flux = np.array([
                        bazin(t, self.bazin_features[k * 5],
                              self.bazin_features[k * 5 + 1],
                              self.bazin_features[k * 5 + 2],
                              self.bazin_features[k * 5 + 3],
                              self.bazin_features[k * 5 + 4]) for t in bxaxis
                    ])

                    ax.plot(bxaxis, fitted_flux, color='C0', label='Bazin')

                if k == 0:
                    ax.legend(loc='best')

                ax.set_title('Filter: {0}'.format(self.filters[k]))
                ax.set_ylabel('Flux')
                ax.set_xlabel('MJD - {0}'.format(np.min(
                    self.photometry['mjd'])))
            plt.tight_layout()

        else:  #if threeDim

            wavelength_plot = np.linspace(np.floor(np.min(X[:, 1])),
                                          np.ceil(np.max(X[:, 1])) + 1., 100)
            time_grid, wavelength_grid = np.meshgrid(time_plot,
                                                     wavelength_plot)
            x = np.vstack([time_grid.flatten(), wavelength_grid.flatten()]).T

            y_pred, sigma = gp.predict(x, return_std=True)

            fig = plt.figure()
            ax = fig.gca(projection='3d')

            ax.scatter(X[:, 1], X[:, 0], y, c='k')

            surf = ax.plot_surface(wavelength_grid,
                                   time_grid,
                                   y_pred.reshape(np.shape(time_grid)),
                                   cmap=plt.cm.viridis)

            ax.set_xlabel('$\lambda$')
            ax.set_ylabel('MJD - {0}'.format(np.min(self.photometry['mjd'])))
            ax.set_zlabel('Flux')
            plt.tight_layout()

        if save:
            plt.savefig(output_file)
        if show:
            plt.show()