コード例 #1
0
ファイル: rewinder2_figures.py プロジェクト: adrn/streams
def simulated_streams():

    filename = os.path.join(plot_path, "simulated_streams.{}".format(ext))
    fig,axes = plt.subplots(2,4,figsize=grid_figsize,
                            sharex=True, sharey=True)

    ticks = [-100,-50,0,50]
    alphas = [0.2, 0.27, 0.34, 0.4]
    rcparams = {'lines.linestyle' : 'none',
                'lines.marker' : ','}

    with rc_context(rc=rcparams):
        for ii,_m in enumerate(range(6,9+1)):
            alpha = alphas[ii]
            mass = "2.5e{}".format(_m)
            print(mass)
            m = float(mass)

            data_filename = os.path.join(streamspath, "data", "observed_particles",
                                         "2.5e{}.hdf5".format(_m))
            cfg_filename = os.path.join(streamspath, "config", "exp2.yml".format(_m))
            data = read_hdf5(data_filename)
            true_particles = data["true_particles"].to_frame(galactocentric)
            config = read_config(cfg_filename)
            idx = config['particle_idx']

            sgr = SgrSimulation(sgr_path.format(_m),snapfile)
            p = sgr.particles()
            p_bound = sgr.particles(expr="tub==0")

            axes[0,ii].text(0.5, 1.05, r"$2.5\times10^{}M_\odot$".format(_m),
                   horizontalalignment='center',
                   fontsize=24,
                   transform=axes[0,ii].transAxes)

            axes[0,ii].plot(p["x"].value, p["y"].value,
                            alpha=alpha, rasterized=True, color='#555555')
            axes[1,ii].plot(p["x"].value, p["z"].value,
                            alpha=alpha, rasterized=True, color='#555555')

            if _m == 8:
                axes[0,ii].plot(true_particles["x"].value[idx],
                             true_particles["y"].value[idx],
                             marker='+', markeredgewidth=1.5,
                             markersize=8, alpha=0.9, color='k')
                axes[1,ii].plot(true_particles["x"].value[idx],
                             true_particles["z"].value[idx],
                             marker='+', markeredgewidth=1.5,
                             markersize=8, alpha=0.9, color='k')
            axes[1,ii].set_xticks(ticks)
            axes[1,ii].set_xlabel("$X$ [kpc]")

    axes[0,0].set_ylabel("$Y$ [kpc]")
    axes[1,0].set_ylabel("$Z$ [kpc]")

    axes[0,0].set_yticks(ticks)
    axes[1,0].set_yticks(ticks)
    axes[-1,-1].set_xlim(-110,75)
    axes[-1,-1].set_ylim(-110,75)

    fig.tight_layout()
    fig.subplots_adjust(top=0.92, hspace=0.025, wspace=0.1)
    fig.savefig(filename, dpi=200)
コード例 #2
0
ファイル: talk_figures.py プロジェクト: adrn/streams
def q_p(**kwargs):

    filename = os.path.join(plot_path, "q_p.pdf")
    fig,axes = plt.subplots(2,4,figsize=(14,7.5),
                            sharex=True, sharey=True)

    bins = np.linspace(0.,10,40)
    nparticles = 5000
    for kk,_m in enumerate(range(6,9+1)):
        mass = "2.5e{}".format(_m)
        m = float(mass)
        print(mass)

        sgr = SgrSimulation(sgr_path.format(_m), snapfile)
        p = sgr.particles(n=nparticles, expr="(tub!=0)")#" & (tub<400)")
        tub = p.tub
        s = sgr.satellite()

        potential = LawMajewski2010()

        X = np.vstack((s._X[...,:3], p._X[...,:3].copy()))
        V = np.vstack((s._X[...,3:], p._X[...,3:].copy()))
        integrator = LeapfrogIntegrator(potential._acceleration_at,
                                        np.array(X), np.array(V),
                                        args=(X.shape[0], np.zeros_like(X)))
        ts, rs, vs = integrator.run(t1=sgr.t1, t2=sgr.t2, dt=-1.)

        s_orbit = np.vstack((rs[:,0][:,np.newaxis].T, vs[:,0][:,np.newaxis].T)).T
        p_orbits = np.vstack((rs[:,1:].T, vs[:,1:].T)).T
        t_idx = np.array([np.argmin(np.fabs(ts - t)) for t in p.tub])

        p_x = np.array([p_orbits[jj,ii] for ii,jj in enumerate(t_idx)])
        s_x = np.array([s_orbit[jj,0] for jj in t_idx])

        #############################################
        # determine tail_bit
        diff = p_x-s_x
        norm_r = s_x[:,:3] / np.sqrt(np.sum(s_x[:,:3]**2, axis=-1))[:,np.newaxis]
        norm_diff_r = diff[:,:3] / np.sqrt(np.sum(diff[:,:3]**2, axis=-1))[:,np.newaxis]
        dot_prod_r = np.sum(norm_diff_r*norm_r, axis=-1)
        tail_bit = (dot_prod_r > 0.).astype(int)*2 - 1
        #############################################

        r_tide = potential._tidal_radius(m, s_orbit[...,:3])#*0.69336
        s_R_orbit = np.sqrt(np.sum(s_orbit[...,:3]**2, axis=-1))
        a_pm = (s_R_orbit + r_tide*tail_bit) / s_R_orbit
        q = np.sqrt(np.sum((p_x[:,:3] - s_x[:,:3])**2,axis=-1))

        f = r_tide / s_R_orbit
        s_V = np.sqrt(np.sum(s_orbit[...,3:]**2, axis=-1))
        vdisp = s_V * f / 1.4
        p = np.sqrt(np.sum((p_x[:,3:] - s_x[...,3:])**2,axis=-1))

        fig,axes = plt.subplots(2,1,figsize=(10,6),sharex=True)

        axes[0].plot(tub, q, marker='.', alpha=0.5, color='#666666')
        axes[0].plot(ts, r_tide*1.4, linewidth=2., alpha=0.8, color='k',
                     linestyle='-', marker=None)
        axes[0].set_ylim(0., max(r_tide)*4)

        axes[1].plot(tub, (p*u.kpc/u.Myr).to(u.km/u.s).value,
                     marker='.', alpha=0.5, color='#666666')
        axes[1].plot(ts, (vdisp*u.kpc/u.Myr).to(u.km/u.s).value, color='k',
                     linewidth=2., alpha=0.75, linestyle='-', marker=None)

        M_enc = potential._enclosed_mass(s_R_orbit)
        #delta_E = 4/3.*G.decompose(usys).value**2*m*(M_enc / s_V)**2*r_tide**2/s_R_orbit**4
        delta_v2 = 4/3.*G.decompose(usys).value**2*(M_enc / s_V)**2*\
                        np.mean(r_tide**2)/s_R_orbit**4
        delta_v = (np.sqrt(2*delta_v2)*u.kpc/u.Myr).to(u.km/u.s).value

        axes[1].plot(ts, delta_v, linewidth=2., color='#2166AC',
                     alpha=0.75, linestyle='--', marker=None)

        axes[1].set_ylim(0., max((vdisp*u.kpc/u.Myr).to(u.km/u.s).value)*4)

        axes[0].set_xlim(min(ts), max(ts))

        fig.savefig(os.path.join(plot_path, "q_p_{}.png".format(mass)),
                    transparent=True)
コード例 #3
0
ファイル: rewinder2_figures.py プロジェクト: adrn/streams
def total_rv():

    filenamer = os.path.join(plot_path, "rel_r.png")
    filenamev = os.path.join(plot_path, "rel_v.png")

    figr,axesr = plt.subplots(4,1,figsize=(10,14),
                              sharex=True)
    figv,axesv = plt.subplots(4,1,figsize=(10,14),
                              sharex=True)

    nparticles = 2000
    for k,_m in enumerate(range(6,9+1)):
        mass = "2.5e{}".format(_m)
        m = float(mass)
        print(mass)

        sgr = SgrSimulation(sgr_path.format(_m),snapfile)
        p = sgr.particles(n=nparticles, expr=expr)
        s = sgr.satellite()

        X = np.vstack((s._X[...,:3], p._X[...,:3].copy()))
        V = np.vstack((s._X[...,3:], p._X[...,3:].copy()))
        integrator = LeapfrogIntegrator(sgr.potential._acceleration_at,
                                        np.array(X), np.array(V),
                                        args=(X.shape[0], np.zeros_like(X)))
        ts, rs, vs = integrator.run(t1=sgr.t1, t2=sgr.t2, dt=-1.)

        s_orbit = np.vstack((rs[:,0][:,np.newaxis].T, vs[:,0][:,np.newaxis].T)).T
        p_orbits = np.vstack((rs[:,1:].T, vs[:,1:].T)).T
        t_idx = np.array([np.argmin(np.fabs(ts - t)) for t in p.tub])

        m_t = (-s.mdot*ts + s.m0)[:,np.newaxis]
        s_R = np.sqrt(np.sum(s_orbit[...,:3]**2, axis=-1))
        s_V = np.sqrt(np.sum(s_orbit[...,3:]**2, axis=-1))
        r_tide = sgr.potential._tidal_radius(m_t, s_orbit[...,:3])
        v_disp = s_V * r_tide / s_R

        # cartesian basis to project into
        x_hat = s_orbit[...,:3] / np.sqrt(np.sum(s_orbit[...,:3]**2, axis=-1))[...,np.newaxis]
        _y_hat = s_orbit[...,3:] / np.sqrt(np.sum(s_orbit[...,3:]**2, axis=-1))[...,np.newaxis]
        z_hat = np.cross(x_hat, _y_hat)
        y_hat = -np.cross(x_hat, z_hat)

        # translate to satellite position
        rel_orbits = p_orbits - s_orbit
        rel_pos = rel_orbits[...,:3]
        rel_vel = rel_orbits[...,3:]

        # project onto each
        X = np.sum(rel_pos * x_hat, axis=-1)
        Y = np.sum(rel_pos * y_hat, axis=-1)
        Z = np.sum(rel_pos * z_hat, axis=-1)
        RR = np.sqrt(X**2 + Y**2 + Z**2)

        VX = np.sum(rel_vel * x_hat, axis=-1)
        VY = np.sum(rel_vel * y_hat, axis=-1)
        VZ = np.sum(rel_vel * z_hat, axis=-1)
        VV = (np.sqrt(VX**2 + VY**2 + VZ**2)*u.kpc/u.Myr).to(u.km/u.s).value
        v_disp = (v_disp*u.kpc/u.Myr).to(u.km/u.s).value

        _tcross = r_tide / np.sqrt(G.decompose(usys).value*m/r_tide)
        for ii,jj in enumerate(t_idx):
            #tcross = r_tide[jj,0] / _v[jj,ii]
            tcross = _tcross[jj]
            bnd = int(tcross / 2)

            ix1,ix2 = jj-bnd, jj+bnd
            if ix1 < 0: ix1 = 0
            if ix2 > max(sgr.t1,sgr.t2): ix2 = -1
            axesr[k].plot(ts[ix1:ix2],
                          RR[ix1:ix2,ii],
                          linestyle='-', alpha=0.1, marker=None, color='#555555', zorder=-1)

            axesv[k].plot(ts[ix1:ix2],
                          VV[ix1:ix2,ii],
                          linestyle='-', alpha=0.1, marker=None, color='#555555', zorder=-1)

        axesr[k].plot(ts, r_tide*2., marker=None)

        axesr[k].set_xlim(ts.min(), ts.max())
        axesv[k].set_xlim(ts.min(), ts.max())

        axesr[k].set_ylim(0,max(r_tide)*7)
        axesv[k].set_ylim(0,max(v_disp)*7)

        # axes[1,k].set_xlabel(r"$x_1$")

        # if k == 0:
        #     axes[0,k].set_ylabel(r"$x_2$")
        #     axes[1,k].set_ylabel(r"$x_3$")

        axesr[k].text(3000, max(r_tide)*5, r"$2.5\times10^{}M_\odot$".format(_m))
        axesv[k].text(3000, max(v_disp)*5, r"$2.5\times10^{}M_\odot$".format(_m))

    axesr[-1].set_xlabel("time [Myr]")
    axesv[-1].set_xlabel("time [Myr]")

    figr.suptitle("Relative distance", fontsize=26)
    figr.tight_layout()
    figr.subplots_adjust(top=0.92, hspace=0.025, wspace=0.1)
    figr.savefig(filenamer)

    figv.suptitle("Relative velocity", fontsize=26)
    figv.tight_layout()
    figv.subplots_adjust(top=0.92, hspace=0.025, wspace=0.1)
    figv.savefig(filenamev)
コード例 #4
0
ファイル: rewinder2_figures.py プロジェクト: adrn/streams
def Lpts():
    np.random.seed(42)

    potential = LawMajewski2010()
    filename = os.path.join(plot_path, "Lpts_r.{}".format(ext))
    filename2 = os.path.join(plot_path, "Lpts_v.{}".format(ext))

    fig,axes = plt.subplots(2,4,figsize=grid_figsize,
                            sharex=True, sharey=True)
    fig2,axes2 = plt.subplots(2,4,figsize=grid_figsize,
                              sharex=True, sharey=True)

    bins = np.linspace(-3,3,50)
    nparticles = 2000
    for k,_m in enumerate(range(6,9+1)):
        mass = "2.5e{}".format(_m)
        m = float(mass)
        print(mass)

        sgr = SgrSimulation(sgr_path.format(_m),snapfile)
        p = sgr.particles(n=nparticles, expr=expr)
        s = sgr.satellite()
        dt = -1.

        coord, r_tide, v_disp = particles_x1x2x3(p, s,
                                                 sgr.potential,
                                                 sgr.t1, sgr.t2, dt,
                                                 at_tub=False)
        (x1,x2,x3,vx1,vx2,vx3) = coord
        ts = np.arange(sgr.t1,sgr.t2+dt,dt)
        t_idx = np.array([np.argmin(np.fabs(ts - t)) for t in p.tub])

        _tcross = r_tide / np.sqrt(G.decompose(usys).value*m/r_tide)
        for ii,jj in enumerate(t_idx):
            #tcross = r_tide[jj,0] / _v[jj,ii]
            tcross = _tcross[jj]
            bnd = int(tcross / 2)

            ix1,ix2 = jj-bnd, jj+bnd
            if ix1 < 0: ix1 = 0
            if ix2 > max(sgr.t1,sgr.t2): ix2 = -1

            axes[0,k].set_rasterization_zorder(1)
            axes[0,k].plot(x1[jj-bnd:jj+bnd,ii]/r_tide[jj-bnd:jj+bnd,0],
                           x2[jj-bnd:jj+bnd,ii]/r_tide[jj-bnd:jj+bnd,0],
                           linestyle='-', alpha=0.1, marker=None, color='#555555',
                           zorder=-1)

            axes[1,k].set_rasterization_zorder(1)
            axes[1,k].plot(x1[jj-bnd:jj+bnd,ii]/r_tide[jj-bnd:jj+bnd,0],
                           x3[jj-bnd:jj+bnd,ii]/r_tide[jj-bnd:jj+bnd,0],
                           linestyle='-', alpha=0.1, marker=None, color='#555555',
                           zorder=-1)

        circ = Circle((0,0), radius=1., fill=False, alpha=0.75,
                      edgecolor='k', linestyle='solid')
        axes[0,k].add_patch(circ)
        circ = Circle((0,0), radius=1., fill=False, alpha=0.75,
                     edgecolor='k', linestyle='solid')
        axes[1,k].add_patch(circ)

        axes[0,k].axhline(0., color='k', alpha=0.75)
        axes[1,k].axhline(0., color='k', alpha=0.75)

        axes[0,k].set_xlim(-5,5)
        axes[0,k].set_ylim(axes[0,k].get_xlim())

        axes[1,k].set_xlabel(r"$x_1/r_{\rm tide}$")

        if k == 0:
            axes[0,k].set_ylabel(r"$x_2/r_{\rm tide}$")
            axes[1,k].set_ylabel(r"$x_3/r_{\rm tide}$")

        _tcross = r_tide / np.sqrt(G.decompose(usys).value*m/r_tide)
        for ii,jj in enumerate(t_idx):
            #tcross = r_tide[jj,0] / _v[jj,ii]
            tcross = _tcross[jj]
            bnd = int(tcross / 2)

            ix1,ix2 = jj-bnd, jj+bnd
            if ix1 < 0: ix1 = 0
            if ix2 > max(sgr.t1,sgr.t2): ix2 = -1

            axes2[0,k].set_rasterization_zorder(1)
            axes2[0,k].plot(vx1[jj-bnd:jj+bnd,ii]/v_disp[jj-bnd:jj+bnd,0],
                            vx2[jj-bnd:jj+bnd,ii]/v_disp[jj-bnd:jj+bnd,0],
                            linestyle='-', alpha=0.1, marker=None, color='#555555',
                            zorder=-1)

            axes2[1,k].set_rasterization_zorder(1)
            axes2[1,k].plot(vx1[jj-bnd:jj+bnd,ii]/v_disp[jj-bnd:jj+bnd,0],
                            vx3[jj-bnd:jj+bnd,ii]/v_disp[jj-bnd:jj+bnd,0],
                            linestyle='-', alpha=0.1, marker=None, color='#555555',
                            zorder=-1)

        circ = Circle((0,0), radius=1., fill=False, alpha=0.75,
                      edgecolor='k', linestyle='solid')
        axes2[0,k].add_patch(circ)
        circ = Circle((0,0), radius=1., fill=False, alpha=0.75,
                      edgecolor='k', linestyle='solid')
        axes2[1,k].add_patch(circ)

        axes2[0,k].axhline(0., color='k', alpha=0.75)
        axes2[1,k].axhline(0., color='k', alpha=0.75)

        axes2[1,k].set_xlim(-5,5)
        axes2[1,k].set_ylim(axes2[1,k].get_xlim())

        axes2[1,k].set_xlabel(r"$v_{x_1}/\sigma_v$")

        if k == 0:
            axes2[0,k].set_ylabel(r"$v_{x_2}/\sigma_v$")
            axes2[1,k].set_ylabel(r"$v_{x_3}/\sigma_v$")

        axes[0,k].text(0.5, 1.05, r"$2.5\times10^{}M_\odot$".format(_m),
                       horizontalalignment='center',
                       fontsize=24,
                       transform=axes[0,k].transAxes)

        axes2[0,k].text(0.5, 1.05, r"$2.5\times10^{}M_\odot$".format(_m),
                        horizontalalignment='center',
                        fontsize=24,
                        transform=axes2[0,k].transAxes)

    fig.tight_layout()
    fig.subplots_adjust(top=0.92, hspace=0.025, wspace=0.1)
    fig.savefig(filename)

    fig2.tight_layout()
    fig2.subplots_adjust(top=0.92, hspace=0.025, wspace=0.1)
    fig2.savefig(filename2)
コード例 #5
0
ファイル: infer_potential.py プロジェクト: adrn/streams
def main(config_file, mpi=False, threads=None, overwrite=False, continue_sampler=False):
    """ TODO: """

    # get a pool object given the configuration parameters
    # -- This needs to go here so I don't read in the particle file for each thread. --
    pool = get_pool(mpi=mpi, threads=threads)

    # read configuration from a YAML file
    config = io.read_config(config_file)
    np.random.seed(config["seed"])
    random.seed(config["seed"])

    if not os.path.exists(config['streams_path']):
        raise IOError("Specified streams path '{}' doesn't exist!".format(config['streams_path']))
    logger.debug("Path to streams project: {}".format(config['streams_path']))

    # the path to write things to
    output_path = config["output_path"]
    logger.debug("Will write data to:\n\t{}".format(output_path))
    cache_output_path = os.path.join(output_path, "cache")

    # get a StreamModel from a config dict
    model = si.StreamModel.from_config(config)
    logger.info("Model has {} parameters".format(model.nparameters))

    if os.path.exists(cache_output_path) and overwrite:
        logger.info("Writing over output path '{}'".format(cache_output_path))
        logger.debug("Deleting files: '{}'".format(os.listdir(cache_output_path)))
        shutil.rmtree(cache_output_path)

    # emcee parameters
    # read in the number of walkers to use
    nwalkers = config["walkers"]
    nsteps = config["steps"]
    output_every = config.get("output_every", None)
    nburn = config.get("burn_in", 0)
    start_truth = config.get("start_truth", False)
    a = config.get("a", 2.) # emcee tuning param

    if not os.path.exists(cache_output_path) and not continue_sampler:
        logger.info("Output path '{}' doesn't exist, running inference..."\
                    .format(cache_output_path))
        os.mkdir(cache_output_path)

        # sample starting positions
        p0 = model.sample_priors(size=nwalkers,
                                 start_truth=start_truth)
        logger.debug("Priors sampled...")

        if nburn > 0:
            sampler = si.StreamModelSampler(model, nwalkers, pool=pool, a=a)

            time0 = time.time()
            logger.info("Burning in sampler for {} steps...".format(nburn))
            pos, xx, yy = sampler.run_mcmc(p0, nburn)

            pos = fix_whack_walkers(pos, sampler.acceptance_fraction,
                                    sampler.flatlnprobability, sampler.flatchain,
                                    threshold=config.get("acceptance_threshold", None))

            t = time.time() - time0
            logger.debug("Spent {} seconds on burn-in...".format(t))

        else:
            pos = p0

        if nsteps > 0:
            sampler = si.StreamModelSampler(model, nwalkers, pool=pool, a=a)
            sampler.run_inference(pos, nsteps, path=cache_output_path, first_step=0,
                                  output_every=output_every,
                                  output_file_fmt="inference_{:06d}.hdf5")

    elif os.path.exists(cache_output_path) and not continue_sampler:
        logger.info("Output path '{}' already exists, not running sampler..."\
                    .format(cache_output_path))

    elif os.path.exists(cache_output_path) and continue_sampler:
        if len(os.listdir(cache_output_path)) == 0:
            logger.error("No files in path: {}".format(cache_output_path))
            sys.exit(1)

        continue_files = glob.glob(os.path.join(cache_output_path, "inference_*.hdf5"))
        continue_file = config.get("continue_file", sorted(continue_files)[-1])
        continue_file = os.path.join(cache_output_path, continue_file)
        if not os.path.exists(continue_file):
            logger.error("File {} doesn't exist!".format(continue_file))
            sys.exit(1)

        with h5py.File(continue_file, "r") as f:
            old_chain = f["chain"].value
            old_flatchain = np.vstack(old_chain)
            old_lnprobability = f["lnprobability"].value
            old_flatlnprobability = np.vstack(old_lnprobability)
            old_acc_frac = f["acceptance_fraction"].value
            last_step = f["last_step"].value

        pos = old_chain[:,-1]
        pos = fix_whack_walkers(pos, old_acc_frac,
                                old_flatlnprobability,
                                old_flatchain,
                                threshold=config.get("acceptance_threshold", None))

        sampler = si.StreamModelSampler(model, nwalkers, pool=pool, a=a)
        logger.info("Continuing sampler...running {} walkers for {} steps..."\
                .format(nwalkers, nsteps))
        sampler.run_inference(pos, nsteps, path=cache_output_path, first_step=last_step,
                              output_every=output_every,
                              output_file_fmt = "inference_{:07d}.hdf5")

    else:
        print("Unknown state.")
        sys.exit(1)

    pool.close() if hasattr(pool, 'close') else None

    #############################################################
    # Plotting
    #
    plot_config = config.get("plot", dict())
    plot_ext = plot_config.get("ext", "png")

    # glob properly orders the list
    for filename in sorted(glob.glob(os.path.join(cache_output_path,"inference_*.hdf5"))):
        logger.debug("Reading file {}...".format(filename))
        with h5py.File(filename, "r") as f:
            try:
                chain = np.hstack((chain,f["chain"].value))
            except NameError:
                chain = f["chain"].value

            acceptance_fraction = f["acceptance_fraction"].value

    try:
        acor = autocorr.integrated_time(np.mean(chain, axis=0), axis=0,
                                        window=50) # 50 comes from emcee
    except:
        acor = []

    flatchain = np.vstack(chain)

    # thin chain
    if config.get("thin_chain", True):
        if len(acor) > 0:
            t_med = np.median(acor)
            thin_chain = chain[:,::int(t_med)]
            thin_flatchain = np.vstack(thin_chain)
            logger.info("Median autocorrelation time: {}".format(t_med))
        else:
            logger.warn("FAILED TO THIN CHAIN")
            thin_chain = chain
            thin_flatchain = flatchain
    else:
        thin_chain = chain
        thin_flatchain = flatchain

    # plot true_particles, true_satellite over the rest of the stream
    gc_particles = model.true_particles.to_frame(galactocentric)
    m = model.true_satellite.mass
    # HACK
    sgr = SgrSimulation("sgr_nfw/M2.5e+0{}".format(int(np.floor(np.log10(m)))), "SNAP113")
    all_gc_particles = sgr.particles(n=1000, expr="tub!=0").to_frame(galactocentric)

    fig,axes = plt.subplots(1,2,figsize=(16,8))
    axes[0].plot(all_gc_particles["x"].value, all_gc_particles["z"].value,
                 markersize=10., marker='.', linestyle='none', alpha=0.25)
    axes[0].plot(gc_particles["x"].value, gc_particles["z"].value,
                 markersize=10., marker='o', linestyle='none', alpha=0.75)
    axes[1].plot(all_gc_particles["vx"].to(u.km/u.s).value,
                 all_gc_particles["vz"].to(u.km/u.s).value,
                 markersize=10., marker='.', linestyle='none', alpha=0.25)
    axes[1].plot(gc_particles["vx"].to(u.km/u.s).value,
                 gc_particles["vz"].to(u.km/u.s).value,
                 markersize=10., marker='o', linestyle='none', alpha=0.75)
    fig.savefig(os.path.join(output_path, "xyz_vxvyvz.{}".format(plot_ext)))

    if plot_config.get("mcmc_diagnostics", False):
        logger.debug("Plotting MCMC diagnostics...")

        diagnostics_path = os.path.join(output_path, "diagnostics")
        if not os.path.exists(diagnostics_path):
            os.mkdir(diagnostics_path)

        # plot histogram of autocorrelation times
        if len(acor) > 0:
            fig,ax = plt.subplots(1,1,figsize=(12,6))
            ax.plot(acor, marker='o', linestyle='none') #model.nparameters//5)
            ax.set_xlabel("Parameter index")
            ax.set_ylabel("Autocorrelation time")
            fig.savefig(os.path.join(diagnostics_path, "acor.{}".format(plot_ext)))

        # plot histogram of acceptance fractions
        fig,ax = plt.subplots(1,1,figsize=(8,8))
        ax.hist(acceptance_fraction, bins=nwalkers//5)
        ax.set_xlabel("Acceptance fraction")
        fig.suptitle("Histogram of acceptance fractions for all walkers")
        fig.savefig(os.path.join(diagnostics_path, "acc_frac.{}".format(plot_ext)))

        # plot individual walkers
        plt.figure(figsize=(12,6))
        for k in range(model.nparameters):
            plt.clf()
            for ii in range(nwalkers):
                plt.plot(chain[ii,:,k], alpha=0.4, drawstyle='steps', color='k')

            plt.axhline(model.truths[k], color='r', lw=2., linestyle='-', alpha=0.5)
            plt.savefig(os.path.join(diagnostics_path, "param_{}.{}".format(k, plot_ext)))

        plt.close('all')

    if plot_config.get("posterior", False):
        logger.debug("Plotting posterior distributions...")

        flatchain_dict = model.label_flatchain(thin_flatchain)
        p0 = model.sample_priors(size=1000) # HACK HACK HACK
        p0_dict = model.label_flatchain(np.vstack(p0))
        potential_group = model.parameters.get('potential', None)
        particles_group = model.parameters.get('particles', None)
        satellite_group = model.parameters.get('satellite', None)
        flatchains = dict()

        if potential_group:
            this_flatchain = np.zeros((len(thin_flatchain),len(potential_group)))
            this_p0 = np.zeros((len(p0),len(potential_group)))
            this_truths = []
            this_extents = []
            for ii,pname in enumerate(potential_group.keys()):
                f = _unit_transform[pname]
                p = model.parameters['potential'][pname]

                this_flatchain[:,ii] = f(np.squeeze(flatchain_dict['potential'][pname]))
                this_p0[:,ii] = f(np.squeeze(p0_dict['potential'][pname]))
                this_truths.append(f(p.truth))
                this_extents.append((f(p._prior.a), f(p._prior.b)))

                print(pname, np.median(this_flatchain[:,ii]), np.std(this_flatchain[:,ii]))

            fig = triangle.corner(this_p0,
                        point_kwargs=dict(color='#2b8cbe',alpha=0.1),
                        hist_kwargs=dict(color='#2b8cbe',alpha=0.75,normed=True,bins=50),
                        plot_contours=False)

            fig = triangle.corner(this_flatchain,
                        fig=fig,
                        truths=this_truths,
                        labels=[_label_map[k] for k in potential_group.keys()],
                        extents=this_extents,
                        point_kwargs=dict(color='k',alpha=1.),
                        hist_kwargs=dict(color='k',alpha=0.75,normed=True,bins=50))
            fig.savefig(os.path.join(output_path, "potential.{}".format(plot_ext)))

            flatchains['potential'] = this_flatchain

        nparticles = model.true_particles.nparticles
        if particles_group and len(particles_group) > 1:
            for jj in range(nparticles):
                this_flatchain = np.zeros((len(thin_flatchain),len(particles_group)))
                this_p0 = np.zeros((len(p0),len(particles_group)))
                this_truths = []
                this_extents = None
                for ii,pname in enumerate(particles_group.keys()):
                    f = _unit_transform[pname]
                    p = model.parameters['particles'][pname]

                    this_flatchain[:,ii] = f(np.squeeze(flatchain_dict['particles'][pname][:,jj]))
                    this_p0[:,ii] = f(np.squeeze(p0_dict['particles'][pname][:,jj]))
                    this_truths.append(f(p.truth[jj]))
                    #this_extents.append((f(p._prior.a), f(p._prior.b)))

                fig = triangle.corner(this_p0,
                            point_kwargs=dict(color='#2b8cbe',alpha=0.1),
                            hist_kwargs=dict(color='#2b8cbe',alpha=0.75,normed=True,bins=50),
                            plot_contours=False)

                fig = triangle.corner(this_flatchain,
                            fig=fig,
                            truths=this_truths,
                            labels=[_label_map[k] for k in particles_group.keys()],
                            extents=this_extents,
                            point_kwargs=dict(color='k',alpha=1.),
                            hist_kwargs=dict(color='k',alpha=0.75,normed=True,bins=50))
                fig.savefig(os.path.join(output_path, "particle{}.{}".format(jj,plot_ext)))

        # plot the posterior for the satellite parameters
        if satellite_group and len(satellite_group) > 1:
            jj = 0
            this_flatchain = np.zeros((len(thin_flatchain),len(satellite_group)))
            this_p0 = np.zeros((len(p0),len(satellite_group)))
            this_truths = []
            this_extents = None
            for ii,pname in enumerate(satellite_group.keys()):
                f = _unit_transform[pname]
                p = model.parameters['satellite'][pname]

                this_flatchain[:,ii] = f(np.squeeze(flatchain_dict['satellite'][pname][:,jj]))
                this_p0[:,ii] = f(np.squeeze(p0_dict['satellite'][pname][:,jj]))
                try:
                    this_truths.append(f(p.truth[jj]))
                except: # IndexError:
                    this_truths.append(f(p.truth))
                #this_extents.append((f(p._prior.a), f(p._prior.b)))

            fig = triangle.corner(this_p0,
                        point_kwargs=dict(color='#2b8cbe',alpha=0.1),
                        hist_kwargs=dict(color='#2b8cbe',alpha=0.75,normed=True,bins=50),
                        plot_contours=False)

            fig = triangle.corner(this_flatchain,
                        fig=fig,
                        truths=this_truths,
                        labels=[_label_map[k] for k in satellite_group.keys()],
                        extents=this_extents,
                        point_kwargs=dict(color='k',alpha=1.),
                        hist_kwargs=dict(color='k',alpha=0.75,normed=True,bins=50))
            fig.savefig(os.path.join(output_path, "satellite.{}".format(plot_ext)))

            flatchains['satellite'] = this_flatchain

        if flatchains.has_key('potential') and flatchains.has_key('satellite'):
            this_flatchain = np.hstack((flatchains['potential'],flatchains['satellite']))
            labels = [_label_map[k] for k in potential_group.keys()+satellite_group.keys()]
            fig = triangle.corner(this_flatchain,
                        labels=labels,
                        point_kwargs=dict(color='k',alpha=1.),
                        hist_kwargs=dict(color='k',alpha=0.75,normed=True,bins=50))
            fig.savefig(os.path.join(output_path, "suck-it-up.{}".format(plot_ext)))
コード例 #6
0
def q_p(**kwargs):

    filename = os.path.join(plot_path, "q_p.pdf")
    fig, axes = plt.subplots(2, 4, figsize=(14, 7.5), sharex=True, sharey=True)

    bins = np.linspace(0., 10, 40)
    nparticles = 5000
    for kk, _m in enumerate(range(6, 9 + 1)):
        mass = "2.5e{}".format(_m)
        m = float(mass)
        print(mass)

        sgr = SgrSimulation(sgr_path.format(_m), snapfile)
        p = sgr.particles(n=nparticles, expr="(tub!=0)")  #" & (tub<400)")
        tub = p.tub
        s = sgr.satellite()

        potential = LawMajewski2010()

        X = np.vstack((s._X[..., :3], p._X[..., :3].copy()))
        V = np.vstack((s._X[..., 3:], p._X[..., 3:].copy()))
        integrator = LeapfrogIntegrator(potential._acceleration_at,
                                        np.array(X),
                                        np.array(V),
                                        args=(X.shape[0], np.zeros_like(X)))
        ts, rs, vs = integrator.run(t1=sgr.t1, t2=sgr.t2, dt=-1.)

        s_orbit = np.vstack(
            (rs[:, 0][:, np.newaxis].T, vs[:, 0][:, np.newaxis].T)).T
        p_orbits = np.vstack((rs[:, 1:].T, vs[:, 1:].T)).T
        t_idx = np.array([np.argmin(np.fabs(ts - t)) for t in p.tub])

        p_x = np.array([p_orbits[jj, ii] for ii, jj in enumerate(t_idx)])
        s_x = np.array([s_orbit[jj, 0] for jj in t_idx])

        #############################################
        # determine tail_bit
        diff = p_x - s_x
        norm_r = s_x[:, :3] / np.sqrt(np.sum(s_x[:, :3]**2,
                                             axis=-1))[:, np.newaxis]
        norm_diff_r = diff[:, :3] / np.sqrt(np.sum(diff[:, :3]**2,
                                                   axis=-1))[:, np.newaxis]
        dot_prod_r = np.sum(norm_diff_r * norm_r, axis=-1)
        tail_bit = (dot_prod_r > 0.).astype(int) * 2 - 1
        #############################################

        r_tide = potential._tidal_radius(m, s_orbit[..., :3])  #*0.69336
        s_R_orbit = np.sqrt(np.sum(s_orbit[..., :3]**2, axis=-1))
        a_pm = (s_R_orbit + r_tide * tail_bit) / s_R_orbit
        q = np.sqrt(np.sum((p_x[:, :3] - s_x[:, :3])**2, axis=-1))

        f = r_tide / s_R_orbit
        s_V = np.sqrt(np.sum(s_orbit[..., 3:]**2, axis=-1))
        vdisp = s_V * f / 1.4
        p = np.sqrt(np.sum((p_x[:, 3:] - s_x[..., 3:])**2, axis=-1))

        fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)

        axes[0].plot(tub, q, marker='.', alpha=0.5, color='#666666')
        axes[0].plot(ts,
                     r_tide * 1.4,
                     linewidth=2.,
                     alpha=0.8,
                     color='k',
                     linestyle='-',
                     marker=None)
        axes[0].set_ylim(0., max(r_tide) * 4)

        axes[1].plot(tub, (p * u.kpc / u.Myr).to(u.km / u.s).value,
                     marker='.',
                     alpha=0.5,
                     color='#666666')
        axes[1].plot(ts, (vdisp * u.kpc / u.Myr).to(u.km / u.s).value,
                     color='k',
                     linewidth=2.,
                     alpha=0.75,
                     linestyle='-',
                     marker=None)

        M_enc = potential._enclosed_mass(s_R_orbit)
        #delta_E = 4/3.*G.decompose(usys).value**2*m*(M_enc / s_V)**2*r_tide**2/s_R_orbit**4
        delta_v2 = 4/3.*G.decompose(usys).value**2*(M_enc / s_V)**2*\
                        np.mean(r_tide**2)/s_R_orbit**4
        delta_v = (np.sqrt(2 * delta_v2) * u.kpc / u.Myr).to(u.km / u.s).value

        axes[1].plot(ts,
                     delta_v,
                     linewidth=2.,
                     color='#2166AC',
                     alpha=0.75,
                     linestyle='--',
                     marker=None)

        axes[1].set_ylim(0.,
                         max((vdisp * u.kpc / u.Myr).to(u.km / u.s).value) * 4)

        axes[0].set_xlim(min(ts), max(ts))

        fig.savefig(os.path.join(plot_path, "q_p_{}.png".format(mass)),
                    transparent=True)
コード例 #7
0
ファイル: infer_potential.py プロジェクト: adrn/streams
def main(config_file,
         mpi=False,
         threads=None,
         overwrite=False,
         continue_sampler=False):
    """ TODO: """

    # get a pool object given the configuration parameters
    # -- This needs to go here so I don't read in the particle file for each thread. --
    pool = get_pool(mpi=mpi, threads=threads)

    # read configuration from a YAML file
    config = io.read_config(config_file)
    np.random.seed(config["seed"])
    random.seed(config["seed"])

    if not os.path.exists(config['streams_path']):
        raise IOError("Specified streams path '{}' doesn't exist!".format(
            config['streams_path']))
    logger.debug("Path to streams project: {}".format(config['streams_path']))

    # the path to write things to
    output_path = config["output_path"]
    logger.debug("Will write data to:\n\t{}".format(output_path))
    cache_output_path = os.path.join(output_path, "cache")

    # get a StreamModel from a config dict
    model = si.StreamModel.from_config(config)
    logger.info("Model has {} parameters".format(model.nparameters))

    if os.path.exists(cache_output_path) and overwrite:
        logger.info("Writing over output path '{}'".format(cache_output_path))
        logger.debug("Deleting files: '{}'".format(
            os.listdir(cache_output_path)))
        shutil.rmtree(cache_output_path)

    # emcee parameters
    # read in the number of walkers to use
    nwalkers = config["walkers"]
    nsteps = config["steps"]
    output_every = config.get("output_every", None)
    nburn = config.get("burn_in", 0)
    start_truth = config.get("start_truth", False)
    a = config.get("a", 2.)  # emcee tuning param

    if not os.path.exists(cache_output_path) and not continue_sampler:
        logger.info("Output path '{}' doesn't exist, running inference..."\
                    .format(cache_output_path))
        os.mkdir(cache_output_path)

        # sample starting positions
        p0 = model.sample_priors(size=nwalkers, start_truth=start_truth)
        logger.debug("Priors sampled...")

        if nburn > 0:
            sampler = si.StreamModelSampler(model, nwalkers, pool=pool, a=a)

            time0 = time.time()
            logger.info("Burning in sampler for {} steps...".format(nburn))
            pos, xx, yy = sampler.run_mcmc(p0, nburn)

            pos = fix_whack_walkers(pos,
                                    sampler.acceptance_fraction,
                                    sampler.flatlnprobability,
                                    sampler.flatchain,
                                    threshold=config.get(
                                        "acceptance_threshold", None))

            t = time.time() - time0
            logger.debug("Spent {} seconds on burn-in...".format(t))

        else:
            pos = p0

        if nsteps > 0:
            sampler = si.StreamModelSampler(model, nwalkers, pool=pool, a=a)
            sampler.run_inference(pos,
                                  nsteps,
                                  path=cache_output_path,
                                  first_step=0,
                                  output_every=output_every,
                                  output_file_fmt="inference_{:06d}.hdf5")

    elif os.path.exists(cache_output_path) and not continue_sampler:
        logger.info("Output path '{}' already exists, not running sampler..."\
                    .format(cache_output_path))

    elif os.path.exists(cache_output_path) and continue_sampler:
        if len(os.listdir(cache_output_path)) == 0:
            logger.error("No files in path: {}".format(cache_output_path))
            sys.exit(1)

        continue_files = glob.glob(
            os.path.join(cache_output_path, "inference_*.hdf5"))
        continue_file = config.get("continue_file", sorted(continue_files)[-1])
        continue_file = os.path.join(cache_output_path, continue_file)
        if not os.path.exists(continue_file):
            logger.error("File {} doesn't exist!".format(continue_file))
            sys.exit(1)

        with h5py.File(continue_file, "r") as f:
            old_chain = f["chain"].value
            old_flatchain = np.vstack(old_chain)
            old_lnprobability = f["lnprobability"].value
            old_flatlnprobability = np.vstack(old_lnprobability)
            old_acc_frac = f["acceptance_fraction"].value
            last_step = f["last_step"].value

        pos = old_chain[:, -1]
        pos = fix_whack_walkers(pos,
                                old_acc_frac,
                                old_flatlnprobability,
                                old_flatchain,
                                threshold=config.get("acceptance_threshold",
                                                     None))

        sampler = si.StreamModelSampler(model, nwalkers, pool=pool, a=a)
        logger.info("Continuing sampler...running {} walkers for {} steps..."\
                .format(nwalkers, nsteps))
        sampler.run_inference(pos,
                              nsteps,
                              path=cache_output_path,
                              first_step=last_step,
                              output_every=output_every,
                              output_file_fmt="inference_{:07d}.hdf5")

    else:
        print("Unknown state.")
        sys.exit(1)

    pool.close() if hasattr(pool, 'close') else None

    #############################################################
    # Plotting
    #
    plot_config = config.get("plot", dict())
    plot_ext = plot_config.get("ext", "png")

    # glob properly orders the list
    for filename in sorted(
            glob.glob(os.path.join(cache_output_path, "inference_*.hdf5"))):
        logger.debug("Reading file {}...".format(filename))
        with h5py.File(filename, "r") as f:
            try:
                chain = np.hstack((chain, f["chain"].value))
            except NameError:
                chain = f["chain"].value

            acceptance_fraction = f["acceptance_fraction"].value

    try:
        acor = autocorr.integrated_time(np.mean(chain, axis=0),
                                        axis=0,
                                        window=50)  # 50 comes from emcee
    except:
        acor = []

    flatchain = np.vstack(chain)

    # thin chain
    if config.get("thin_chain", True):
        if len(acor) > 0:
            t_med = np.median(acor)
            thin_chain = chain[:, ::int(t_med)]
            thin_flatchain = np.vstack(thin_chain)
            logger.info("Median autocorrelation time: {}".format(t_med))
        else:
            logger.warn("FAILED TO THIN CHAIN")
            thin_chain = chain
            thin_flatchain = flatchain
    else:
        thin_chain = chain
        thin_flatchain = flatchain

    # plot true_particles, true_satellite over the rest of the stream
    gc_particles = model.true_particles.to_frame(galactocentric)
    m = model.true_satellite.mass
    # HACK
    sgr = SgrSimulation("sgr_nfw/M2.5e+0{}".format(int(np.floor(np.log10(m)))),
                        "SNAP113")
    all_gc_particles = sgr.particles(n=1000,
                                     expr="tub!=0").to_frame(galactocentric)

    fig, axes = plt.subplots(1, 2, figsize=(16, 8))
    axes[0].plot(all_gc_particles["x"].value,
                 all_gc_particles["z"].value,
                 markersize=10.,
                 marker='.',
                 linestyle='none',
                 alpha=0.25)
    axes[0].plot(gc_particles["x"].value,
                 gc_particles["z"].value,
                 markersize=10.,
                 marker='o',
                 linestyle='none',
                 alpha=0.75)
    axes[1].plot(all_gc_particles["vx"].to(u.km / u.s).value,
                 all_gc_particles["vz"].to(u.km / u.s).value,
                 markersize=10.,
                 marker='.',
                 linestyle='none',
                 alpha=0.25)
    axes[1].plot(gc_particles["vx"].to(u.km / u.s).value,
                 gc_particles["vz"].to(u.km / u.s).value,
                 markersize=10.,
                 marker='o',
                 linestyle='none',
                 alpha=0.75)
    fig.savefig(os.path.join(output_path, "xyz_vxvyvz.{}".format(plot_ext)))

    if plot_config.get("mcmc_diagnostics", False):
        logger.debug("Plotting MCMC diagnostics...")

        diagnostics_path = os.path.join(output_path, "diagnostics")
        if not os.path.exists(diagnostics_path):
            os.mkdir(diagnostics_path)

        # plot histogram of autocorrelation times
        if len(acor) > 0:
            fig, ax = plt.subplots(1, 1, figsize=(12, 6))
            ax.plot(acor, marker='o', linestyle='none')  #model.nparameters//5)
            ax.set_xlabel("Parameter index")
            ax.set_ylabel("Autocorrelation time")
            fig.savefig(
                os.path.join(diagnostics_path, "acor.{}".format(plot_ext)))

        # plot histogram of acceptance fractions
        fig, ax = plt.subplots(1, 1, figsize=(8, 8))
        ax.hist(acceptance_fraction, bins=nwalkers // 5)
        ax.set_xlabel("Acceptance fraction")
        fig.suptitle("Histogram of acceptance fractions for all walkers")
        fig.savefig(
            os.path.join(diagnostics_path, "acc_frac.{}".format(plot_ext)))

        # plot individual walkers
        plt.figure(figsize=(12, 6))
        for k in range(model.nparameters):
            plt.clf()
            for ii in range(nwalkers):
                plt.plot(chain[ii, :, k],
                         alpha=0.4,
                         drawstyle='steps',
                         color='k')

            plt.axhline(model.truths[k],
                        color='r',
                        lw=2.,
                        linestyle='-',
                        alpha=0.5)
            plt.savefig(
                os.path.join(diagnostics_path,
                             "param_{}.{}".format(k, plot_ext)))

        plt.close('all')

    if plot_config.get("posterior", False):
        logger.debug("Plotting posterior distributions...")

        flatchain_dict = model.label_flatchain(thin_flatchain)
        p0 = model.sample_priors(size=1000)  # HACK HACK HACK
        p0_dict = model.label_flatchain(np.vstack(p0))
        potential_group = model.parameters.get('potential', None)
        particles_group = model.parameters.get('particles', None)
        satellite_group = model.parameters.get('satellite', None)
        flatchains = dict()

        if potential_group:
            this_flatchain = np.zeros(
                (len(thin_flatchain), len(potential_group)))
            this_p0 = np.zeros((len(p0), len(potential_group)))
            this_truths = []
            this_extents = []
            for ii, pname in enumerate(potential_group.keys()):
                f = _unit_transform[pname]
                p = model.parameters['potential'][pname]

                this_flatchain[:, ii] = f(
                    np.squeeze(flatchain_dict['potential'][pname]))
                this_p0[:, ii] = f(np.squeeze(p0_dict['potential'][pname]))
                this_truths.append(f(p.truth))
                this_extents.append((f(p._prior.a), f(p._prior.b)))

                print(pname, np.median(this_flatchain[:, ii]),
                      np.std(this_flatchain[:, ii]))

            fig = triangle.corner(this_p0,
                                  point_kwargs=dict(color='#2b8cbe',
                                                    alpha=0.1),
                                  hist_kwargs=dict(color='#2b8cbe',
                                                   alpha=0.75,
                                                   normed=True,
                                                   bins=50),
                                  plot_contours=False)

            fig = triangle.corner(
                this_flatchain,
                fig=fig,
                truths=this_truths,
                labels=[_label_map[k] for k in potential_group.keys()],
                extents=this_extents,
                point_kwargs=dict(color='k', alpha=1.),
                hist_kwargs=dict(color='k', alpha=0.75, normed=True, bins=50))
            fig.savefig(
                os.path.join(output_path, "potential.{}".format(plot_ext)))

            flatchains['potential'] = this_flatchain

        nparticles = model.true_particles.nparticles
        if particles_group and len(particles_group) > 1:
            for jj in range(nparticles):
                this_flatchain = np.zeros(
                    (len(thin_flatchain), len(particles_group)))
                this_p0 = np.zeros((len(p0), len(particles_group)))
                this_truths = []
                this_extents = None
                for ii, pname in enumerate(particles_group.keys()):
                    f = _unit_transform[pname]
                    p = model.parameters['particles'][pname]

                    this_flatchain[:, ii] = f(
                        np.squeeze(flatchain_dict['particles'][pname][:, jj]))
                    this_p0[:, ii] = f(
                        np.squeeze(p0_dict['particles'][pname][:, jj]))
                    this_truths.append(f(p.truth[jj]))
                    #this_extents.append((f(p._prior.a), f(p._prior.b)))

                fig = triangle.corner(this_p0,
                                      point_kwargs=dict(color='#2b8cbe',
                                                        alpha=0.1),
                                      hist_kwargs=dict(color='#2b8cbe',
                                                       alpha=0.75,
                                                       normed=True,
                                                       bins=50),
                                      plot_contours=False)

                fig = triangle.corner(
                    this_flatchain,
                    fig=fig,
                    truths=this_truths,
                    labels=[_label_map[k] for k in particles_group.keys()],
                    extents=this_extents,
                    point_kwargs=dict(color='k', alpha=1.),
                    hist_kwargs=dict(color='k',
                                     alpha=0.75,
                                     normed=True,
                                     bins=50))
                fig.savefig(
                    os.path.join(output_path,
                                 "particle{}.{}".format(jj, plot_ext)))

        # plot the posterior for the satellite parameters
        if satellite_group and len(satellite_group) > 1:
            jj = 0
            this_flatchain = np.zeros(
                (len(thin_flatchain), len(satellite_group)))
            this_p0 = np.zeros((len(p0), len(satellite_group)))
            this_truths = []
            this_extents = None
            for ii, pname in enumerate(satellite_group.keys()):
                f = _unit_transform[pname]
                p = model.parameters['satellite'][pname]

                this_flatchain[:, ii] = f(
                    np.squeeze(flatchain_dict['satellite'][pname][:, jj]))
                this_p0[:,
                        ii] = f(np.squeeze(p0_dict['satellite'][pname][:, jj]))
                try:
                    this_truths.append(f(p.truth[jj]))
                except:  # IndexError:
                    this_truths.append(f(p.truth))
                #this_extents.append((f(p._prior.a), f(p._prior.b)))

            fig = triangle.corner(this_p0,
                                  point_kwargs=dict(color='#2b8cbe',
                                                    alpha=0.1),
                                  hist_kwargs=dict(color='#2b8cbe',
                                                   alpha=0.75,
                                                   normed=True,
                                                   bins=50),
                                  plot_contours=False)

            fig = triangle.corner(
                this_flatchain,
                fig=fig,
                truths=this_truths,
                labels=[_label_map[k] for k in satellite_group.keys()],
                extents=this_extents,
                point_kwargs=dict(color='k', alpha=1.),
                hist_kwargs=dict(color='k', alpha=0.75, normed=True, bins=50))
            fig.savefig(
                os.path.join(output_path, "satellite.{}".format(plot_ext)))

            flatchains['satellite'] = this_flatchain

        if flatchains.has_key('potential') and flatchains.has_key('satellite'):
            this_flatchain = np.hstack(
                (flatchains['potential'], flatchains['satellite']))
            labels = [
                _label_map[k]
                for k in potential_group.keys() + satellite_group.keys()
            ]
            fig = triangle.corner(this_flatchain,
                                  labels=labels,
                                  point_kwargs=dict(color='k', alpha=1.),
                                  hist_kwargs=dict(color='k',
                                                   alpha=0.75,
                                                   normed=True,
                                                   bins=50))
            fig.savefig(
                os.path.join(output_path, "suck-it-up.{}".format(plot_ext)))