Пример #1
0
    def run_synth(self, planes=None, force=False):
        full_synth_dir = os.path.join(STARFISH, self.synth_dir)
        if not os.path.exists(full_synth_dir):
            os.makedirs(full_synth_dir)

        # Use PHAT AST from the outer field (field 0)
        crowd_path = os.path.join(self.synth_dir, "crowding.dat")
        full_crowd_path = os.path.join(STARFISH, crowd_path)
        tbl = PhatAstTable()
        tbl.write_crowdfile_for_field(full_crowd_path, 0,
                                      bands=PHAT_BANDS)
        crowd = ExtantCrowdingTable(crowd_path)

        # No extinction, yet
        young_av = ExtinctionDistribution()
        old_av = ExtinctionDistribution()
        rel_extinction = np.ones(len(PHAT_BANDS), dtype=float)
        for av in (young_av, old_av):
            av.set_uniform(0.)

        if planes is None:
            planes = self.planes.all_planes

        self.synth = Synth(self.synth_dir, self.builder, self.lockfile, crowd,
                           rel_extinction,
                           young_extinction=young_av,
                           old_extinction=old_av,
                           planes=planes,
                           mass_span=(0.08, 150.),
                           nstars=10000000)
        existing_synth = len(glob(
            os.path.join(STARFISH, self.synth_dir, "z*"))) == 0
        if existing_synth or force:
            self.synth.run_synth(n_cpu=4, clean=False)
Пример #2
0
    def run_synth(self, planes=None, force=False):
        full_synth_dir = os.path.join(STARFISH, self.synth_dir)
        if not os.path.exists(full_synth_dir):
            os.makedirs(full_synth_dir)

        # Use PHAT AST from the outer field (field 0)
        crowd_path = os.path.join(self.synth_dir, "crowding.dat")
        full_crowd_path = os.path.join(STARFISH, crowd_path)
        tbl = PhatAstTable()
        tbl.write_crowdfile_for_field(full_crowd_path, 0, bands=PHAT_BANDS)
        crowd = ExtantCrowdingTable(crowd_path)

        # No extinction, yet
        young_av = ExtinctionDistribution()
        old_av = ExtinctionDistribution()
        rel_extinction = np.ones(len(PHAT_BANDS), dtype=float)
        for av in (young_av, old_av):
            av.set_uniform(0.)

        if planes is None:
            planes = self.planes.all_planes

        self.synth = Synth(self.synth_dir,
                           self.builder,
                           self.lockfile,
                           crowd,
                           rel_extinction,
                           young_extinction=young_av,
                           old_extinction=old_av,
                           planes=planes,
                           mass_span=(0.08, 150.),
                           nstars=10000000)
        existing_synth = len(glob(os.path.join(STARFISH, self.synth_dir,
                                               "z*"))) == 0
        if existing_synth or force:
            self.synth.run_synth(n_cpu=4, clean=False)
Пример #3
0
    def run_synth(self, n_cpu=1, config_only=False):
        full_synth_dir = os.path.join(STARFISH, self.synth_dir)

        self.synth = Synth(self.synth_dir,
                           self.builder,
                           self.lockfile,
                           self.crowd,
                           self.bands,
                           self.rel_extinction,
                           young_extinction=self.young_av,
                           old_extinction=self.old_av,
                           planes=self.all_planes,
                           mass_span=(0.08, 150.),
                           nstars=10000000)
        existing_synth = len(glob(
            os.path.join(full_synth_dir, "z*"))) > 0
        if not config_only:
            if not existing_synth:
                self.synth.run_synth(n_cpu=n_cpu, clean=False)
Пример #4
0
class Pipeline(object):
    """Pipeline for Multi-CMD fitting and comparison"""
    def __init__(self, brick, root_dir, isoc_args=None, phases=None):
        super(Pipeline, self).__init__()
        self.brick = brick
        self.catalog = Catalog(brick)
        self.root_dir = root_dir
        self.isoc_dir = os.path.join(root_dir, 'isoc')
        self.lib_dir = os.path.join(root_dir, 'lib')
        self.synth_dir = os.path.join(root_dir, 'synth')

        self.z_grid = [0.015, 0.019, 0.024]

        self.get_isochrones(isoc_args=isoc_args, phases=phases)
        self.build_lockfile()
        self.planes = PhatPlanes()
        self.run_synth()

        self.fits = OrderedDict()

        self._solution_tables = {}

    def get_solution_table(self, key):
        if key not in self._solution_tables:
            tbl = self.fits[key].solution_table()
            self._solution_tables[key] = tbl
        return tbl

    def get_isochrones(self, isoc_args=None, phases=None):
        if isoc_args is None:
            isoc_args = {}
        if not os.path.exists(os.path.join(STARFISH, self.isoc_dir)):
            for z in self.z_grid:
                r_wfc3 = AgeGridRequest(z,
                                        min_log_age=6.6,
                                        max_log_age=10.13,
                                        delta_log_age=0.02,
                                        photsys='wfc3_wide',
                                        **isoc_args)
                r_acs = AgeGridRequest(z,
                                       min_log_age=6.6,
                                       max_log_age=10.13,
                                       delta_log_age=0.02,
                                       photsys='acs_wfc',
                                       **isoc_args)
                isoc_set = join_isochrone_sets(r_wfc3.isochrone_set,
                                               r_acs.isochrone_set,
                                               left_bands=WFC3_BANDS,
                                               right_bands=ACS_BANDS)
                for isoc in isoc_set:
                    isoc = Isochrone(isoc)
                    isoc.rename_column('F275W1', 'F275W')
                    if phases is not None:
                        sels = []
                        for p in phases:
                            sels.append(np.where(isoc['stage'] == p)[0])
                        s = np.concatenate(sels)
                        isoc = isoc[s]
                    isoc.export_for_starfish(os.path.join(
                        STARFISH, self.isoc_dir),
                                             bands=PHAT_BANDS)

        d = Distance(785 * u.kpc)
        self.builder = LibraryBuilder(self.isoc_dir,
                                      self.lib_dir,
                                      nmag=len(PHAT_BANDS),
                                      dmod=d.distmod.value,
                                      iverb=3)
        if not os.path.exists(self.builder.full_isofile_path):
            self.builder.install()

    def build_lockfile(self):
        if not os.path.exists(os.path.join(STARFISH, self.synth_dir)):
            os.makedirs(os.path.join(STARFISH, self.synth_dir))

        self.lockfile = Lockfile(self.builder.read_isofile(),
                                 self.synth_dir,
                                 unbinned=False)

        # Bin young isochrones
        young_grid = np.linspace(6.5, 8.95, 10)
        for i, logage0 in enumerate(young_grid[:-1]):
            logage0 = logage0
            logage1 = young_grid[i + 1]
            z_str = "0019"
            mean_age = (logage0 + logage1) / 0.2
            name = "z{0}_{1:05.2f}".format(z_str, mean_age)
            self.lockfile.lock_box(name, (logage0, logage1), (0.014, 0.025))

        # Bin old isochrones
        old_grid = np.arange(1e9, 14 * 1e9, 1e9)
        for i, age0 in enumerate(old_grid[:-1]):
            logage0 = np.log10(age0 - 0.05 * 1e9)
            logage1 = np.log10(old_grid[i + 1])
            z_str = "0019"
            mean_age = (logage0 + logage1) / 0.2
            name = "z{0}_{1:05.2f}".format(z_str, mean_age)
            self.lockfile.lock_box(name, (logage0, logage1), (0.014, 0.025))

    def run_synth(self, planes=None, force=False):
        full_synth_dir = os.path.join(STARFISH, self.synth_dir)
        if not os.path.exists(full_synth_dir):
            os.makedirs(full_synth_dir)

        # Use PHAT AST from the outer field (field 0)
        crowd_path = os.path.join(self.synth_dir, "crowding.dat")
        full_crowd_path = os.path.join(STARFISH, crowd_path)
        tbl = PhatAstTable()
        tbl.write_crowdfile_for_field(full_crowd_path, 0, bands=PHAT_BANDS)
        crowd = ExtantCrowdingTable(crowd_path)

        # No extinction, yet
        young_av = ExtinctionDistribution()
        old_av = ExtinctionDistribution()
        rel_extinction = np.ones(len(PHAT_BANDS), dtype=float)
        for av in (young_av, old_av):
            av.set_uniform(0.)

        if planes is None:
            planes = self.planes.all_planes

        self.synth = Synth(self.synth_dir,
                           self.builder,
                           self.lockfile,
                           crowd,
                           rel_extinction,
                           young_extinction=young_av,
                           old_extinction=old_av,
                           planes=planes,
                           mass_span=(0.08, 150.),
                           nstars=10000000)
        existing_synth = len(glob(os.path.join(STARFISH, self.synth_dir,
                                               "z*"))) == 0
        if existing_synth or force:
            self.synth.run_synth(n_cpu=4, clean=False)

    def fit_planes(self, key, color_planes, phot_colors, redo=False):
        fit_dir = os.path.join(self.root_dir, key)
        data_root = os.path.join(fit_dir, "phot.")
        for plane, (band1, band2) in zip(color_planes, phot_colors):
            self.catalog.write(band1, band2, data_root, plane.suffix)
        sfh = SFH(data_root, self.synth, fit_dir, planes=color_planes)
        if (not os.path.exists(sfh.full_outfile_path)) or redo:
            sfh.run_sfh()
        self.fits[key] = sfh

    def show_isoc_phase_sim_hess(self, fig):
        opt_sim = self.planes.get_sim_hess(('f475w', 'f814w'), self.synth,
                                           self.lockfile)
        ir_sim = self.planes.get_sim_hess(('f110w', 'f160w'), self.synth,
                                          self.lockfile)
        opt_cmd = self.planes[('f475w', 'f814w')]
        ir_cmd = self.planes[('f110w', 'f160w')]

        gs = gridspec.GridSpec(2,
                               3,
                               wspace=0.4,
                               bottom=0.2,
                               width_ratios=[1., 1., 0.1])
        ax_opt = fig.add_subplot(gs[0, 0])
        ax_ir = fig.add_subplot(gs[0, 1])
        ax_obs_opt = fig.add_subplot(gs[1, 0])
        ax_obs_ir = fig.add_subplot(gs[1, 1])
        cb_ax = fig.add_subplot(gs[1, 2])

        plot_hess(ax_opt,
                  opt_sim.hess,
                  opt_cmd,
                  opt_sim.origin,
                  imshow_args=None)
        plot_hess(ax_ir, ir_sim.hess, ir_cmd, ir_sim.origin, imshow_args=None)

        c = self.catalog.data['f475w_vega'] - self.catalog.data['f814w_vega']
        contour_hess(ax_obs_opt,
                     c,
                     self.catalog.data['f814w_vega'],
                     opt_cmd.x_span,
                     opt_cmd.y_span,
                     plot_args={'ms': 3})
        plot_isochrone_phases(ax_obs_opt, 'F475W', 'F814W', show_cb=False)
        # opt_cmd.plot_mask(ax_obs_opt)
        ax_obs_opt.set_xlabel(opt_cmd.x_label)
        ax_obs_opt.set_ylabel(opt_cmd.y_label)
        ax_obs_opt.set_xlim(opt_cmd.xlim)
        ax_obs_opt.set_ylim(opt_cmd.ylim)

        c = self.catalog.data['f110w_vega'] - self.catalog.data['f160w_vega']
        contour_hess(ax_obs_ir,
                     c,
                     self.catalog.data['f160w_vega'],
                     ir_cmd.x_span,
                     ir_cmd.y_span,
                     plot_args={'ms': 3})
        plot_isochrone_phases(ax_obs_ir,
                              'F110W',
                              'F160W',
                              show_cb=True,
                              cb_ax=cb_ax)
        # ir_cmd.plot_mask(ax_obs_ir)
        ax_obs_ir.set_xlabel(ir_cmd.x_label)
        ax_obs_ir.set_ylabel(ir_cmd.y_label)
        ax_obs_ir.set_xlim(ir_cmd.xlim)
        ax_obs_ir.set_ylim(ir_cmd.ylim)
        fig.show()

    def plot_contour_hess(self, ax, bands, plane_key):
        plane = self.planes[plane_key]
        c = self.catalog.data[bands[0]] - self.catalog.data[bands[-1]]
        contour_hess(ax,
                     c,
                     self.catalog.data[bands[-1]],
                     plane.x_span,
                     plane.y_span,
                     plot_args={'ms': 3})
        ax.set_xlabel(plane.x_label)
        ax.set_ylabel(plane.y_label)
        ax.set_xlim(*plane.xlim)
        ax.set_ylim(*plane.ylim)

    def plot_sim_hess(self, ax, plane_key):
        plane = self.planes[plane_key]
        sim = self.planes.get_sim_hess(plane_key, self.synth, self.lockfile)
        plot_hess(ax, sim.hess, plane, sim.origin, imshow_args=None)

    def plot_obs_hess(self, arg1):
        pass

    def plot_fit_hess(self, arg1):
        pass

    def plot_predicted_hess(self, arg1):
        pass

    def plot_triptyk(self,
                     fig,
                     ax_obs,
                     ax_model,
                     ax_chi,
                     fit_key,
                     plane_key,
                     xtick=1.,
                     xfmt="%.0f"):
        cmapper = lambda: cubehelix.cmap(startHue=240,
                                         endHue=-300,
                                         minSat=1,
                                         maxSat=2.5,
                                         minLight=.3,
                                         maxLight=.8,
                                         gamma=.9)
        fit = self.fits[fit_key]
        plane = self.planes[plane_key]
        ctp = ChiTriptykPlot(fit, plane)
        ctp.setup_axes(fig,
                       ax_obs=ax_obs,
                       ax_mod=ax_model,
                       ax_chi=ax_chi,
                       major_x=xtick,
                       major_x_fmt=xfmt)
        ctp.plot_obs_in_ax(ax_obs, cmap=cmapper())
        ctp.plot_mod_in_ax(ax_model, cmap=cmapper())
        ctp.plot_chi_in_ax(ax_chi, cmap=cubehelix.cmap())
        ax_obs.text(0.0,
                    1.01,
                    "Observed",
                    transform=ax_obs.transAxes,
                    size=8,
                    ha='left')
        ax_model.text(0.0,
                      1.01,
                      "Model",
                      transform=ax_model.transAxes,
                      size=8,
                      ha='left')
        ax_chi.text(0.0,
                    1.01,
                    r"$\log \chi^2$",
                    transform=ax_chi.transAxes,
                    size=8,
                    ha='left')

    def plot_isoc_grid_ages(self, ax, band1, band2, show_cb=False, cb_ax=None):
        isoc_set = get_demo_age_grid(
            **dict(isoc_kind='parsec_CAF09_v1.2S', photsys_version='yang'))
        cmap = cubehelix.cmap(startHue=240,
                              endHue=-300,
                              minSat=1,
                              maxSat=2.5,
                              minLight=.3,
                              maxLight=.8,
                              gamma=.9)
        norm = mpl.colors.Normalize(vmin=7., vmax=10.1)
        scalar_map = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
        scalar_map.set_array(np.array([isoc.age for isoc in isoc_set]))

        d = Distance(785 * u.kpc)
        for isoc in isoc_set:
            ax.plot(isoc[band1] - isoc[band2],
                    isoc[band2] + d.distmod.value,
                    c=scalar_map.to_rgba(np.log10(isoc.age)),
                    lw=0.8)
        if show_cb:
            cb = plt.colorbar(mappable=scalar_map,
                              cax=cb_ax,
                              ax=ax,
                              ticks=np.arange(6., 10.2))
            cb.set_label(r"log(age)")

    def plot_isoc_grid_phases(self,
                              ax,
                              band1,
                              band2,
                              show_cb=False,
                              cb_ax=None):
        plot_isochrone_phases(ax, band1, band2, show_cb=show_cb, cb_ax=cb_ax)

    def show_lockfile(self,
                      fig,
                      logage_lim=(6.2, 10.2),
                      logzzsol_lim=(-0.2, 0.2)):
        # fig = plt.figure(figsize=(6, 6))
        ax = fig.add_subplot(111)
        plot_isochrone_logage_logzsol(ax, self.builder, c='k', s=8)
        plot_lock_polygons(ax, self.lockfile, facecolor='None', edgecolor='r')
        ax.set_xlim(*logage_lim)
        ax.set_ylim(*logzzsol_lim)
        ax.set_xlabel(r"$\log(A)$")
        ax.set_ylabel(r"$\log(Z/Z_\odot)$")
        fig.show()
Пример #5
0
class Pipeline(object):
    """Pipeline for Multi-CMD fitting and comparison"""
    def __init__(self, brick, root_dir, isoc_args=None, phases=None):
        super(Pipeline, self).__init__()
        self.brick = brick
        self.catalog = Catalog(brick)
        self.root_dir = root_dir
        self.isoc_dir = os.path.join(root_dir, 'isoc')
        self.lib_dir = os.path.join(root_dir, 'lib')
        self.synth_dir = os.path.join(root_dir, 'synth')

        self.z_grid = [0.015, 0.019, 0.024]

        self.get_isochrones(isoc_args=isoc_args, phases=phases)
        self.build_lockfile()
        self.planes = PhatPlanes()
        self.run_synth()

        self.fits = OrderedDict()

        self._solution_tables = {}

    def get_solution_table(self, key):
        if key not in self._solution_tables:
            tbl = self.fits[key].solution_table()
            self._solution_tables[key] = tbl
        return tbl

    def get_isochrones(self, isoc_args=None, phases=None):
        if isoc_args is None:
            isoc_args = {}
        if not os.path.exists(os.path.join(STARFISH, self.isoc_dir)):
            for z in self.z_grid:
                r_wfc3 = AgeGridRequest(z,
                                        min_log_age=6.6,
                                        max_log_age=10.13,
                                        delta_log_age=0.02,
                                        photsys='wfc3_wide', **isoc_args)
                r_acs = AgeGridRequest(z,
                                       min_log_age=6.6,
                                       max_log_age=10.13,
                                       delta_log_age=0.02,
                                       photsys='acs_wfc', **isoc_args)
                isoc_set = join_isochrone_sets(r_wfc3.isochrone_set,
                                               r_acs.isochrone_set,
                                               left_bands=WFC3_BANDS,
                                               right_bands=ACS_BANDS)
                for isoc in isoc_set:
                    isoc = Isochrone(isoc)
                    isoc.rename_column('F275W1', 'F275W')
                    if phases is not None:
                        sels = []
                        for p in phases:
                            sels.append(np.where(isoc['stage'] == p)[0])
                        s = np.concatenate(sels)
                        isoc = isoc[s]
                    isoc.export_for_starfish(os.path.join(STARFISH,
                                                          self.isoc_dir),
                                             bands=PHAT_BANDS)

        d = Distance(785 * u.kpc)
        self.builder = LibraryBuilder(self.isoc_dir, self.lib_dir,
                                      nmag=len(PHAT_BANDS),
                                      dmod=d.distmod.value,
                                      iverb=3)
        if not os.path.exists(self.builder.full_isofile_path):
            self.builder.install()

    def build_lockfile(self):
        if not os.path.exists(os.path.join(STARFISH, self.synth_dir)):
            os.makedirs(os.path.join(STARFISH, self.synth_dir))

        self.lockfile = Lockfile(self.builder.read_isofile(), self.synth_dir,
                                 unbinned=False)

        # Bin young isochrones
        young_grid = np.linspace(6.5, 8.95, 10)
        for i, logage0 in enumerate(young_grid[:-1]):
            logage0 = logage0
            logage1 = young_grid[i + 1]
            z_str = "0019"
            mean_age = (logage0 + logage1) / 0.2
            name = "z{0}_{1:05.2f}".format(z_str, mean_age)
            self.lockfile.lock_box(name, (logage0, logage1), (0.014, 0.025))

        # Bin old isochrones
        old_grid = np.arange(1e9, 14 * 1e9, 1e9)
        for i, age0 in enumerate(old_grid[:-1]):
            logage0 = np.log10(age0 - 0.05 * 1e9)
            logage1 = np.log10(old_grid[i + 1])
            z_str = "0019"
            mean_age = (logage0 + logage1) / 0.2
            name = "z{0}_{1:05.2f}".format(z_str, mean_age)
            self.lockfile.lock_box(name, (logage0, logage1), (0.014, 0.025))

    def run_synth(self, planes=None, force=False):
        full_synth_dir = os.path.join(STARFISH, self.synth_dir)
        if not os.path.exists(full_synth_dir):
            os.makedirs(full_synth_dir)

        # Use PHAT AST from the outer field (field 0)
        crowd_path = os.path.join(self.synth_dir, "crowding.dat")
        full_crowd_path = os.path.join(STARFISH, crowd_path)
        tbl = PhatAstTable()
        tbl.write_crowdfile_for_field(full_crowd_path, 0,
                                      bands=PHAT_BANDS)
        crowd = ExtantCrowdingTable(crowd_path)

        # No extinction, yet
        young_av = ExtinctionDistribution()
        old_av = ExtinctionDistribution()
        rel_extinction = np.ones(len(PHAT_BANDS), dtype=float)
        for av in (young_av, old_av):
            av.set_uniform(0.)

        if planes is None:
            planes = self.planes.all_planes

        self.synth = Synth(self.synth_dir, self.builder, self.lockfile, crowd,
                           rel_extinction,
                           young_extinction=young_av,
                           old_extinction=old_av,
                           planes=planes,
                           mass_span=(0.08, 150.),
                           nstars=10000000)
        existing_synth = len(glob(
            os.path.join(STARFISH, self.synth_dir, "z*"))) == 0
        if existing_synth or force:
            self.synth.run_synth(n_cpu=4, clean=False)

    def fit_planes(self, key, color_planes, phot_colors, redo=False):
        fit_dir = os.path.join(self.root_dir, key)
        data_root = os.path.join(fit_dir, "phot.")
        for plane, (band1, band2) in zip(color_planes, phot_colors):
            self.catalog.write(band1, band2, data_root, plane.suffix)
        sfh = SFH(data_root, self.synth, fit_dir, planes=color_planes)
        if (not os.path.exists(sfh.full_outfile_path)) or redo:
            sfh.run_sfh()
        self.fits[key] = sfh

    def show_isoc_phase_sim_hess(self, fig):
        opt_sim = self.planes.get_sim_hess(('f475w', 'f814w'),
                                           self.synth, self.lockfile)
        ir_sim = self.planes.get_sim_hess(('f110w', 'f160w'),
                                          self.synth, self.lockfile)
        opt_cmd = self.planes[('f475w', 'f814w')]
        ir_cmd = self.planes[('f110w', 'f160w')]

        gs = gridspec.GridSpec(2, 3, wspace=0.4, bottom=0.2,
                               width_ratios=[1., 1., 0.1])
        ax_opt = fig.add_subplot(gs[0, 0])
        ax_ir = fig.add_subplot(gs[0, 1])
        ax_obs_opt = fig.add_subplot(gs[1, 0])
        ax_obs_ir = fig.add_subplot(gs[1, 1])
        cb_ax = fig.add_subplot(gs[1, 2])

        plot_hess(ax_opt, opt_sim.hess, opt_cmd, opt_sim.origin,
                  imshow_args=None)
        plot_hess(ax_ir, ir_sim.hess, ir_cmd, ir_sim.origin,
                  imshow_args=None)

        c = self.catalog.data['f475w_vega'] - self.catalog.data['f814w_vega']
        contour_hess(ax_obs_opt, c, self.catalog.data['f814w_vega'],
                     opt_cmd.x_span, opt_cmd.y_span,
                     plot_args={'ms': 3})
        plot_isochrone_phases(ax_obs_opt, 'F475W', 'F814W', show_cb=False)
        # opt_cmd.plot_mask(ax_obs_opt)
        ax_obs_opt.set_xlabel(opt_cmd.x_label)
        ax_obs_opt.set_ylabel(opt_cmd.y_label)
        ax_obs_opt.set_xlim(opt_cmd.xlim)
        ax_obs_opt.set_ylim(opt_cmd.ylim)

        c = self.catalog.data['f110w_vega'] - self.catalog.data['f160w_vega']
        contour_hess(ax_obs_ir, c, self.catalog.data['f160w_vega'],
                     ir_cmd.x_span, ir_cmd.y_span,
                     plot_args={'ms': 3})
        plot_isochrone_phases(ax_obs_ir, 'F110W', 'F160W', show_cb=True,
                              cb_ax=cb_ax)
        # ir_cmd.plot_mask(ax_obs_ir)
        ax_obs_ir.set_xlabel(ir_cmd.x_label)
        ax_obs_ir.set_ylabel(ir_cmd.y_label)
        ax_obs_ir.set_xlim(ir_cmd.xlim)
        ax_obs_ir.set_ylim(ir_cmd.ylim)
        fig.show()

    def plot_contour_hess(self, ax, bands, plane_key):
        plane = self.planes[plane_key]
        c = self.catalog.data[bands[0]] - self.catalog.data[bands[-1]]
        contour_hess(ax, c, self.catalog.data[bands[-1]],
                     plane.x_span, plane.y_span,
                     plot_args={'ms': 3})
        ax.set_xlabel(plane.x_label)
        ax.set_ylabel(plane.y_label)
        ax.set_xlim(*plane.xlim)
        ax.set_ylim(*plane.ylim)

    def plot_sim_hess(self, ax, plane_key):
        plane = self.planes[plane_key]
        sim = self.planes.get_sim_hess(plane_key, self.synth, self.lockfile)
        plot_hess(ax, sim.hess, plane, sim.origin,
                  imshow_args=None)

    def plot_obs_hess(self, arg1):
        pass

    def plot_fit_hess(self, arg1):
        pass

    def plot_predicted_hess(self, arg1):
        pass

    def plot_triptyk(self, fig, ax_obs, ax_model, ax_chi, fit_key, plane_key,
                     xtick=1., xfmt="%.0f"):
        cmapper = lambda: cubehelix.cmap(startHue=240, endHue=-300, minSat=1,
                                         maxSat=2.5, minLight=.3,
                                         maxLight=.8, gamma=.9)
        fit = self.fits[fit_key]
        plane = self.planes[plane_key]
        ctp = ChiTriptykPlot(fit, plane)
        ctp.setup_axes(fig, ax_obs=ax_obs, ax_mod=ax_model, ax_chi=ax_chi,
                       major_x=xtick, major_x_fmt=xfmt)
        ctp.plot_obs_in_ax(ax_obs, cmap=cmapper())
        ctp.plot_mod_in_ax(ax_model, cmap=cmapper())
        ctp.plot_chi_in_ax(ax_chi, cmap=cubehelix.cmap())
        ax_obs.text(0.0, 1.01, "Observed",
                    transform=ax_obs.transAxes, size=8, ha='left')
        ax_model.text(0.0, 1.01, "Model",
                      transform=ax_model.transAxes, size=8, ha='left')
        ax_chi.text(0.0, 1.01, r"$\log \chi^2$",
                    transform=ax_chi.transAxes, size=8, ha='left')

    def plot_isoc_grid_ages(self, ax, band1, band2,
                            show_cb=False, cb_ax=None):
        isoc_set = get_demo_age_grid(**dict(isoc_kind='parsec_CAF09_v1.2S',
                                            photsys_version='yang'))
        cmap = cubehelix.cmap(startHue=240, endHue=-300,
                              minSat=1, maxSat=2.5, minLight=.3,
                              maxLight=.8, gamma=.9)
        norm = mpl.colors.Normalize(vmin=7., vmax=10.1)
        scalar_map = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
        scalar_map.set_array(np.array([isoc.age for isoc in isoc_set]))

        d = Distance(785 * u.kpc)
        for isoc in isoc_set:
            ax.plot(isoc[band1] - isoc[band2],
                    isoc[band2] + d.distmod.value,
                    c=scalar_map.to_rgba(np.log10(isoc.age)),
                    lw=0.8)
        if show_cb:
            cb = plt.colorbar(mappable=scalar_map,
                              cax=cb_ax, ax=ax,
                              ticks=np.arange(6., 10.2))
            cb.set_label(r"log(age)")

    def plot_isoc_grid_phases(self, ax, band1, band2,
                              show_cb=False, cb_ax=None):
        plot_isochrone_phases(ax, band1, band2,
                              show_cb=show_cb, cb_ax=cb_ax)

    def show_lockfile(self, fig, logage_lim=(6.2, 10.2),
                      logzzsol_lim=(-0.2, 0.2)):
        # fig = plt.figure(figsize=(6, 6))
        ax = fig.add_subplot(111)
        plot_isochrone_logage_logzsol(ax, self.builder, c='k', s=8)
        plot_lock_polygons(ax, self.lockfile, facecolor='None', edgecolor='r')
        ax.set_xlim(*logage_lim)
        ax.set_ylim(*logzzsol_lim)
        ax.set_xlabel(r"$\log(A)$")
        ax.set_ylabel(r"$\log(Z/Z_\odot)$")
        fig.show()
Пример #6
0
class PipelineBase(object):
    """Abstract baseclass for running StarFISH pipelines."""
    __metaclass__ = abc.ABCMeta

    def __init__(self, **kwargs):
        self.root_dir = kwargs.pop('root_dir')
        self.n_synth_cpu = kwargs.pop('n_synth_cpu', 1)
        self.synth_config_only = kwargs.pop('synth_config_only', False)

        # StarFISH product directories
        self.isoc_dir = os.path.join(self.root_dir, 'isoc')
        self.lib_dir = os.path.join(self.root_dir, 'lib')
        self.synth_dir = os.path.join(self.root_dir, 'synth')

        # result caches
        self.fits = OrderedDict()
        self._solution_tables = {}

        print "PipelineBase", kwargs
        if len(kwargs) > 0:
            print "Uncaught arguments:", kwargs
        super(PipelineBase, self).__init__()

        dirs = (self.isoc_dir, self.lib_dir, self.synth_dir)
        for d in dirs:
            if not os.path.exists(os.path.join(STARFISH, d)):
                os.makedirs(os.path.join(STARFISH, d))

        print self.isoc_dir, self.lib_dir, self.synth_dir

        # Now run the pipeline
        self.setup_isochrones()
        print "self.builder.full_isofile_path", self.builder.full_isofile_path
        self.build_lockfile()
        self.build_crowding()
        self.build_extinction()
        self.mask_planes()  # mask planes based on completeness cuts
        if self.synth_config_only:
            self.run_synth(n_cpu=1,
                           config_only=self.synth_config_only)
        else:
            self.run_synth(n_cpu=self.n_synth_cpu,
                           config_only=self.synth_config_only)

    def run_synth(self, n_cpu=1, config_only=False):
        full_synth_dir = os.path.join(STARFISH, self.synth_dir)

        self.synth = Synth(self.synth_dir,
                           self.builder,
                           self.lockfile,
                           self.crowd,
                           self.bands,
                           self.rel_extinction,
                           young_extinction=self.young_av,
                           old_extinction=self.old_av,
                           planes=self.all_planes,
                           mass_span=(0.08, 150.),
                           nstars=10000000)
        existing_synth = len(glob(
            os.path.join(full_synth_dir, "z*"))) > 0
        if not config_only:
            if not existing_synth:
                self.synth.run_synth(n_cpu=n_cpu, clean=False)

    @abc.abstractmethod
    def mask_planes(self):
        pass

    @property
    def hold_template(self):
        return self.lockfile.empty_hold

    def fit(self, fit_key, plane_keys, dataset, fit_dir=None,
            redo=False, hold=None):
        if fit_dir is None:
            fit_dir = os.path.join(self.root_dir, fit_key)
        data_root = os.path.join(fit_dir, "phot.")
        planes = []
        for plane_key in plane_keys:
            plane = self.planes[plane_key]
            planes.append(plane)
            dataset.write_phot(plane.x_mag, plane.y_mag,
                               data_root, plane.suffix)
        sfh = SFH(data_root, self.synth, fit_dir, planes=planes)
        if (not os.path.exists(sfh.full_outfile_path)) or redo:
            sfh.run_sfh(hold=hold)
        self.fits[fit_key] = sfh

    def make_fit_diff_hess(self, dataset, fit_key, plane_key):
        obs_hess = self.make_obs_hess(dataset, plane_key)
        fit_hess = self.make_fit_hess(fit_key, plane_key)
        return Hess(obs_hess.hess - fit_hess.hess,
                    self.planes[plane_key])

    def make_chisq_hess(self, dataset, fit_key, plane_key):
        obs_hess = self.make_obs_hess(dataset, plane_key)
        fit_hess = self.make_fit_hess(fit_key, plane_key)
        sigma = np.sqrt(obs_hess.hess)
        chi = ((obs_hess.hess - fit_hess.hess) / sigma) ** 2.
        return Hess(chi, self.planes[plane_key])

    def compute_fit_chi(self, dataset, fit_key, plane_key, chi_hess=None):
        """Compute the reduced chi-sq for the plane with the given fit.

        Returns both the sum of chi-sq and the total number of pixels in
        in the plane that were not masked.
        """
        if chi_hess is None:
            chi_hess = self.make_chisq_hess(dataset, fit_key, plane_key)
        g = np.where(np.isfinite(chi_hess.masked_hess))
        n_pix = len(g[0])
        chi_sum = chi_hess.masked_hess[g].sum()
        n_amp = len(self.lockfile.active_groups)
        return chi_sum / (n_pix - n_amp)

    def make_sim_hess(self, plane_key):
        return self.get_sim_hess(plane_key)

    def make_fit_hess(self, fit_key, plane_key):
        plane = self.planes[plane_key]
        return SimHess.from_sfh_solution(self.fits[fit_key], plane)

    def make_obs_hess(self, dataset, plane_key):
        plane = self.planes[plane_key]
        x = dataset.get_phot(plane.x_mag)
        y = dataset.get_phot(plane.y_mag)
        return StarCatalogHess(x, y, plane)

    def init_plane_axes(self, ax, plane_key):
        plane = self.planes[plane_key]
        setup_hess_axes(ax, plane, 'lower')

    def plot_sim_hess(self, ax, plane_key, imshow=None):
        plane = self.planes[plane_key]
        sim = self.get_sim_hess(plane_key)
        return plot_hess(ax, sim.hess, plane, sim.origin,
                         imshow_args=imshow)

    def plot_fit_hess(self, ax, fit_key, plane_key, imshow=None):
        plane = self.planes[plane_key]
        fit_hess = SimHess.from_sfh_solution(self.fits[fit_key], plane)
        return plot_hess(ax, fit_hess.hess, plane, fit_hess.origin,
                         imshow_args=imshow)

    def plot_obs_hess(self, ax, dataset, plane_key, imshow=None):
        plane = self.planes[plane_key]
        x = dataset.get_phot(plane.x_mag)
        y = dataset.get_phot(plane.y_mag)
        obs_hess = StarCatalogHess(x, y, plane)
        return plot_hess(ax, obs_hess.hess, plane, obs_hess.origin,
                         imshow_args=imshow)

    def plot_hess_array(self, ax, hess, plane_key, imshow=None, log=True):
        plane = self.planes[plane_key]
        return plot_hess(ax, hess, plane, plane.origin,
                         imshow_args=imshow, log=log)

    def plot_lockfile(self, ax,
                      logage_lim=(6.2, 10.2),
                      logzzsol_lim=(-0.2, 0.2)):
        plot_isochrone_logage_logzsol(ax, self.builder, c='k', s=8)
        plot_lock_polygons(ax, self.lockfile, facecolor='None', edgecolor='r')
        ax.set_xlim(*logage_lim)
        ax.set_ylim(*logzzsol_lim)
        ax.set_xlabel(r"$\log(A)$")
        ax.set_ylabel(r"$\log(Z/Z_\odot)$")

    def plot_triptyk(self, fig, ax_obs, ax_model, ax_chi, fit_key, plane_key,
                     xtick=1., xfmt="%.0f"):
        fit = self.fits[fit_key]
        plane = self.planes[plane_key]
        ctp = ChiTriptykPlot(fit, plane)
        ctp.setup_axes(fig, ax_obs=ax_obs, ax_mod=ax_model, ax_chi=ax_chi,
                       major_x=xtick, major_x_fmt=xfmt)
        ctp.plot_obs_in_ax(ax_obs, cmap=perceptual_rainbow_16.mpl_colormap)
        ctp.plot_mod_in_ax(ax_model, cmap=perceptual_rainbow_16.mpl_colormap)
        ctp.plot_chi_in_ax(ax_chi, cmap=perceptual_rainbow_16.mpl_colormap)
        ax_obs.text(0.0, 1.01, "Observed",
                    transform=ax_obs.transAxes, size=8, ha='left')
        ax_model.text(0.0, 1.01, "Model",
                      transform=ax_model.transAxes, size=8, ha='left')
        ax_chi.text(0.0, 1.01, r"$\log \chi^2$",
                    transform=ax_chi.transAxes, size=8, ha='left')

    def plot_linear_sfh_circles(self, ax, fit_key, ylim=(-0.2, 0.2),
                                amp_key='sfr'):
        sfh = self.fits[fit_key]
        cp = LinearSFHCirclePlot(sfh.solution_table())
        cp.plot_in_ax(ax, max_area=800, amp_key=amp_key)
        for tl in ax.get_ymajorticklabels():
            tl.set_visible(False)
        ax.set_ylim(*ylim)

    def plot_log_sfh_circles(self, ax, fit_key, ylim=(-0.2, 0.2),
                             amp_key='sfr'):
        sfh = self.fits[fit_key]
        cp = SFHCirclePlot(sfh.solution_table())
        cp.plot_in_ax(ax, max_area=800, amp_key=amp_key)
        for logage in np.log10(np.arange(1, 13, 1) * 1e9):
            ax.axvline(logage, c='0.8', zorder=-1)
        ax.set_ylim(*ylim)